ශබ්දයපුරෝකථනය කිරීම සඳහා මෙය යූ-නෙට් පදනම් කරගත් ආකෘතියකි .
යූ-නෙට්යනු ආදර්ශ රූප සටහනේ යූ හැඩයෙන් එය නම ලබා ගනී. විශේෂාංග සිතියම් විභේදනය ක්රමයෙන් අඩු කිරීමෙන් (අඩක්) සහ විභේදනය වැඩි කිරීමෙන් එය ලබා දී ඇති රූපයක් සකසනු ලැබේ. සෑම විභේදනයකදීම පාස්-හරහා සම්බන්ධතාවයක් ඇත.
මෙමක්රියාත්මක කිරීම මුල් යූ-නෙට් (අවශේෂ කුට්ටි, බහු-හිස අවධානය) සඳහා වෙනස් කිරීම් රාශියක් අඩංගු වන අතර කාල-පියවර කාවැද්දීම් එකතු කරයි .
24import math
25from typing import Optional, Tuple, Union, List
26
27import torch
28from torch import nn
29
30from labml_helpers.module import Module
33class Swish(Module):
40 def forward(self, x):
41 return x * torch.sigmoid(x)
44class TimeEmbedding(nn.Module):
n_channels
කාවැද්දීමේ මානයන් ගණන වේ49 def __init__(self, n_channels: int):
53 super().__init__()
54 self.n_channels = n_channels
පළමුරේඛීය ස්ථරය
56 self.lin1 = nn.Linear(self.n_channels // 4, self.n_channels)
සක්රීයකිරීම
58 self.act = Swish()
දෙවනරේඛීය ස්ථරය
60 self.lin2 = nn.Linear(self.n_channels, self.n_channels)
62 def forward(self, t: torch.Tensor):
72 half_dim = self.n_channels // 8
73 emb = math.log(10_000) / (half_dim - 1)
74 emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
75 emb = t[:, None] * emb[None, :]
76 emb = torch.cat((emb.sin(), emb.cos()), dim=1)
එම්එල්පීසමඟ පරිවර්තනය කරන්න
79 emb = self.act(self.lin1(emb))
80 emb = self.lin2(emb)
83 return emb
අවශේෂකොටසකදී කණ්ඩායම් සාමාන්යකරණය සමග convolution ස්ථර දෙකක් ඇත. සෑම විභේදනයක්ම අවශේෂ කොටස් දෙකකින් සකසනු ලැබේ.
86class ResidualBlock(Module):
in_channels
ආදාන නාලිකා ගණනout_channels
ආදාන නාලිකා ගණනtime_channels
කාලය පියවර () කාවැද්දීම් සංඛ්යාව නාලිකා වේn_groups
කණ්ඩායම් සාමාන්යකරණය සඳහා කණ්ඩායම් සංඛ්යාව වේdropout
හැලහැප්මේ අනුපාතය වේ94 def __init__(self, in_channels: int, out_channels: int, time_channels: int,
95 n_groups: int = 32, dropout: float = 0.1):
103 super().__init__()
කණ්ඩායම්සාමාන්යකරණය සහ පළමු කැටි ගැසුණු ස්ථරය
105 self.norm1 = nn.GroupNorm(n_groups, in_channels)
106 self.act1 = Swish()
107 self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))
කණ්ඩායම්සාමාන්යකරණය සහ දෙවන කැටි ගැසුණු ස්තරය
110 self.norm2 = nn.GroupNorm(n_groups, out_channels)
111 self.act2 = Swish()
112 self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))
ආදානනාලිකා ගණන ප්රතිදාන නාලිකා ගණනට සමාන නොවේ නම් කෙටිමං සම්බන්ධතාවය ප්රක්ෂේපණය කළ යුතුය
116 if in_channels != out_channels:
117 self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1))
118 else:
119 self.shortcut = nn.Identity()
කාලකාවැද්දීම් සඳහා රේඛීය ස්ථරය
122 self.time_emb = nn.Linear(time_channels, out_channels)
123 self.time_act = Swish()
124
125 self.dropout = nn.Dropout(dropout)
x
හැඩය ඇත [batch_size, in_channels, height, width]
t
හැඩය ඇත [batch_size, time_channels]
127 def forward(self, x: torch.Tensor, t: torch.Tensor):
පළමුකැටි ගැසුණු ස්ථරය
133 h = self.conv1(self.act1(self.norm1(x)))
කාලකාවැද්දීම් එකතු කරන්න
135 h += self.time_emb(self.time_act(t))[:, :, None, None]
දෙවනකැටි ගැසුණු ස්ථරය
137 h = self.conv2(self.dropout(self.act2(self.norm2(h))))
කෙටිමංසම්බන්ධතාවය එකතු කර ආපසු යන්න
140 return h + self.shortcut(x)
143class AttentionBlock(Module):
n_channels
යනු ආදානයේ නාලිකා ගණන n_heads
බහු-හිස අවධානය යොමු ප්රධානීන් සංඛ්යාව වේ d_k
එක් එක් හිසෙහි මානයන් ගණන n_groups
කණ්ඩායම් සාමාන්යකරණය සඳහා කණ්ඩායම්සංඛ්යාව වේ150 def __init__(self, n_channels: int, n_heads: int = 1, d_k: int = None, n_groups: int = 32):
157 super().__init__()
පෙරනිමි d_k
160 if d_k is None:
161 d_k = n_channels
සාමාන්යකරණයස්ථරය
163 self.norm = nn.GroupNorm(n_groups, n_channels)
විමසුම, යතුර සහ අගයන් සඳහා ප්රක්ෂේපණ
165 self.projection = nn.Linear(n_channels, n_heads * d_k * 3)
අවසානපරිවර්තනය සඳහා රේඛීය ස්ථරය
167 self.output = nn.Linear(n_heads * d_k, n_channels)
තිත්නිෂ්පාදන අවධානය සඳහා පරිමාණය
169 self.scale = d_k ** -0.5
171 self.n_heads = n_heads
172 self.d_k = d_k
x
හැඩය ඇත [batch_size, in_channels, height, width]
t
හැඩය ඇත [batch_size, time_channels]
174 def forward(self, x: torch.Tensor, t: Optional[torch.Tensor] = None):
t
භාවිතා නොකෙරේ, නමුත් එය තර්කවල තබා ඇත්තේ අවධානය යොමු කිරීමේ ස්ථර ශ්රිතයේ අත්සන සමඟ ගැලපෙන බැවිනි ResidualBlock
.
181 _ = t
හැඩයලබා ගන්න
183 batch_size, n_channels, height, width = x.shape
x
හැඩයට වෙනස් කරන්න [batch_size, seq, n_channels]
185 x = x.view(batch_size, n_channels, -1).permute(0, 2, 1)
විමසුමලබා ගන්න, යතුර, සහ අගයන් (concatenated) සහ එය හැඩගස්වා [batch_size, seq, n_heads, 3 * d_k]
187 qkv = self.projection(x).view(batch_size, -1, self.n_heads, 3 * self.d_k)
විමසුම, යතුර සහ අගයන් බෙදීම්. ඔවුන් එක් එක් හැඩය ඇත [batch_size, seq, n_heads, d_k]
189 q, k, v = torch.chunk(qkv, 3, dim=-1)
පරිමාණතිත් නිෂ්පාදනයක් ගණනය කරන්න
191 attn = torch.einsum('bihd,bjhd->bijh', q, k) * self.scale
අනුක්රමිකමානය ඔස්සේ සොෆ්ට්මැක්ස්
193 attn = attn.softmax(dim=2)
අගයන්අනුව ගුණ කරන්න
195 res = torch.einsum('bijh,bjhd->bihd', attn, v)
නැවතහැඩගස්වන්න [batch_size, seq, n_heads * d_k]
197 res = res.view(batch_size, -1, self.n_heads * self.d_k)
බවටපරිවර්තනය කරන්න [batch_size, seq, n_channels]
199 res = self.output(res)
මඟහැරීමේ සම්බන්ධතාවය එක් කරන්න
202 res += x
හැඩයටවෙනස් කරන්න [batch_size, in_channels, height, width]
205 res = res.permute(0, 2, 1).view(batch_size, n_channels, height, width)
208 return res
මෙයඒකාබද්ධ ResidualBlock
හා AttentionBlock
. මෙම එක් එක් යෝජනාව දී U-Net පළමු භාගයේ දී භාවිතා වේ.
211class DownBlock(Module):
218 def __init__(self, in_channels: int, out_channels: int, time_channels: int, has_attn: bool):
219 super().__init__()
220 self.res = ResidualBlock(in_channels, out_channels, time_channels)
221 if has_attn:
222 self.attn = AttentionBlock(out_channels)
223 else:
224 self.attn = nn.Identity()
226 def forward(self, x: torch.Tensor, t: torch.Tensor):
227 x = self.res(x, t)
228 x = self.attn(x)
229 return x
මෙයඒකාබද්ධ ResidualBlock
හා AttentionBlock
. සෑම විභේදනයකදීම යූ-නෙට් හි දෙවන භාගයේදී මේවා භාවිතා වේ.
232class UpBlock(Module):
239 def __init__(self, in_channels: int, out_channels: int, time_channels: int, has_attn: bool):
240 super().__init__()
ආදානයටඇත්තේ යූ-නෙට් හි පළමු භාගයේ සිට එකම විභේදනයේ ප්රතිදානය අප සංයුක්ත කරන in_channels + out_channels
බැවිනි
243 self.res = ResidualBlock(in_channels + out_channels, out_channels, time_channels)
244 if has_attn:
245 self.attn = AttentionBlock(out_channels)
246 else:
247 self.attn = nn.Identity()
249 def forward(self, x: torch.Tensor, t: torch.Tensor):
250 x = self.res(x, t)
251 x = self.attn(x)
252 return x
එයතවත් එකක් ResidualBlock
සමඟ ඒකාබද්ධ ResidualBlock
වේ. AttentionBlock
මෙම කොටස U-Net හි අඩුම විභේදනයෙන් යොදනු ලැබේ.
255class MiddleBlock(Module):
263 def __init__(self, n_channels: int, time_channels: int):
264 super().__init__()
265 self.res1 = ResidualBlock(n_channels, n_channels, time_channels)
266 self.attn = AttentionBlock(n_channels)
267 self.res2 = ResidualBlock(n_channels, n_channels, time_channels)
269 def forward(self, x: torch.Tensor, t: torch.Tensor):
270 x = self.res1(x, t)
271 x = self.attn(x)
272 x = self.res2(x, t)
273 return x
276class Upsample(nn.Module):
281 def __init__(self, n_channels):
282 super().__init__()
283 self.conv = nn.ConvTranspose2d(n_channels, n_channels, (4, 4), (2, 2), (1, 1))
285 def forward(self, x: torch.Tensor, t: torch.Tensor):
t
භාවිතා නොකෙරේ, නමුත් එය තර්කවල තබා ඇත්තේ අවධානය යොමු කිරීමේ ස්ථර ශ්රිතයේ අත්සන සමඟ ගැලපෙන බැවිනි ResidualBlock
.
288 _ = t
289 return self.conv(x)
292class Downsample(nn.Module):
297 def __init__(self, n_channels):
298 super().__init__()
299 self.conv = nn.Conv2d(n_channels, n_channels, (3, 3), (2, 2), (1, 1))
301 def forward(self, x: torch.Tensor, t: torch.Tensor):
t
භාවිතා නොකෙරේ, නමුත් එය තර්කවල තබා ඇත්තේ අවධානය යොමු කිරීමේ ස්ථර ශ්රිතයේ අත්සන සමඟ ගැලපෙන බැවිනි ResidualBlock
.
304 _ = t
305 return self.conv(x)
308class UNet(Module):
image_channels
යනු රූපයේ නාලිකා ගණන. RGB සඳහා. n_channels
ආරම්භක විශේෂාංග සිතියමේ නාලිකා ගණන අපි රූපය බවට පරිවර්තනය කරමු ch_mults
යනු එක් එක් විභේදනයේ නාලිකා අංක ලැයිස්තුවයි. නාලිකා ගණන වේ ch_mults[i] * n_channels
is_attn
යනු එක් එක් විභේදනයේ දී අවධානය යොමු කළ යුතුද යන්න පෙන්නුම් කරන බූලියන් ලැයිස්තුවකි n_blocks
එක් එක් යෝජනාව UpDownBlocks
දී සංඛ්යාව වේ313 def __init__(self, image_channels: int = 3, n_channels: int = 64,
314 ch_mults: Union[Tuple[int, ...], List[int]] = (1, 2, 2, 4),
315 is_attn: Union[Tuple[bool, ...], List[int]] = (False, False, True, True),
316 n_blocks: int = 2):
324 super().__init__()
යෝජනාගණන
327 n_resolutions = len(ch_mults)
විශේෂාංගසිතියමට ව්යාපෘති රූපය
330 self.image_proj = nn.Conv2d(image_channels, n_channels, kernel_size=(3, 3), padding=(1, 1))
කාලයකාවැද්දීම ස්ථරය. කාල කාවැද්දීම n_channels * 4
නාලිකා ඇත
333 self.time_emb = TimeEmbedding(n_channels * 4)
336 down = []
නාලිකාගණන
338 out_channels = in_channels = n_channels
එක්එක් යෝජනාව සඳහා
340 for i in range(n_resolutions):
මෙමවිභේදනයේ ප්රතිදාන නාලිකා ගණන
342 out_channels = in_channels * ch_mults[i]
එකතුකරන්න n_blocks
344 for _ in range(n_blocks):
345 down.append(DownBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
346 in_channels = out_channels
පසුගියහැර අනෙකුත් සියලු යෝජනා දී ආදර්ශ පහළ
348 if i < n_resolutions - 1:
349 down.append(Downsample(in_channels))
මොඩියුලකට්ටලය ඒකාබද්ධ කරන්න
352 self.down = nn.ModuleList(down)
මැදකොටස
355 self.middle = MiddleBlock(out_channels, n_channels * 4, )
358 up = []
නාලිකාගණන
360 in_channels = out_channels
එක්එක් යෝජනාව සඳහා
362 for i in reversed(range(n_resolutions)):
n_blocks
එකම විභේදනයේ
364 out_channels = in_channels
365 for _ in range(n_blocks):
366 up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
නාලිකාගණන අඩු කිරීම සඳහා අවසාන කොටස
368 out_channels = in_channels // ch_mults[i]
369 up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
370 in_channels = out_channels
පසුගියහැර අනෙකුත් සියලු යෝජනා දී ආදර්ශ දක්වා
372 if i > 0:
373 up.append(Upsample(in_channels))
මොඩියුලකට්ටලය ඒකාබද්ධ කරන්න
376 self.up = nn.ModuleList(up)
අවසානසාමාන්යකරණය සහ කැටි ගැසුණු ස්ථරය
379 self.norm = nn.GroupNorm(8, n_channels)
380 self.act = Swish()
381 self.final = nn.Conv2d(in_channels, image_channels, kernel_size=(3, 3), padding=(1, 1))
x
හැඩය ඇත [batch_size, in_channels, height, width]
t
හැඩය ඇත [batch_size]
383 def forward(self, x: torch.Tensor, t: torch.Tensor):
කාල-පියවරකාවැද්දීම් ලබා ගන්න
390 t = self.time_emb(t)
රූපප්රක්ෂේපණය ලබා ගන්න
393 x = self.image_proj(x)
h
මඟ හැරීමේ සම්බන්ධතාවය සඳහා එක් එක් විභේදනයේ ප්රතිදානයන් ගබඩා කරනු ඇත
396 h = [x]
යූ-නෙට්හි පළමු භාගය
398 for m in self.down:
399 x = m(x, t)
400 h.append(x)
මැද(පහළ)
403 x = self.middle(x, t)
යූ-නෙට්හි දෙවන භාගය
406 for m in self.up:
407 if isinstance(m, Upsample):
408 x = m(x, t)
409 else:
යූ-නෙට්හි පළමු භාගයේ සිට මඟ හැරීමේ සම්බන්ධතාවය ලබාගෙන සංයුක්ත කරන්න
411 s = h.pop()
412 x = torch.cat((x, s), dim=1)
414 x = m(x, t)
අවසානසාමාන්යකරණය සහ කැටි කිරීම
417 return self.final(self.act(self.norm(x)))