mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-06 15:22:21 +08:00
cleanup save checkpoint
This commit is contained in:
@ -703,10 +703,6 @@ class CFR:
|
||||
self.tracker(self.info_sets)
|
||||
tracker.save()
|
||||
|
||||
# Save checkpoints every $1,000$ iterations
|
||||
if (t + 1) % 1_000 == 0:
|
||||
experiment.save_checkpoint()
|
||||
|
||||
# Print the information sets
|
||||
logger.inspect(self.info_sets)
|
||||
|
||||
|
@ -19,16 +19,16 @@ simplicity.
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.utils.data
|
||||
import torchvision
|
||||
from PIL import Image
|
||||
|
||||
import torch
|
||||
import torch.utils.data
|
||||
from labml import lab, tracker, experiment, monit
|
||||
from labml.configs import BaseConfigs, option
|
||||
from labml_nn.helpers.device import DeviceConfigs
|
||||
from labml_nn.diffusion.ddpm import DenoiseDiffusion
|
||||
from labml_nn.diffusion.ddpm.unet import UNet
|
||||
from labml_nn.helpers.device import DeviceConfigs
|
||||
|
||||
|
||||
class Configs(BaseConfigs):
|
||||
@ -154,8 +154,6 @@ class Configs(BaseConfigs):
|
||||
self.sample()
|
||||
# New line in the console
|
||||
tracker.new_line()
|
||||
# Save the model
|
||||
experiment.save_checkpoint()
|
||||
|
||||
|
||||
class CelebADataset(torch.utils.data.Dataset):
|
||||
|
@ -570,8 +570,6 @@ class Configs(BaseConfigs):
|
||||
# Save images at intervals
|
||||
batches_done = epoch * len(self.dataloader) + i
|
||||
if batches_done % self.sample_interval == 0:
|
||||
# Save models when sampling images
|
||||
experiment.save_checkpoint()
|
||||
# Sample images
|
||||
self.sample_images(batches_done)
|
||||
|
||||
|
@ -404,7 +404,8 @@ class Configs(BaseConfigs):
|
||||
tracker.add('generated', torch.cat([generated_images[:6], real_images[:3]], dim=0))
|
||||
# Save model checkpoints
|
||||
if (idx + 1) % self.save_checkpoint_interval == 0:
|
||||
experiment.save_checkpoint()
|
||||
# Save checkpoint
|
||||
pass
|
||||
|
||||
# Flush tracker
|
||||
tracker.save()
|
||||
|
@ -19,9 +19,6 @@ We decided to write a simpler implementation to make it easier for readers who a
|
||||
import dataclasses
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
|
||||
from labml import experiment, lab, tracker, monit, logger
|
||||
from labml.logger import Text
|
||||
from labml.utils.download import download_file
|
||||
@ -31,6 +28,8 @@ from labml_nn.transformers import Encoder, MultiHeadAttention
|
||||
from labml_nn.transformers.feed_forward import FeedForward
|
||||
from labml_nn.transformers.models import EmbeddingsWithPositionalEncoding, TransformerLayer
|
||||
from labml_nn.transformers.utils import subsequent_mask
|
||||
from torch import nn
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
|
||||
|
||||
class AutoregressiveModel(nn.Module):
|
||||
@ -280,9 +279,6 @@ class Trainer:
|
||||
if (i + 1) % 10 == 0:
|
||||
tracker.save()
|
||||
|
||||
# Save the model
|
||||
experiment.save_checkpoint()
|
||||
|
||||
|
||||
def main():
|
||||
# Create experiment
|
||||
|
@ -12,9 +12,6 @@ This is the training code for
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader, RandomSampler
|
||||
|
||||
from labml import monit, lab, tracker, experiment, logger
|
||||
from labml.logger import Text
|
||||
from labml_nn.helpers.datasets import TextFileDataset
|
||||
@ -22,6 +19,8 @@ from labml_nn.optimizers.noam import Noam
|
||||
from labml_nn.transformers.retro import model as retro
|
||||
from labml_nn.transformers.retro.dataset import Dataset, RetroIndex
|
||||
from labml_nn.transformers.retro.model import RetroModel, NearestNeighborEncoder
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader, RandomSampler
|
||||
|
||||
|
||||
class Sampler:
|
||||
@ -217,7 +216,6 @@ def train():
|
||||
logger.log([(prompt.replace('\n', '\\n\n'), Text.subtle),
|
||||
(sampler.sample(prompt, 128).replace('\n', '\\n\n'), Text.none)])
|
||||
# Save models
|
||||
experiment.save_checkpoint()
|
||||
|
||||
|
||||
#
|
||||
|
@ -17,16 +17,16 @@ For simplicity, we do not do a training and validation split.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torchvision.transforms.functional
|
||||
|
||||
import torch
|
||||
import torch.utils.data
|
||||
import torchvision.transforms.functional
|
||||
from torch import nn
|
||||
|
||||
from labml import lab, tracker, experiment, monit
|
||||
from labml.configs import BaseConfigs
|
||||
from labml_nn.helpers.device import DeviceConfigs
|
||||
from labml_nn.unet.carvana import CarvanaDataset
|
||||
from labml_nn.unet import UNet
|
||||
from labml_nn.unet.carvana import CarvanaDataset
|
||||
from torch import nn
|
||||
|
||||
|
||||
class Configs(BaseConfigs):
|
||||
@ -141,7 +141,6 @@ class Configs(BaseConfigs):
|
||||
# New line in the console
|
||||
tracker.new_line()
|
||||
# Save the model
|
||||
experiment.save_checkpoint()
|
||||
|
||||
|
||||
def main():
|
||||
|
Reference in New Issue
Block a user