Quellcode für adversarialloss.misc

import torch
from torch import nn
from torchvision.models import vgg16


class VGG16PerceptualLoss(nn.Module):
    def __init__(self):
        """Initialize the VGG16 perceptual loss model.

        It computes the perceptual loss as the mean squared error between the features.

        The model is set to evaluation mode and the parameters are frozen.

        **TODO**:

        - Load the VGG16 model with pretrained weights. Use `torchvision.models.vgg16(pretrained=True)`.

        - Restrict the VGG16 model to the first 16 layers by using `self.vgg = vgg16(pretrained=True).features[:16]`.

        - Set the model to evaluation mode using `.eval()`.

        - Freeze the parameters of the VGG16 model by setting `param.requires_grad = False` for all parameters.
          NOTE: Iterate through all parameters by using the `self.vgg.parameters()`-Iterator.

        - Initialize the L2 loss function using `nn.MSELoss()`.
        """
        super(VGG16PerceptualLoss, self).__init__()
        self.vgg = vgg16(pretrained=True).features[:16].eval().cuda()

        for param in self.vgg.parameters():
            param.requires_grad = False

        self.l1_loss = nn.L1Loss()

    def forward(self, output, target):
        """Compute the perceptual loss between two images.

        Parameters:
        -----------
            output (torch.Tensor):
              The output image tensor from the upscaler network.

            target (torch.Tensor):
              The target image tensor from ground truth.

        Returns:
        --------
            torch.Tensor:
              The computed perceptual loss as the mean squared error between the features of the two images.

        **TODO**:

        - Resize `output` and `target` to 224x224 using `torch.nn.functional.interpolate()`. Use `mode='bilinear'` and `align_corners=False`.

        - Pass `output` through the VGG16 model to get the features `f1`.

        - Pass `target` through the VGG16 model to get the features `f2`. Note: You should use `torch.no_grad()` to avoid computing gradients for the target image.

        - Compute and return the L2 loss between `f1` and `f2` using `self.l2_loss(f1, f2)`.
        """
        f1 = self.vgg(output)

        with torch.no_grad():
            f2 = self.vgg(target)

        return self.l1_loss(f1, f2)


import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import os

from tqdm import tqdm
import lpips
from torch.utils.tensorboard import SummaryWriter


# Load the dataset from flowersSquared folder
class CustomImageDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, transformInput=None, transformOutput=None):
        self.root_dir = root_dir
        self.transformInput = transformInput
        self.transformOutput = transformOutput
        self.image_files = [
            f
            for f in os.listdir(root_dir)
            if f.lower().endswith((".png", ".jpg", ".jpeg", ".bmp", ".tiff"))
        ]

        # img_paths = [os.path.join(self.root_dir, self.image_files[idx]) for idx in range(len(self.image_files))]
        # self.images = [Image.open(img_path).convert("RGB") for img_path in img_paths]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, self.image_files[idx])
        image = Image.open(img_path).convert("RGB")
        # image = self.images[idx]

        if self.transformInput:
            imageInput = self.transformInput(image)

        if self.transformOutput:
            imageOutput = self.transformOutput(image)

        return imageInput, imageOutput  # Return both inputs and outputs


