Adversarial Loss - Musterlösung

  1import torch
  2from torch import nn
  3from torchvision.models import vgg16
  4import torch.nn.init as init
  5from misc import (
  6    get_dataloader,
  7    ResNetBlock,
  8    VGG16PerceptualLoss,
  9    train,
 10    TVLoss,
 11)
 12from tqdm import tqdm
 13from torch.utils.tensorboard import SummaryWriter
 14
 15
 16class Generator(nn.Module):
 17    def __init__(self):
 18        """Initialize the Upscale4x model.
 19
 20        This model performs 4x upscaling using a series of ResNet blocks and an upsampling layer.
 21
 22        **TODO**:
 23
 24        - Call the `__init__` method of the base class `nn.Module`.
 25
 26        - Define an upsampling layer using `nn.Upsample(scale_factor=4, mode="bilinear", align_corners=True) <https://docs.pytorch.org/docs/stable/generated/torch.nn.Upsample.html>`_.
 27
 28        - Define a sequential model consisting of:
 29
 30        - Five `ResNetBlock` layers with 3->16, 16->32, 32->64, 64->128 and 128->256 channels as well as kernel sizes 7.
 31
 32        - A PixelShuffle layer with an upscale factor of 4.
 33
 34        - A final convolutional layer with 16 input channels, 3 output channels and kernel size 5 with padding 2.
 35        """
 36        super(Generator, self).__init__()
 37
 38        self.upBilinear = nn.Upsample(
 39            scale_factor=4, mode="bilinear", align_corners=True
 40        )
 41
 42        self.model = nn.Sequential(
 43            ResNetBlock(3, 16, kernel_size=7),
 44            ResNetBlock(16, 32, kernel_size=7),
 45            ResNetBlock(32, 64, kernel_size=7),
 46            ResNetBlock(64, 128, kernel_size=7),
 47            ResNetBlock(128, 256, kernel_size=7),
 48            nn.PixelShuffle(upscale_factor=4),  # First upsample
 49            nn.Conv2d(16, 3, kernel_size=7, padding=3),  # Final conv to reduce channels
 50        )
 51
 52    def forward(self, x):
 53        """Perform the forward pass of the Upscale2x model.
 54
 55        Parameters:
 56        -----------
 57            x (torch.Tensor):
 58              The input tensor to be upscaled.
 59
 60        Returns:
 61        --------
 62            torch.Tensor:
 63              The upscaled output tensor.
 64
 65        **TODO**:
 66
 67        - Pass the input tensor through the model.
 68
 69        - Also, apply the upsampling layer to the input tensor `x`.
 70
 71        - Add the upsampled tensor to the output of the model.
 72        """
 73        x = self.upBilinear(x) + self.model(x)
 74
 75        return x
 76
 77
 78class Critic(nn.Module):
 79    def __init__(self):
 80        """Initialize the Critic model.
 81
 82        This model is a convolutional neural network that takes an image as input and outputs a single score indicating the quality of the image.
 83
 84        **TODO**:
 85
 86        - Call the `__init__` method of the base class `nn.Module`.
 87
 88        - Define a sequential model consisting of:
 89            - A convolutional layer with 3 input channels, 32 output channels, kernel size 9, stride 2, and padding 4.
 90            - A LeakyReLU activation function with an inplace operation.
 91            - A convolutional layer with 32 input channels, 64 output channels, kernel size 5, stride 2, and padding 2.
 92            - A LeakyReLU activation function with an inplace operation.
 93            - A convolutional layer with 64 input channels, 128 output channels, kernel size 5, stride 2, and padding 2.
 94            - A LeakyReLU activation function with an inplace operation.
 95            - A convolutional layer with 128 input channels, 256 output channels, kernel size 5, stride 2, and padding 2.
 96            - A LeakyReLU activation function with an inplace operation.
 97            - A convolutional layer with 256 input channels, 512 output channels, kernel size 5, stride 2, and padding 2.
 98            - A LeakyReLU activation function with an inplace operation.
 99            - A convolutional layer with 512 input channels, 1024 output channels, kernel size 5, stride 2, and padding 2.
100            - A LeakyReLU activation function with an inplace operation.
101            - An average pooling layer with kernel size (4, 4) to reduce the spatial dimensions.
102            - A flattening layer to convert the output to a 1D tensor.
103            - A linear layer with 1024 input features and 1 output feature (no bias).
104        """
105        super(Critic, self).__init__()
106        self.model = nn.Sequential(
107            nn.Conv2d(3, 32, kernel_size=9, stride=2, padding=4),
108            nn.LeakyReLU(inplace=True),  # 32x128x128
109            nn.Conv2d(32, 64, kernel_size=5, stride=2, padding=2),
110            nn.LeakyReLU(inplace=True),  # 64x64x64
111            nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2),
112            nn.LeakyReLU(inplace=True),  # 128x32x32
113            nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2),
114            nn.LeakyReLU(inplace=True),  # 256x16x16
115            nn.Conv2d(256, 512, kernel_size=5, stride=2, padding=2),
116            nn.LeakyReLU(inplace=True),  # 512x8x8
117            nn.Conv2d(512, 1024, kernel_size=5, stride=2, padding=2),
118            nn.LeakyReLU(inplace=True),  # 1024x4x4
119            nn.AvgPool2d(kernel_size=(4, 4)),  # 1024x1x1
120            nn.Flatten(),
121            nn.Linear(1024, 1, bias=False),  # Final output layer
122        )
123
124    def forward(self, x):
125        """
126        Perform the forward pass of the Critic model.
127        Parameters:
128        -----------
129            x (torch.Tensor):
130              The input tensor to be processed by the Critic model.
131
132        Returns:
133        --------
134            torch.Tensor: The output score from the Critic model.
135
136        **TODO**:
137
138        - Pass the input tensor through the model.
139
140        - Return the output score from the model.
141        """
142        return self.model(x)
143
144
145class GeneratorLoss(nn.Module):
146    def __init__(self, critic):
147        """Initialize the GeneratorLoss module.
148
149        Parameters:
150        -----------
151            critic (nn.Module):
152              The critic model used for adversarial loss computation.
153
154        **TODO**:
155
156        - Call the `__init__` method of the base class `nn.Module`.
157
158        - Initialize the `VGG16PerceptualLoss` for perceptual loss computation.
159
160        - Initialize the `TVLoss` for total variation loss computation.
161
162        - Store the critic model for adversarial loss computation.
163        """
164        super(GeneratorLoss, self).__init__()
165        self.perceptualLoss = VGG16PerceptualLoss()
166        self.tvLoss = TVLoss()
167        self.critic = critic
168
169    def forward(self, output, target, epoch):
170        """Compute the generator loss.
171
172        The generator loss is a combination of perceptual loss, total variation loss, and adversarial loss.
173
174        The sum of the perceptual loss and total variation loss is called content loss as it is used to measure the quality of the generated image in terms
175        of content similarity to the target image.
176
177        The adversarial loss is computed using the critic model, which is trained to distinguish between real and generated images.
178        The generator aims to maximize the critic's output for generated images, thus it tries to fool the critic. Mathematically, this is achieved by negating
179        the critic's output.
180
181        Since the critic is not yet fully trained during the initial epochs, we apply a linear scaling factor to the adversarial loss based on the current epoch.
182        This allows the generator to focus more on content loss in the early stages of training and gradually increase the importance of adversarial loss as
183        training progresses. In the first epoch, the adversarial loss is not applied at all, and it starts to increase linearly until it reaches its full weight at epoch 5 epoch.
184
185        The generator shall minimize the content loss while maximizing the adversarial loss, which is achieved by negating the critic's output.
186
187
188        Parameters:
189        -----------
190            output (torch.Tensor):
191              The output tensor from the generator.
192
193            target (torch.Tensor):
194              The target tensor for comparison.
195
196            epoch (int):
197              The current training epoch.
198
199        Returns:
200        --------
201            Dictionary with the following keys:
202
203            - "generator_loss": The total generator loss, which includes perceptual loss, TV loss, and adversarial loss.
204
205            - "content_loss": The content loss (perceptual loss).
206
207            - "adversarial_loss": The adversarial loss computed from the critic.
208
209        **TODO**:
210
211        - Compute the adversarial loss by running the generator images through the critic and **taking the mean**. Then scale it by 0.01.
212
213        - Compute the linear scaling factor for the adversarial loss based on the current epoch. The scaling factor should be 0 in the first epoch and increase linearly to 1 by epoch 5.
214
215        - Compute the content loss as the sum of perceptual loss and TV loss. Scale the TV loss by 0.1 to reduce its impact on the total loss.
216
217        - Compute the total generator loss as the sum of content loss and the **negative** adversarial loss scaled by the linear scaling factor.
218
219        - Return a dictionary containing the total generator loss, content loss, and adversarial loss.
220        """
221        adversarial_loss = 0.01 * self.critic(output).mean()
222
223        adversarial_lambda = min(1.0, epoch / 5.0)
224
225        content_loss = self.perceptualLoss(output, target) + 0.1 * self.tvLoss(output)
226
227        return {
228            "generator_loss": content_loss - adversarial_lambda * adversarial_loss,
229            "content_loss": content_loss,
230            "adversarial_loss": adversarial_loss,
231        }
232
233
234class CriticLoss(nn.Module):
235    def __init__(self, critic):
236        """Initialize the CriticLoss module.
237        This module computes the loss for the critic model, including the gradient penalty to enforce the Lipschitz constraint.
238        """
239        super(CriticLoss, self).__init__()
240        self.critic = critic
241
242    def compute_gradient_penalty(self, real, fake, lambda_gp=30):
243        """Compute the gradient penalty for the critic.
244        This function calculates the gradient penalty to enforce the Lipschitz constraint on the critic model.
245        """
246
247        # Generate random interpolation between real and fake images
248        batch_size = real.size(0)
249        epsilon = torch.rand(batch_size, 1, 1, 1, device=real.device)
250        interpolated = (epsilon * real + (1 - epsilon) * fake).requires_grad_(True)
251
252        # Compute the critic's output for the interpolated images
253        critic_output = self.critic(interpolated)
254
255        # Compute the gradients of the critic's output with respect to the interpolated images
256        grad_outputs = torch.ones_like(critic_output)
257
258        gradients = torch.autograd.grad(
259            outputs=critic_output,
260            inputs=interpolated,
261            grad_outputs=grad_outputs,
262            create_graph=True,
263            retain_graph=True,
264            only_inputs=True,
265        )[0]
266
267        # Compute the gradient penalty
268        gradients = gradients.view(batch_size, -1)
269        gradient_norm = gradients.norm(2, dim=1)
270        gradient_penalty = ((gradient_norm - 1) ** 2).mean() * lambda_gp
271
272        # Return the gradient penalty and the gradient norm for logging purposes
273        return gradient_penalty, gradient_norm
274
275    def forward(self, real, fake):
276        """Compute the critic loss, including the gradient penalty.
277        Parameters:
278        -----------
279            real (torch.Tensor):
280              The real images from the dataset.
281
282            fake (torch.Tensor):
283              The generated images from the generator model.
284
285        Returns:
286        --------
287            Dictionary with the following keys:
288
289            - "loss_c": The total critic loss, which includes the WGAN loss and the gradient penalty (torch.Tensor).
290
291            - "gradient_norm": The gradient norm for logging purposes (torch.Tensor).
292
293            - "pure_wgan_loss": The pure WGAN loss (without gradient penalty) for logging purposes (torch.Tensor).
294
295        **TODO**:
296            - Calculate the WGAN loss as the difference between the **mean** critic score for real images and the **mean** critic score for fake images.
297
298            - Compute the gradient penalty using the `compute_gradient_penalty` method. Note: This method returns both the gradient penalty and the gradient norm.
299
300            - Return the total critic loss, gradient norm, and pure WGAN loss.
301        """
302        gp, gradient_norm = self.compute_gradient_penalty(real, fake)
303
304        loss_c = -self.critic(real).mean() + self.critic(fake).mean()
305
306        return {
307            "loss_c": loss_c + gp,
308            "gradient_norm": gradient_norm,
309            "pure_wgan_loss": loss_c,
310        }
311
312
313class UpscaleTrainer:
314    def __init__(self):
315        self.criticUpdates = 0
316        self.generator = Generator().cuda()
317        self.critic = Critic().cuda()
318
319        self.generatorLoss = GeneratorLoss(self.critic).cuda()
320        self.criticLoss = CriticLoss(self.critic).cuda()
321
322        self.optimGenerator = torch.optim.Adam(self.generator.parameters(), lr=0.0005)
323        self.optimCritic = torch.optim.Adam(self.critic.parameters(), lr=0.0001)
324
325        # Count and print parameters
326        gen_params = sum(p.numel() for p in self.generator.parameters())
327        critic_params = sum(p.numel() for p in self.critic.parameters())
328        print(f"Generator parameters: {gen_params:,}")
329        print(f"Critic parameters: {critic_params:,}")
330
331    def train_critic(self, input, target):
332        """Train the critic model on a batch of input and target images.
333
334        Parameters:
335        -----------
336            input (torch.Tensor):
337              The input tensor containing the images to be processed by the generator.
338
339            target (torch.Tensor):
340              The target tensor containing the ground truth images for comparison.
341
342        Returns:
343        --------
344            dict: A dictionary containing the gradient norm and the critic loss with the following keys:.
345
346                "gradient_norm":  The gradient norm computed during the training of the critic (float).
347                "loss_c": The critic loss computed during the training (float).
348
349        **TODO**:
350
351        - Pass the input tensor through the generator to obtain the generator output.
352
353        - Zero the gradients of the critic optimizer (self.optimCritic).
354
355        - Compute the critic loss using the `CriticLoss` module, which includes the WGAN loss and the gradient penalty.
356          Store the gradient norm and the critic loss for later so you can return it.
357
358        - Backpropagate the critic loss to compute the gradients.
359
360        - Clip the gradients of the generator to prevent exploding gradients (use `torch.nn.utils.clip_grad_norm_ https://docs.pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_norm_.html>`_ with `max_norm=5.0`).
361
362        - Step the critic optimizer to update the critic's parameters.
363
364        - Return a dictionary containing the gradient norm and the critic loss.
365        """
366        output = self.generator(input)
367
368        self.optimCritic.zero_grad()
369        result = self.criticLoss(target, output)
370        critic_loss, gradient_norm, loss_c = (
371            result["loss_c"],
372            result["gradient_norm"],
373            result["pure_wgan_loss"],
374        )
375        critic_loss.backward()
376        torch.nn.utils.clip_grad_norm_(self.generator.parameters(), max_norm=5.0) # THIS IS A BUG, IT SHOULD BE self.critic.parameters()
377        self.optimCritic.step()
378
379        return {
380            "gradient_norm": gradient_norm.mean().item(),
381            "loss_c": loss_c.item(),
382        }
383
384    def train_generator(self, input, target, epoch):
385        """Train the generator model on a batch of input and target images.
386
387        Parameters:
388        -----------
389            input (torch.Tensor):
390                The input tensor containing the images to be processed by the generator.
391
392            target (torch.Tensor):
393                The target tensor containing the ground truth images for comparison.
394
395            epoch (int):
396            The current training epoch, used to scale the adversarial loss.
397
398        Returns:
399        --------
400            dict: A dictionary containing the total generator loss, content loss, adversarial loss, and gradient norm with the following keys:
401                "loss": The total generator loss (float).
402                "content_loss": The content loss (float).
403                "adversarial_loss": The adversarial loss (float).
404                "gradient_norm": The gradient norm (float).
405                "output": The output tensor from the generator (torch.Tensor).
406
407        **TODO**:
408
409        - Zero the gradients of the generator optimizer (self.optimGenerator).
410
411        - Pass the input tensor through the generator to obtain the generated upsample image.
412
413        - Compute the generator loss using the `GeneratorLoss` module, which includes perceptual loss, TV loss, and adversarial loss.
414          Store the content loss, adversarial loss, and total generator loss for later so you can return it.
415
416        - Backpropagate the total generator loss to compute the gradients.
417
418        - Clip the gradients of the generator to prevent exploding gradients (use `torch.nn.utils.clip_grad_norm_ https://docs.pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_norm_.html>`_ with `max_norm=1.0`).
419
420        - Call `torch.nn.utils.clip_grad_norm_` again with `max_norm=1e9` and store the gradient norm for later so you can return it.
421
422        - Step the generator optimizer to update the generator's parameters.
423
424        - Return a dictionary containing the total generator loss, content loss, adversarial loss, the output and the gradient norm.
425        """
426        self.optimGenerator.zero_grad()
427        output = self.generator(input)
428
429        result = self.generatorLoss(output, target, epoch)
430        loss, content_loss, adversarial_loss = (
431            result["generator_loss"],
432            result["content_loss"],
433            result["adversarial_loss"],
434        )
435
436        loss.backward()
437
438        torch.nn.utils.clip_grad_norm_(self.generator.parameters(), max_norm=1.0)
439        gen_norm = torch.nn.utils.clip_grad_norm_(
440            self.generator.parameters(), max_norm=1e9
441        )
442
443        self.optimGenerator.step()
444
445        return {
446            "loss": loss.item(),
447            "content_loss": content_loss.item(),
448            "adversarial_loss": adversarial_loss.mean().item(),
449            "gradient_norm": gen_norm.item(),
450            "output": output.detach() if output is not None else None,
451        }
452
453    def train_batch(self, input, target, epoch):
454        """Train a batch of images using the critic and generator models.
455
456        Parameters:
457        -----------
458            input (torch.Tensor):
459              The input tensor containing the images to be processed by the generator.
460
461            target (torch.Tensor):
462              The target tensor containing the ground truth images for comparison.
463
464            epoch (int):
465              The current training epoch, used to scale the adversarial loss.
466
467        Returns:
468        --------
469            A dictionary containing the scores from the critic and generator models with the following keys:
470            - "critic": A dictionary containing the critic scores with keys "gradient_norm" and "loss_c".
471            - "generator": A dictionary containing the generator scores with keys "loss", "content_loss", "adversarial_loss", "gradient_norm", and "output".
472
473        **TODO**:
474
475        - Train the critic model using the `train_critic` method with the input and target tensors.
476
477        - Increment the critic updates counter (self.criticUpdates).
478
479        - If the critic updates counter is 5 or the epoch is less than 1,
480          train the generator model using the `train_generator` method with the input and target tensors, and the current epoch. Also reset the critic updates counter to 0.
481
482        - If the critic updates counter is not 5 and the epoch is greater than or equal to 1, skip training the generator and set the generator scores to None.
483
484        - Return the critic scores and generator scores (if available)
485        """
486        # Train Critic every step
487        scoresCritic = self.train_critic(input, target)
488        self.criticUpdates += 1
489
490        # Train Generator only every 4th step
491        if self.criticUpdates == 5 or epoch < 1:
492            scoresGenerator = self.train_generator(input, target, epoch)
493            self.criticUpdates = 0
494        else:
495            scoresGenerator = None
496
497        return {"critic": scoresCritic, "generator": scoresGenerator}
498
499
500if __name__ == "__main__":
501    prefix = "upscale4x_adversarialloss"
502
503    dataloader = get_dataloader(inputSize=64, outputSize=256, batch_size=48)
504
505    trainer = UpscaleTrainer()
506
507    train(prefix, trainer, dataloader)