1import os
2import torch
3from torch import nn
4from misc import DEVICE, CNNNetwork, load_data, epoch
5
6LR = 0.001 # Lernrate
7
8
9def save_checkpoint(model, optimizer, epoch, filename="checkpoint.pth"):
10 """Speichert den aktuellen Zustand des Modells und des Optimierers in einer Datei.
11
12 Parameters:
13 -----------
14 model (nn.Module):
15 Das zu speichernde Modell.
16
17 optimizer (torch.optim.Optimizer):
18 Der Optimierer, dessen Zustand gespeichert werden soll.
19
20 epoch (int):
21 Die aktuelle Epoche, die im Checkpoint gespeichert wird.
22
23 filename (str):
24 Der Name der Datei, in der der Checkpoint gespeichert wird.
25
26 **TODO**:
27 Erzeuge ein Dictionary, das den Zustand des Modells, des Optimierers und die aktuelle Epoche enthält.
28 Den Zustand der Modells und des Optimierers kannst du mit `model.state_dict()` und `optimizer.state_dict()` erhalten.
29 Speichere dieses Dictionary mit `torch.save()` unter dem angegebenen Dateinamen.
30 """
31 torch.save(
32 {
33 "epoch": epoch,
34 "model_state_dict": model.state_dict(),
35 "optimizer_state_dict": optimizer.state_dict(),
36 },
37 filename,
38 )
39
40
41def load_checkpoint(model, optimizer, filename="checkpoint.pth"):
42 """Lädt den Zustand des Modells und des Optimierers aus einer Datei.
43
44 Parameters:
45 -----------
46 model (nn.Module):
47 Das Modell, in das die gespeicherten Zustände geladen werden.
48
49 optimizer (torch.optim.Optimizer):
50 Der Optimierer, dessen Zustand geladen wird.
51
52 filename (str):
53 Der Name der Datei, aus der der Checkpoint geladen wird.
54
55 **TODO**:
56 Versuche, den Checkpoint mit `torch.load()` zu laden.
57 Wenn die Datei nicht gefunden wird, gib eine entsprechende Fehlermeldung aus und starte ohne gespeicherten Zustand.
58 Wenn der Checkpoint geladen wird, versuche, den Zustand des Modells und des Optimizers zu laden.
59 Du kannst `model.load_state_dict()` und `optimizer.load_state_dict()` verwenden um die Zustände ins Modell zu laden.
60 Wenn ein Fehler beim Laden auftritt, gib eine Fehlermeldung aus und starte ohne gespeicherten Zustand.
61 Gibt die aktuelle Epoche zurück, die im Checkpoint gespeichert ist.
62 """
63 try:
64 checkpoint = torch.load(filename, weights_only=True)
65 model.load_state_dict(checkpoint["model_state_dict"])
66 optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
67 return checkpoint["epoch"]
68 except Exception as e:
69 print(f"Fehler beim Laden des Checkpoints {filename}: {e}")
70 print("Starte ohne gespeicherten Zustand.")
71 return 0
72
73
74if __name__ == "__main__":
75 training_set, validation_set = load_data()
76
77 # Initialisierung des Modells, Loss-Kriteriums und Optimierers
78 model = CNNNetwork().to(DEVICE)
79 criterion = nn.CrossEntropyLoss()
80 optimizer = torch.optim.Adam(
81 model.parameters(), lr=LR
82 ) # Checkpoint laden, falls vorhanden
83
84 # Checkpoint laden, falls vorhanden
85 dirname = os.path.dirname(os.path.abspath(__file__))
86 chkpt_path = os.path.join(dirname, "checkpoint.pth")
87
88 ep = load_checkpoint(model, optimizer, chkpt_path)
89 if ep > 0:
90 print(f"Checkpoint geladen, fortsetzen bei Epoche {ep}.")
91
92 # Das Modell trainieren
93 for n in range(ep, ep + 30):
94 epoch(model, n, True, training_set, criterion, optimizer)
95 epoch(model, n, False, validation_set, criterion, optimizer)
96
97 # Checkpoint nach jeder Epoche speichern
98 save_checkpoint(model, optimizer, n + 1, chkpt_path)