Musterlösung für die Embeddings

  1import os
  2import torch
  3import torch.nn as nn
  4from sklearn.manifold import TSNE
  5
  6from scipy.linalg import orthogonal_procrustes
  7
  8import pandas as pd
  9import seaborn as sns
 10import matplotlib.pyplot as plt
 11import io
 12import PIL.Image
 13import numpy as np
 14
 15from tqdm import tqdm
 16from misc import (
 17    DEVICE,
 18    load_data,
 19    epoch,
 20    load_checkpoint,
 21    TensorBoardLogger,
 22    save_checkpoint,
 23    LR,
 24    ResNet,
 25)
 26
 27
 28class EmbeddingLogger(TensorBoardLogger):
 29    def __init__(self, validation_set):
 30        super().__init__()
 31        self.validation_set = validation_set
 32        self.previous_embeddings_2d = None
 33
 34        self.frames = []
 35        self.step = 1
 36
 37    def calculate_embeddings(self, model):
 38        """Berechnet alle Embeddings für die Daten im Dataloader.
 39
 40        Parameters:
 41        -----------
 42        model (nn.Module):
 43            Das Modell, das die Embeddings berechnet.
 44
 45        Returns:
 46        --------
 47        embeddings (np.ndarray):
 48            Die berechneten Embeddings als NumPy-Array.
 49
 50        labels (np.ndarray):
 51            Die zugehörigen Labels als NumPy-Array.
 52
 53        **TODO**:
 54
 55        -  Setzen Sie das Modell in den Evaluationsmodus (`model.eval() <https://docs.pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.eval>`_), um sicherzustellen, dass die Batch-Normalisierung deaktiviert ist.
 56
 57        -  Erstellen Sie leere Listen für `embeddings` und `labels`, um die Ergebnisse zu speichern.
 58
 59        -  Verwenden Sie `torch.no_grad() <https://docs.pytorch.org/docs/stable/generated/torch.no_grad.html>`_, um den Gradientenfluss zu deaktivieren, da wir nur die Embeddings berechnen und nicht trainieren.
 60
 61        -  Iterieren Sie über `self.validation_set` und berechnen Sie die Embeddings für jedes Batch indem Sie die Eingaben auf das Gerät (`DEVICE`) verschieben und das Modell aufrufen.
 62
 63        -  Das Modell liefert ein Tupel zurück, wobei der zweite Wert die Embeddings sind.
 64
 65        -  Verschieben Sie die Embeddings und Labels auf die CPU (rufen Sie `tensor.cpu() <https://docs.pytorch.org/docs/stable/generated/torch.Tensor.cpu.html>`_ auf ) und speichern Sie sie in den Listen `embeddings` und `labels`.
 66
 67        -  Konvertieren Sie die Listen `embeddings` und `labels` in NumPy-Arrays, indem Sie `torch.cat(embeddings, dim=0) <https://docs.pytorch.org/docs/stable/generated/torch.cat.html>`_ `.numpy() <https://docs.pytorch.org/docs/stable/generated/torch.Tensor.numpy.html>`_ und `torch.cat(labels, dim=0) <https://docs.pytorch.org/docs/stable/generated/torch.cat.html>`_ `.numpy() <https://docs.pytorch.org/docs/stable/generated/torch.Tensor.numpy.html>`_ verwenden.
 68
 69        -  Setzen Sie das Modell wieder in den Trainingsmodus (`model.train() <https://docs.pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.train>`_), um sicherzustellen, dass es für zukünftige Trainingsschritte bereit ist.
 70
 71        -  Geben Sie die berechneten Embeddings und Labels zurück.
 72        """
 73        model.eval()
 74        embeddings = []
 75        labels = []
 76        bar = tqdm(self.validation_set, desc="Berechne Embeddings")
 77        with torch.no_grad():
 78            for inputs, l in bar:
 79                inputs = inputs.to(DEVICE)
 80                _, emb = model(inputs)
 81                embeddings.append(emb.cpu())
 82                labels.append(l.cpu())
 83
 84        bar.close()
 85
 86        model.train()
 87
 88        return torch.cat(embeddings, dim=0).numpy(), torch.cat(labels, dim=0).numpy()
 89
 90    def calculate_tsne(self, embeddings, previous_embeddings_2d=None):
 91        """Berechnet das t-SNE-Modell für die Embeddings.
 92
 93        Parameters:
 94        -----------
 95        embeddings (np.ndarray):
 96            Die Embeddings, die in 2D projiziert werden sollen.
 97
 98        Returns:
 99        --------
100        embeddings_2d (np.ndarray):
101            Die 2D-Projektion der Embeddings.
102
103        **TODO**:
104
105        -  Verwenden Sie `sklearn.manifold.TSNE <https://scikit-learn.org/stable/modules/generated/sklearn.manifold.TSNE.html>`_ um die Embeddings in 2D zu projizieren.
106           Setzen Sie `n_components=2` und verwenden Sie `init="pca"` für die Initialisierung.
107           Wenn `self.previous_embeddings_2d` nicht `None` ist, verwenden Sie stattdessen diese als Initialisierung.
108
109        -  Konvertieren Sie die 2D-Embeddings in ein `NumPy-Array <https://numpy.org/doc/stable/reference/generated/numpy.array.html>`_ mit `dtype=np.float32`.
110
111        -  Normalisieren Sie die 2D-Embeddings, indem Sie den Mittelwert und die Standardabweichung berechnen und
112           die Embeddings so transformieren, dass sie einen Mittelwert von 0 und eine Standardabweichung von 1 haben.
113
114        -  Verwenden Sie `np.mean(embeddings_2d, axis=0, keepdims=True) <https://numpy.org/doc/2.2/reference/generated/numpy.mean.html>`_ für den Mittelwert und `np.std(embeddings_2d, axis=0, keepdims=True) <https://numpy.org/doc/stable/reference/generated/numpy.std.html>`_ für die Standardabweichung.
115
116        -  Normalisieren Sie die Embeddings mit `(embeddings_2d - m) / s`, wobei `m` der Mittelwert und `s` die Standardabweichung ist.
117
118        -  Geben Sie die normalisierten 2D-Embeddings zurück.
119        """
120        if previous_embeddings_2d is not None:
121            tsne_model = TSNE(n_components=2, init=previous_embeddings_2d)
122        else:
123            tsne_model = TSNE(n_components=2, init="pca")
124
125        embeddings_2d = tsne_model.fit_transform(embeddings)
126
127        embeddings_2d = np.array(embeddings_2d, dtype=np.float32)
128        m = np.mean(embeddings_2d, axis=0, keepdims=True)  # Normalize to zero mean
129        s = np.std(embeddings_2d, axis=0, keepdims=True)  # Normalize to unit variance
130        embeddings_2d = (embeddings_2d - m) / s  # Normalize the embeddings
131
132        return embeddings_2d
133
134    def register_embeddings_2d(self, embeddings_2d, previous_embeddings_2d=None):
135        """Registriert die 2D-Embeddings, um sie mit den vorherigen Embeddings zu vergleichen.
136
137        Parameters:
138        -----------
139        embeddings_2d (np.ndarray):
140            Die 2D-Embeddings, die registriert werden sollen.
141
142        previous_embeddings_2d (np.ndarray, optional):
143              Die vorherigen 2D-Embeddings, die für die Registrierung verwendet werden sollen. Standardmäßig None.
144
145        Returns:
146        --------
147        embeddings_2d (np.ndarray):
148            Die registrierten 2D-Embeddings.
149
150        **TODO**:
151
152        - Wenn `previous_embeddings_2d` nicht `None` ist, verwenden Sie `scipy.linalg.orthogonal_procrustes <https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.orthogonal_procrustes.html>`_ um die 2D-Embeddings zu registrieren.
153          Dies hilft, die Embeddings so zu transformieren, dass sie mit den vorherigen bestmöglich Embeddings übereinstimmen.
154
155        - Die Funktion liefert die orthogonale Rotationsmatrix `R` und die Skala `s`, aber wir verwenden nur `R`, um die 2D-Embeddings zu transformieren.
156
157        - Transformieren Sie die 2D-Embeddings mit `embeddings_2d @ R`, um sie an die vorherigen Embeddings anzupassen.
158
159        - Geben Sie die transformierten 2D-Embeddings zurück.
160        """
161        if previous_embeddings_2d is not None:
162            R, _ = orthogonal_procrustes(embeddings_2d, previous_embeddings_2d)
163            embeddings_2d = embeddings_2d @ R
164
165        return embeddings_2d
166
167    def visualize_embeddings(self, embeddings_2d, labels, step, axs):
168        """Visualisiert die 2D-Embeddings mit t-SNE und speichert das Bild.
169
170        Parameters:
171        -----------
172        embeddings_2d (np.ndarray):
173            Die 2D-Embeddings, die visualisiert werden sollen.
174
175        labels (np.ndarray):
176            Die zugehörigen Labels für die Embeddings.
177
178        step (int):
179            Der aktuelle Schritt oder die Epoche, die für den Titel des Plots verwendet wird.
180
181        axs (matplotlib.axes.Axes):
182            Die Achsen, auf denen die Embeddings visualisiert werden sollen.
183
184        **TODO**:
185
186        - Erstellen Sie mit Pandas ein DataFrame mit den 2D-Embeddings und den zugehörigen Labels,
187          um die Daten für die Visualisierung vorzubereiten.
188
189        - Verwenden Sie `seaborn.scatterplot <https://seaborn.pydata.org/generated/seaborn.scatterplot.html>`_ um die 2D-Embeddings zu visualisieren.
190
191        - Setzen Sie die Achsenlimits auf (-3.0, 3.0) für beide Achsen, um eine konsistente Darstellung zu gewährleisten.
192
193        - Entfernen Sie die Legende (`axs.get_legend().remove()`), um den Plot übersichtlicher zu gestalten.
194
195        - Setzen Sie den Titel des Plots sinnvoll.
196        """
197        df = pd.DataFrame(
198            {"x": embeddings_2d[:, 0], "y": embeddings_2d[:, 1], "label": labels}
199        )
200
201        sns.scatterplot(data=df, x="x", y="y", hue="label", palette="muted", ax=axs)
202        axs.set_xlim(-3.0, 3.0)
203        axs.set_ylim(-3.0, 3.0)
204        axs.get_legend().remove()
205        axs.set_title(f"t-SNE Embedding Projection - Step {step}")
206
207    def append_frame(self, image):
208        """Fügt ein Bild zu den Frames hinzu, die später als GIF gespeichert werden."""
209        self.writer.add_image(
210            f"embeddings", np.array(image), global_step=self.step, dataformats="HWC"
211        )
212
213        self.frames.append(image)
214
215        dirname = os.path.dirname(os.path.abspath(__file__))
216        image_path = os.path.join(dirname, "images")
217        os.makedirs(image_path, exist_ok=True)
218
219        self.frames[0].save(
220            os.path.join(image_path, "animation.gif"),
221            save_all=True,
222            append_images=self.frames[1:],
223            duration=300,
224            loop=0,
225        )
226
227        image.save(os.path.join(image_path, f"embeddings_{self.step}.png"))
228
229    def log_embeddings(self, model):
230        embeddings, labels = self.calculate_embeddings(model)
231        embeddings_2d = self.calculate_tsne(embeddings, self.previous_embeddings_2d)
232        embeddings_2d = self.register_embeddings_2d(
233            embeddings_2d, self.previous_embeddings_2d
234        )
235        self.previous_embeddings_2d = embeddings_2d
236
237        fig = plt.figure(figsize=(8, 6))
238        axs = fig.add_subplot(1, 1, 1)
239        image = self.visualize_embeddings(embeddings_2d, labels, self.step, axs)
240        fig.tight_layout()
241
242        buf = io.BytesIO()
243        fig.savefig(buf, format="png")
244        buf.seek(0)
245
246        image = PIL.Image.open(buf)
247
248        self.append_frame(image)
249        self.step += 1
250
251
252if __name__ == "__main__":
253    training_set, validation_set = load_data()
254
255    # Initialisierung des Modells, Loss-Kriteriums und Optimierers
256    model = ResNet().to(DEVICE)
257    criterion = nn.CrossEntropyLoss(reduction="none")
258    optimizer = torch.optim.Adam(
259        model.parameters(), lr=LR
260    )  # Checkpoint laden, falls vorhanden
261
262    # Checkpoint laden, falls vorhanden
263    dirname = os.path.dirname(os.path.abspath(__file__))
264    chkpt_path = os.path.join(dirname, "checkpoint.pth")
265
266    ep = load_checkpoint(model, optimizer, chkpt_path)
267    if ep > 0:
268        print(f"Checkpoint geladen, fortsetzen bei Epoche {ep}.")
269
270    # Das Modell trainieren
271    logger = EmbeddingLogger(validation_set)
272
273    # Logge den Graphen des Modells
274    input_tensor = torch.randn(1, 3, 32, 32).to(DEVICE)  # Beispiel-Eingabetensor
275    logger.log_graph(model, input_tensor)
276
277    umap_model = None
278    for n in range(ep, ep + 200):
279        log_after = 100000
280        if n == 0:
281            log_after = 5000
282        if n == 1:
283            log_after = 10000
284        if n == 2:
285            log_after = 50000
286
287        epoch(
288            model,
289            n,
290            True,
291            training_set,
292            criterion,
293            optimizer,
294            logger=logger,
295            log_after_n_samples=log_after,
296        )
297        epoch(model, n, False, validation_set, criterion, optimizer, logger=logger)
298
299        # Checkpoint nach jeder Epoche speichern
300        save_checkpoint(model, optimizer, n + 1, chkpt_path)