def get_dataloader(inputSize=128, outputSize=256, batch_size=32):
    transformInput = transforms.Compose(
        [
            transforms.Resize((inputSize, inputSize)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )

    transformOutput = transforms.Compose(
        [
            transforms.Resize((outputSize, outputSize)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )

    root_dir = os.path.join(
        os.path.dirname(__file__), "../flower_dataset/flowersSquared"
    )

    dataset = CustomImageDataset(
        root_dir=root_dir,
        transformInput=transformInput,
        transformOutput=transformOutput,
    )

    dataloader = DataLoader(
        dataset, batch_size=batch_size, shuffle=True, pin_memory=True  # , num_workers=8
    )

    return dataloader


class ResNetBlock(nn.Module):
[Doku] def __init__(self, in_channels, out_channels, kernel_size=9, padding=None): """Initialisiert einen ResNet-Block mit zwei Convolutional-Schichten, Batch-Normalisierung und ReLU-Aktivierung. Parameters: ----------- in_channels (int): Anzahl der Eingabekanäle. out_channels (int): Anzahl der Ausgabekanäle. kernel_size (int, optional): Größe des Convolutional-Kernels. Standard ist 9. padding (int, optional): Padding für die Convolutional-Schichten. Standard ist None. In dem Fall wird das Padding automatisch berechnet, so dass die Ausgabe die gleiche Größe wie die Eingabe hat. """ super(ResNetBlock, self).__init__() if padding is None: padding = (kernel_size - 1) // 2 self.conv1 = nn.Conv2d( in_channels, out_channels, kernel_size=kernel_size, padding=padding, bias=False, ) self.bn1 = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d( out_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding, bias=False, ) self.bn2 = nn.BatchNorm2d(out_channels) if in_channels != out_channels: self.shortcut = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), ) else: self.shortcut = nn.Identity()
def forward(self, x): residual = self.shortcut(x) out = self.conv1(x) out = self.bn1(self.relu(out)) out = self.conv2(out) out = out + residual return self.bn2(self.relu(out)) def save_checkpoint(checkpoint_dict, epoch, filename="checkpoint.pth"): """Speichert den aktuellen Zustand des Modells und des Optimierers in einer Datei.""" checkpoint_dict = { key: value.state_dict() for key, value in checkpoint_dict.items() } checkpoint_dict["epoch"] = epoch torch.save( checkpoint_dict, filename, ) def load_checkpoint(checkpoint_dict, filename="checkpoint.pth"): """Lädt den Zustand des Modells und des Optimierers aus einer Datei.""" try: checkpoint = torch.load(filename, weights_only=True) for key in checkpoint_dict: if key in checkpoint: checkpoint_dict[key].load_state_dict(checkpoint[key]) else: raise KeyError(f"Key {key} not found in checkpoint.") return checkpoint["epoch"] except Exception as e: print(f"Fehler beim Laden des Checkpoints {filename}: {e}") print("Starte ohne gespeicherten Zustand.") return 0 def log_metrics(writer, metrics): for key, value in metrics.items(): step = value.step if value.count <= 20: continue writer.add_scalar(key, value.compute(), global_step=step) writer.flush() class TVLoss(nn.Module): def __init__(self): super(TVLoss, self).__init__() def forward(self, img): return torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :])) + torch.mean( torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]) ) # Denormalize images def denormalize(tensor): mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1).cuda() std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1).cuda() return tensor * std + mean def log_images(writer, model, dataloader, epoch): model.eval() with torch.no_grad(): stiches = [] for i, (input, target) in enumerate(dataloader): if i >= 3: # Log only first 3 images break input = input.cuda() target = target.cuda() output = model(input) # Rescale to 256x256 and stitch together input_resized = torch.nn.functional.interpolate( input[0:1], size=(256, 256), mode="bilinear", align_corners=False ) output_resized = torch.nn.functional.interpolate( output[0:1], size=(256, 256), mode="bilinear", align_corners=False ) target_resized = torch.nn.functional.interpolate( target[0:1], size=(256, 256), mode="bilinear", align_corners=False ) input_norm = denormalize(input_resized[0]).clamp(0, 1) output_norm = denormalize(output_resized[0]).clamp(0, 1) target_norm = denormalize(target_resized[0]).clamp(0, 1) # Stitch images horizontally stitched = torch.cat([input_norm, output_norm, target_norm], dim=2) stiches.append(stitched) # Convert to grid and log stitched = torch.cat(stiches, dim=1) writer.add_image(f"Images", stitched, epoch) model.train() class PSNR(nn.Module): def __init__(self, max_val=1.0): super(PSNR, self).__init__() self.max_val = max_val def forward(self, output, target): mse = F.mse_loss(output, target) psnr = 20 * torch.log10(self.max_val / torch.sqrt(mse)) return psnr class Metric: def __init__(self, factor=1.0, abs=False): self.total = 0.0 self.count = 0 self.step = 0 self.factor = factor self.abs = abs def state_dict(self): return { "total": self.total, "count": self.count, "step": self.step, "factor": self.factor, "abs": self.abs, } def load_state_dict(self, state_dict): self.total = state_dict["total"] self.count = state_dict["count"] self.step = state_dict["step"] self.factor = state_dict["factor"] self.abs = state_dict["abs"] def update(self, value): self.total += value self.count += 1 self.step += 1 def compute(self, reset=True): if self.count == 0: avg = 0.0 else: avg = self.total / self.count if reset: self.total = 0.0 self.count = 0 if self.abs: avg = abs(avg) return self.factor * avg def train(prefix, trainer, dataloader): print(f"Training {prefix} model...") metric = lpips.LPIPS(net="vgg").cuda() # Using SqueezeNet for perceptual loss mseMetric = nn.MSELoss() psnrMetric = PSNR(max_val=6.0) writer = SummaryWriter(f"runs/{prefix}") lpips_score = Metric() mse_score = Metric() psnr_score = Metric() content_loss_score = Metric() loss_c_score = Metric(abs=True) generator_loss_score = Metric() adversarial_loss_score = Metric(abs=True) gradient_norm_score = Metric() generator_gradient_norm_score = Metric() scores = { "LPIPS": lpips_score, "MSE": mse_score, "PSNR": psnr_score, "Content": content_loss_score, "loss_C": loss_c_score, "Generator": generator_loss_score, "Adversarial": adversarial_loss_score, "Critic Gradient Norm": gradient_norm_score, "Generator Gradient Norm": generator_gradient_norm_score, } checkpoint_dict = { "generator": trainer.generator, "critic": trainer.critic, "optimGenerator": trainer.optimGenerator, "optimCritic": trainer.optimCritic, } | scores ep = load_checkpoint(checkpoint_dict, filename=f"{prefix}_checkpoint.pt") for epoch in range(ep, ep + 1000): bar = tqdm(dataloader) for input, target in bar: input = input.cuda() target = target.cuda() result = trainer.train_batch(input, target, epoch) scoresCritic, scoresGenerator = (result["critic"], result["generator"]) output = scoresGenerator["output"] if scoresGenerator is not None else None if scoresCritic is not None: gradient_norm_score.update(scoresCritic["gradient_norm"]) loss_c_score.update(scoresCritic["loss_c"]) if scoresGenerator is not None and output is not None: generator_gradient_norm_score.update(scoresGenerator["gradient_norm"]) content_loss_score.update(scoresGenerator["content_loss"]) adversarial_loss_score.update(scoresGenerator["adversarial_loss"]) generator_loss_score.update(scoresGenerator["loss"]) lpips_score.update( metric( 2.0 * denormalize(output) - 1.0, 2.0 * denormalize(target) - 1.0 ) .mean() .item() ) mse_score.update(mseMetric(output, target).item()) psnr_score.update(psnrMetric(output, target).item()) bar.set_description( f"[{epoch+1}], Content: {content_loss_score.compute(reset=False):.3f}, loss_C: {loss_c_score.compute(reset=False):.3f}" ) log_metrics(writer, scores) log_images(writer, trainer.generator, dataloader, epoch + 1) save_checkpoint( checkpoint_dict, epoch + 1, filename=f"{prefix}_checkpoint.pt", )