From ebb94842dbdfbf29f94d2b19d9beae50d6cbb837 Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Sun, 20 Jul 2025 09:13:11 +0530 Subject: [PATCH] cleanup save checkpoint --- labml_nn/cfr/__init__.py | 4 ---- labml_nn/diffusion/ddpm/experiment.py | 8 +++----- labml_nn/gan/cycle_gan/__init__.py | 2 -- labml_nn/gan/stylegan/experiment.py | 3 ++- labml_nn/transformers/glu_variants/simple.py | 8 ++------ labml_nn/transformers/retro/train.py | 6 ++---- labml_nn/unet/experiment.py | 9 ++++----- 7 files changed, 13 insertions(+), 27 deletions(-) diff --git a/labml_nn/cfr/__init__.py b/labml_nn/cfr/__init__.py index 40b5033e..48e49605 100644 --- a/labml_nn/cfr/__init__.py +++ b/labml_nn/cfr/__init__.py @@ -703,10 +703,6 @@ class CFR: self.tracker(self.info_sets) tracker.save() - # Save checkpoints every $1,000$ iterations - if (t + 1) % 1_000 == 0: - experiment.save_checkpoint() - # Print the information sets logger.inspect(self.info_sets) diff --git a/labml_nn/diffusion/ddpm/experiment.py b/labml_nn/diffusion/ddpm/experiment.py index a185522c..1bb4c648 100644 --- a/labml_nn/diffusion/ddpm/experiment.py +++ b/labml_nn/diffusion/ddpm/experiment.py @@ -19,16 +19,16 @@ simplicity. """ from typing import List -import torch -import torch.utils.data import torchvision from PIL import Image +import torch +import torch.utils.data from labml import lab, tracker, experiment, monit from labml.configs import BaseConfigs, option -from labml_nn.helpers.device import DeviceConfigs from labml_nn.diffusion.ddpm import DenoiseDiffusion from labml_nn.diffusion.ddpm.unet import UNet +from labml_nn.helpers.device import DeviceConfigs class Configs(BaseConfigs): @@ -154,8 +154,6 @@ class Configs(BaseConfigs): self.sample() # New line in the console tracker.new_line() - # Save the model - experiment.save_checkpoint() class CelebADataset(torch.utils.data.Dataset): diff --git a/labml_nn/gan/cycle_gan/__init__.py b/labml_nn/gan/cycle_gan/__init__.py index a09823c7..0a78e261 100644 --- a/labml_nn/gan/cycle_gan/__init__.py +++ b/labml_nn/gan/cycle_gan/__init__.py @@ -570,8 +570,6 @@ class Configs(BaseConfigs): # Save images at intervals batches_done = epoch * len(self.dataloader) + i if batches_done % self.sample_interval == 0: - # Save models when sampling images - experiment.save_checkpoint() # Sample images self.sample_images(batches_done) diff --git a/labml_nn/gan/stylegan/experiment.py b/labml_nn/gan/stylegan/experiment.py index 60f6606b..7a33aba9 100644 --- a/labml_nn/gan/stylegan/experiment.py +++ b/labml_nn/gan/stylegan/experiment.py @@ -404,7 +404,8 @@ class Configs(BaseConfigs): tracker.add('generated', torch.cat([generated_images[:6], real_images[:3]], dim=0)) # Save model checkpoints if (idx + 1) % self.save_checkpoint_interval == 0: - experiment.save_checkpoint() + # Save checkpoint + pass # Flush tracker tracker.save() diff --git a/labml_nn/transformers/glu_variants/simple.py b/labml_nn/transformers/glu_variants/simple.py index 0b821d9e..67f54cce 100644 --- a/labml_nn/transformers/glu_variants/simple.py +++ b/labml_nn/transformers/glu_variants/simple.py @@ -19,9 +19,6 @@ We decided to write a simpler implementation to make it easier for readers who a import dataclasses import torch -from torch import nn -from torch.utils.data import Dataset, DataLoader - from labml import experiment, lab, tracker, monit, logger from labml.logger import Text from labml.utils.download import download_file @@ -31,6 +28,8 @@ from labml_nn.transformers import Encoder, MultiHeadAttention from labml_nn.transformers.feed_forward import FeedForward from labml_nn.transformers.models import EmbeddingsWithPositionalEncoding, TransformerLayer from labml_nn.transformers.utils import subsequent_mask +from torch import nn +from torch.utils.data import Dataset, DataLoader class AutoregressiveModel(nn.Module): @@ -280,9 +279,6 @@ class Trainer: if (i + 1) % 10 == 0: tracker.save() - # Save the model - experiment.save_checkpoint() - def main(): # Create experiment diff --git a/labml_nn/transformers/retro/train.py b/labml_nn/transformers/retro/train.py index ec4d7bb4..61372c9e 100644 --- a/labml_nn/transformers/retro/train.py +++ b/labml_nn/transformers/retro/train.py @@ -12,9 +12,6 @@ This is the training code for """ import torch -from torch import nn -from torch.utils.data import DataLoader, RandomSampler - from labml import monit, lab, tracker, experiment, logger from labml.logger import Text from labml_nn.helpers.datasets import TextFileDataset @@ -22,6 +19,8 @@ from labml_nn.optimizers.noam import Noam from labml_nn.transformers.retro import model as retro from labml_nn.transformers.retro.dataset import Dataset, RetroIndex from labml_nn.transformers.retro.model import RetroModel, NearestNeighborEncoder +from torch import nn +from torch.utils.data import DataLoader, RandomSampler class Sampler: @@ -217,7 +216,6 @@ def train(): logger.log([(prompt.replace('\n', '\\n\n'), Text.subtle), (sampler.sample(prompt, 128).replace('\n', '\\n\n'), Text.none)]) # Save models - experiment.save_checkpoint() # diff --git a/labml_nn/unet/experiment.py b/labml_nn/unet/experiment.py index 06dbe269..1a45ca43 100644 --- a/labml_nn/unet/experiment.py +++ b/labml_nn/unet/experiment.py @@ -17,16 +17,16 @@ For simplicity, we do not do a training and validation split. """ import numpy as np +import torchvision.transforms.functional + import torch import torch.utils.data -import torchvision.transforms.functional -from torch import nn - from labml import lab, tracker, experiment, monit from labml.configs import BaseConfigs from labml_nn.helpers.device import DeviceConfigs -from labml_nn.unet.carvana import CarvanaDataset from labml_nn.unet import UNet +from labml_nn.unet.carvana import CarvanaDataset +from torch import nn class Configs(BaseConfigs): @@ -141,7 +141,6 @@ class Configs(BaseConfigs): # New line in the console tracker.new_line() # Save the model - experiment.save_checkpoint() def main():