1import torch
2from torch import nn
3from torchvision.models import vgg16
4import torch.nn.init as init
5from misc import (
6 get_dataloader,
7 ResNetBlock,
8 VGG16PerceptualLoss,
9 train,
10 TVLoss,
11)
12from tqdm import tqdm
13from torch.utils.tensorboard import SummaryWriter
14
15
16class Generator(nn.Module):
17 def __init__(self):
18 """Initialize the Upscale4x model.
19
20 This model performs 4x upscaling using a series of ResNet blocks and an upsampling layer.
21
22 **TODO**:
23
24 - Call the `__init__` method of the base class `nn.Module`.
25
26 - Define an upsampling layer using `nn.Upsample(scale_factor=4, mode="bilinear", align_corners=True) <https://docs.pytorch.org/docs/stable/generated/torch.nn.Upsample.html>`_.
27
28 - Define a sequential model consisting of:
29
30 - Five `ResNetBlock` layers with 3->16, 16->32, 32->64, 64->128 and 128->256 channels as well as kernel sizes 7.
31
32 - A PixelShuffle layer with an upscale factor of 4.
33
34 - A final convolutional layer with 16 input channels, 3 output channels and kernel size 5 with padding 2.
35 """
36 super(Generator, self).__init__()
37
38 self.upBilinear = nn.Upsample(
39 scale_factor=4, mode="bilinear", align_corners=True
40 )
41
42 self.model = nn.Sequential(
43 ResNetBlock(3, 16, kernel_size=7),
44 ResNetBlock(16, 32, kernel_size=7),
45 ResNetBlock(32, 64, kernel_size=7),
46 ResNetBlock(64, 128, kernel_size=7),
47 ResNetBlock(128, 256, kernel_size=7),
48 nn.PixelShuffle(upscale_factor=4), # First upsample
49 nn.Conv2d(16, 3, kernel_size=7, padding=3), # Final conv to reduce channels
50 )
51
52 def forward(self, x):
53 """Perform the forward pass of the Upscale2x model.
54
55 Parameters:
56 -----------
57 x (torch.Tensor):
58 The input tensor to be upscaled.
59
60 Returns:
61 --------
62 torch.Tensor:
63 The upscaled output tensor.
64
65 **TODO**:
66
67 - Pass the input tensor through the model.
68
69 - Also, apply the upsampling layer to the input tensor `x`.
70
71 - Add the upsampled tensor to the output of the model.
72 """
73 x = self.upBilinear(x) + self.model(x)
74
75 return x
76
77
78class Critic(nn.Module):
79 def __init__(self):
80 """Initialize the Critic model.
81
82 This model is a convolutional neural network that takes an image as input and outputs a single score indicating the quality of the image.
83
84 **TODO**:
85
86 - Call the `__init__` method of the base class `nn.Module`.
87
88 - Define a sequential model consisting of:
89 - A convolutional layer with 3 input channels, 32 output channels, kernel size 9, stride 2, and padding 4.
90 - A LeakyReLU activation function with an inplace operation.
91 - A convolutional layer with 32 input channels, 64 output channels, kernel size 5, stride 2, and padding 2.
92 - A LeakyReLU activation function with an inplace operation.
93 - A convolutional layer with 64 input channels, 128 output channels, kernel size 5, stride 2, and padding 2.
94 - A LeakyReLU activation function with an inplace operation.
95 - A convolutional layer with 128 input channels, 256 output channels, kernel size 5, stride 2, and padding 2.
96 - A LeakyReLU activation function with an inplace operation.
97 - A convolutional layer with 256 input channels, 512 output channels, kernel size 5, stride 2, and padding 2.
98 - A LeakyReLU activation function with an inplace operation.
99 - A convolutional layer with 512 input channels, 1024 output channels, kernel size 5, stride 2, and padding 2.
100 - A LeakyReLU activation function with an inplace operation.
101 - An average pooling layer with kernel size (4, 4) to reduce the spatial dimensions.
102 - A flattening layer to convert the output to a 1D tensor.
103 - A linear layer with 1024 input features and 1 output feature (no bias).
104 """
105 super(Critic, self).__init__()
106 self.model = nn.Sequential(
107 nn.Conv2d(3, 32, kernel_size=9, stride=2, padding=4),
108 nn.LeakyReLU(inplace=True), # 32x128x128
109 nn.Conv2d(32, 64, kernel_size=5, stride=2, padding=2),
110 nn.LeakyReLU(inplace=True), # 64x64x64
111 nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2),
112 nn.LeakyReLU(inplace=True), # 128x32x32
113 nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2),
114 nn.LeakyReLU(inplace=True), # 256x16x16
115 nn.Conv2d(256, 512, kernel_size=5, stride=2, padding=2),
116 nn.LeakyReLU(inplace=True), # 512x8x8
117 nn.Conv2d(512, 1024, kernel_size=5, stride=2, padding=2),
118 nn.LeakyReLU(inplace=True), # 1024x4x4
119 nn.AvgPool2d(kernel_size=(4, 4)), # 1024x1x1
120 nn.Flatten(),
121 nn.Linear(1024, 1, bias=False), # Final output layer
122 )
123
124 def forward(self, x):
125 """
126 Perform the forward pass of the Critic model.
127 Parameters:
128 -----------
129 x (torch.Tensor):
130 The input tensor to be processed by the Critic model.
131
132 Returns:
133 --------
134 torch.Tensor: The output score from the Critic model.
135
136 **TODO**:
137
138 - Pass the input tensor through the model.
139
140 - Return the output score from the model.
141 """
142 return self.model(x)
143
144
145class GeneratorLoss(nn.Module):
146 def __init__(self, critic):
147 """Initialize the GeneratorLoss module.
148
149 Parameters:
150 -----------
151 critic (nn.Module):
152 The critic model used for adversarial loss computation.
153
154 **TODO**:
155
156 - Call the `__init__` method of the base class `nn.Module`.
157
158 - Initialize the `VGG16PerceptualLoss` for perceptual loss computation.
159
160 - Initialize the `TVLoss` for total variation loss computation.
161
162 - Store the critic model for adversarial loss computation.
163 """
164 super(GeneratorLoss, self).__init__()
165 self.perceptualLoss = VGG16PerceptualLoss()
166 self.tvLoss = TVLoss()
167 self.critic = critic
168
169 def forward(self, output, target, epoch):
170 """Compute the generator loss.
171
172 The generator loss is a combination of perceptual loss, total variation loss, and adversarial loss.
173
174 The sum of the perceptual loss and total variation loss is called content loss as it is used to measure the quality of the generated image in terms
175 of content similarity to the target image.
176
177 The adversarial loss is computed using the critic model, which is trained to distinguish between real and generated images.
178 The generator aims to maximize the critic's output for generated images, thus it tries to fool the critic. Mathematically, this is achieved by negating
179 the critic's output.
180
181 Since the critic is not yet fully trained during the initial epochs, we apply a linear scaling factor to the adversarial loss based on the current epoch.
182 This allows the generator to focus more on content loss in the early stages of training and gradually increase the importance of adversarial loss as
183 training progresses. In the first epoch, the adversarial loss is not applied at all, and it starts to increase linearly until it reaches its full weight at epoch 5 epoch.
184
185 The generator shall minimize the content loss while maximizing the adversarial loss, which is achieved by negating the critic's output.
186
187
188 Parameters:
189 -----------
190 output (torch.Tensor):
191 The output tensor from the generator.
192
193 target (torch.Tensor):
194 The target tensor for comparison.
195
196 epoch (int):
197 The current training epoch.
198
199 Returns:
200 --------
201 Dictionary with the following keys:
202
203 - "generator_loss": The total generator loss, which includes perceptual loss, TV loss, and adversarial loss.
204
205 - "content_loss": The content loss (perceptual loss).
206
207 - "adversarial_loss": The adversarial loss computed from the critic.
208
209 **TODO**:
210
211 - Compute the adversarial loss by running the generator images through the critic and **taking the mean**. Then scale it by 0.01.
212
213 - Compute the linear scaling factor for the adversarial loss based on the current epoch. The scaling factor should be 0 in the first epoch and increase linearly to 1 by epoch 5.
214
215 - Compute the content loss as the sum of perceptual loss and TV loss. Scale the TV loss by 0.1 to reduce its impact on the total loss.
216
217 - Compute the total generator loss as the sum of content loss and the **negative** adversarial loss scaled by the linear scaling factor.
218
219 - Return a dictionary containing the total generator loss, content loss, and adversarial loss.
220 """
221 adversarial_loss = 0.01 * self.critic(output).mean()
222
223 adversarial_lambda = min(1.0, epoch / 5.0)
224
225 content_loss = self.perceptualLoss(output, target) + 0.1 * self.tvLoss(output)
226
227 return {
228 "generator_loss": content_loss - adversarial_lambda * adversarial_loss,
229 "content_loss": content_loss,
230 "adversarial_loss": adversarial_loss,
231 }
232
233
234class CriticLoss(nn.Module):
235 def __init__(self, critic):
236 """Initialize the CriticLoss module.
237 This module computes the loss for the critic model, including the gradient penalty to enforce the Lipschitz constraint.
238 """
239 super(CriticLoss, self).__init__()
240 self.critic = critic
241
242 def compute_gradient_penalty(self, real, fake, lambda_gp=30):
243 """Compute the gradient penalty for the critic.
244 This function calculates the gradient penalty to enforce the Lipschitz constraint on the critic model.
245 """
246
247 # Generate random interpolation between real and fake images
248 batch_size = real.size(0)
249 epsilon = torch.rand(batch_size, 1, 1, 1, device=real.device)
250 interpolated = (epsilon * real + (1 - epsilon) * fake).requires_grad_(True)
251
252 # Compute the critic's output for the interpolated images
253 critic_output = self.critic(interpolated)
254
255 # Compute the gradients of the critic's output with respect to the interpolated images
256 grad_outputs = torch.ones_like(critic_output)
257
258 gradients = torch.autograd.grad(
259 outputs=critic_output,
260 inputs=interpolated,
261 grad_outputs=grad_outputs,
262 create_graph=True,
263 retain_graph=True,
264 only_inputs=True,
265 )[0]
266
267 # Compute the gradient penalty
268 gradients = gradients.view(batch_size, -1)
269 gradient_norm = gradients.norm(2, dim=1)
270 gradient_penalty = ((gradient_norm - 1) ** 2).mean() * lambda_gp
271
272 # Return the gradient penalty and the gradient norm for logging purposes
273 return gradient_penalty, gradient_norm
274
275 def forward(self, real, fake):
276 """Compute the critic loss, including the gradient penalty.
277 Parameters:
278 -----------
279 real (torch.Tensor):
280 The real images from the dataset.
281
282 fake (torch.Tensor):
283 The generated images from the generator model.
284
285 Returns:
286 --------
287 Dictionary with the following keys:
288
289 - "loss_c": The total critic loss, which includes the WGAN loss and the gradient penalty (torch.Tensor).
290
291 - "gradient_norm": The gradient norm for logging purposes (torch.Tensor).
292
293 - "pure_wgan_loss": The pure WGAN loss (without gradient penalty) for logging purposes (torch.Tensor).
294
295 **TODO**:
296 - Calculate the WGAN loss as the difference between the **mean** critic score for real images and the **mean** critic score for fake images.
297
298 - Compute the gradient penalty using the `compute_gradient_penalty` method. Note: This method returns both the gradient penalty and the gradient norm.
299
300 - Return the total critic loss, gradient norm, and pure WGAN loss.
301 """
302 gp, gradient_norm = self.compute_gradient_penalty(real, fake)
303
304 loss_c = -self.critic(real).mean() + self.critic(fake).mean()
305
306 return {
307 "loss_c": loss_c + gp,
308 "gradient_norm": gradient_norm,
309 "pure_wgan_loss": loss_c,
310 }
311
312
313class UpscaleTrainer:
314 def __init__(self):
315 self.criticUpdates = 0
316 self.generator = Generator().cuda()
317 self.critic = Critic().cuda()
318
319 self.generatorLoss = GeneratorLoss(self.critic).cuda()
320 self.criticLoss = CriticLoss(self.critic).cuda()
321
322 self.optimGenerator = torch.optim.Adam(self.generator.parameters(), lr=0.0005)
323 self.optimCritic = torch.optim.Adam(self.critic.parameters(), lr=0.0001)
324
325 # Count and print parameters
326 gen_params = sum(p.numel() for p in self.generator.parameters())
327 critic_params = sum(p.numel() for p in self.critic.parameters())
328 print(f"Generator parameters: {gen_params:,}")
329 print(f"Critic parameters: {critic_params:,}")
330
331 def train_critic(self, input, target):
332 """Train the critic model on a batch of input and target images.
333
334 Parameters:
335 -----------
336 input (torch.Tensor):
337 The input tensor containing the images to be processed by the generator.
338
339 target (torch.Tensor):
340 The target tensor containing the ground truth images for comparison.
341
342 Returns:
343 --------
344 dict: A dictionary containing the gradient norm and the critic loss with the following keys:.
345
346 "gradient_norm": The gradient norm computed during the training of the critic (float).
347 "loss_c": The critic loss computed during the training (float).
348
349 **TODO**:
350
351 - Pass the input tensor through the generator to obtain the generator output.
352
353 - Zero the gradients of the critic optimizer (self.optimCritic).
354
355 - Compute the critic loss using the `CriticLoss` module, which includes the WGAN loss and the gradient penalty.
356 Store the gradient norm and the critic loss for later so you can return it.
357
358 - Backpropagate the critic loss to compute the gradients.
359
360 - Clip the gradients of the generator to prevent exploding gradients (use `torch.nn.utils.clip_grad_norm_ https://docs.pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_norm_.html>`_ with `max_norm=5.0`).
361
362 - Step the critic optimizer to update the critic's parameters.
363
364 - Return a dictionary containing the gradient norm and the critic loss.
365 """
366 output = self.generator(input)
367
368 self.optimCritic.zero_grad()
369 result = self.criticLoss(target, output)
370 critic_loss, gradient_norm, loss_c = (
371 result["loss_c"],
372 result["gradient_norm"],
373 result["pure_wgan_loss"],
374 )
375 critic_loss.backward()
376 torch.nn.utils.clip_grad_norm_(self.generator.parameters(), max_norm=5.0) # THIS IS A BUG, IT SHOULD BE self.critic.parameters()
377 self.optimCritic.step()
378
379 return {
380 "gradient_norm": gradient_norm.mean().item(),
381 "loss_c": loss_c.item(),
382 }
383
384 def train_generator(self, input, target, epoch):
385 """Train the generator model on a batch of input and target images.
386
387 Parameters:
388 -----------
389 input (torch.Tensor):
390 The input tensor containing the images to be processed by the generator.
391
392 target (torch.Tensor):
393 The target tensor containing the ground truth images for comparison.
394
395 epoch (int):
396 The current training epoch, used to scale the adversarial loss.
397
398 Returns:
399 --------
400 dict: A dictionary containing the total generator loss, content loss, adversarial loss, and gradient norm with the following keys:
401 "loss": The total generator loss (float).
402 "content_loss": The content loss (float).
403 "adversarial_loss": The adversarial loss (float).
404 "gradient_norm": The gradient norm (float).
405 "output": The output tensor from the generator (torch.Tensor).
406
407 **TODO**:
408
409 - Zero the gradients of the generator optimizer (self.optimGenerator).
410
411 - Pass the input tensor through the generator to obtain the generated upsample image.
412
413 - Compute the generator loss using the `GeneratorLoss` module, which includes perceptual loss, TV loss, and adversarial loss.
414 Store the content loss, adversarial loss, and total generator loss for later so you can return it.
415
416 - Backpropagate the total generator loss to compute the gradients.
417
418 - Clip the gradients of the generator to prevent exploding gradients (use `torch.nn.utils.clip_grad_norm_ https://docs.pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_norm_.html>`_ with `max_norm=1.0`).
419
420 - Call `torch.nn.utils.clip_grad_norm_` again with `max_norm=1e9` and store the gradient norm for later so you can return it.
421
422 - Step the generator optimizer to update the generator's parameters.
423
424 - Return a dictionary containing the total generator loss, content loss, adversarial loss, the output and the gradient norm.
425 """
426 self.optimGenerator.zero_grad()
427 output = self.generator(input)
428
429 result = self.generatorLoss(output, target, epoch)
430 loss, content_loss, adversarial_loss = (
431 result["generator_loss"],
432 result["content_loss"],
433 result["adversarial_loss"],
434 )
435
436 loss.backward()
437
438 torch.nn.utils.clip_grad_norm_(self.generator.parameters(), max_norm=1.0)
439 gen_norm = torch.nn.utils.clip_grad_norm_(
440 self.generator.parameters(), max_norm=1e9
441 )
442
443 self.optimGenerator.step()
444
445 return {
446 "loss": loss.item(),
447 "content_loss": content_loss.item(),
448 "adversarial_loss": adversarial_loss.mean().item(),
449 "gradient_norm": gen_norm.item(),
450 "output": output.detach() if output is not None else None,
451 }
452
453 def train_batch(self, input, target, epoch):
454 """Train a batch of images using the critic and generator models.
455
456 Parameters:
457 -----------
458 input (torch.Tensor):
459 The input tensor containing the images to be processed by the generator.
460
461 target (torch.Tensor):
462 The target tensor containing the ground truth images for comparison.
463
464 epoch (int):
465 The current training epoch, used to scale the adversarial loss.
466
467 Returns:
468 --------
469 A dictionary containing the scores from the critic and generator models with the following keys:
470 - "critic": A dictionary containing the critic scores with keys "gradient_norm" and "loss_c".
471 - "generator": A dictionary containing the generator scores with keys "loss", "content_loss", "adversarial_loss", "gradient_norm", and "output".
472
473 **TODO**:
474
475 - Train the critic model using the `train_critic` method with the input and target tensors.
476
477 - Increment the critic updates counter (self.criticUpdates).
478
479 - If the critic updates counter is 5 or the epoch is less than 1,
480 train the generator model using the `train_generator` method with the input and target tensors, and the current epoch. Also reset the critic updates counter to 0.
481
482 - If the critic updates counter is not 5 and the epoch is greater than or equal to 1, skip training the generator and set the generator scores to None.
483
484 - Return the critic scores and generator scores (if available)
485 """
486 # Train Critic every step
487 scoresCritic = self.train_critic(input, target)
488 self.criticUpdates += 1
489
490 # Train Generator only every 4th step
491 if self.criticUpdates == 5 or epoch < 1:
492 scoresGenerator = self.train_generator(input, target, epoch)
493 self.criticUpdates = 0
494 else:
495 scoresGenerator = None
496
497 return {"critic": scoresCritic, "generator": scoresGenerator}
498
499
500if __name__ == "__main__":
501 prefix = "upscale4x_adversarialloss"
502
503 dataloader = get_dataloader(inputSize=64, outputSize=256, batch_size=48)
504
505 trainer = UpscaleTrainer()
506
507 train(prefix, trainer, dataloader)