ResNet - Musterlösung

  1import os
  2import torch
  3import torch.nn as nn
  4from misc import (
  5    DEVICE,
  6    load_data,
  7    epoch,
  8    load_checkpoint,
  9    TensorBoardLogger,
 10    save_checkpoint,
 11    LR,
 12)
 13
 14
 15class ResidualBlock(nn.Module):
 16    def __init__(self, in_channels, out_channels, stride=1):
 17        """Initialisiert einen Residual Block.
 18
 19        Parameters:
 20        -----------
 21        in_channels (int):
 22          Anzahl der Eingabekanäle.
 23
 24        out_channels (int):
 25          Anzahl der Ausgabekanäle.
 26
 27        stride (int):
 28          Schrittweite für die Faltung. Standard ist 1.
 29
 30        **TODO**:
 31
 32        - Rufen Sie die `__init__` Methode der Basisklasse `nn.Module` auf.
 33
 34        - Initialisieren Sie dann die Schichten des Residual Blocks.
 35
 36        - Verwenden Sie `nn.Conv2d <https://docs.pytorch.org/docs/stable/generated/torch.nn.Conv2d.html>`_ für die Faltungsschichten. Setzen Sie `kernel_size=3`, `padding=1` und `bias=False`.
 37
 38        - Die erste Faltungsschicht sollte `in_channels` zu `out_channels` transformieren, die zweite Faltungsschicht sollte `out_channels` zu `out_channels` transformieren.
 39
 40        - Die ersten Faltungsschicht sollte `stride` als Schrittweite verwenden.
 41
 42        - Fügen Sie `nn.BatchNorm2d <https://docs.pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html>`_ nach jeder Faltungsschicht hinzu. Achten Sie darauf, dass die Batch-Normalisierung die gleiche Anzahl an Ausgabekanälen wie die Faltungsschicht hat.
 43
 44        - Verwenden Sie `nn.ReLU <https://docs.pytorch.org/docs/stable/generated/torch.nn.ReLU.html>`_ als Aktivierungsfunktion.
 45
 46        - Implementieren Sie die Shortcut-Verbindung. Wenn `stride` nicht 1 ist oder `in_channels` nicht gleich `out_channels`, verwenden Sie eine 1x1 Faltung, um die Dimensionen anzupassen. Andernfalls verwenden Sie `nn.Identity()`.
 47        """
 48        super(ResidualBlock, self).__init__()
 49        self.conv1 = nn.Sequential(
 50            nn.Conv2d(
 51                in_channels,
 52                out_channels,
 53                kernel_size=3,
 54                padding=1,
 55                stride=stride,
 56                bias=False,
 57            ),
 58            nn.BatchNorm2d(out_channels),
 59        )
 60        self.conv2 = nn.Sequential(
 61            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
 62            nn.BatchNorm2d(out_channels),
 63        )
 64        self.relu = nn.ReLU(inplace=True)
 65
 66        # Shortcut connection
 67        if stride != 1 or in_channels != out_channels:
 68            self.shortcut = nn.Conv2d(
 69                in_channels, out_channels, kernel_size=1, stride=stride, bias=False
 70            )
 71        else:
 72            self.shortcut = nn.Identity()
 73
 74    def forward(self, x):
 75        """Führt den Vorwärtsdurchlauf des Residual Blocks aus.
 76
 77        Parameters:
 78        -----------
 79        x (torch.Tensor):
 80          Eingabetensor.
 81
 82        **TODO**:
 83        Implementieren Sie den Vorwärtsdurchlauf des Residual Blocks.
 84        Orientieren Sie sich an der in der Aufgabenstellung gegebenen Beschreibung sowie der Grafik.
 85        """
 86        residual = self.shortcut(x)
 87        out = self.relu(self.conv1(x))
 88        out = self.relu(self.conv2(out) + residual)
 89        return out
 90
 91
 92class ResNet(nn.Module):
 93    def __init__(self, num_classes=10):
 94        """Initialisiert das ResNet Modell.
 95
 96        Parameters:
 97        -----------
 98        num_classes (int):
 99          Anzahl der Klassen für die Klassifikation.
100
101        **TODO**:
102
103        - Rufen Sie die `__init__` Methode der Basisklasse `nn.Module` auf.
104
105        - Definieren Sie dann die Schichten des ResNet Modells.
106
107        - Verwenden Sie `nn.Conv2d <https://docs.pytorch.org/docs/stable/generated/torch.nn.Conv2d.html>`_ für die erste Faltungsschichten um von 3 auf 32 Kanäle zu transformieren. Setzen Sie `kernel_size=7`, `padding=3` und `stride=2` für diese Schicht.
108
109        - Fügen Sie `nn.BatchNorm2d <https://docs.pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html>`_ und `nn.ReLU <https://docs.pytorch.org/docs/stable/generated/torch.nn.ReLU.html>`_ nach der ersten Faltungsschicht hinzu.
110
111        - Hinweis: Sie können die `nn.Sequential` Klasse verwenden, um mehrere Schichten zu kombinieren.
112
113        - Erstellen Sie dann drei Ebenen mit der Methode `make_layer`.
114
115        - Die erste Ebene sollte 6 Residual Blocks mit `in_channels=32`, `out_channels=32` und `stride=1` enthalten.
116
117        - Die zweite Ebene sollte 6 Residual Blocks mit `in_channels=32`, `out_channels=64` und `stride=2` enthalten.
118
119        - Die dritte Ebene sollte 12 Residual Blocks mit `in_channels=64`, `out_channels=128` und `stride=2` enthalten.
120
121        - Fügen Sie eine `nn.AvgPool2d <https://docs.pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html>`_ Schicht mit einem Kernel von (4, 4) hinzu, um die räumliche Dimension der Feature-Maps zu reduzieren.
122
123        - Fügen Sie eine voll verbundene Schicht `nn.Linear <https://docs.pytorch.org/docs/stable/generated/torch.nn.Linear.html>`_ hinzu, die die Ausgabe der Durchschnittspooling-Schicht auf `num_classes` transformiert.
124
125        - Die Eingabegröße für die voll verbundene Schicht sollte 128 sein, da die letzte Residual Block Schicht 128 Kanäle hat.
126
127        - Verwenden Sie `torch.flatten <https://pytorch.org/docs/stable/generated/torch.flatten.html>`_ um die Ausgabe der Durchschnittspooling-Schicht in einen Vektor umzuwandeln, bevor Sie sie an die voll verbundene Schicht weitergeben.
128        """
129        super(ResNet, self).__init__()
130
131        # Initlal block
132        self.layer0 = nn.Sequential(
133            nn.Conv2d(3, 32, kernel_size=7, padding=3, stride=2, bias=False),
134            nn.BatchNorm2d(32),
135            nn.ReLU(inplace=True),
136        )
137
138        # Residual blocks
139        self.layer1 = self.make_layer(32, 32, 6, 1)
140        self.layer2 = self.make_layer(32, 64, 6, 2)
141        self.layer3 = self.make_layer(64, 128, 12, 2)
142
143        # Average pooling and fully connected layer
144        self.avgpool = nn.AvgPool2d((4, 4))
145        self.fc = nn.Linear(128, num_classes)
146
147    def make_layer(self, in_channels, out_channels, num_blocks, stride):
148        """Erstellt eine Sequenz von Residual Blocks.
149
150        Parameters:
151        -----------
152        in_channels (int):
153          Anzahl der Eingabekanäle.
154
155        out_channels (int):
156          Anzahl der Ausgabekanäle.
157
158        num_blocks (int):
159          Anzahl der Residual Blocks in dieser Schicht.
160
161        stride (int):
162          Schrittweite für die erste Faltungsschicht des ersten Blocks.
163
164        Returns:
165        --------
166        nn.Sequential:
167          Eine Sequenz von Residual Blocks.
168
169        **TODO**:
170
171        - Erstellen Sie eine Liste von Schichten, die die Residual Blocks enthalten.
172
173        - Die erste Schicht sollte einen Residual Block mit `in_channels`, `out_channels` und `stride` sein.
174
175        - Die folgenden Schichten sollten Residual Blocks mit gleichbleibender Kanalanzahl sein. Verwenden Sie `out_channels` sowohl für die Eingabe- als auch für die Ausgabekanäle.
176
177        - Verwenden Sie `nn.Sequential` um die Schichten zu kombinieren und zurückzugeben.
178
179        **Hinweis**:
180
181        - Die erste Schicht sollte die Schrittweite `stride` verwenden, während die anderen Schichten eine Schrittweite von 1 haben.
182
183        - Sie können die gewünschten Layer mit `nn.Sequential <https://docs.pytorch.org/docs/stable/generated/torch.nn.Sequential.html>`_ kombinieren.
184
185        - Dazu können Sie die Blöcke zunächst in einer Liste (z.B. `layers`) sammeln und dann `nn.Sequential(*layers)` verwenden, um sie zu kombinieren.
186        """
187        strides = [stride] + [1] * (num_blocks - 1)
188        layers = []
189        for s in strides:
190            layers.append(ResidualBlock(in_channels, out_channels, s))
191            in_channels = out_channels
192
193        return nn.Sequential(*layers)
194
195    def forward(self, x):
196        """Führt den Vorwärtsdurchlauf des ResNet Modells aus.
197
198        Parameters:
199        -----------
200        x (torch.Tensor):
201          Eingabetensor.
202
203        **TODO**:
204        Implementieren Sie den Vorwärtsdurchlauf des ResNet Modells.
205        Orientieren Sie sich an der in der Aufgabenstellung gegebenen Beschreibung sowie der Grafik.
206        """
207        x = self.layer0(x)
208        x = self.layer1(x)
209        x = self.layer2(x)
210        x = self.layer3(x)
211        x = self.avgpool(x)
212        x = torch.flatten(x, 1)
213        x = self.fc(x)
214        return x
215
216
217if __name__ == "__main__":
218    training_set, validation_set = load_data()
219
220    # Initialisierung des Modells, Loss-Kriteriums und Optimierers
221    model = ResNet().to(DEVICE)
222    criterion = nn.CrossEntropyLoss(reduction="none")
223    optimizer = torch.optim.Adam(
224        model.parameters(), lr=LR
225    )  # Checkpoint laden, falls vorhanden
226
227    # Checkpoint laden, falls vorhanden
228    dirname = os.path.dirname(os.path.abspath(__file__))
229    chkpt_path = os.path.join(dirname, "checkpoint.pth")
230
231    ep = load_checkpoint(model, optimizer, chkpt_path)
232    if ep > 0:
233        print(f"Checkpoint geladen, fortsetzen bei Epoche {ep}.")
234
235    # Das Modell trainieren
236    logger = TensorBoardLogger()
237
238    # Logge den Graphen des Modells
239    input_tensor = torch.randn(1, 3, 32, 32).to(DEVICE)  # Beispiel-Eingabetensor
240    logger.log_graph(model, input_tensor)
241
242    umap_model = None
243    for n in range(ep, ep + 200):
244        epoch(
245            model,
246            n,
247            True,
248            training_set,
249            criterion,
250            optimizer,
251            logger=logger,
252            log_after_n_samples=10000,
253        )
254        epoch(model, n, False, validation_set, criterion, optimizer, logger=logger)
255
256        # Checkpoint nach jeder Epoche speichern
257        save_checkpoint(model, optimizer, n + 1, chkpt_path)