Checkpoints =========== Um Zwischenstände während des Trainings zu speichern und später wiederherzustellen, können Checkpoints verwendet werden. Diese ermöglichen es, den Trainingsprozess zu unterbrechen und später fortzusetzen, ohne von vorne beginnen zu müssen oder um den besten Zustand des Modells zu sichern. PyTorch bietet eine einfache Möglichkeit, Checkpoints zu erstellen. Alle Module in PyTorch erlauben es ihren internen Zustand zu speichern und später wiederherzustellen. Dies nennt man `state_dict `_. Die Menge aller relevanten state_dicts definiert dann einen Checkpoint. Der Checkpoint enthält also alle Informationen, die benötigt werden, um den Zustand des Modells und des Optimierers zu einem bestimmten Zeitpunkt wiederherzustellen. Um das `state_dict` abzurufen implementieren Modudle die Methode `state_dict() `_. Das Laden eines früheren `state_dicts` erfolgt dann über die Methode `load_state_dict() `_. Ein Checkpoint sollte mindestens folgende Informationen enthalten: - Der Zustand des Modells (`model.state_dict()`) - Der Zustand des Optimierers (`optimizer.state_dict()`) - Die aktuelle Epoche (z.B. `epoch`) **Aufgabe 1**: Checkpoint speichern ----------------------------------- Implementiere die Funktion `checkpoints.save_checkpoint(...)`, die den aktuellen Zustand des Modells, des Optimierers und die aktuelle Epoche in einer Datei speichert. Verwende dazu die Funktion `torch.save(...) `_. .. autofunction:: checkpoints.save_checkpoint .. admonition:: Lösung anzeigen :class: toggle .. code-block:: python def save_checkpoint(model, optimizer, epoch, filename='checkpoint.pth'): torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), }, filename) **Aufgabe 2**: Checkpoint laden ------------------------------- Implementiere die Funktion `checkpoints.load_checkpoint(...)`, die den gespeicherten Zustand des Modells, des Optimierers und die aktuelle Epoche aus einer Datei lädt. Verwende dazu die Funktion `torch.load(...) `_. Es macht Sinn das Laden des Checkpoints sowie das Laden der `state_dict` in einem `try`-`except`-Block zu kapseln, um Fehler beim Laden zu behandeln. Wenn der Checkpoint nicht gefunden wird, sollte eine Fehlermeldung ausgegeben werden und das Training ohne gespeicherten Zustand fortgesetzt werden. Dieses Verhalten macht es elegant möglich das Training ohne Checkpoint zu starten, falls der Checkpoint nicht gefunden wird (oder inkompatibel ist), gleichzeitig aber auch den Checkpoint zu laden und von diesem aus fortzusetzen, wenn er vorhanden ist. .. autofunction:: checkpoints.load_checkpoint .. admonition:: Lösung anzeigen :class: toggle .. code-block:: python def load_checkpoint(model, optimizer, filename='checkpoint.pth'): try: checkpoint = torch.load(filename, weights_only=True) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) epoch = checkpoint['epoch'] return epoch except FileNotFoundError: print(f"Checkpoint-Datei {filename} nicht gefunden. Starte ohne gespeicherten Zustand.") return 0 except Exception as e: print(f"Fehler beim Laden des Checkpoints {filename}: {e}") print("Starte ohne gespeicherten Zustand.") return 0 **Musterlösung** ---------------- :doc:`checkpoints_source`