1import os
2import torch
3import torch.nn as nn
4from sklearn.manifold import TSNE
5
6from scipy.linalg import orthogonal_procrustes
7
8import pandas as pd
9import seaborn as sns
10import matplotlib.pyplot as plt
11import io
12import PIL.Image
13import numpy as np
14
15from tqdm import tqdm
16from misc import (
17 DEVICE,
18 load_data,
19 epoch,
20 load_checkpoint,
21 TensorBoardLogger,
22 save_checkpoint,
23 LR,
24 ResNet,
25)
26
27
28class EmbeddingLogger(TensorBoardLogger):
29 def __init__(self, validation_set):
30 super().__init__()
31 self.validation_set = validation_set
32 self.previous_embeddings_2d = None
33
34 self.frames = []
35 self.step = 1
36
37 def calculate_embeddings(self, model):
38 """Berechnet alle Embeddings für die Daten im Dataloader.
39
40 Parameters:
41 -----------
42 model (nn.Module):
43 Das Modell, das die Embeddings berechnet.
44
45 Returns:
46 --------
47 embeddings (np.ndarray):
48 Die berechneten Embeddings als NumPy-Array.
49
50 labels (np.ndarray):
51 Die zugehörigen Labels als NumPy-Array.
52
53 **TODO**:
54
55 - Setzen Sie das Modell in den Evaluationsmodus (`model.eval() <https://docs.pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.eval>`_), um sicherzustellen, dass die Batch-Normalisierung deaktiviert ist.
56
57 - Erstellen Sie leere Listen für `embeddings` und `labels`, um die Ergebnisse zu speichern.
58
59 - Verwenden Sie `torch.no_grad() <https://docs.pytorch.org/docs/stable/generated/torch.no_grad.html>`_, um den Gradientenfluss zu deaktivieren, da wir nur die Embeddings berechnen und nicht trainieren.
60
61 - Iterieren Sie über `self.validation_set` und berechnen Sie die Embeddings für jedes Batch indem Sie die Eingaben auf das Gerät (`DEVICE`) verschieben und das Modell aufrufen.
62
63 - Das Modell liefert ein Tupel zurück, wobei der zweite Wert die Embeddings sind.
64
65 - Verschieben Sie die Embeddings und Labels auf die CPU (rufen Sie `tensor.cpu() <https://docs.pytorch.org/docs/stable/generated/torch.Tensor.cpu.html>`_ auf ) und speichern Sie sie in den Listen `embeddings` und `labels`.
66
67 - Konvertieren Sie die Listen `embeddings` und `labels` in NumPy-Arrays, indem Sie `torch.cat(embeddings, dim=0) <https://docs.pytorch.org/docs/stable/generated/torch.cat.html>`_ `.numpy() <https://docs.pytorch.org/docs/stable/generated/torch.Tensor.numpy.html>`_ und `torch.cat(labels, dim=0) <https://docs.pytorch.org/docs/stable/generated/torch.cat.html>`_ `.numpy() <https://docs.pytorch.org/docs/stable/generated/torch.Tensor.numpy.html>`_ verwenden.
68
69 - Setzen Sie das Modell wieder in den Trainingsmodus (`model.train() <https://docs.pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.train>`_), um sicherzustellen, dass es für zukünftige Trainingsschritte bereit ist.
70
71 - Geben Sie die berechneten Embeddings und Labels zurück.
72 """
73 model.eval()
74 embeddings = []
75 labels = []
76 bar = tqdm(self.validation_set, desc="Berechne Embeddings")
77 with torch.no_grad():
78 for inputs, l in bar:
79 inputs = inputs.to(DEVICE)
80 _, emb = model(inputs)
81 embeddings.append(emb.cpu())
82 labels.append(l.cpu())
83
84 bar.close()
85
86 model.train()
87
88 return torch.cat(embeddings, dim=0).numpy(), torch.cat(labels, dim=0).numpy()
89
90 def calculate_tsne(self, embeddings, previous_embeddings_2d=None):
91 """Berechnet das t-SNE-Modell für die Embeddings.
92
93 Parameters:
94 -----------
95 embeddings (np.ndarray):
96 Die Embeddings, die in 2D projiziert werden sollen.
97
98 Returns:
99 --------
100 embeddings_2d (np.ndarray):
101 Die 2D-Projektion der Embeddings.
102
103 **TODO**:
104
105 - Verwenden Sie `sklearn.manifold.TSNE <https://scikit-learn.org/stable/modules/generated/sklearn.manifold.TSNE.html>`_ um die Embeddings in 2D zu projizieren.
106 Setzen Sie `n_components=2` und verwenden Sie `init="pca"` für die Initialisierung.
107 Wenn `self.previous_embeddings_2d` nicht `None` ist, verwenden Sie stattdessen diese als Initialisierung.
108
109 - Konvertieren Sie die 2D-Embeddings in ein `NumPy-Array <https://numpy.org/doc/stable/reference/generated/numpy.array.html>`_ mit `dtype=np.float32`.
110
111 - Normalisieren Sie die 2D-Embeddings, indem Sie den Mittelwert und die Standardabweichung berechnen und
112 die Embeddings so transformieren, dass sie einen Mittelwert von 0 und eine Standardabweichung von 1 haben.
113
114 - Verwenden Sie `np.mean(embeddings_2d, axis=0, keepdims=True) <https://numpy.org/doc/2.2/reference/generated/numpy.mean.html>`_ für den Mittelwert und `np.std(embeddings_2d, axis=0, keepdims=True) <https://numpy.org/doc/stable/reference/generated/numpy.std.html>`_ für die Standardabweichung.
115
116 - Normalisieren Sie die Embeddings mit `(embeddings_2d - m) / s`, wobei `m` der Mittelwert und `s` die Standardabweichung ist.
117
118 - Geben Sie die normalisierten 2D-Embeddings zurück.
119 """
120 if previous_embeddings_2d is not None:
121 tsne_model = TSNE(n_components=2, init=previous_embeddings_2d)
122 else:
123 tsne_model = TSNE(n_components=2, init="pca")
124
125 embeddings_2d = tsne_model.fit_transform(embeddings)
126
127 embeddings_2d = np.array(embeddings_2d, dtype=np.float32)
128 m = np.mean(embeddings_2d, axis=0, keepdims=True) # Normalize to zero mean
129 s = np.std(embeddings_2d, axis=0, keepdims=True) # Normalize to unit variance
130 embeddings_2d = (embeddings_2d - m) / s # Normalize the embeddings
131
132 return embeddings_2d
133
134 def register_embeddings_2d(self, embeddings_2d, previous_embeddings_2d=None):
135 """Registriert die 2D-Embeddings, um sie mit den vorherigen Embeddings zu vergleichen.
136
137 Parameters:
138 -----------
139 embeddings_2d (np.ndarray):
140 Die 2D-Embeddings, die registriert werden sollen.
141
142 previous_embeddings_2d (np.ndarray, optional):
143 Die vorherigen 2D-Embeddings, die für die Registrierung verwendet werden sollen. Standardmäßig None.
144
145 Returns:
146 --------
147 embeddings_2d (np.ndarray):
148 Die registrierten 2D-Embeddings.
149
150 **TODO**:
151
152 - Wenn `previous_embeddings_2d` nicht `None` ist, verwenden Sie `scipy.linalg.orthogonal_procrustes <https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.orthogonal_procrustes.html>`_ um die 2D-Embeddings zu registrieren.
153 Dies hilft, die Embeddings so zu transformieren, dass sie mit den vorherigen bestmöglich Embeddings übereinstimmen.
154
155 - Die Funktion liefert die orthogonale Rotationsmatrix `R` und die Skala `s`, aber wir verwenden nur `R`, um die 2D-Embeddings zu transformieren.
156
157 - Transformieren Sie die 2D-Embeddings mit `embeddings_2d @ R`, um sie an die vorherigen Embeddings anzupassen.
158
159 - Geben Sie die transformierten 2D-Embeddings zurück.
160 """
161 if previous_embeddings_2d is not None:
162 R, _ = orthogonal_procrustes(embeddings_2d, previous_embeddings_2d)
163 embeddings_2d = embeddings_2d @ R
164
165 return embeddings_2d
166
167 def visualize_embeddings(self, embeddings_2d, labels, step, axs):
168 """Visualisiert die 2D-Embeddings mit t-SNE und speichert das Bild.
169
170 Parameters:
171 -----------
172 embeddings_2d (np.ndarray):
173 Die 2D-Embeddings, die visualisiert werden sollen.
174
175 labels (np.ndarray):
176 Die zugehörigen Labels für die Embeddings.
177
178 step (int):
179 Der aktuelle Schritt oder die Epoche, die für den Titel des Plots verwendet wird.
180
181 axs (matplotlib.axes.Axes):
182 Die Achsen, auf denen die Embeddings visualisiert werden sollen.
183
184 **TODO**:
185
186 - Erstellen Sie mit Pandas ein DataFrame mit den 2D-Embeddings und den zugehörigen Labels,
187 um die Daten für die Visualisierung vorzubereiten.
188
189 - Verwenden Sie `seaborn.scatterplot <https://seaborn.pydata.org/generated/seaborn.scatterplot.html>`_ um die 2D-Embeddings zu visualisieren.
190
191 - Setzen Sie die Achsenlimits auf (-3.0, 3.0) für beide Achsen, um eine konsistente Darstellung zu gewährleisten.
192
193 - Entfernen Sie die Legende (`axs.get_legend().remove()`), um den Plot übersichtlicher zu gestalten.
194
195 - Setzen Sie den Titel des Plots sinnvoll.
196 """
197 df = pd.DataFrame(
198 {"x": embeddings_2d[:, 0], "y": embeddings_2d[:, 1], "label": labels}
199 )
200
201 sns.scatterplot(data=df, x="x", y="y", hue="label", palette="muted", ax=axs)
202 axs.set_xlim(-3.0, 3.0)
203 axs.set_ylim(-3.0, 3.0)
204 axs.get_legend().remove()
205 axs.set_title(f"t-SNE Embedding Projection - Step {step}")
206
207 def append_frame(self, image):
208 """Fügt ein Bild zu den Frames hinzu, die später als GIF gespeichert werden."""
209 self.writer.add_image(
210 f"embeddings", np.array(image), global_step=self.step, dataformats="HWC"
211 )
212
213 self.frames.append(image)
214
215 dirname = os.path.dirname(os.path.abspath(__file__))
216 image_path = os.path.join(dirname, "images")
217 os.makedirs(image_path, exist_ok=True)
218
219 self.frames[0].save(
220 os.path.join(image_path, "animation.gif"),
221 save_all=True,
222 append_images=self.frames[1:],
223 duration=300,
224 loop=0,
225 )
226
227 image.save(os.path.join(image_path, f"embeddings_{self.step}.png"))
228
229 def log_embeddings(self, model):
230 embeddings, labels = self.calculate_embeddings(model)
231 embeddings_2d = self.calculate_tsne(embeddings, self.previous_embeddings_2d)
232 embeddings_2d = self.register_embeddings_2d(
233 embeddings_2d, self.previous_embeddings_2d
234 )
235 self.previous_embeddings_2d = embeddings_2d
236
237 fig = plt.figure(figsize=(8, 6))
238 axs = fig.add_subplot(1, 1, 1)
239 image = self.visualize_embeddings(embeddings_2d, labels, self.step, axs)
240 fig.tight_layout()
241
242 buf = io.BytesIO()
243 fig.savefig(buf, format="png")
244 buf.seek(0)
245
246 image = PIL.Image.open(buf)
247
248 self.append_frame(image)
249 self.step += 1
250
251
252if __name__ == "__main__":
253 training_set, validation_set = load_data()
254
255 # Initialisierung des Modells, Loss-Kriteriums und Optimierers
256 model = ResNet().to(DEVICE)
257 criterion = nn.CrossEntropyLoss(reduction="none")
258 optimizer = torch.optim.Adam(
259 model.parameters(), lr=LR
260 ) # Checkpoint laden, falls vorhanden
261
262 # Checkpoint laden, falls vorhanden
263 dirname = os.path.dirname(os.path.abspath(__file__))
264 chkpt_path = os.path.join(dirname, "checkpoint.pth")
265
266 ep = load_checkpoint(model, optimizer, chkpt_path)
267 if ep > 0:
268 print(f"Checkpoint geladen, fortsetzen bei Epoche {ep}.")
269
270 # Das Modell trainieren
271 logger = EmbeddingLogger(validation_set)
272
273 # Logge den Graphen des Modells
274 input_tensor = torch.randn(1, 3, 32, 32).to(DEVICE) # Beispiel-Eingabetensor
275 logger.log_graph(model, input_tensor)
276
277 umap_model = None
278 for n in range(ep, ep + 200):
279 log_after = 100000
280 if n == 0:
281 log_after = 5000
282 if n == 1:
283 log_after = 10000
284 if n == 2:
285 log_after = 50000
286
287 epoch(
288 model,
289 n,
290 True,
291 training_set,
292 criterion,
293 optimizer,
294 logger=logger,
295 log_after_n_samples=log_after,
296 )
297 epoch(model, n, False, validation_set, criterion, optimizer, logger=logger)
298
299 # Checkpoint nach jeder Epoche speichern
300 save_checkpoint(model, optimizer, n + 1, chkpt_path)