ශබ්දයපුරෝකථනය කිරීම සඳහා මෙය යූ-නෙට් පදනම් කරගත් ආකෘතියකි .
යූ-නෙට්යනු ආදර්ශ රූප සටහනේ යූ හැඩයෙන් එය නම ලබා ගනී. විශේෂාංග සිතියම් විභේදනය ක්රමයෙන් අඩු කිරීමෙන් (අඩක්) සහ විභේදනය වැඩි කිරීමෙන් එය ලබා දී ඇති රූපයක් සකසනු ලැබේ. සෑම විභේදනයකදීම පාස්-හරහා සම්බන්ධතාවයක් ඇත.

මෙමක්රියාත්මක කිරීම මුල් යූ-නෙට් (අවශේෂ කුට්ටි, බහු-හිස අවධානය) සඳහා වෙනස් කිරීම් රාශියක් අඩංගු වන අතර කාල-පියවර කාවැද්දීම් එකතු කරයි .
24import math
25from typing import Optional, Tuple, Union, List
26
27import torch
28from torch import nn
29
30from labml_helpers.module import Module33class Swish(Module):40 def forward(self, x):
41 return x * torch.sigmoid(x)44class TimeEmbedding(nn.Module):n_channels
කාවැද්දීමේ මානයන් ගණන වේ49 def __init__(self, n_channels: int):53 super().__init__()
54 self.n_channels = n_channelsපළමුරේඛීය ස්ථරය
56 self.lin1 = nn.Linear(self.n_channels // 4, self.n_channels)සක්රීයකිරීම
58 self.act = Swish()දෙවනරේඛීය ස්ථරය
60 self.lin2 = nn.Linear(self.n_channels, self.n_channels)62 def forward(self, t: torch.Tensor):72 half_dim = self.n_channels // 8
73 emb = math.log(10_000) / (half_dim - 1)
74 emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
75 emb = t[:, None] * emb[None, :]
76 emb = torch.cat((emb.sin(), emb.cos()), dim=1)එම්එල්පීසමඟ පරිවර්තනය කරන්න
79 emb = self.act(self.lin1(emb))
80 emb = self.lin2(emb)83 return embඅවශේෂකොටසකදී කණ්ඩායම් සාමාන්යකරණය සමග convolution ස්ථර දෙකක් ඇත. සෑම විභේදනයක්ම අවශේෂ කොටස් දෙකකින් සකසනු ලැබේ.
86class ResidualBlock(Module):in_channels
ආදාන නාලිකා ගණන out_channels
ආදාන නාලිකා ගණන time_channels
කාලය පියවර () කාවැද්දීම් සංඛ්යාව නාලිකා වේ n_groups
කණ්ඩායම් සාමාන්යකරණය සඳහා කණ්ඩායම්සංඛ්යාව වේ94 def __init__(self, in_channels: int, out_channels: int, time_channels: int, n_groups: int = 32):101 super().__init__()කණ්ඩායම්සාමාන්යකරණය සහ පළමු කැටි ගැසුණු ස්ථරය
103 self.norm1 = nn.GroupNorm(n_groups, in_channels)
104 self.act1 = Swish()
105 self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))කණ්ඩායම්සාමාන්යකරණය සහ දෙවන කැටි ගැසුණු ස්තරය
108 self.norm2 = nn.GroupNorm(n_groups, out_channels)
109 self.act2 = Swish()
110 self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))ආදානනාලිකා ගණන ප්රතිදාන නාලිකා ගණනට සමාන නොවේ නම් කෙටිමං සම්බන්ධතාවය ප්රක්ෂේපණය කළ යුතුය
114 if in_channels != out_channels:
115 self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1))
116 else:
117 self.shortcut = nn.Identity()කාලකාවැද්දීම් සඳහා රේඛීය ස්ථරය
120 self.time_emb = nn.Linear(time_channels, out_channels)x
හැඩය ඇත [batch_size, in_channels, height, width]
t
හැඩය ඇත [batch_size, time_channels]
122 def forward(self, x: torch.Tensor, t: torch.Tensor):පළමුකැටි ගැසුණු ස්ථරය
128 h = self.conv1(self.act1(self.norm1(x)))කාලකාවැද්දීම් එකතු කරන්න
130 h += self.time_emb(t)[:, :, None, None]දෙවනකැටි ගැසුණු ස්ථරය
132 h = self.conv2(self.act2(self.norm2(h)))කෙටිමංසම්බන්ධතාවය එකතු කර ආපසු යන්න
135 return h + self.shortcut(x)138class AttentionBlock(Module):n_channels
යනු ආදානයේ නාලිකා ගණන n_heads
බහු-හිස අවධානය යොමු ප්රධානීන් සංඛ්යාව වේ d_k
එක් එක් හිසෙහි මානයන් ගණන n_groups
කණ්ඩායම් සාමාන්යකරණය සඳහා කණ්ඩායම්සංඛ්යාව වේ145 def __init__(self, n_channels: int, n_heads: int = 1, d_k: int = None, n_groups: int = 32):152 super().__init__()පෙරනිමි d_k
155 if d_k is None:
156 d_k = n_channelsසාමාන්යකරණයස්ථරය
158 self.norm = nn.GroupNorm(n_groups, n_channels)විමසුම, යතුර සහ අගයන් සඳහා ප්රක්ෂේපණ
160 self.projection = nn.Linear(n_channels, n_heads * d_k * 3)අවසානපරිවර්තනය සඳහා රේඛීය ස්ථරය
162 self.output = nn.Linear(n_heads * d_k, n_channels)තිත්නිෂ්පාදන අවධානය සඳහා පරිමාණය
164 self.scale = d_k ** -0.5166 self.n_heads = n_heads
167 self.d_k = d_kx
හැඩය ඇත [batch_size, in_channels, height, width]
t
හැඩය ඇත [batch_size, time_channels]
169 def forward(self, x: torch.Tensor, t: Optional[torch.Tensor] = None):t
භාවිතා නොකෙරේ, නමුත් එය තර්කවල තබා ඇත්තේ අවධානය යොමු කිරීමේ ස්ථර ශ්රිතයේ අත්සන සමඟ ගැලපෙන බැවිනි ResidualBlock
.
176 _ = tහැඩයලබා ගන්න
178 batch_size, n_channels, height, width = x.shapex
හැඩයට වෙනස් කරන්න [batch_size, seq, n_channels]
180 x = x.view(batch_size, n_channels, -1).permute(0, 2, 1)විමසුමලබා ගන්න, යතුර, සහ අගයන් (concatenated) සහ එය හැඩගස්වා [batch_size, seq, n_heads, 3 * d_k]
182 qkv = self.projection(x).view(batch_size, -1, self.n_heads, 3 * self.d_k)විමසුම, යතුර සහ අගයන් බෙදීම්. ඔවුන් එක් එක් හැඩය ඇත [batch_size, seq, n_heads, d_k]
184 q, k, v = torch.chunk(qkv, 3, dim=-1)පරිමාණතිත් නිෂ්පාදනයක් ගණනය කරන්න
186 attn = torch.einsum('bihd,bjhd->bijh', q, k) * self.scaleඅනුක්රමිකමානය ඔස්සේ සොෆ්ට්මැක්ස්
188 attn = attn.softmax(dim=1)අගයන්අනුව ගුණ කරන්න
190 res = torch.einsum('bijh,bjhd->bihd', attn, v)නැවතහැඩගස්වන්න [batch_size, seq, n_heads * d_k]
192 res = res.view(batch_size, -1, self.n_heads * self.d_k)බවටපරිවර්තනය කරන්න [batch_size, seq, n_channels]
194 res = self.output(res)මඟහැරීමේ සම්බන්ධතාවය එක් කරන්න
197 res += xහැඩයටවෙනස් කරන්න [batch_size, in_channels, height, width]
200 res = res.permute(0, 2, 1).view(batch_size, n_channels, height, width)203 return resමෙයඒකාබද්ධ ResidualBlock
හා AttentionBlock
. මෙම එක් එක් යෝජනාව දී U-Net පළමු භාගයේ දී භාවිතා වේ.
206class DownBlock(Module):213 def __init__(self, in_channels: int, out_channels: int, time_channels: int, has_attn: bool):
214 super().__init__()
215 self.res = ResidualBlock(in_channels, out_channels, time_channels)
216 if has_attn:
217 self.attn = AttentionBlock(out_channels)
218 else:
219 self.attn = nn.Identity()221 def forward(self, x: torch.Tensor, t: torch.Tensor):
222 x = self.res(x, t)
223 x = self.attn(x)
224 return xමෙයඒකාබද්ධ ResidualBlock
හා AttentionBlock
. සෑම විභේදනයකදීම යූ-නෙට් හි දෙවන භාගයේදී මේවා භාවිතා වේ.
227class UpBlock(Module):234 def __init__(self, in_channels: int, out_channels: int, time_channels: int, has_attn: bool):
235 super().__init__()ආදානයටඇත්තේ යූ-නෙට් හි පළමු භාගයේ සිට එකම විභේදනයේ ප්රතිදානය අප සංයුක්ත කරන in_channels + out_channels
බැවිනි
238 self.res = ResidualBlock(in_channels + out_channels, out_channels, time_channels)
239 if has_attn:
240 self.attn = AttentionBlock(out_channels)
241 else:
242 self.attn = nn.Identity()244 def forward(self, x: torch.Tensor, t: torch.Tensor):
245 x = self.res(x, t)
246 x = self.attn(x)
247 return xඑයතවත් එකක් ResidualBlock
සමඟ ඒකාබද්ධ ResidualBlock
වේ. AttentionBlock
මෙම කොටස U-Net හි අඩුම විභේදනයෙන් යොදනු ලැබේ.
250class MiddleBlock(Module):258 def __init__(self, n_channels: int, time_channels: int):
259 super().__init__()
260 self.res1 = ResidualBlock(n_channels, n_channels, time_channels)
261 self.attn = AttentionBlock(n_channels)
262 self.res2 = ResidualBlock(n_channels, n_channels, time_channels)264 def forward(self, x: torch.Tensor, t: torch.Tensor):
265 x = self.res1(x, t)
266 x = self.attn(x)
267 x = self.res2(x, t)
268 return x271class Upsample(nn.Module):276 def __init__(self, n_channels):
277 super().__init__()
278 self.conv = nn.ConvTranspose2d(n_channels, n_channels, (4, 4), (2, 2), (1, 1))280 def forward(self, x: torch.Tensor, t: torch.Tensor):t
භාවිතා නොකෙරේ, නමුත් එය තර්කවල තබා ඇත්තේ අවධානය යොමු කිරීමේ ස්ථර ශ්රිතයේ අත්සන සමඟ ගැලපෙන බැවිනි ResidualBlock
.
283 _ = t
284 return self.conv(x)287class Downsample(nn.Module):292 def __init__(self, n_channels):
293 super().__init__()
294 self.conv = nn.Conv2d(n_channels, n_channels, (3, 3), (2, 2), (1, 1))296 def forward(self, x: torch.Tensor, t: torch.Tensor):t
භාවිතා නොකෙරේ, නමුත් එය තර්කවල තබා ඇත්තේ අවධානය යොමු කිරීමේ ස්ථර ශ්රිතයේ අත්සන සමඟ ගැලපෙන බැවිනි ResidualBlock
.
299 _ = t
300 return self.conv(x)303class UNet(Module):image_channels
යනු රූපයේ නාලිකා ගණන. RGB සඳහා. n_channels
ආරම්භක විශේෂාංග සිතියමේ නාලිකා ගණන අපි රූපය බවට පරිවර්තනය කරමු ch_mults
යනු එක් එක් විභේදනයේ නාලිකා අංක ලැයිස්තුවයි. නාලිකා ගණන වේ ch_mults[i] * n_channels
is_attn
යනු එක් එක් විභේදනයේ දී අවධානය යොමු කළ යුතුද යන්න පෙන්නුම් කරන බූලියන් ලැයිස්තුවකි n_blocks
එක් එක් යෝජනාව UpDownBlocks
දී සංඛ්යාව වේ308 def __init__(self, image_channels: int = 3, n_channels: int = 64,
309 ch_mults: Union[Tuple[int, ...], List[int]] = (1, 2, 2, 4),
310 is_attn: Union[Tuple[bool, ...], List[int]] = (False, False, True, True),
311 n_blocks: int = 2):319 super().__init__()යෝජනාගණන
322 n_resolutions = len(ch_mults)විශේෂාංගසිතියමට ව්යාපෘති රූපය
325 self.image_proj = nn.Conv2d(image_channels, n_channels, kernel_size=(3, 3), padding=(1, 1))කාලයකාවැද්දීම ස්ථරය. කාල කාවැද්දීම n_channels * 4
නාලිකා ඇත
328 self.time_emb = TimeEmbedding(n_channels * 4)331 down = []නාලිකාගණන
333 out_channels = in_channels = n_channelsඑක්එක් යෝජනාව සඳහා
335 for i in range(n_resolutions):මෙමවිභේදනයේ ප්රතිදාන නාලිකා ගණන
337 out_channels = in_channels * ch_mults[i]එකතුකරන්න n_blocks
339 for _ in range(n_blocks):
340 down.append(DownBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
341 in_channels = out_channelsපසුගියහැර අනෙකුත් සියලු යෝජනා දී ආදර්ශ පහළ
343 if i < n_resolutions - 1:
344 down.append(Downsample(in_channels))මොඩියුලකට්ටලය ඒකාබද්ධ කරන්න
347 self.down = nn.ModuleList(down)මැදකොටස
350 self.middle = MiddleBlock(out_channels, n_channels * 4, )353 up = []නාලිකාගණන
355 in_channels = out_channelsඑක්එක් යෝජනාව සඳහා
357 for i in reversed(range(n_resolutions)):n_blocks
එකම විභේදනයේ
359 out_channels = in_channels
360 for _ in range(n_blocks):
361 up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))නාලිකාගණන අඩු කිරීම සඳහා අවසාන කොටස
363 out_channels = in_channels // ch_mults[i]
364 up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
365 in_channels = out_channelsපසුගියහැර අනෙකුත් සියලු යෝජනා දී ආදර්ශ දක්වා
367 if i > 0:
368 up.append(Upsample(in_channels))මොඩියුලකට්ටලය ඒකාබද්ධ කරන්න
371 self.up = nn.ModuleList(up)අවසානසාමාන්යකරණය සහ කැටි ගැසුණු ස්ථරය
374 self.norm = nn.GroupNorm(8, n_channels)
375 self.act = Swish()
376 self.final = nn.Conv2d(in_channels, image_channels, kernel_size=(3, 3), padding=(1, 1))x
හැඩය ඇත [batch_size, in_channels, height, width]
t
හැඩය ඇත [batch_size]
378 def forward(self, x: torch.Tensor, t: torch.Tensor):කාල-පියවරකාවැද්දීම් ලබා ගන්න
385 t = self.time_emb(t)රූපප්රක්ෂේපණය ලබා ගන්න
388 x = self.image_proj(x)h
මඟ හැරීමේ සම්බන්ධතාවය සඳහා එක් එක් විභේදනයේ ප්රතිදානයන් ගබඩා කරනු ඇත
391 h = [x]යූ-නෙට්හි පළමු භාගය
393 for m in self.down:
394 x = m(x, t)
395 h.append(x)මැද(පහළ)
398 x = self.middle(x, t)යූ-නෙට්හි දෙවන භාගය
401 for m in self.up:
402 if isinstance(m, Upsample):
403 x = m(x, t)
404 else:යූ-නෙට්හි පළමු භාගයේ සිට මඟ හැරීමේ සම්බන්ධතාවය ලබාගෙන සංයුක්ත කරන්න
406 s = h.pop()
407 x = torch.cat((x, s), dim=1)409 x = m(x, t)අවසානසාමාන්යකරණය සහ කැටි කිරීම
412 return self.final(self.act(self.norm(x)))