Checkpoints - Musterlösung

 1import os
 2import torch
 3from torch import nn
 4from misc import DEVICE, CNNNetwork, load_data, epoch
 5
 6LR = 0.001  # Lernrate
 7
 8
 9def save_checkpoint(model, optimizer, epoch, filename="checkpoint.pth"):
10    """Speichert den aktuellen Zustand des Modells und des Optimierers in einer Datei.
11
12    Parameters:
13    -----------
14    model (nn.Module):
15        Das zu speichernde Modell.
16
17    optimizer (torch.optim.Optimizer):
18        Der Optimierer, dessen Zustand gespeichert werden soll.
19
20    epoch (int):
21        Die aktuelle Epoche, die im Checkpoint gespeichert wird.
22
23    filename (str):
24        Der Name der Datei, in der der Checkpoint gespeichert wird.
25
26    **TODO**:
27    Erzeuge ein Dictionary, das den Zustand des Modells, des Optimierers und die aktuelle Epoche enthält.
28    Den Zustand der Modells und des Optimierers kannst du mit `model.state_dict()` und `optimizer.state_dict()` erhalten.
29    Speichere dieses Dictionary mit `torch.save()` unter dem angegebenen Dateinamen.
30    """
31    torch.save(
32        {
33            "epoch": epoch,
34            "model_state_dict": model.state_dict(),
35            "optimizer_state_dict": optimizer.state_dict(),
36        },
37        filename,
38    )
39
40
41def load_checkpoint(model, optimizer, filename="checkpoint.pth"):
42    """Lädt den Zustand des Modells und des Optimierers aus einer Datei.
43
44    Parameters:
45    -----------
46    model (nn.Module):
47        Das Modell, in das die gespeicherten Zustände geladen werden.
48
49    optimizer (torch.optim.Optimizer):
50        Der Optimierer, dessen Zustand geladen wird.
51
52    filename (str):
53        Der Name der Datei, aus der der Checkpoint geladen wird.
54
55    **TODO**:
56    Versuche, den Checkpoint mit `torch.load()` zu laden.
57    Wenn die Datei nicht gefunden wird, gib eine entsprechende Fehlermeldung aus und starte ohne gespeicherten Zustand.
58    Wenn der Checkpoint geladen wird, versuche, den Zustand des Modells und des Optimizers zu laden.
59    Du kannst `model.load_state_dict()` und `optimizer.load_state_dict()` verwenden um die Zustände ins Modell zu laden.
60    Wenn ein Fehler beim Laden auftritt, gib eine Fehlermeldung aus und starte ohne gespeicherten Zustand.
61    Gibt die aktuelle Epoche zurück, die im Checkpoint gespeichert ist.
62    """
63    try:
64        checkpoint = torch.load(filename, weights_only=True)
65        model.load_state_dict(checkpoint["model_state_dict"])
66        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
67        return checkpoint["epoch"]
68    except Exception as e:
69        print(f"Fehler beim Laden des Checkpoints {filename}: {e}")
70        print("Starte ohne gespeicherten Zustand.")
71        return 0
72
73
74if __name__ == "__main__":
75    training_set, validation_set = load_data()
76
77    # Initialisierung des Modells, Loss-Kriteriums und Optimierers
78    model = CNNNetwork().to(DEVICE)
79    criterion = nn.CrossEntropyLoss()
80    optimizer = torch.optim.Adam(
81        model.parameters(), lr=LR
82    )  # Checkpoint laden, falls vorhanden
83
84    # Checkpoint laden, falls vorhanden
85    dirname = os.path.dirname(os.path.abspath(__file__))
86    chkpt_path = os.path.join(dirname, "checkpoint.pth")
87
88    ep = load_checkpoint(model, optimizer, chkpt_path)
89    if ep > 0:
90        print(f"Checkpoint geladen, fortsetzen bei Epoche {ep}.")
91
92    # Das Modell trainieren
93    for n in range(ep, ep + 30):
94        epoch(model, n, True, training_set, criterion, optimizer)
95        epoch(model, n, False, validation_set, criterion, optimizer)
96
97        # Checkpoint nach jeder Epoche speichern
98        save_checkpoint(model, optimizer, n + 1, chkpt_path)