1import torch
2from torch import nn
3import torchvision
4import os
5from misc import (
6 CNNNetwork,
7 load_data,
8 epoch,
9 save_checkpoint,
10 load_checkpoint,
11 DEVICE,
12 LR,
13)
14
15from torch.utils.tensorboard import SummaryWriter
16
17
18class TensorBoardLogger:
19 def __init__(self):
20 """
21 Initialisiert den TensorBoard-Logger.
22 """
23 self.create_writer()
24 self._reset_samples_statistics()
25 self._reset_metrics()
26
27 def create_writer(self):
28 """
29 Erstellt einen TensorBoard-SummaryWriter, der die Logs in einem Verzeichnis speichert.
30
31 **TODO**:
32 Erstellen Sie einen `SummaryWriter <https://docs.pytorch.org/docs/stable/tensorboard.html#torch.utils.tensorboard.writer.SummaryWriter>`_, der die Logs in einem Verzeichnis namens "runs"
33 speichert. Das Verzeichnis sollte im gleichen Verzeichnis wie dieses Skript liegen.
34 Verwenden Sie `os.path.dirname <https://www.tutorialspoint.com/python/os_path_dirname.htm#:~:text=The%20Python%20os.,the%20specified%20file%20or%20directory.>`_ `(os.path.abspath <https://www.geeksforgeeks.org/python/python-os-path-abspath-method-with-example/>`_ `(__file__))`, um den Pfad zum aktuellen Verzeichnis zu erhalten,
35 und `os.path.join() <https://www.geeksforgeeks.org/python/python-os-path-join-method/>`_, um den Pfad zum "runs"-Verzeichnis zu erstellen.
36 """
37 dirname = os.path.dirname(os.path.abspath(__file__))
38 board_path = os.path.join(dirname, "runs")
39
40 self.writer = SummaryWriter(log_dir=board_path)
41
42 def _reset_metrics(self):
43 """Setzt die Metriken zurück."""
44 self.metrics = {"total_loss": 0.0, "total_correct": 0.0, "total_samples": 0}
45
46 def _reset_samples_statistics(self):
47 """Setzt die Statistik der Samples zurück."""
48 self.sample_statistics = {}
49 for i in range(10):
50 self.sample_statistics[i] = {
51 "samples": torch.tensor([], device=DEVICE),
52 "loss": torch.tensor([], device=DEVICE),
53 }
54
55 def log_graph(self, model, input_tensor):
56 """Loggt den Graphen des Modells in TensorBoard.
57
58 Parameters:
59 -----------
60 model (nn.Module):
61 Das Modell, dessen Graph geloggt werden soll.
62
63 input_tensor (torch.Tensor):
64 Ein Beispiel-Eingabetensor, der die Form des Eingabedaten repräsentiert.
65
66 **TODO**:
67 Verwenden Sie `writer.add_graph() <https://pytorch.org/docs/stable/tensorboard.html#torch.utils.tensorboard.writer.SummaryWriter.add_graph>`_, um den Graphen
68 des Modells zu loggen.
69 """
70 self.writer.add_graph(model, input_tensor)
71
72 def update_metrics(self, logits, labels):
73 """Aktualisiert die Metriken für Trainings- oder Validierungsdaten.
74
75 Parameters:
76 -----------
77 logits (torch.Tensor):
78 Die zum aktuell verarbeiteten Batch gehörenden Logits (Vorhersagen) des Modells.
79
80 labels (torch.Tensor):
81 Die zugehörigen Labels für den Batch.
82
83 Updates:
84 --------
85 self.metrics (dict):
86 Diese Variable speichert die Metriken `total_loss`, `total_correct` und `total_samples`.
87
88 self.metrics["total_loss"] (float):
89 Der kumulierte summarische Verlust über alle bisherigen Batches.
90
91 self.metrics["total_correct"] (int):
92 Die Anzahl der korrekten Vorhersagen über alle bisherigen Batches.
93
94 self.metrics["total_samples"] (int):
95 Die Gesamtzahl der verarbeiteten Samples über alle bisherigen Batches.
96
97 **TODO**:
98 Aktualisiere die Metriken `total_loss`, `total_correct` und `total_samples` für den aktuellen Batch.
99 - Berechne den Verlust für den Batch mit `nn.CrossEntropyLoss()`.
100 - Zähle die Anzahl der korrekten Vorhersagen im Batch. Hinweis: Verwende `torch.argmax(logits, 1) <https://docs.pytorch.org/docs/stable/generated/torch.argmax.html>`_ um die Vorhersagen zu erhalten und vergleiche sie mit den Labels.
101 - Aktualisiere die Metriken in `self.metrics["total_loss"]`, `self.metrics["total_correct"]` und `self.metrics["total_samples"]` entsprechend.
102 """
103 criterion = nn.CrossEntropyLoss()
104 loss = criterion(logits, labels)
105
106 # Berechne die Anzahl der korrekten Vorhersagen
107 predicted = torch.argmax(logits, 1)
108 correct = (predicted == labels).sum().item()
109
110 # Aktualisiere die Metriken
111 self.metrics["total_loss"] += loss.item()
112 self.metrics["total_correct"] += correct
113 self.metrics["total_samples"] += labels.size(0)
114
115 def log_metrics(self, step, train=True):
116 """Loggt Metriken in TensorBoard.
117
118 Parameters:
119 -----------
120 step (int):
121 Der aktuelle Schritt, der für das Logging verwendet wird.
122
123 train (bool):
124 Gibt an, ob die Metriken aus dem Trainings- oder Validierungsdatensatz stammen.
125
126 **TODO**:
127 Logge die skalaren Metriken `loss` und `accuracy` über den `SummaryWriter` ins TensorBoard.
128 Die Metriken sollten in zwei verschiedenen Tags gespeichert werden: "train" für Trainingsmetriken und "validation" für
129 Validierungsmetriken. Überprüfe, ob `train` wahr ist, um zu entscheiden, ob es sich um Trainings- oder Validierungsmetriken handelt.
130
131 Die Metriken sollten mit dem aktuellen Schritt `step` geloggt werden.
132 - Berechnen Sie den Verlust als `self.metrics["total_loss"] / self.metrics["total_samples"]`.
133 - Berechnen Sie die Genauigkeit als `self.metrics["total_correct"] / self.metrics["total_samples"]`.
134 - Verwenden Sie `self.writer.add_scalar() <https://docs.pytorch.org/docs/stable/tensorboard.html#torch.utils.tensorboard.writer.SummaryWriter.add_scalar>`_ um die Metriken zu loggen.
135 - Rufen Sie zum Schluß `self._reset_metrics()` auf, um die Metriken zurückzusetzen.
136 """
137 loss = self.metrics["total_loss"] / self.metrics["total_samples"]
138 accuracy = self.metrics["total_correct"] / self.metrics["total_samples"]
139
140 tag = "train" if train else "validation"
141 self.writer.add_scalar(f"{tag}/loss", loss, step)
142 self.writer.add_scalar(f"{tag}/accuracy", accuracy, step)
143
144 self._reset_metrics()
145
146 def log_sample_statistics(self, train, step):
147 """Loggt die am schlechtesten klassifizierten Samples in TensorBoard.
148
149 Parameters:
150 -----------
151 train (bool):
152 Gibt an, ob die Samples aus dem Trainings- oder Validierungsdatensatz stammen.
153
154 step (int):
155 Der aktuelle Schritt, der für das Logging verwendet wird.
156
157 **TODO**:
158 Logge die am schlechtesten klassifizierten Samples für jede Klasse in TensorBoard.
159 Die Samples sollten in einem Grid-Format geloggt werden, wobei jede Klasse in einem eigenen Tag gespeichert wird.
160
161 - Iteriere über die Klassen-IDs (0-9) und logge die Samples für jede Klasse.
162 - Verwende `self.sample_statistics[cls_id]["samples"]` um die Samples für die Klasse `cls_id` zu erhalten.
163 - Verwende `torchvision.utils.make_grid() <https://docs.pytorch.org/vision/stable/generated/torchvision.utils.make_grid.html>`_ um die Samples in einem Grid zu formatieren
164 - Übergeben Sie `normalize=True` um die Samples zu normalisieren.
165 - Logge die Samples mit `self.writer.add_image() <https://docs.pytorch.org/docs/stable//tensorboard.html#torch.utils.tensorboard.writer.SummaryWriter.add_image>`_ unter dem Tag `f"{tag}/worst_samples/class_{cls_id}"`, wobei `tag` entweder "train" oder "validation" ist.
166 - Rufen Sie ganz zum Schluß `self._reset_samples_statistics()` auf, um die Statistik der Samples zurückzusetzen, nachdem die Samples geloggt wurden.
167 """
168 # Logge die schlechtesten Samples, wenn die Epoche abgeschlossen ist
169 tag = "train" if train else "validation"
170
171 # Iteriere über die Klassen-IDs (0-9) und logge die Samples für jede Klasse
172 for cls_id in range(10):
173 # Erstelle ein Grid aus den Samples der Klasse
174 grid = torchvision.utils.make_grid(
175 self.sample_statistics[cls_id]["samples"],
176 normalize=True,
177 )
178
179 # Logge das Grid der Samples in TensorBoard
180 self.writer.add_image(
181 f"{tag}/worst_samples/class_{cls_id}",
182 grid,
183 global_step=step,
184 )
185
186 # Setze die Statistik der Samples zurück
187 self._reset_samples_statistics()
188
189 def update_sample_statistics(self, batch, labels, loss):
190 """Aggregiere die am schlechtesten klassifizierten Samples für jede Klasse.
191
192 Parameters:
193 -----------
194 batch (torch.Tensor):
195 Der Batch von Eingabedaten.
196
197 labels (torch.Tensor):
198 Die zugehörigen Labels für den Batch.
199
200 loss (torch.Tensor):
201 Der Verlust für den Batch, berechnet mit `nn.CrossEntropyLoss(reduction="none")`.
202
203 Updates:
204 --------
205 self.sample_statistics (dict):
206 Diese Variable speichert die Samples und deren zugehörigen Verlusten für jede Klasse.
207
208 self.sample_statistics[cls_id]["samples"] (torch.Tensor):
209 Enthält die bisher schwierigsten Samples für die Klasse `cls_id`.
210
211 self.sample_statistics[cls_id]["loss"] (torch.Tensor):
212 Enthält die bisher größten Loss-Werte für die Klasse `cls_id`.
213
214 Diese Methode aggregiert die am schlechtesten klassifizierten Samples für jede Klasse.
215 Die Samples werden in `self.sample_statistics` gespeichert, die für jede Klasse eine Liste von Samples
216 und deren zugehörigen Verlusten enthält.
217 Die Methode iteriert über die Klassen-IDs (0-9) und speichert die 64 Samples mit dem höchsten Verlust
218 für jede Klasse.
219
220 **TODO**:
221 - Iterieren Sie über die Klassen-IDs (0-9) und speichern Sie die 64 Samples mit dem höchsten Verlust für jede Klasse.
222 - Verwenden Sie `torch.cat() <https://docs.pytorch.org/docs/stable/generated/torch.cat.html>`_ um die Samples und Verluste für jede Klasse zu aggregieren.
223 Verwenden Sie `torch.clone() <https://docs.pytorch.org/docs/stable/generated/torch.clone.html>`_ `.detach() <https://docs.pytorch.org/docs/stable/generated/torch.Tensor.detach.html>`_ um sicherzustellen, dass die Samples und Verluste nicht mehr mit der
224 Gradientenberechnung von AutoGrad verbunden sind.
225 - Sortieren Sie die Samples nach Verlust in absteigender Reihenfolge und behalten Sie nur die 64 schlechtesten Samples.
226 Verwenden Sie `torch.argsort() <https://docs.pytorch.org/docs/stable/generated/torch.argsort.html>`_ um die Indizes der Samples nach Verlust zu sortieren.
227 - Aktualisieren Sie `self.sample_statistics` für jede Klasse mit den aggregierten Samples und Verlusten.
228 """
229 # Iteriert über die Klassen-IDs (0-9) und speichert die schlechtesten Samples
230 for cls_id in range(10):
231 # Filtere die Samples für die aktuelle Klasse
232 ids = labels == cls_id
233
234 # Konkatenieren der Samples und Verluste für die aktuelle Klasse
235 self.sample_statistics[cls_id]["samples"] = torch.cat(
236 [
237 self.sample_statistics[cls_id]["samples"],
238 batch[ids].clone().detach(),
239 ]
240 )
241 self.sample_statistics[cls_id]["loss"] = torch.cat(
242 [
243 self.sample_statistics[cls_id]["loss"],
244 loss[ids].clone().detach(),
245 ]
246 )
247
248 # Sortiere die Samples nach Verlust in absteigender Reihenfolge
249 sorted_indices = torch.argsort(
250 self.sample_statistics[cls_id]["loss"], descending=True
251 )
252
253 # Behalte nur die 64 schlechtesten Samples
254 sorted_indices = sorted_indices[:64]
255
256 # Aktualisiere die Samples und Verluste für die aktuelle Klasse
257 self.sample_statistics[cls_id]["samples"] = self.sample_statistics[cls_id][
258 "samples"
259 ][sorted_indices]
260
261 self.sample_statistics[cls_id]["loss"] = self.sample_statistics[cls_id][
262 "loss"
263 ][sorted_indices]
264
265
266if __name__ == "__main__":
267 training_set, validation_set = load_data()
268
269 # Initialisierung des Modells, Loss-Kriteriums und Optimierers
270 model = CNNNetwork().to(DEVICE)
271 criterion = nn.CrossEntropyLoss(reduction="none")
272 optimizer = torch.optim.Adam(
273 model.parameters(), lr=LR
274 ) # Checkpoint laden, falls vorhanden
275
276 # Checkpoint laden, falls vorhanden
277 dirname = os.path.dirname(os.path.abspath(__file__))
278 chkpt_path = os.path.join(dirname, "checkpoint.pth")
279
280 ep = load_checkpoint(model, optimizer, chkpt_path)
281 if ep > 0:
282 print(f"Checkpoint geladen, fortsetzen bei Epoche {ep}.")
283
284 # Das Modell trainieren
285 logger = TensorBoardLogger()
286
287 # Logge den Graphen des Modells
288 input_tensor = torch.randn(1, 3, 32, 32).to(DEVICE) # Beispiel-Eingabetensor
289 logger.log_graph(model, input_tensor)
290
291 for n in range(ep, ep + 30):
292 epoch(model, n, True, training_set, criterion, optimizer, logger=logger)
293 epoch(model, n, False, validation_set, criterion, optimizer, logger=logger)
294
295 # Checkpoint nach jeder Epoche speichern
296 save_checkpoint(model, optimizer, n + 1, chkpt_path)