Checkpoints
Um Zwischenstände während des Trainings zu speichern und später wiederherzustellen, können Checkpoints verwendet werden. Diese ermöglichen es, den Trainingsprozess zu unterbrechen und später fortzusetzen, ohne von vorne beginnen zu müssen oder um den besten Zustand des Modells zu sichern. PyTorch bietet eine einfache Möglichkeit, Checkpoints zu erstellen.
Alle Module in PyTorch erlauben es ihren internen Zustand zu speichern und später wiederherzustellen. Dies nennt man state_dict. Die Menge aller relevanten state_dicts definiert dann einen Checkpoint. Der Checkpoint enthält also alle Informationen, die benötigt werden, um den Zustand des Modells und des Optimierers zu einem bestimmten Zeitpunkt wiederherzustellen.
Um das state_dict abzurufen implementieren Modudle die Methode state_dict(). Das Laden eines früheren state_dicts erfolgt dann über die Methode load_state_dict().
Ein Checkpoint sollte mindestens folgende Informationen enthalten:
Der Zustand des Modells (model.state_dict())
Der Zustand des Optimierers (optimizer.state_dict())
Die aktuelle Epoche (z.B. epoch)
Aufgabe 1: Checkpoint speichern
Implementiere die Funktion checkpoints.save_checkpoint(…), die den aktuellen Zustand des Modells, des Optimierers und die aktuelle Epoche in einer Datei speichert. Verwende dazu die Funktion torch.save(…).
Lösung anzeigen
def save_checkpoint(model, optimizer, epoch, filename='checkpoint.pth'):
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}, filename)
Aufgabe 2: Checkpoint laden
Implementiere die Funktion checkpoints.load_checkpoint(…), die den gespeicherten Zustand des Modells, des Optimierers und die aktuelle Epoche aus einer Datei lädt. Verwende dazu die Funktion torch.load(…). Es macht Sinn das Laden des Checkpoints sowie das Laden der state_dict in einem try-except-Block zu kapseln, um Fehler beim Laden zu behandeln. Wenn der Checkpoint nicht gefunden wird, sollte eine Fehlermeldung ausgegeben werden und das Training ohne gespeicherten Zustand fortgesetzt werden. Dieses Verhalten macht es elegant möglich das Training ohne Checkpoint zu starten, falls der Checkpoint nicht gefunden wird (oder inkompatibel ist), gleichzeitig aber auch den Checkpoint zu laden und von diesem aus fortzusetzen, wenn er vorhanden ist.
Lösung anzeigen
def load_checkpoint(model, optimizer, filename='checkpoint.pth'):
try:
checkpoint = torch.load(filename, weights_only=True)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
return epoch
except FileNotFoundError:
print(f"Checkpoint-Datei {filename} nicht gefunden. Starte ohne gespeicherten Zustand.")
return 0
except Exception as e:
print(f"Fehler beim Laden des Checkpoints {filename}: {e}")
print("Starte ohne gespeicherten Zustand.")
return 0