mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-14 09:31:42 +08:00
cleanup save checkpoint
This commit is contained in:
@ -703,10 +703,6 @@ class CFR:
|
|||||||
self.tracker(self.info_sets)
|
self.tracker(self.info_sets)
|
||||||
tracker.save()
|
tracker.save()
|
||||||
|
|
||||||
# Save checkpoints every $1,000$ iterations
|
|
||||||
if (t + 1) % 1_000 == 0:
|
|
||||||
experiment.save_checkpoint()
|
|
||||||
|
|
||||||
# Print the information sets
|
# Print the information sets
|
||||||
logger.inspect(self.info_sets)
|
logger.inspect(self.info_sets)
|
||||||
|
|
||||||
|
@ -19,16 +19,16 @@ simplicity.
|
|||||||
"""
|
"""
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.utils.data
|
|
||||||
import torchvision
|
import torchvision
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.data
|
||||||
from labml import lab, tracker, experiment, monit
|
from labml import lab, tracker, experiment, monit
|
||||||
from labml.configs import BaseConfigs, option
|
from labml.configs import BaseConfigs, option
|
||||||
from labml_nn.helpers.device import DeviceConfigs
|
|
||||||
from labml_nn.diffusion.ddpm import DenoiseDiffusion
|
from labml_nn.diffusion.ddpm import DenoiseDiffusion
|
||||||
from labml_nn.diffusion.ddpm.unet import UNet
|
from labml_nn.diffusion.ddpm.unet import UNet
|
||||||
|
from labml_nn.helpers.device import DeviceConfigs
|
||||||
|
|
||||||
|
|
||||||
class Configs(BaseConfigs):
|
class Configs(BaseConfigs):
|
||||||
@ -154,8 +154,6 @@ class Configs(BaseConfigs):
|
|||||||
self.sample()
|
self.sample()
|
||||||
# New line in the console
|
# New line in the console
|
||||||
tracker.new_line()
|
tracker.new_line()
|
||||||
# Save the model
|
|
||||||
experiment.save_checkpoint()
|
|
||||||
|
|
||||||
|
|
||||||
class CelebADataset(torch.utils.data.Dataset):
|
class CelebADataset(torch.utils.data.Dataset):
|
||||||
|
@ -570,8 +570,6 @@ class Configs(BaseConfigs):
|
|||||||
# Save images at intervals
|
# Save images at intervals
|
||||||
batches_done = epoch * len(self.dataloader) + i
|
batches_done = epoch * len(self.dataloader) + i
|
||||||
if batches_done % self.sample_interval == 0:
|
if batches_done % self.sample_interval == 0:
|
||||||
# Save models when sampling images
|
|
||||||
experiment.save_checkpoint()
|
|
||||||
# Sample images
|
# Sample images
|
||||||
self.sample_images(batches_done)
|
self.sample_images(batches_done)
|
||||||
|
|
||||||
|
@ -404,7 +404,8 @@ class Configs(BaseConfigs):
|
|||||||
tracker.add('generated', torch.cat([generated_images[:6], real_images[:3]], dim=0))
|
tracker.add('generated', torch.cat([generated_images[:6], real_images[:3]], dim=0))
|
||||||
# Save model checkpoints
|
# Save model checkpoints
|
||||||
if (idx + 1) % self.save_checkpoint_interval == 0:
|
if (idx + 1) % self.save_checkpoint_interval == 0:
|
||||||
experiment.save_checkpoint()
|
# Save checkpoint
|
||||||
|
pass
|
||||||
|
|
||||||
# Flush tracker
|
# Flush tracker
|
||||||
tracker.save()
|
tracker.save()
|
||||||
|
@ -19,9 +19,6 @@ We decided to write a simpler implementation to make it easier for readers who a
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
|
||||||
from torch.utils.data import Dataset, DataLoader
|
|
||||||
|
|
||||||
from labml import experiment, lab, tracker, monit, logger
|
from labml import experiment, lab, tracker, monit, logger
|
||||||
from labml.logger import Text
|
from labml.logger import Text
|
||||||
from labml.utils.download import download_file
|
from labml.utils.download import download_file
|
||||||
@ -31,6 +28,8 @@ from labml_nn.transformers import Encoder, MultiHeadAttention
|
|||||||
from labml_nn.transformers.feed_forward import FeedForward
|
from labml_nn.transformers.feed_forward import FeedForward
|
||||||
from labml_nn.transformers.models import EmbeddingsWithPositionalEncoding, TransformerLayer
|
from labml_nn.transformers.models import EmbeddingsWithPositionalEncoding, TransformerLayer
|
||||||
from labml_nn.transformers.utils import subsequent_mask
|
from labml_nn.transformers.utils import subsequent_mask
|
||||||
|
from torch import nn
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
|
||||||
|
|
||||||
class AutoregressiveModel(nn.Module):
|
class AutoregressiveModel(nn.Module):
|
||||||
@ -280,9 +279,6 @@ class Trainer:
|
|||||||
if (i + 1) % 10 == 0:
|
if (i + 1) % 10 == 0:
|
||||||
tracker.save()
|
tracker.save()
|
||||||
|
|
||||||
# Save the model
|
|
||||||
experiment.save_checkpoint()
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# Create experiment
|
# Create experiment
|
||||||
|
@ -12,9 +12,6 @@ This is the training code for
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
|
||||||
from torch.utils.data import DataLoader, RandomSampler
|
|
||||||
|
|
||||||
from labml import monit, lab, tracker, experiment, logger
|
from labml import monit, lab, tracker, experiment, logger
|
||||||
from labml.logger import Text
|
from labml.logger import Text
|
||||||
from labml_nn.helpers.datasets import TextFileDataset
|
from labml_nn.helpers.datasets import TextFileDataset
|
||||||
@ -22,6 +19,8 @@ from labml_nn.optimizers.noam import Noam
|
|||||||
from labml_nn.transformers.retro import model as retro
|
from labml_nn.transformers.retro import model as retro
|
||||||
from labml_nn.transformers.retro.dataset import Dataset, RetroIndex
|
from labml_nn.transformers.retro.dataset import Dataset, RetroIndex
|
||||||
from labml_nn.transformers.retro.model import RetroModel, NearestNeighborEncoder
|
from labml_nn.transformers.retro.model import RetroModel, NearestNeighborEncoder
|
||||||
|
from torch import nn
|
||||||
|
from torch.utils.data import DataLoader, RandomSampler
|
||||||
|
|
||||||
|
|
||||||
class Sampler:
|
class Sampler:
|
||||||
@ -217,7 +216,6 @@ def train():
|
|||||||
logger.log([(prompt.replace('\n', '\\n\n'), Text.subtle),
|
logger.log([(prompt.replace('\n', '\\n\n'), Text.subtle),
|
||||||
(sampler.sample(prompt, 128).replace('\n', '\\n\n'), Text.none)])
|
(sampler.sample(prompt, 128).replace('\n', '\\n\n'), Text.none)])
|
||||||
# Save models
|
# Save models
|
||||||
experiment.save_checkpoint()
|
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
|
@ -17,16 +17,16 @@ For simplicity, we do not do a training and validation split.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torchvision.transforms.functional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.data
|
import torch.utils.data
|
||||||
import torchvision.transforms.functional
|
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
from labml import lab, tracker, experiment, monit
|
from labml import lab, tracker, experiment, monit
|
||||||
from labml.configs import BaseConfigs
|
from labml.configs import BaseConfigs
|
||||||
from labml_nn.helpers.device import DeviceConfigs
|
from labml_nn.helpers.device import DeviceConfigs
|
||||||
from labml_nn.unet.carvana import CarvanaDataset
|
|
||||||
from labml_nn.unet import UNet
|
from labml_nn.unet import UNet
|
||||||
|
from labml_nn.unet.carvana import CarvanaDataset
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
class Configs(BaseConfigs):
|
class Configs(BaseConfigs):
|
||||||
@ -141,7 +141,6 @@ class Configs(BaseConfigs):
|
|||||||
# New line in the console
|
# New line in the console
|
||||||
tracker.new_line()
|
tracker.new_line()
|
||||||
# Save the model
|
# Save the model
|
||||||
experiment.save_checkpoint()
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
Reference in New Issue
Block a user