cleanup save checkpoint

This commit is contained in:
Varuna Jayasiri
2025-07-20 09:13:11 +05:30
parent 5eecda7e28
commit ebb94842db
7 changed files with 13 additions and 27 deletions

View File

@ -703,10 +703,6 @@ class CFR:
self.tracker(self.info_sets) self.tracker(self.info_sets)
tracker.save() tracker.save()
# Save checkpoints every $1,000$ iterations
if (t + 1) % 1_000 == 0:
experiment.save_checkpoint()
# Print the information sets # Print the information sets
logger.inspect(self.info_sets) logger.inspect(self.info_sets)

View File

@ -19,16 +19,16 @@ simplicity.
""" """
from typing import List from typing import List
import torch
import torch.utils.data
import torchvision import torchvision
from PIL import Image from PIL import Image
import torch
import torch.utils.data
from labml import lab, tracker, experiment, monit from labml import lab, tracker, experiment, monit
from labml.configs import BaseConfigs, option from labml.configs import BaseConfigs, option
from labml_nn.helpers.device import DeviceConfigs
from labml_nn.diffusion.ddpm import DenoiseDiffusion from labml_nn.diffusion.ddpm import DenoiseDiffusion
from labml_nn.diffusion.ddpm.unet import UNet from labml_nn.diffusion.ddpm.unet import UNet
from labml_nn.helpers.device import DeviceConfigs
class Configs(BaseConfigs): class Configs(BaseConfigs):
@ -154,8 +154,6 @@ class Configs(BaseConfigs):
self.sample() self.sample()
# New line in the console # New line in the console
tracker.new_line() tracker.new_line()
# Save the model
experiment.save_checkpoint()
class CelebADataset(torch.utils.data.Dataset): class CelebADataset(torch.utils.data.Dataset):

View File

@ -570,8 +570,6 @@ class Configs(BaseConfigs):
# Save images at intervals # Save images at intervals
batches_done = epoch * len(self.dataloader) + i batches_done = epoch * len(self.dataloader) + i
if batches_done % self.sample_interval == 0: if batches_done % self.sample_interval == 0:
# Save models when sampling images
experiment.save_checkpoint()
# Sample images # Sample images
self.sample_images(batches_done) self.sample_images(batches_done)

View File

@ -404,7 +404,8 @@ class Configs(BaseConfigs):
tracker.add('generated', torch.cat([generated_images[:6], real_images[:3]], dim=0)) tracker.add('generated', torch.cat([generated_images[:6], real_images[:3]], dim=0))
# Save model checkpoints # Save model checkpoints
if (idx + 1) % self.save_checkpoint_interval == 0: if (idx + 1) % self.save_checkpoint_interval == 0:
experiment.save_checkpoint() # Save checkpoint
pass
# Flush tracker # Flush tracker
tracker.save() tracker.save()

View File

@ -19,9 +19,6 @@ We decided to write a simpler implementation to make it easier for readers who a
import dataclasses import dataclasses
import torch import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from labml import experiment, lab, tracker, monit, logger from labml import experiment, lab, tracker, monit, logger
from labml.logger import Text from labml.logger import Text
from labml.utils.download import download_file from labml.utils.download import download_file
@ -31,6 +28,8 @@ from labml_nn.transformers import Encoder, MultiHeadAttention
from labml_nn.transformers.feed_forward import FeedForward from labml_nn.transformers.feed_forward import FeedForward
from labml_nn.transformers.models import EmbeddingsWithPositionalEncoding, TransformerLayer from labml_nn.transformers.models import EmbeddingsWithPositionalEncoding, TransformerLayer
from labml_nn.transformers.utils import subsequent_mask from labml_nn.transformers.utils import subsequent_mask
from torch import nn
from torch.utils.data import Dataset, DataLoader
class AutoregressiveModel(nn.Module): class AutoregressiveModel(nn.Module):
@ -280,9 +279,6 @@ class Trainer:
if (i + 1) % 10 == 0: if (i + 1) % 10 == 0:
tracker.save() tracker.save()
# Save the model
experiment.save_checkpoint()
def main(): def main():
# Create experiment # Create experiment

View File

@ -12,9 +12,6 @@ This is the training code for
""" """
import torch import torch
from torch import nn
from torch.utils.data import DataLoader, RandomSampler
from labml import monit, lab, tracker, experiment, logger from labml import monit, lab, tracker, experiment, logger
from labml.logger import Text from labml.logger import Text
from labml_nn.helpers.datasets import TextFileDataset from labml_nn.helpers.datasets import TextFileDataset
@ -22,6 +19,8 @@ from labml_nn.optimizers.noam import Noam
from labml_nn.transformers.retro import model as retro from labml_nn.transformers.retro import model as retro
from labml_nn.transformers.retro.dataset import Dataset, RetroIndex from labml_nn.transformers.retro.dataset import Dataset, RetroIndex
from labml_nn.transformers.retro.model import RetroModel, NearestNeighborEncoder from labml_nn.transformers.retro.model import RetroModel, NearestNeighborEncoder
from torch import nn
from torch.utils.data import DataLoader, RandomSampler
class Sampler: class Sampler:
@ -217,7 +216,6 @@ def train():
logger.log([(prompt.replace('\n', '\\n\n'), Text.subtle), logger.log([(prompt.replace('\n', '\\n\n'), Text.subtle),
(sampler.sample(prompt, 128).replace('\n', '\\n\n'), Text.none)]) (sampler.sample(prompt, 128).replace('\n', '\\n\n'), Text.none)])
# Save models # Save models
experiment.save_checkpoint()
# #

View File

@ -17,16 +17,16 @@ For simplicity, we do not do a training and validation split.
""" """
import numpy as np import numpy as np
import torchvision.transforms.functional
import torch import torch
import torch.utils.data import torch.utils.data
import torchvision.transforms.functional
from torch import nn
from labml import lab, tracker, experiment, monit from labml import lab, tracker, experiment, monit
from labml.configs import BaseConfigs from labml.configs import BaseConfigs
from labml_nn.helpers.device import DeviceConfigs from labml_nn.helpers.device import DeviceConfigs
from labml_nn.unet.carvana import CarvanaDataset
from labml_nn.unet import UNet from labml_nn.unet import UNet
from labml_nn.unet.carvana import CarvanaDataset
from torch import nn
class Configs(BaseConfigs): class Configs(BaseConfigs):
@ -141,7 +141,6 @@ class Configs(BaseConfigs):
# New line in the console # New line in the console
tracker.new_line() tracker.new_line()
# Save the model # Save the model
experiment.save_checkpoint()
def main(): def main():