මෙය කාර්වානා දත්ත කට්ටලයේ යූ-නෙට් ආකෘතියක් පුහුණු කරයි. ඔබට බාගත කිරීමේ උපදෙස් සොයාගත හැකිය Kaggle.
carvana/train
ෆෝල්ඩරය තුළ පුහුණු පින්තූර සහ carvana/train_masks
ෆෝල්ඩරයේ වෙස් මුහුණු සුරකින්න.
සරලබව සඳහා, අපි පුහුණුවක් සහ වලංගු භේදයක් නොකරමු.
19import numpy as np
20import torch
21import torch.utils.data
22import torchvision.transforms.functional
23from torch import nn
24
25from labml import lab, tracker, experiment, monit
26from labml.configs import BaseConfigs
27from labml_helpers.device import DeviceConfigs
28from labml_nn.unet.carvana import CarvanaDataset
29from labml_nn.unet import UNet32class Configs(BaseConfigs):ආකෘතියපුහුණු කිරීමේ උපකරණය. DeviceConfigs
ලබා ගත හැකි CUDA උපාංගයක් අහුලනවා හෝ CPU කිරීමට පෙරනිමි.
39 device: torch.device = DeviceConfigs()රූපයේනාලිකා ගණන. RGB සඳහා.
45 image_channels: int = 3නිමැවුම්ආවරණයේ නාලිකා ගණන. ද්විමය ආවරණ සඳහා.
47 mask_channels: int = 1කණ්ඩායම්ප්රමාණය
50 batch_size: int = 1ඉගෙනුම්අනුපාතය
52 learning_rate: float = 2.5e-4පුහුණුඑපොච් ගණන
55 epochs: int = 4දත්තකට්ටලය
58 dataset: CarvanaDatasetදත්තකාරකය
60 data_loader: torch.utils.data.DataLoaderපාඩුශ්රිතය
63 loss_func = nn.BCELoss()ද්විමයවර්ගීකරණය සඳහා සිග්මෝයිඩ් ශ්රිතය
65 sigmoid = nn.Sigmoid()ආදම්ප්රශස්තකරණය
68 optimizer: torch.optim.Adam70 def init(self):කාර්වානා දත්ත කට්ටලය ආරම්භ කරන්න
72 self.dataset = CarvanaDataset(lab.get_data_path() / 'carvana' / 'train',
73 lab.get_data_path() / 'carvana' / 'train_masks')ආකෘතියආරම්භ කරන්න
75 self.model = UNet(self.image_channels, self.mask_channels).to(self.device)දත්තකාරකය සාදන්න
78 self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size,
79 shuffle=True, pin_memory=True)ප්රශස්තකරණයසාදන්න
81 self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)රූපලොග් වීම
84 tracker.set_image("sample", True)86 @torch.no_grad()
87 def sample(self, idx=-1):අහඹුනියැදියක් ලබා ගන්න
93 x, _ = self.dataset[np.random.randint(len(self.dataset))]උපාංගයවෙත දත්ත ගෙනයන්න
95 x = x.to(self.device)පුරෝකථනයකළ වෙස් මුහුණ ලබා ගන්න
98 mask = self.sigmoid(self.model(x[None, :]))වෙස්මුහුණෙහි ප්රමාණයට රූපය වගා කරන්න
100 x = torchvision.transforms.functional.center_crop(x, [mask.shape[2], mask.shape[3]])ලොග්සාම්පල
102 tracker.save('sample', x * mask)104 def train(self):112 for _, (image, mask) in monit.mix(('Train', self.data_loader), (self.sample, list(range(50)))):ගෝලීයපියවර වැඩි කිරීම
114 tracker.add_global_step()උපාංගයවෙත දත්ත ගෙනයන්න
116 image, mask = image.to(self.device), mask.to(self.device)අනුක්රමිකශුන්ය කරන්න
119 self.optimizer.zero_grad()පුරෝකථනයකරන ලද වෙස්මුහුණු පිවිසුම් ලබා ගන්න
121 logits = self.model(image)ඉලක්කගතවෙස්මුහුණ පිවිසුම් ප්රමාණයට වගා කරන්න. යූ-නෙට් හි සංවලිත ස්ථර වල අපි පෑඩින් භාවිතා නොකරන්නේ නම් පිවිසුම් ප්රමාණය කුඩා වේ.
124 mask = torchvision.transforms.functional.center_crop(mask, [logits.shape[2], logits.shape[3]])අලාභයගණනය කරන්න
126 loss = self.loss_func(self.sigmoid(logits), mask)අනුක්රමිකගණනය
128 loss.backward()ප්රශස්තිකරණපියවරක් ගන්න
130 self.optimizer.step()අලාභයලුහුබඳින්න
132 tracker.save('loss', loss)134 def run(self):138 for _ in monit.loop(self.epochs):ආකෘතියපුහුණු කරන්න
140 self.train()කොන්සෝලයේනව රේඛාවක්
142 tracker.new_line()ආකෘතියසුරකින්න
144 experiment.save_checkpoint()147def main():අත්හදාබැලීම සාදන්න
149 experiment.create(name='unet')වින්යාසයන්සාදන්න
152 configs = Configs()වින්යාසයන්සකසන්න. ශබ්දකෝෂයේ අගයන් සම්මත කිරීමෙන් ඔබට පෙරනිමි අභිබවා යා හැකිය.
155 experiment.configs(configs, {})ආරම්භකරන්න
158 configs.init()ඉතිරිකිරීම සහ පැටවීම සඳහා ආකෘති සකසන්න
161 experiment.add_pytorch_models({'model': configs.model})පුහුණුලූපය ආරම්භ කර ක්රියාත්මක කරන්න
164 with experiment.start():
165 configs.run()169if __name__ == '__main__':
170 main()