cleanup save checkpoint

This commit is contained in:
Varuna Jayasiri
2025-07-20 09:13:11 +05:30
parent 5eecda7e28
commit ebb94842db
7 changed files with 13 additions and 27 deletions

View File

@ -703,10 +703,6 @@ class CFR:
self.tracker(self.info_sets)
tracker.save()
# Save checkpoints every $1,000$ iterations
if (t + 1) % 1_000 == 0:
experiment.save_checkpoint()
# Print the information sets
logger.inspect(self.info_sets)

View File

@ -19,16 +19,16 @@ simplicity.
"""
from typing import List
import torch
import torch.utils.data
import torchvision
from PIL import Image
import torch
import torch.utils.data
from labml import lab, tracker, experiment, monit
from labml.configs import BaseConfigs, option
from labml_nn.helpers.device import DeviceConfigs
from labml_nn.diffusion.ddpm import DenoiseDiffusion
from labml_nn.diffusion.ddpm.unet import UNet
from labml_nn.helpers.device import DeviceConfigs
class Configs(BaseConfigs):
@ -154,8 +154,6 @@ class Configs(BaseConfigs):
self.sample()
# New line in the console
tracker.new_line()
# Save the model
experiment.save_checkpoint()
class CelebADataset(torch.utils.data.Dataset):

View File

@ -570,8 +570,6 @@ class Configs(BaseConfigs):
# Save images at intervals
batches_done = epoch * len(self.dataloader) + i
if batches_done % self.sample_interval == 0:
# Save models when sampling images
experiment.save_checkpoint()
# Sample images
self.sample_images(batches_done)

View File

@ -404,7 +404,8 @@ class Configs(BaseConfigs):
tracker.add('generated', torch.cat([generated_images[:6], real_images[:3]], dim=0))
# Save model checkpoints
if (idx + 1) % self.save_checkpoint_interval == 0:
experiment.save_checkpoint()
# Save checkpoint
pass
# Flush tracker
tracker.save()

View File

@ -19,9 +19,6 @@ We decided to write a simpler implementation to make it easier for readers who a
import dataclasses
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from labml import experiment, lab, tracker, monit, logger
from labml.logger import Text
from labml.utils.download import download_file
@ -31,6 +28,8 @@ from labml_nn.transformers import Encoder, MultiHeadAttention
from labml_nn.transformers.feed_forward import FeedForward
from labml_nn.transformers.models import EmbeddingsWithPositionalEncoding, TransformerLayer
from labml_nn.transformers.utils import subsequent_mask
from torch import nn
from torch.utils.data import Dataset, DataLoader
class AutoregressiveModel(nn.Module):
@ -280,9 +279,6 @@ class Trainer:
if (i + 1) % 10 == 0:
tracker.save()
# Save the model
experiment.save_checkpoint()
def main():
# Create experiment

View File

@ -12,9 +12,6 @@ This is the training code for
"""
import torch
from torch import nn
from torch.utils.data import DataLoader, RandomSampler
from labml import monit, lab, tracker, experiment, logger
from labml.logger import Text
from labml_nn.helpers.datasets import TextFileDataset
@ -22,6 +19,8 @@ from labml_nn.optimizers.noam import Noam
from labml_nn.transformers.retro import model as retro
from labml_nn.transformers.retro.dataset import Dataset, RetroIndex
from labml_nn.transformers.retro.model import RetroModel, NearestNeighborEncoder
from torch import nn
from torch.utils.data import DataLoader, RandomSampler
class Sampler:
@ -217,7 +216,6 @@ def train():
logger.log([(prompt.replace('\n', '\\n\n'), Text.subtle),
(sampler.sample(prompt, 128).replace('\n', '\\n\n'), Text.none)])
# Save models
experiment.save_checkpoint()
#

View File

@ -17,16 +17,16 @@ For simplicity, we do not do a training and validation split.
"""
import numpy as np
import torchvision.transforms.functional
import torch
import torch.utils.data
import torchvision.transforms.functional
from torch import nn
from labml import lab, tracker, experiment, monit
from labml.configs import BaseConfigs
from labml_nn.helpers.device import DeviceConfigs
from labml_nn.unet.carvana import CarvanaDataset
from labml_nn.unet import UNet
from labml_nn.unet.carvana import CarvanaDataset
from torch import nn
class Configs(BaseConfigs):
@ -141,7 +141,6 @@ class Configs(BaseConfigs):
# New line in the console
tracker.new_line()
# Save the model
experiment.save_checkpoint()
def main():