TensorBoard - Musterlösung

  1import torch
  2from torch import nn
  3import torchvision
  4import os
  5from misc import (
  6    CNNNetwork,
  7    load_data,
  8    epoch,
  9    save_checkpoint,
 10    load_checkpoint,
 11    DEVICE,
 12    LR,
 13)
 14
 15from torch.utils.tensorboard import SummaryWriter
 16
 17
 18class TensorBoardLogger:
 19    def __init__(self):
 20        """
 21        Initialisiert den TensorBoard-Logger.
 22        """
 23        self.create_writer()
 24        self._reset_samples_statistics()
 25        self._reset_metrics()
 26
 27    def create_writer(self):
 28        """
 29        Erstellt einen TensorBoard-SummaryWriter, der die Logs in einem Verzeichnis speichert.
 30
 31        **TODO**:
 32        Erstellen Sie einen `SummaryWriter <https://docs.pytorch.org/docs/stable/tensorboard.html#torch.utils.tensorboard.writer.SummaryWriter>`_, der die Logs in einem Verzeichnis namens "runs"
 33        speichert. Das Verzeichnis sollte im gleichen Verzeichnis wie dieses Skript liegen.
 34        Verwenden Sie `os.path.dirname <https://www.tutorialspoint.com/python/os_path_dirname.htm#:~:text=The%20Python%20os.,the%20specified%20file%20or%20directory.>`_ `(os.path.abspath <https://www.geeksforgeeks.org/python/python-os-path-abspath-method-with-example/>`_ `(__file__))`, um den Pfad zum aktuellen Verzeichnis zu erhalten,
 35        und `os.path.join() <https://www.geeksforgeeks.org/python/python-os-path-join-method/>`_, um den Pfad zum "runs"-Verzeichnis zu erstellen.
 36        """
 37        dirname = os.path.dirname(os.path.abspath(__file__))
 38        board_path = os.path.join(dirname, "runs")
 39
 40        self.writer = SummaryWriter(log_dir=board_path)
 41
 42    def _reset_metrics(self):
 43        """Setzt die Metriken zurück."""
 44        self.metrics = {"total_loss": 0.0, "total_correct": 0.0, "total_samples": 0}
 45
 46    def _reset_samples_statistics(self):
 47        """Setzt die Statistik der Samples zurück."""
 48        self.sample_statistics = {}
 49        for i in range(10):
 50            self.sample_statistics[i] = {
 51                "samples": torch.tensor([], device=DEVICE),
 52                "loss": torch.tensor([], device=DEVICE),
 53            }
 54
 55    def log_graph(self, model, input_tensor):
 56        """Loggt den Graphen des Modells in TensorBoard.
 57
 58        Parameters:
 59        -----------
 60        model (nn.Module):
 61          Das Modell, dessen Graph geloggt werden soll.
 62
 63        input_tensor (torch.Tensor):
 64          Ein Beispiel-Eingabetensor, der die Form des Eingabedaten repräsentiert.
 65
 66        **TODO**:
 67        Verwenden Sie `writer.add_graph() <https://pytorch.org/docs/stable/tensorboard.html#torch.utils.tensorboard.writer.SummaryWriter.add_graph>`_, um den Graphen
 68        des Modells zu loggen.
 69        """
 70        self.writer.add_graph(model, input_tensor)
 71
 72    def update_metrics(self, logits, labels):
 73        """Aktualisiert die Metriken für Trainings- oder Validierungsdaten.
 74
 75        Parameters:
 76        -----------
 77        logits (torch.Tensor):
 78          Die zum aktuell verarbeiteten Batch gehörenden Logits (Vorhersagen) des Modells.
 79
 80        labels (torch.Tensor):
 81          Die zugehörigen Labels für den Batch.
 82
 83        Updates:
 84        --------
 85        self.metrics (dict):
 86          Diese Variable speichert die Metriken `total_loss`, `total_correct` und `total_samples`.
 87
 88        self.metrics["total_loss"] (float):
 89          Der kumulierte summarische Verlust über alle bisherigen Batches.
 90
 91        self.metrics["total_correct"] (int):
 92          Die Anzahl der korrekten Vorhersagen über alle bisherigen Batches.
 93
 94        self.metrics["total_samples"] (int):
 95          Die Gesamtzahl der verarbeiteten Samples über alle bisherigen Batches.
 96
 97        **TODO**:
 98        Aktualisiere die Metriken `total_loss`, `total_correct` und `total_samples` für den aktuellen Batch.
 99        - Berechne den Verlust für den Batch mit `nn.CrossEntropyLoss()`.
100        - Zähle die Anzahl der korrekten Vorhersagen im Batch. Hinweis: Verwende `torch.argmax(logits, 1) <https://docs.pytorch.org/docs/stable/generated/torch.argmax.html>`_ um die Vorhersagen zu erhalten und vergleiche sie mit den Labels.
101        - Aktualisiere die Metriken in `self.metrics["total_loss"]`, `self.metrics["total_correct"]` und `self.metrics["total_samples"]` entsprechend.
102        """
103        criterion = nn.CrossEntropyLoss()
104        loss = criterion(logits, labels)
105
106        # Berechne die Anzahl der korrekten Vorhersagen
107        predicted = torch.argmax(logits, 1)
108        correct = (predicted == labels).sum().item()
109
110        # Aktualisiere die Metriken
111        self.metrics["total_loss"] += loss.item()
112        self.metrics["total_correct"] += correct
113        self.metrics["total_samples"] += labels.size(0)
114
115    def log_metrics(self, step, train=True):
116        """Loggt Metriken in TensorBoard.
117
118        Parameters:
119        -----------
120        step (int):
121          Der aktuelle Schritt, der für das Logging verwendet wird.
122
123        train (bool):
124          Gibt an, ob die Metriken aus dem Trainings- oder Validierungsdatensatz stammen.
125
126        **TODO**:
127        Logge die skalaren Metriken `loss` und `accuracy` über den `SummaryWriter` ins TensorBoard.
128        Die Metriken sollten in zwei verschiedenen Tags gespeichert werden: "train" für Trainingsmetriken und "validation" für
129        Validierungsmetriken. Überprüfe, ob `train` wahr ist, um zu entscheiden, ob es sich um Trainings- oder Validierungsmetriken handelt.
130
131        Die Metriken sollten mit dem aktuellen Schritt `step` geloggt werden.
132        - Berechnen Sie den Verlust als `self.metrics["total_loss"] / self.metrics["total_samples"]`.
133        - Berechnen Sie die Genauigkeit als `self.metrics["total_correct"] / self.metrics["total_samples"]`.
134        - Verwenden Sie `self.writer.add_scalar() <https://docs.pytorch.org/docs/stable/tensorboard.html#torch.utils.tensorboard.writer.SummaryWriter.add_scalar>`_ um die Metriken zu loggen.
135        - Rufen Sie zum Schluß `self._reset_metrics()` auf, um die Metriken zurückzusetzen.
136        """
137        loss = self.metrics["total_loss"] / self.metrics["total_samples"]
138        accuracy = self.metrics["total_correct"] / self.metrics["total_samples"]
139
140        tag = "train" if train else "validation"
141        self.writer.add_scalar(f"{tag}/loss", loss, step)
142        self.writer.add_scalar(f"{tag}/accuracy", accuracy, step)
143
144        self._reset_metrics()
145
146    def log_sample_statistics(self, train, step):
147        """Loggt die am schlechtesten klassifizierten Samples in TensorBoard.
148
149        Parameters:
150        -----------
151        train (bool):
152          Gibt an, ob die Samples aus dem Trainings- oder Validierungsdatensatz stammen.
153
154        step (int):
155          Der aktuelle Schritt, der für das Logging verwendet wird.
156
157        **TODO**:
158        Logge die am schlechtesten klassifizierten Samples für jede Klasse in TensorBoard.
159        Die Samples sollten in einem Grid-Format geloggt werden, wobei jede Klasse in einem eigenen Tag gespeichert wird.
160
161        - Iteriere über die Klassen-IDs (0-9) und logge die Samples für jede Klasse.
162        - Verwende `self.sample_statistics[cls_id]["samples"]` um die Samples für die Klasse `cls_id` zu erhalten.
163        - Verwende `torchvision.utils.make_grid() <https://docs.pytorch.org/vision/stable/generated/torchvision.utils.make_grid.html>`_ um die Samples in einem Grid zu formatieren
164        - Übergeben Sie `normalize=True` um die Samples zu normalisieren.
165        - Logge die Samples mit `self.writer.add_image() <https://docs.pytorch.org/docs/stable//tensorboard.html#torch.utils.tensorboard.writer.SummaryWriter.add_image>`_ unter dem Tag `f"{tag}/worst_samples/class_{cls_id}"`, wobei `tag` entweder "train" oder "validation" ist.
166        - Rufen Sie ganz zum Schluß `self._reset_samples_statistics()` auf, um die Statistik der Samples zurückzusetzen, nachdem die Samples geloggt wurden.
167        """
168        # Logge die schlechtesten Samples, wenn die Epoche abgeschlossen ist
169        tag = "train" if train else "validation"
170
171        # Iteriere über die Klassen-IDs (0-9) und logge die Samples für jede Klasse
172        for cls_id in range(10):
173            # Erstelle ein Grid aus den Samples der Klasse
174            grid = torchvision.utils.make_grid(
175                self.sample_statistics[cls_id]["samples"],
176                normalize=True,
177            )
178
179            # Logge das Grid der Samples in TensorBoard
180            self.writer.add_image(
181                f"{tag}/worst_samples/class_{cls_id}",
182                grid,
183                global_step=step,
184            )
185
186        # Setze die Statistik der Samples zurück
187        self._reset_samples_statistics()
188
189    def update_sample_statistics(self, batch, labels, loss):
190        """Aggregiere die am schlechtesten klassifizierten Samples für jede Klasse.
191
192        Parameters:
193        -----------
194        batch (torch.Tensor):
195          Der Batch von Eingabedaten.
196
197        labels (torch.Tensor):
198          Die zugehörigen Labels für den Batch.
199
200        loss (torch.Tensor):
201          Der Verlust für den Batch, berechnet mit `nn.CrossEntropyLoss(reduction="none")`.
202
203        Updates:
204        --------
205        self.sample_statistics (dict):
206          Diese Variable speichert die Samples und deren zugehörigen Verlusten für jede Klasse.
207
208        self.sample_statistics[cls_id]["samples"] (torch.Tensor):
209          Enthält die bisher schwierigsten Samples für die Klasse `cls_id`.
210
211        self.sample_statistics[cls_id]["loss"] (torch.Tensor):
212          Enthält die bisher größten Loss-Werte für die Klasse `cls_id`.
213
214        Diese Methode aggregiert die am schlechtesten klassifizierten Samples für jede Klasse.
215        Die Samples werden in `self.sample_statistics` gespeichert, die für jede Klasse eine Liste von Samples
216        und deren zugehörigen Verlusten enthält.
217        Die Methode iteriert über die Klassen-IDs (0-9) und speichert die 64 Samples mit dem höchsten Verlust
218        für jede Klasse.
219
220        **TODO**:
221        - Iterieren Sie über die Klassen-IDs (0-9) und speichern Sie die 64 Samples mit dem höchsten Verlust für jede Klasse.
222        - Verwenden Sie `torch.cat() <https://docs.pytorch.org/docs/stable/generated/torch.cat.html>`_ um die Samples und Verluste für jede Klasse zu aggregieren.
223          Verwenden Sie `torch.clone() <https://docs.pytorch.org/docs/stable/generated/torch.clone.html>`_ `.detach() <https://docs.pytorch.org/docs/stable/generated/torch.Tensor.detach.html>`_ um sicherzustellen, dass die Samples und Verluste nicht mehr mit der
224           Gradientenberechnung von AutoGrad verbunden sind.
225        - Sortieren Sie die Samples nach Verlust in absteigender Reihenfolge und behalten Sie nur die 64 schlechtesten Samples.
226          Verwenden Sie `torch.argsort() <https://docs.pytorch.org/docs/stable/generated/torch.argsort.html>`_ um die Indizes der Samples nach Verlust zu sortieren.
227        - Aktualisieren Sie `self.sample_statistics` für jede Klasse mit den aggregierten Samples und Verlusten.
228        """
229        # Iteriert über die Klassen-IDs (0-9) und speichert die schlechtesten Samples
230        for cls_id in range(10):
231            # Filtere die Samples für die aktuelle Klasse
232            ids = labels == cls_id
233
234            # Konkatenieren der Samples und Verluste für die aktuelle Klasse
235            self.sample_statistics[cls_id]["samples"] = torch.cat(
236                [
237                    self.sample_statistics[cls_id]["samples"],
238                    batch[ids].clone().detach(),
239                ]
240            )
241            self.sample_statistics[cls_id]["loss"] = torch.cat(
242                [
243                    self.sample_statistics[cls_id]["loss"],
244                    loss[ids].clone().detach(),
245                ]
246            )
247
248            # Sortiere die Samples nach Verlust in absteigender Reihenfolge
249            sorted_indices = torch.argsort(
250                self.sample_statistics[cls_id]["loss"], descending=True
251            )
252
253            # Behalte nur die 64 schlechtesten Samples
254            sorted_indices = sorted_indices[:64]
255
256            # Aktualisiere die Samples und Verluste für die aktuelle Klasse
257            self.sample_statistics[cls_id]["samples"] = self.sample_statistics[cls_id][
258                "samples"
259            ][sorted_indices]
260
261            self.sample_statistics[cls_id]["loss"] = self.sample_statistics[cls_id][
262                "loss"
263            ][sorted_indices]
264
265
266if __name__ == "__main__":
267    training_set, validation_set = load_data()
268
269    # Initialisierung des Modells, Loss-Kriteriums und Optimierers
270    model = CNNNetwork().to(DEVICE)
271    criterion = nn.CrossEntropyLoss(reduction="none")
272    optimizer = torch.optim.Adam(
273        model.parameters(), lr=LR
274    )  # Checkpoint laden, falls vorhanden
275
276    # Checkpoint laden, falls vorhanden
277    dirname = os.path.dirname(os.path.abspath(__file__))
278    chkpt_path = os.path.join(dirname, "checkpoint.pth")
279
280    ep = load_checkpoint(model, optimizer, chkpt_path)
281    if ep > 0:
282        print(f"Checkpoint geladen, fortsetzen bei Epoche {ep}.")
283
284    # Das Modell trainieren
285    logger = TensorBoardLogger()
286
287    # Logge den Graphen des Modells
288    input_tensor = torch.randn(1, 3, 32, 32).to(DEVICE)  # Beispiel-Eingabetensor
289    logger.log_graph(model, input_tensor)
290
291    for n in range(ep, ep + 30):
292        epoch(model, n, True, training_set, criterion, optimizer, logger=logger)
293        epoch(model, n, False, validation_set, criterion, optimizer, logger=logger)
294
295        # Checkpoint nach jeder Epoche speichern
296        save_checkpoint(model, optimizer, n + 1, chkpt_path)