diff --git a/docs/activations/index.html b/docs/activations/index.html index 47bffbdd..6f9e2821 100644 --- a/docs/activations/index.html +++ b/docs/activations/index.html @@ -84,7 +84,6 @@ function handleImages() { var images = document.querySelectorAll('p>img') - console.log(images); for (var i = 0; i < images.length; ++i) { handleImage(images[i]) } diff --git a/docs/activations/swish.html b/docs/activations/swish.html index a53db223..e1b3984c 100644 --- a/docs/activations/swish.html +++ b/docs/activations/swish.html @@ -123,7 +123,6 @@ function handleImages() { var images = document.querySelectorAll('p>img') - console.log(images); for (var i = 0; i < images.length; ++i) { handleImage(images[i]) } diff --git a/docs/adaptive_computation/index.html b/docs/adaptive_computation/index.html index 798b23e4..718fc2dc 100644 --- a/docs/adaptive_computation/index.html +++ b/docs/adaptive_computation/index.html @@ -88,7 +88,6 @@ function handleImages() { var images = document.querySelectorAll('p>img') - console.log(images); for (var i = 0; i < images.length; ++i) { handleImage(images[i]) } diff --git a/docs/adaptive_computation/parity.html b/docs/adaptive_computation/parity.html index 9e9119b3..925c31d7 100644 --- a/docs/adaptive_computation/parity.html +++ b/docs/adaptive_computation/parity.html @@ -236,7 +236,6 @@ function handleImages() { var images = document.querySelectorAll('p>img') - console.log(images); for (var i = 0; i < images.length; ++i) { handleImage(images[i]) } diff --git a/docs/adaptive_computation/ponder_net/experiment.html b/docs/adaptive_computation/ponder_net/experiment.html index 372c80c8..3ff6165a 100644 --- a/docs/adaptive_computation/ponder_net/experiment.html +++ b/docs/adaptive_computation/ponder_net/experiment.html @@ -599,7 +599,6 @@ function handleImages() { var images = document.querySelectorAll('p>img') - console.log(images); for (var i = 0; i < images.length; ++i) { handleImage(images[i]) } diff --git a/docs/adaptive_computation/ponder_net/index.html b/docs/adaptive_computation/ponder_net/index.html index 5ac045dc..6c8d1813 100644 --- a/docs/adaptive_computation/ponder_net/index.html +++ b/docs/adaptive_computation/ponder_net/index.html @@ -765,7 +765,6 @@ s is odd and false otherwise.
function handleImages() { var images = document.querySelectorAll('p>img') - console.log(images); for (var i = 0; i < images.length; ++i) { handleImage(images[i]) } diff --git a/docs/adaptive_computation/ponder_net/readme.html b/docs/adaptive_computation/ponder_net/readme.html index 07e4c924..e242a77f 100644 --- a/docs/adaptive_computation/ponder_net/readme.html +++ b/docs/adaptive_computation/ponder_net/readme.html @@ -88,7 +88,6 @@ function handleImages() { var images = document.querySelectorAll('p>img') - console.log(images); for (var i = 0; i < images.length; ++i) { handleImage(images[i]) } diff --git a/docs/adaptive_computation/readme.html b/docs/adaptive_computation/readme.html index 8268a4a6..52f739d0 100644 --- a/docs/adaptive_computation/readme.html +++ b/docs/adaptive_computation/readme.html @@ -88,7 +88,6 @@ function handleImages() { var images = document.querySelectorAll('p>img') - console.log(images); for (var i = 0; i < images.length; ++i) { handleImage(images[i]) } diff --git a/docs/capsule_networks/index.html b/docs/capsule_networks/index.html index a3a535e3..3d0df89a 100644 --- a/docs/capsule_networks/index.html +++ b/docs/capsule_networks/index.html @@ -449,7 +449,6 @@ M1001 80h400000v40h-400000z'/> function handleImages() { var images = document.querySelectorAll('p>img') - console.log(images); for (var i = 0; i < images.length; ++i) { handleImage(images[i]) } diff --git a/docs/capsule_networks/mnist.html b/docs/capsule_networks/mnist.html index c24ff3f6..6bd85274 100644 --- a/docs/capsule_networks/mnist.html +++ b/docs/capsule_networks/mnist.html @@ -559,7 +559,6 @@ function handleImages() { var images = document.querySelectorAll('p>img') - console.log(images); for (var i = 0; i < images.length; ++i) { handleImage(images[i]) } diff --git a/docs/capsule_networks/readme.html b/docs/capsule_networks/readme.html index 59f56294..4ba2c9df 100644 --- a/docs/capsule_networks/readme.html +++ b/docs/capsule_networks/readme.html @@ -92,7 +92,6 @@ function handleImages() { var images = document.querySelectorAll('p>img') - console.log(images); for (var i = 0; i < images.length; ++i) { handleImage(images[i]) } diff --git a/docs/cfr/analytics.html b/docs/cfr/analytics.html index 32b645c3..818b1e63 100644 --- a/docs/cfr/analytics.html +++ b/docs/cfr/analytics.html @@ -165,7 +165,6 @@ function handleImages() { var images = document.querySelectorAll('p>img') - console.log(images); for (var i = 0; i < images.length; ++i) { handleImage(images[i]) } diff --git a/docs/cfr/index.html b/docs/cfr/index.html index 3a1ddeb0..b095575e 100644 --- a/docs/cfr/index.html +++ b/docs/cfr/index.html @@ -1372,7 +1372,6 @@ M834 80h400000v40h-400000z'/> function handleImages() { var images = document.querySelectorAll('p>img') - console.log(images); for (var i = 0; i < images.length; ++i) { handleImage(images[i]) } diff --git a/docs/cfr/infoset_saver.html b/docs/cfr/infoset_saver.html index 8d589755..0b34e2eb 100644 --- a/docs/cfr/infoset_saver.html +++ b/docs/cfr/infoset_saver.html @@ -146,7 +146,6 @@ function handleImages() { var images = document.querySelectorAll('p>img') - console.log(images); for (var i = 0; i < images.length; ++i) { handleImage(images[i]) } diff --git a/docs/cfr/kuhn/index.html b/docs/cfr/kuhn/index.html index 1a26a86e..1104bfc8 100644 --- a/docs/cfr/kuhn/index.html +++ b/docs/cfr/kuhn/index.html @@ -846,7 +846,6 @@ function handleImages() { var images = document.querySelectorAll('p>img') - console.log(images); for (var i = 0; i < images.length; ++i) { handleImage(images[i]) } diff --git a/docs/conv_mixer/experiment.html b/docs/conv_mixer/experiment.html index e1126261..f2fdc0f3 100644 --- a/docs/conv_mixer/experiment.html +++ b/docs/conv_mixer/experiment.html @@ -338,7 +338,6 @@ function handleImages() { var images = document.querySelectorAll('p>img') - console.log(images); for (var i = 0; i < images.length; ++i) { handleImage(images[i]) } diff --git a/docs/conv_mixer/index.html b/docs/conv_mixer/index.html index e646a809..bcb01fae 100644 --- a/docs/conv_mixer/index.html +++ b/docs/conv_mixer/index.html @@ -686,7 +686,6 @@ function handleImages() { var images = document.querySelectorAll('p>img') - console.log(images); for (var i = 0; i < images.length; ++i) { handleImage(images[i]) } diff --git a/docs/conv_mixer/readme.html b/docs/conv_mixer/readme.html index d6a50371..47aed1fc 100644 --- a/docs/conv_mixer/readme.html +++ b/docs/conv_mixer/readme.html @@ -91,7 +91,6 @@ function handleImages() { var images = document.querySelectorAll('p>img') - console.log(images); for (var i = 0; i < images.length; ++i) { handleImage(images[i]) } diff --git a/docs/diffusion/ddpm/evaluate.html b/docs/diffusion/ddpm/evaluate.html index 5f759a45..c25ad829 100644 --- a/docs/diffusion/ddpm/evaluate.html +++ b/docs/diffusion/ddpm/evaluate.html @@ -1354,7 +1354,6 @@ M834 80h400000v40h-400000z'/> function handleImages() { var images = document.querySelectorAll('p>img') - console.log(images); for (var i = 0; i < images.length; ++i) { handleImage(images[i]) } diff --git a/docs/diffusion/ddpm/experiment.html b/docs/diffusion/ddpm/experiment.html index b7d9a09f..0f9b88e5 100644 --- a/docs/diffusion/ddpm/experiment.html +++ b/docs/diffusion/ddpm/experiment.html @@ -946,7 +946,6 @@ function handleImages() { var images = document.querySelectorAll('p>img') - console.log(images); for (var i = 0; i < images.length; ++i) { handleImage(images[i]) } diff --git a/docs/diffusion/ddpm/index.html b/docs/diffusion/ddpm/index.html index 0e33c2a8..3680db5d 100644 --- a/docs/diffusion/ddpm/index.html +++ b/docs/diffusion/ddpm/index.html @@ -845,7 +845,6 @@ M834 80h400000v40h-400000z'/> function handleImages() { var images = document.querySelectorAll('p>img') - console.log(images); for (var i = 0; i < images.length; ++i) { handleImage(images[i]) } diff --git a/docs/diffusion/ddpm/readme.html b/docs/diffusion/ddpm/readme.html index 8c637386..568c4d2e 100644 --- a/docs/diffusion/ddpm/readme.html +++ b/docs/diffusion/ddpm/readme.html @@ -90,7 +90,6 @@ function handleImages() { var images = document.querySelectorAll('p>img') - console.log(images); for (var i = 0; i < images.length; ++i) { handleImage(images[i]) } diff --git a/docs/diffusion/ddpm/unet.html b/docs/diffusion/ddpm/unet.html index 53c06fb8..db279241 100644 --- a/docs/diffusion/ddpm/unet.html +++ b/docs/diffusion/ddpm/unet.html @@ -1407,7 +1407,6 @@ M834 80h400000v40h-400000z'/> function handleImages() { var images = document.querySelectorAll('p>img') - console.log(images); for (var i = 0; i < images.length; ++i) { handleImage(images[i]) } diff --git a/docs/diffusion/ddpm/utils.html b/docs/diffusion/ddpm/utils.html index 041ab383..b2bc2224 100644 --- a/docs/diffusion/ddpm/utils.html +++ b/docs/diffusion/ddpm/utils.html @@ -110,7 +110,6 @@ function handleImages() { var images = document.querySelectorAll('p>img') - console.log(images); for (var i = 0; i < images.length; ++i) { handleImage(images[i]) } diff --git a/docs/diffusion/index.html b/docs/diffusion/index.html index bdef2c14..61736161 100644 --- a/docs/diffusion/index.html +++ b/docs/diffusion/index.html @@ -86,7 +86,6 @@ function handleImages() { var images = document.querySelectorAll('p>img') - console.log(images); for (var i = 0; i < images.length; ++i) { handleImage(images[i]) } diff --git a/docs/distillation/index.html b/docs/distillation/index.html index 7ad42a71..db2cbfe4 100644 --- a/docs/distillation/index.html +++ b/docs/distillation/index.html @@ -740,7 +740,6 @@ function handleImages() { var images = document.querySelectorAll('p>img') - console.log(images); for (var i = 0; i < images.length; ++i) { handleImage(images[i]) } diff --git a/docs/distillation/large.html b/docs/distillation/large.html index 3acc0035..6ce89365 100644 --- a/docs/distillation/large.html +++ b/docs/distillation/large.html @@ -351,7 +351,6 @@ function handleImages() { var images = document.querySelectorAll('p>img') - console.log(images); for (var i = 0; i < images.length; ++i) { handleImage(images[i]) } diff --git a/docs/distillation/readme.html b/docs/distillation/readme.html index da4d6c81..39f785bc 100644 --- a/docs/distillation/readme.html +++ b/docs/distillation/readme.html @@ -90,7 +90,6 @@ function handleImages() { var images = document.querySelectorAll('p>img') - console.log(images); for (var i = 0; i < images.length; ++i) { handleImage(images[i]) } diff --git a/docs/distillation/small.html b/docs/distillation/small.html index f252465e..238fad85 100644 --- a/docs/distillation/small.html +++ b/docs/distillation/small.html @@ -338,7 +338,6 @@ function handleImages() { var images = document.querySelectorAll('p>img') - console.log(images); for (var i = 0; i < images.length; ++i) { handleImage(images[i]) } diff --git a/docs/experiments/cifar10.html b/docs/experiments/cifar10.html index d5b76ad5..edb811d8 100644 --- a/docs/experiments/cifar10.html +++ b/docs/experiments/cifar10.html @@ -403,7 +403,6 @@ function handleImages() { var images = document.querySelectorAll('p>img') - console.log(images); for (var i = 0; i < images.length; ++i) { handleImage(images[i]) } diff --git a/docs/experiments/index.html b/docs/experiments/index.html index 7afaf1be..1187dc77 100644 --- a/docs/experiments/index.html +++ b/docs/experiments/index.html @@ -73,7 +73,6 @@ function handleImages() { var images = document.querySelectorAll('p>img') - console.log(images); for (var i = 0; i < images.length; ++i) { handleImage(images[i]) } diff --git a/docs/experiments/mnist.html b/docs/experiments/mnist.html index 9a3284a7..d0ce5731 100644 --- a/docs/experiments/mnist.html +++ b/docs/experiments/mnist.html @@ -440,7 +440,6 @@ function handleImages() { var images = document.querySelectorAll('p>img') - console.log(images); for (var i = 0; i < images.length; ++i) { handleImage(images[i]) } diff --git a/docs/experiments/nlp_autoregression.html b/docs/experiments/nlp_autoregression.html index 1fa02cb2..fe824990 100644 --- a/docs/experiments/nlp_autoregression.html +++ b/docs/experiments/nlp_autoregression.html @@ -1042,7 +1042,6 @@ function handleImages() { var images = document.querySelectorAll('p>img') - console.log(images); for (var i = 0; i < images.length; ++i) { handleImage(images[i]) } diff --git a/docs/experiments/nlp_classification.html b/docs/experiments/nlp_classification.html index c9a5b61a..917dcda1 100644 --- a/docs/experiments/nlp_classification.html +++ b/docs/experiments/nlp_classification.html @@ -79,15 +79,16 @@ 15import torchtext 16from torch import nn 17from torch.utils.data import DataLoader -18from torchtext.vocab import Vocab -19 -20from labml import lab, tracker, monit -21from labml.configs import option -22from labml_helpers.device import DeviceConfigs -23from labml_helpers.metrics.accuracy import Accuracy -24from labml_helpers.module import Module -25from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex -26from labml_nn.optimizers.configs import OptimizerConfigs +18import torchtext.vocab +19from torchtext.vocab import Vocab +20 +21from labml import lab, tracker, monit +22from labml.configs import option +23from labml_helpers.device import DeviceConfigs +24from labml_helpers.metrics.accuracy import Accuracy +25from labml_helpers.module import Module +26from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex +27from labml_nn.optimizers.configs import OptimizerConfigs29class NLPClassificationConfigs(TrainValidConfigs):
30class NLPClassificationConfigs(TrainValidConfigs):
40 optimizer: torch.optim.Adam
41 optimizer: torch.optim.Adam
42 device: torch.device = DeviceConfigs()
43 device: torch.device = DeviceConfigs()
45 model: Module
46 model: Module
47 batch_size: int = 16
48 batch_size: int = 16
49 seq_len: int = 512
50 seq_len: int = 512
51 vocab: Vocab = 'ag_news'
52 vocab: Vocab = 'ag_news'
53 n_tokens: int
54 n_tokens: int
55 n_classes: int = 'ag_news'
56 n_classes: int = 'ag_news'
57 tokenizer: Callable = 'character'
58 tokenizer: Callable = 'character'
60 is_save_models = True
61 is_save_models = True
63 loss_func = nn.CrossEntropyLoss()
64 loss_func = nn.CrossEntropyLoss()
65 accuracy = Accuracy()
66 accuracy = Accuracy()
67 d_model: int = 512
68 d_model: int = 512
69 grad_norm_clip: float = 1.0
70 grad_norm_clip: float = 1.0
72 train_loader: DataLoader = 'ag_news'
73 train_loader: DataLoader = 'ag_news'
74 valid_loader: DataLoader = 'ag_news'
75 valid_loader: DataLoader = 'ag_news'
76 def init(self):
77 def init(self):
81 tracker.set_scalar("accuracy.*", True)
-82 tracker.set_scalar("loss.*", True)
82 tracker.set_scalar("accuracy.*", True)
+83 tracker.set_scalar("loss.*", True)
84 hook_model_outputs(self.mode, self.model, 'model')
85 hook_model_outputs(self.mode, self.model, 'model')
89 self.state_modules = [self.accuracy]
90 self.state_modules = [self.accuracy]
91 def step(self, batch: any, batch_idx: BatchIndex):
92 def step(self, batch: any, batch_idx: BatchIndex):
97 data, target = batch[0].to(self.device), batch[1].to(self.device)
98 data, target = batch[0].to(self.device), batch[1].to(self.device)
100 if self.mode.is_train:
-101 tracker.add_global_step(data.shape[1])
101 if self.mode.is_train:
+102 tracker.add_global_step(data.shape[1])
104 with self.mode.update(is_log_activations=batch_idx.is_last):
105 with self.mode.update(is_log_activations=batch_idx.is_last):
108 output, *_ = self.model(data)
109 output, *_ = self.model(data)
111 loss = self.loss_func(output, target)
-112 tracker.add("loss.", loss)
112 loss = self.loss_func(output, target)
+113 tracker.add("loss.", loss)
115 self.accuracy(output, target)
-116 self.accuracy.track()
116 self.accuracy(output, target)
+117 self.accuracy.track()
119 if self.mode.is_train:
120 if self.mode.is_train:
121 loss.backward()
122 loss.backward()
123 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
124 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
125 self.optimizer.step()
126 self.optimizer.step()
127 if batch_idx.is_last:
-128 tracker.add('model', self.model)
128 if batch_idx.is_last:
+129 tracker.add('model', self.model)
130 self.optimizer.zero_grad()
131 self.optimizer.zero_grad()
133 tracker.save()
134 tracker.save()
136@option(NLPClassificationConfigs.optimizer)
-137def _optimizer(c: NLPClassificationConfigs):
137@option(NLPClassificationConfigs.optimizer)
+138def _optimizer(c: NLPClassificationConfigs):
142 optimizer = OptimizerConfigs()
-143 optimizer.parameters = c.model.parameters()
-144 optimizer.optimizer = 'Adam'
-145 optimizer.d_model = c.d_model
-146
-147 return optimizer
143 optimizer = OptimizerConfigs()
+144 optimizer.parameters = c.model.parameters()
+145 optimizer.optimizer = 'Adam'
+146 optimizer.d_model = c.d_model
+147
+148 return optimizer
150@option(NLPClassificationConfigs.tokenizer)
-151def basic_english():
151@option(NLPClassificationConfigs.tokenizer)
+152def basic_english():
165 from torchtext.data import get_tokenizer
-166 return get_tokenizer('basic_english')
166 from torchtext.data import get_tokenizer
+167 return get_tokenizer('basic_english')
169def character_tokenizer(x: str):
170def character_tokenizer(x: str):
173 return list(x)
174 return list(x)
176@option(NLPClassificationConfigs.tokenizer)
-177def character():
177@option(NLPClassificationConfigs.tokenizer)
+178def character():
181 return character_tokenizer
182 return character_tokenizer
184@option(NLPClassificationConfigs.n_tokens)
-185def _n_tokens(c: NLPClassificationConfigs):
185@option(NLPClassificationConfigs.n_tokens)
+186def _n_tokens(c: NLPClassificationConfigs):
189 return len(c.vocab) + 2
190 return len(c.vocab) + 2
192class CollateFunc:
193class CollateFunc:
197 def __init__(self, tokenizer, vocab: Vocab, seq_len: int, padding_token: int, classifier_token: int):
198 def __init__(self, tokenizer, vocab: Vocab, seq_len: int, padding_token: int, classifier_token: int):
205 self.classifier_token = classifier_token
-206 self.padding_token = padding_token
-207 self.seq_len = seq_len
-208 self.vocab = vocab
-209 self.tokenizer = tokenizer
206 self.classifier_token = classifier_token
+207 self.padding_token = padding_token
+208 self.seq_len = seq_len
+209 self.vocab = vocab
+210 self.tokenizer = tokenizer
211 def __call__(self, batch):
212 def __call__(self, batch):
217 data = torch.full((self.seq_len, len(batch)), self.padding_token, dtype=torch.long)
218 data = torch.full((self.seq_len, len(batch)), self.padding_token, dtype=torch.long)
219 labels = torch.zeros(len(batch), dtype=torch.long)
220 labels = torch.zeros(len(batch), dtype=torch.long)
222 for (i, (_label, _text)) in enumerate(batch):
223 for (i, (_label, _text)) in enumerate(batch):
224 labels[i] = int(_label) - 1
225 labels[i] = int(_label) - 1
226 _text = [self.vocab[token] for token in self.tokenizer(_text)]
227 _text = [self.vocab[token] for token in self.tokenizer(_text)]
228 _text = _text[:self.seq_len]
229 _text = _text[:self.seq_len]
230 data[:len(_text), i] = data.new_tensor(_text)
231 data[:len(_text), i] = data.new_tensor(_text)
233 data[-1, :] = self.classifier_token
234 data[-1, :] = self.classifier_token
236 return data, labels
237 return data, labels
239@option([NLPClassificationConfigs.n_classes,
-240 NLPClassificationConfigs.vocab,
-241 NLPClassificationConfigs.train_loader,
-242 NLPClassificationConfigs.valid_loader])
-243def ag_news(c: NLPClassificationConfigs):
240@option([NLPClassificationConfigs.n_classes,
+241 NLPClassificationConfigs.vocab,
+242 NLPClassificationConfigs.train_loader,
+243 NLPClassificationConfigs.valid_loader])
+244def ag_news(c: NLPClassificationConfigs):
252 train, valid = torchtext.datasets.AG_NEWS(root=str(lab.get_data_path() / 'ag_news'), split=('train', 'test'))
253 train, valid = torchtext.datasets.AG_NEWS(root=str(lab.get_data_path() / 'ag_news'), split=('train', 'test'))
255 with monit.section('Load data'):
-256 from labml_nn.utils import MapStyleDataset
256 with monit.section('Load data'):
+257 from labml_nn.utils import MapStyleDataset
259 train, valid = MapStyleDataset(train), MapStyleDataset(valid)
260 train, valid = MapStyleDataset(train), MapStyleDataset(valid)
262 tokenizer = c.tokenizer
263 tokenizer = c.tokenizer
265 counter = Counter()
266 counter = Counter()
267 for (label, line) in train:
-268 counter.update(tokenizer(line))
268 for (label, line) in train:
+269 counter.update(tokenizer(line))
270 for (label, line) in valid:
-271 counter.update(tokenizer(line))
271 for (label, line) in valid:
+272 counter.update(tokenizer(line))
273 vocab = Vocab(counter, min_freq=1)
274 vocab = torchtext.vocab.vocab(counter, min_freq=1)
276 train_loader = DataLoader(train, batch_size=c.batch_size, shuffle=True,
-277 collate_fn=CollateFunc(tokenizer, vocab, c.seq_len, len(vocab), len(vocab) + 1))
277 train_loader = DataLoader(train, batch_size=c.batch_size, shuffle=True,
+278 collate_fn=CollateFunc(tokenizer, vocab, c.seq_len, len(vocab), len(vocab) + 1))
279 valid_loader = DataLoader(valid, batch_size=c.batch_size, shuffle=True,
-280 collate_fn=CollateFunc(tokenizer, vocab, c.seq_len, len(vocab), len(vocab) + 1))
280 valid_loader = DataLoader(valid, batch_size=c.batch_size, shuffle=True,
+281 collate_fn=CollateFunc(tokenizer, vocab, c.seq_len, len(vocab), len(vocab) + 1))
283 return 4, vocab, train_loader, valid_loader
284 return 4, vocab, train_loader, valid_loader
56class GeneratorResNet(Module):
57class GeneratorResNet(Module):
61 def __init__(self, input_channels: int, n_residual_blocks: int):
-62 super().__init__()
62 def __init__(self, input_channels: int, n_residual_blocks: int):
+63 super().__init__()
70 out_features = 64
-71 layers = [
-72 nn.Conv2d(input_channels, out_features, kernel_size=7, padding=3, padding_mode='reflect'),
-73 nn.InstanceNorm2d(out_features),
-74 nn.ReLU(inplace=True),
-75 ]
-76 in_features = out_features
71 out_features = 64
+72 layers = [
+73 nn.Conv2d(input_channels, out_features, kernel_size=7, padding=3, padding_mode='reflect'),
+74 nn.InstanceNorm2d(out_features),
+75 nn.ReLU(inplace=True),
+76 ]
+77 in_features = out_features
80 for _ in range(2):
-81 out_features *= 2
-82 layers += [
-83 nn.Conv2d(in_features, out_features, kernel_size=3, stride=2, padding=1),
-84 nn.InstanceNorm2d(out_features),
-85 nn.ReLU(inplace=True),
-86 ]
-87 in_features = out_features
81 for _ in range(2):
+82 out_features *= 2
+83 layers += [
+84 nn.Conv2d(in_features, out_features, kernel_size=3, stride=2, padding=1),
+85 nn.InstanceNorm2d(out_features),
+86 nn.ReLU(inplace=True),
+87 ]
+88 in_features = out_features
91 for _ in range(n_residual_blocks):
-92 layers += [ResidualBlock(out_features)]
92 for _ in range(n_residual_blocks):
+93 layers += [ResidualBlock(out_features)]
96 for _ in range(2):
-97 out_features //= 2
-98 layers += [
-99 nn.Upsample(scale_factor=2),
-100 nn.Conv2d(in_features, out_features, kernel_size=3, stride=1, padding=1),
-101 nn.InstanceNorm2d(out_features),
-102 nn.ReLU(inplace=True),
-103 ]
-104 in_features = out_features
97 for _ in range(2):
+98 out_features //= 2
+99 layers += [
+100 nn.Upsample(scale_factor=2),
+101 nn.Conv2d(in_features, out_features, kernel_size=3, stride=1, padding=1),
+102 nn.InstanceNorm2d(out_features),
+103 nn.ReLU(inplace=True),
+104 ]
+105 in_features = out_features
107 layers += [nn.Conv2d(out_features, input_channels, 7, padding=3, padding_mode='reflect'), nn.Tanh()]
108 layers += [nn.Conv2d(out_features, input_channels, 7, padding=3, padding_mode='reflect'), nn.Tanh()]
110 self.layers = nn.Sequential(*layers)
111 self.layers = nn.Sequential(*layers)
113 self.apply(weights_init_normal)
114 self.apply(weights_init_normal)
115 def forward(self, x):
-116 return self.layers(x)
116 def forward(self, x):
+117 return self.layers(x)
119class ResidualBlock(Module):
120class ResidualBlock(Module):
124 def __init__(self, in_features: int):
-125 super().__init__()
-126 self.block = nn.Sequential(
-127 nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, padding_mode='reflect'),
-128 nn.InstanceNorm2d(in_features),
-129 nn.ReLU(inplace=True),
-130 nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, padding_mode='reflect'),
-131 nn.InstanceNorm2d(in_features),
-132 nn.ReLU(inplace=True),
-133 )
125 def __init__(self, in_features: int):
+126 super().__init__()
+127 self.block = nn.Sequential(
+128 nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, padding_mode='reflect'),
+129 nn.InstanceNorm2d(in_features),
+130 nn.ReLU(inplace=True),
+131 nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, padding_mode='reflect'),
+132 nn.InstanceNorm2d(in_features),
+133 nn.ReLU(inplace=True),
+134 )
135 def forward(self, x: torch.Tensor):
-136 return x + self.block(x)
136 def forward(self, x: torch.Tensor):
+137 return x + self.block(x)
139class Discriminator(Module):
140class Discriminator(Module):
144 def __init__(self, input_shape: Tuple[int, int, int]):
-145 super().__init__()
-146 channels, height, width = input_shape
145 def __init__(self, input_shape: Tuple[int, int, int]):
+146 super().__init__()
+147 channels, height, width = input_shape
150 self.output_shape = (1, height // 2 ** 4, width // 2 ** 4)
-151
-152 self.layers = nn.Sequential(
151 self.output_shape = (1, height // 2 ** 4, width // 2 ** 4)
+152
+153 self.layers = nn.Sequential(
154 DiscriminatorBlock(channels, 64, normalize=False),
-155 DiscriminatorBlock(64, 128),
-156 DiscriminatorBlock(128, 256),
-157 DiscriminatorBlock(256, 512),
155 DiscriminatorBlock(channels, 64, normalize=False),
+156 DiscriminatorBlock(64, 128),
+157 DiscriminatorBlock(128, 256),
+158 DiscriminatorBlock(256, 512),
160 nn.ZeroPad2d((1, 0, 1, 0)),
-161 nn.Conv2d(512, 1, kernel_size=4, padding=1)
-162 )
161 nn.ZeroPad2d((1, 0, 1, 0)),
+162 nn.Conv2d(512, 1, kernel_size=4, padding=1)
+163 )
165 self.apply(weights_init_normal)
166 self.apply(weights_init_normal)
167 def forward(self, img):
-168 return self.layers(img)
168 def forward(self, img):
+169 return self.layers(img)
171class DiscriminatorBlock(Module):
172class DiscriminatorBlock(Module):
179 def __init__(self, in_filters: int, out_filters: int, normalize: bool = True):
-180 super().__init__()
-181 layers = [nn.Conv2d(in_filters, out_filters, kernel_size=4, stride=2, padding=1)]
-182 if normalize:
-183 layers.append(nn.InstanceNorm2d(out_filters))
-184 layers.append(nn.LeakyReLU(0.2, inplace=True))
-185 self.layers = nn.Sequential(*layers)
180 def __init__(self, in_filters: int, out_filters: int, normalize: bool = True):
+181 super().__init__()
+182 layers = [nn.Conv2d(in_filters, out_filters, kernel_size=4, stride=2, padding=1)]
+183 if normalize:
+184 layers.append(nn.InstanceNorm2d(out_filters))
+185 layers.append(nn.LeakyReLU(0.2, inplace=True))
+186 self.layers = nn.Sequential(*layers)
187 def forward(self, x: torch.Tensor):
-188 return self.layers(x)
188 def forward(self, x: torch.Tensor):
+189 return self.layers(x)
191def weights_init_normal(m):
192def weights_init_normal(m):
195 classname = m.__class__.__name__
-196 if classname.find("Conv") != -1:
-197 torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
196 classname = m.__class__.__name__
+197 if classname.find("Conv") != -1:
+198 torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
200def load_image(path: str):
201def load_image(path: str):
204 image = Image.open(path)
-205 if image.mode != 'RGB':
-206 image = Image.new("RGB", image.size).paste(image)
-207
-208 return image
205 image = Image.open(path)
+206 if image.mode != 'RGB':
+207 image = Image.new("RGB", image.size).paste(image)
+208
+209 return image
211class ImageDataset(Dataset):
212class ImageDataset(Dataset):
216 @staticmethod
-217 def download(dataset_name: str):
217 @staticmethod
+218 def download(dataset_name: str):
222 url = f'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/{dataset_name}.zip'
223 url = f'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/{dataset_name}.zip'
224 root = lab.get_data_path() / 'cycle_gan'
-225 if not root.exists():
-226 root.mkdir(parents=True)
225 root = lab.get_data_path() / 'cycle_gan'
+226 if not root.exists():
+227 root.mkdir(parents=True)
228 archive = root / f'{dataset_name}.zip'
229 archive = root / f'{dataset_name}.zip'
230 download_file(url, archive)
231 download_file(url, archive)
232 with zipfile.ZipFile(archive, 'r') as f:
-233 f.extractall(root)
233 with zipfile.ZipFile(archive, 'r') as f:
+234 f.extractall(root)
235 def __init__(self, dataset_name: str, transforms_, mode: str):
236 def __init__(self, dataset_name: str, transforms_, mode: str):
244 root = lab.get_data_path() / 'cycle_gan' / dataset_name
245 root = lab.get_data_path() / 'cycle_gan' / dataset_name
246 if not root.exists():
-247 self.download(dataset_name)
247 if not root.exists():
+248 self.download(dataset_name)
250 self.transform = transforms.Compose(transforms_)
251 self.transform = transforms.Compose(transforms_)
253 path_a = root / f'{mode}A'
-254 path_b = root / f'{mode}B'
-255 self.files_a = sorted(str(f) for f in path_a.iterdir())
-256 self.files_b = sorted(str(f) for f in path_b.iterdir())
254 path_a = root / f'{mode}A'
+255 path_b = root / f'{mode}B'
+256 self.files_a = sorted(str(f) for f in path_a.iterdir())
+257 self.files_b = sorted(str(f) for f in path_b.iterdir())
258 def __getitem__(self, index):
259 def __getitem__(self, index):
262 return {"x": self.transform(load_image(self.files_a[index % len(self.files_a)])),
-263 "y": self.transform(load_image(self.files_b[index % len(self.files_b)]))}
263 return {"x": self.transform(load_image(self.files_a[index % len(self.files_a)])),
+264 "y": self.transform(load_image(self.files_b[index % len(self.files_b)]))}
265 def __len__(self):
266 def __len__(self):
267 return max(len(self.files_a), len(self.files_b))
268 return max(len(self.files_a), len(self.files_b))
270class ReplayBuffer:
271class ReplayBuffer:
284 def __init__(self, max_size: int = 50):
-285 self.max_size = max_size
-286 self.data = []
285 def __init__(self, max_size: int = 50):
+286 self.max_size = max_size
+287 self.data = []
288 def push_and_pop(self, data: torch.Tensor):
289 def push_and_pop(self, data: torch.Tensor):
290 data = data.detach()
-291 res = []
-292 for element in data:
-293 if len(self.data) < self.max_size:
-294 self.data.append(element)
-295 res.append(element)
-296 else:
-297 if random.uniform(0, 1) > 0.5:
-298 i = random.randint(0, self.max_size - 1)
-299 res.append(self.data[i].clone())
-300 self.data[i] = element
-301 else:
-302 res.append(element)
-303 return torch.stack(res)
291 data = data.detach()
+292 res = []
+293 for element in data:
+294 if len(self.data) < self.max_size:
+295 self.data.append(element)
+296 res.append(element)
+297 else:
+298 if random.uniform(0, 1) > 0.5:
+299 i = random.randint(0, self.max_size - 1)
+300 res.append(self.data[i].clone())
+301 self.data[i] = element
+302 else:
+303 res.append(element)
+304 return torch.stack(res)
306class Configs(BaseConfigs):
307class Configs(BaseConfigs):
310 device: torch.device = DeviceConfigs()
311 device: torch.device = DeviceConfigs()
313 epochs: int = 200
-314 dataset_name: str = 'monet2photo'
-315 batch_size: int = 1
-316
-317 data_loader_workers = 8
-318
-319 learning_rate = 0.0002
-320 adam_betas = (0.5, 0.999)
-321 decay_start = 100
314 epochs: int = 200
+315 dataset_name: str = 'monet2photo'
+316 batch_size: int = 1
+317
+318 data_loader_workers = 8
+319
+320 learning_rate = 0.0002
+321 adam_betas = (0.5, 0.999)
+322 decay_start = 100
325 gan_loss = torch.nn.MSELoss()
326 gan_loss = torch.nn.MSELoss()
328 cycle_loss = torch.nn.L1Loss()
-329 identity_loss = torch.nn.L1Loss()
329 cycle_loss = torch.nn.L1Loss()
+330 identity_loss = torch.nn.L1Loss()
332 img_height = 256
-333 img_width = 256
-334 img_channels = 3
333 img_height = 256
+334 img_width = 256
+335 img_channels = 3
337 n_residual_blocks = 9
338 n_residual_blocks = 9
340 cyclic_loss_coefficient = 10.0
-341 identity_loss_coefficient = 5.
-342
-343 sample_interval = 500
341 cyclic_loss_coefficient = 10.0
+342 identity_loss_coefficient = 5.
+343
+344 sample_interval = 500
346 generator_xy: GeneratorResNet
-347 generator_yx: GeneratorResNet
-348 discriminator_x: Discriminator
-349 discriminator_y: Discriminator
347 generator_xy: GeneratorResNet
+348 generator_yx: GeneratorResNet
+349 discriminator_x: Discriminator
+350 discriminator_y: Discriminator
352 generator_optimizer: torch.optim.Adam
-353 discriminator_optimizer: torch.optim.Adam
353 generator_optimizer: torch.optim.Adam
+354 discriminator_optimizer: torch.optim.Adam
356 generator_lr_scheduler: torch.optim.lr_scheduler.LambdaLR
-357 discriminator_lr_scheduler: torch.optim.lr_scheduler.LambdaLR
357 generator_lr_scheduler: torch.optim.lr_scheduler.LambdaLR
+358 discriminator_lr_scheduler: torch.optim.lr_scheduler.LambdaLR
360 dataloader: DataLoader
-361 valid_dataloader: DataLoader
361 dataloader: DataLoader
+362 valid_dataloader: DataLoader
363 def sample_images(self, n: int):
364 def sample_images(self, n: int):
365 batch = next(iter(self.valid_dataloader))
-366 self.generator_xy.eval()
-367 self.generator_yx.eval()
-368 with torch.no_grad():
-369 data_x, data_y = batch['x'].to(self.generator_xy.device), batch['y'].to(self.generator_yx.device)
-370 gen_y = self.generator_xy(data_x)
-371 gen_x = self.generator_yx(data_y)
366 batch = next(iter(self.valid_dataloader))
+367 self.generator_xy.eval()
+368 self.generator_yx.eval()
+369 with torch.no_grad():
+370 data_x, data_y = batch['x'].to(self.generator_xy.device), batch['y'].to(self.generator_yx.device)
+371 gen_y = self.generator_xy(data_x)
+372 gen_x = self.generator_yx(data_y)
374 data_x = make_grid(data_x, nrow=5, normalize=True)
-375 data_y = make_grid(data_y, nrow=5, normalize=True)
-376 gen_x = make_grid(gen_x, nrow=5, normalize=True)
-377 gen_y = make_grid(gen_y, nrow=5, normalize=True)
375 data_x = make_grid(data_x, nrow=5, normalize=True)
+376 data_y = make_grid(data_y, nrow=5, normalize=True)
+377 gen_x = make_grid(gen_x, nrow=5, normalize=True)
+378 gen_y = make_grid(gen_y, nrow=5, normalize=True)
380 image_grid = torch.cat((data_x, gen_y, data_y, gen_x), 1)
381 image_grid = torch.cat((data_x, gen_y, data_y, gen_x), 1)
383 plot_image(image_grid)
384 plot_image(image_grid)
385 def initialize(self):
386 def initialize(self):
389 input_shape = (self.img_channels, self.img_height, self.img_width)
390 input_shape = (self.img_channels, self.img_height, self.img_width)
392 self.generator_xy = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device)
-393 self.generator_yx = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device)
-394 self.discriminator_x = Discriminator(input_shape).to(self.device)
-395 self.discriminator_y = Discriminator(input_shape).to(self.device)
393 self.generator_xy = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device)
+394 self.generator_yx = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device)
+395 self.discriminator_x = Discriminator(input_shape).to(self.device)
+396 self.discriminator_y = Discriminator(input_shape).to(self.device)
398 self.generator_optimizer = torch.optim.Adam(
-399 itertools.chain(self.generator_xy.parameters(), self.generator_yx.parameters()),
-400 lr=self.learning_rate, betas=self.adam_betas)
-401 self.discriminator_optimizer = torch.optim.Adam(
-402 itertools.chain(self.discriminator_x.parameters(), self.discriminator_y.parameters()),
-403 lr=self.learning_rate, betas=self.adam_betas)
399 self.generator_optimizer = torch.optim.Adam(
+400 itertools.chain(self.generator_xy.parameters(), self.generator_yx.parameters()),
+401 lr=self.learning_rate, betas=self.adam_betas)
+402 self.discriminator_optimizer = torch.optim.Adam(
+403 itertools.chain(self.discriminator_x.parameters(), self.discriminator_y.parameters()),
+404 lr=self.learning_rate, betas=self.adam_betas)
408 decay_epochs = self.epochs - self.decay_start
-409 self.generator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
-410 self.generator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)
-411 self.discriminator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
-412 self.discriminator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)
409 decay_epochs = self.epochs - self.decay_start
+410 self.generator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
+411 self.generator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)
+412 self.discriminator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
+413 self.discriminator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)
415 transforms_ = [
-416 transforms.Resize(int(self.img_height * 1.12), Image.BICUBIC),
-417 transforms.RandomCrop((self.img_height, self.img_width)),
-418 transforms.RandomHorizontalFlip(),
-419 transforms.ToTensor(),
-420 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
-421 ]
416 transforms_ = [
+417 transforms.Resize(int(self.img_height * 1.12), InterpolationMode.BICUBIC),
+418 transforms.RandomCrop((self.img_height, self.img_width)),
+419 transforms.RandomHorizontalFlip(),
+420 transforms.ToTensor(),
+421 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
+422 ]
424 self.dataloader = DataLoader(
-425 ImageDataset(self.dataset_name, transforms_, 'train'),
-426 batch_size=self.batch_size,
-427 shuffle=True,
-428 num_workers=self.data_loader_workers,
-429 )
425 self.dataloader = DataLoader(
+426 ImageDataset(self.dataset_name, transforms_, 'train'),
+427 batch_size=self.batch_size,
+428 shuffle=True,
+429 num_workers=self.data_loader_workers,
+430 )
432 self.valid_dataloader = DataLoader(
-433 ImageDataset(self.dataset_name, transforms_, "test"),
-434 batch_size=5,
-435 shuffle=True,
-436 num_workers=self.data_loader_workers,
-437 )
433 self.valid_dataloader = DataLoader(
+434 ImageDataset(self.dataset_name, transforms_, "test"),
+435 batch_size=5,
+436 shuffle=True,
+437 num_workers=self.data_loader_workers,
+438 )
439 def run(self):
440 def run(self):
541 gen_x_buffer = ReplayBuffer()
-542 gen_y_buffer = ReplayBuffer()
542 gen_x_buffer = ReplayBuffer()
+543 gen_y_buffer = ReplayBuffer()
545 for epoch in monit.loop(self.epochs):
546 for epoch in monit.loop(self.epochs):
547 for i, batch in monit.enum('Train', self.dataloader):
548 for i, batch in monit.enum('Train', self.dataloader):
549 data_x, data_y = batch['x'].to(self.device), batch['y'].to(self.device)
550 data_x, data_y = batch['x'].to(self.device), batch['y'].to(self.device)
552 true_labels = torch.ones(data_x.size(0), *self.discriminator_x.output_shape,
-553 device=self.device, requires_grad=False)
553 true_labels = torch.ones(data_x.size(0), *self.discriminator_x.output_shape,
+554 device=self.device, requires_grad=False)
555 false_labels = torch.zeros(data_x.size(0), *self.discriminator_x.output_shape,
-556 device=self.device, requires_grad=False)
556 false_labels = torch.zeros(data_x.size(0), *self.discriminator_x.output_shape,
+557 device=self.device, requires_grad=False)
560 gen_x, gen_y = self.optimize_generators(data_x, data_y, true_labels)
561 gen_x, gen_y = self.optimize_generators(data_x, data_y, true_labels)
563 self.optimize_discriminator(data_x, data_y,
-564 gen_x_buffer.push_and_pop(gen_x), gen_y_buffer.push_and_pop(gen_y),
-565 true_labels, false_labels)
564 self.optimize_discriminator(data_x, data_y,
+565 gen_x_buffer.push_and_pop(gen_x), gen_y_buffer.push_and_pop(gen_y),
+566 true_labels, false_labels)
568 tracker.save()
-569 tracker.add_global_step(max(len(data_x), len(data_y)))
569 tracker.save()
+570 tracker.add_global_step(max(len(data_x), len(data_y)))
572 batches_done = epoch * len(self.dataloader) + i
-573 if batches_done % self.sample_interval == 0:
573 batches_done = epoch * len(self.dataloader) + i
+574 if batches_done % self.sample_interval == 0:
575 experiment.save_checkpoint()
576 experiment.save_checkpoint()
577 self.sample_images(batches_done)
578 self.sample_images(batches_done)
580 self.generator_lr_scheduler.step()
-581 self.discriminator_lr_scheduler.step()
581 self.generator_lr_scheduler.step()
+582 self.discriminator_lr_scheduler.step()
583 tracker.new_line()
584 tracker.new_line()
585 def optimize_generators(self, data_x: torch.Tensor, data_y: torch.Tensor, true_labels: torch.Tensor):
586 def optimize_generators(self, data_x: torch.Tensor, data_y: torch.Tensor, true_labels: torch.Tensor):
591 self.generator_xy.train()
-592 self.generator_yx.train()
592 self.generator_xy.train()
+593 self.generator_yx.train()
597 loss_identity = (self.identity_loss(self.generator_yx(data_x), data_x) +
-598 self.identity_loss(self.generator_xy(data_y), data_y))
598 loss_identity = (self.identity_loss(self.generator_yx(data_x), data_x) +
+599 self.identity_loss(self.generator_xy(data_y), data_y))
601 gen_y = self.generator_xy(data_x)
-602 gen_x = self.generator_yx(data_y)
602 gen_y = self.generator_xy(data_x)
+603 gen_x = self.generator_yx(data_y)
607 loss_gan = (self.gan_loss(self.discriminator_y(gen_y), true_labels) +
-608 self.gan_loss(self.discriminator_x(gen_x), true_labels))
608 loss_gan = (self.gan_loss(self.discriminator_y(gen_y), true_labels) +
+609 self.gan_loss(self.discriminator_x(gen_x), true_labels))
615 loss_cycle = (self.cycle_loss(self.generator_yx(gen_y), data_x) +
-616 self.cycle_loss(self.generator_xy(gen_x), data_y))
616 loss_cycle = (self.cycle_loss(self.generator_yx(gen_y), data_x) +
+617 self.cycle_loss(self.generator_xy(gen_x), data_y))
619 loss_generator = (loss_gan +
-620 self.cyclic_loss_coefficient * loss_cycle +
-621 self.identity_loss_coefficient * loss_identity)
620 loss_generator = (loss_gan +
+621 self.cyclic_loss_coefficient * loss_cycle +
+622 self.identity_loss_coefficient * loss_identity)
624 self.generator_optimizer.zero_grad()
-625 loss_generator.backward()
-626 self.generator_optimizer.step()
625 self.generator_optimizer.zero_grad()
+626 loss_generator.backward()
+627 self.generator_optimizer.step()
629 tracker.add({'loss.generator': loss_generator,
-630 'loss.generator.cycle': loss_cycle,
-631 'loss.generator.gan': loss_gan,
-632 'loss.generator.identity': loss_identity})
630 tracker.add({'loss.generator': loss_generator,
+631 'loss.generator.cycle': loss_cycle,
+632 'loss.generator.gan': loss_gan,
+633 'loss.generator.identity': loss_identity})
635 return gen_x, gen_y
636 return gen_x, gen_y
637 def optimize_discriminator(self, data_x: torch.Tensor, data_y: torch.Tensor,
-638 gen_x: torch.Tensor, gen_y: torch.Tensor,
-639 true_labels: torch.Tensor, false_labels: torch.Tensor):
638 def optimize_discriminator(self, data_x: torch.Tensor, data_y: torch.Tensor,
+639 gen_x: torch.Tensor, gen_y: torch.Tensor,
+640 true_labels: torch.Tensor, false_labels: torch.Tensor):
652 loss_discriminator = (self.gan_loss(self.discriminator_x(data_x), true_labels) +
-653 self.gan_loss(self.discriminator_x(gen_x), false_labels) +
-654 self.gan_loss(self.discriminator_y(data_y), true_labels) +
-655 self.gan_loss(self.discriminator_y(gen_y), false_labels))
653 loss_discriminator = (self.gan_loss(self.discriminator_x(data_x), true_labels) +
+654 self.gan_loss(self.discriminator_x(gen_x), false_labels) +
+655 self.gan_loss(self.discriminator_y(data_y), true_labels) +
+656 self.gan_loss(self.discriminator_y(gen_y), false_labels))
658 self.discriminator_optimizer.zero_grad()
-659 loss_discriminator.backward()
-660 self.discriminator_optimizer.step()
659 self.discriminator_optimizer.zero_grad()
+660 loss_discriminator.backward()
+661 self.discriminator_optimizer.step()
663 tracker.add({'loss.discriminator': loss_discriminator})
664 tracker.add({'loss.discriminator': loss_discriminator})
666def train():
667def train():
671 conf = Configs()
672 conf = Configs()
673 experiment.create(name='cycle_gan')
674 experiment.create(name='cycle_gan')
676 experiment.configs(conf, {'dataset_name': 'summer2winter_yosemite'})
-677 conf.initialize()
677 experiment.configs(conf, {'dataset_name': 'summer2winter_yosemite'})
+678 conf.initialize()
682 experiment.add_pytorch_models(get_modules(conf))
683 experiment.add_pytorch_models(get_modules(conf))
684 with experiment.start():
685 with experiment.start():
686 conf.run()
687 conf.run()
689def plot_image(img: torch.Tensor):
690def plot_image(img: torch.Tensor):
693 from matplotlib import pyplot as plt
694 from matplotlib import pyplot as plt
696 img = img.cpu()
697 img = img.cpu()
698 img_min, img_max = img.min(), img.max()
699 img_min, img_max = img.min(), img.max()
700 img = (img - img_min) / (img_max - img_min + 1e-5)
701 img = (img - img_min) / (img_max - img_min + 1e-5)
702 img = img.permute(1, 2, 0)
703 img = img.permute(1, 2, 0)
704 plt.imshow(img)
705 plt.imshow(img)
706 plt.axis('off')
707 plt.axis('off')
708 plt.show()
709 plt.show()
711def evaluate():
712def evaluate():
716 trained_run_uuid = 'f73c1164184711eb9190b74249275441'
717 trained_run_uuid = 'f73c1164184711eb9190b74249275441'
718 conf = Configs()
719 conf = Configs()
720 experiment.create(name='cycle_gan_inference')
721 experiment.create(name='cycle_gan_inference')
722 conf_dict = experiment.load_configs(trained_run_uuid)
723 conf_dict = experiment.load_configs(trained_run_uuid)
731 experiment.configs(conf, conf_dict)
-732 conf.initialize()
732 experiment.configs(conf, conf_dict)
+733 conf.initialize()
737 experiment.add_pytorch_models(get_modules(conf))
738 experiment.add_pytorch_models(get_modules(conf))
740 experiment.load(trained_run_uuid)
741 experiment.load(trained_run_uuid)
743 with experiment.start():
744 with experiment.start():
745 transforms_ = [
-746 transforms.ToTensor(),
-747 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
-748 ]
746 transforms_ = [
+747 transforms.ToTensor(),
+748 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
+749 ]
754 dataset = ImageDataset(conf.dataset_name, transforms_, 'train')
755 dataset = ImageDataset(conf.dataset_name, transforms_, 'train')
756 x_image = dataset[10]['x']
757 x_image = dataset[10]['x']
758 plot_image(x_image)
759 plot_image(x_image)
761 conf.generator_xy.eval()
-762 conf.generator_yx.eval()
762 conf.generator_xy.eval()
+763 conf.generator_yx.eval()
765 with torch.no_grad():
766 with torch.no_grad():
767 data = x_image.unsqueeze(0).to(conf.device)
-768 generated_y = conf.generator_xy(data)
768 data = x_image.unsqueeze(0).to(conf.device)
+769 generated_y = conf.generator_xy(data)
771 plot_image(generated_y[0].cpu())
-772
+ 772 plot_image(generated_y[0].cpu())
773
-774if __name__ == '__main__':
-775 train()
+774
+775if __name__ == '__main__':
+776 train()
184 dataloader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, num_workers=32,
+ 184 dataloader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, num_workers=8,
185 shuffle=True, drop_last=True, pin_memory=True)
466if __name__ == '__main__':
-467 main()
467if __name__ == '__main__':
+468 main()