විසරණ සම්භාවිතා ආකෘති නිරූපණය කිරීම සඳහා යූ-නෙට් ආකෘතිය (ඩීඩීපීඑම්)

ශබ්දයපුරෝකථනය කිරීම සඳහා මෙය යූ-නෙට් පදනම් කරගත් ආකෘතියකි .

යූ-නෙට්යනු ආදර්ශ රූප සටහනේ යූ හැඩයෙන් එය නම ලබා ගනී. විශේෂාංග සිතියම් විභේදනය ක්රමයෙන් අඩු කිරීමෙන් (අඩක්) සහ විභේදනය වැඩි කිරීමෙන් එය ලබා දී ඇති රූපයක් සකසනු ලැබේ. සෑම විභේදනයකදීම පාස්-හරහා සම්බන්ධතාවයක් ඇත.

U-Net diagram from paper

මෙමක්රියාත්මක කිරීම මුල් යූ-නෙට් (අවශේෂ කුට්ටි, බහු-හිස අවධානය) සඳහා වෙනස් කිරීම් රාශියක් අඩංගු වන අතර කාල-පියවර කාවැද්දීම් එකතු කරයි .

24import math
25from typing import Optional, Tuple, Union, List
26
27import torch
28from torch import nn
29
30from labml_helpers.module import Module

ස්විස්ෂ්ක්රියාකාරී ශ්රිතය

33class Swish(Module):
40    def forward(self, x):
41        return x * torch.sigmoid(x)

කාවැද්දීම්සඳහා

44class TimeEmbedding(nn.Module):
  • n_channels කාවැද්දීමේ මානයන් ගණන වේ
49    def __init__(self, n_channels: int):
53        super().__init__()
54        self.n_channels = n_channels

පළමුරේඛීය ස්ථරය

56        self.lin1 = nn.Linear(self.n_channels // 4, self.n_channels)

සක්‍රීයකිරීම

58        self.act = Swish()

දෙවනරේඛීය ස්ථරය

60        self.lin2 = nn.Linear(self.n_channels, self.n_channels)
62    def forward(self, t: torch.Tensor):

ට්රාන්ස්ෆෝමරයට සමානසයිනොසොයිඩල් ස්ථාන කාවැද්දීම් සාදන්න

කොහේද half_dim

72        half_dim = self.n_channels // 8
73        emb = math.log(10_000) / (half_dim - 1)
74        emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
75        emb = t[:, None] * emb[None, :]
76        emb = torch.cat((emb.sin(), emb.cos()), dim=1)

එම්එල්පීසමඟ පරිවර්තනය කරන්න

79        emb = self.act(self.lin1(emb))
80        emb = self.lin2(emb)

83        return emb

අවශේෂකොටස

අවශේෂකොටසකදී කණ්ඩායම් සාමාන්යකරණය සමග convolution ස්ථර දෙකක් ඇත. සෑම විභේදනයක්ම අවශේෂ කොටස් දෙකකින් සකසනු ලැබේ.

86class ResidualBlock(Module):
  • in_channels ආදාන නාලිකා ගණන
  • out_channels ආදාන නාලිකා ගණන
  • time_channels කාලය පියවර () කාවැද්දීම් සංඛ්යාව නාලිකා වේ
  • n_groups කණ්ඩායම් සාමාන්යකරණය සඳහා කණ්ඩායම්සංඛ්යාව වේ
  • 94    def __init__(self, in_channels: int, out_channels: int, time_channels: int, n_groups: int = 32):
    101        super().__init__()

    කණ්ඩායම්සාමාන්යකරණය සහ පළමු කැටි ගැසුණු ස්ථරය

    103        self.norm1 = nn.GroupNorm(n_groups, in_channels)
    104        self.act1 = Swish()
    105        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))

    කණ්ඩායම්සාමාන්යකරණය සහ දෙවන කැටි ගැසුණු ස්තරය

    108        self.norm2 = nn.GroupNorm(n_groups, out_channels)
    109        self.act2 = Swish()
    110        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))

    ආදානනාලිකා ගණන ප්රතිදාන නාලිකා ගණනට සමාන නොවේ නම් කෙටිමං සම්බන්ධතාවය ප්රක්ෂේපණය කළ යුතුය

    114        if in_channels != out_channels:
    115            self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1))
    116        else:
    117            self.shortcut = nn.Identity()

    කාලකාවැද්දීම් සඳහා රේඛීය ස්ථරය

    120        self.time_emb = nn.Linear(time_channels, out_channels)
    • x හැඩය ඇත [batch_size, in_channels, height, width]
    • t හැඩය ඇත [batch_size, time_channels]
    122    def forward(self, x: torch.Tensor, t: torch.Tensor):

    පළමුකැටි ගැසුණු ස්ථරය

    128        h = self.conv1(self.act1(self.norm1(x)))

    කාලකාවැද්දීම් එකතු කරන්න

    130        h += self.time_emb(t)[:, :, None, None]

    දෙවනකැටි ගැසුණු ස්ථරය

    132        h = self.conv2(self.act2(self.norm2(h)))

    කෙටිමංසම්බන්ධතාවය එකතු කර ආපසු යන්න

    135        return h + self.shortcut(x)

    අවධානයවාරණ

    මෙය ට්රාන්ස්ෆෝමර් බහු-හිස අවධානයටසමාන වේ.

    138class AttentionBlock(Module):
    • n_channels යනු ආදානයේ නාලිකා ගණන
    • n_heads බහු-හිස අවධානය යොමු ප්රධානීන් සංඛ්යාව වේ
    • d_k එක් එක් හිසෙහි මානයන් ගණන
  • n_groups කණ්ඩායම් සාමාන්යකරණය සඳහා කණ්ඩායම්සංඛ්යාව වේ
  • 145    def __init__(self, n_channels: int, n_heads: int = 1, d_k: int = None, n_groups: int = 32):
    152        super().__init__()

    පෙරනිමි d_k

    155        if d_k is None:
    156            d_k = n_channels

    සාමාන්යකරණයස්ථරය

    158        self.norm = nn.GroupNorm(n_groups, n_channels)

    විමසුම, යතුර සහ අගයන් සඳහා ප්රක්ෂේපණ

    160        self.projection = nn.Linear(n_channels, n_heads * d_k * 3)

    අවසානපරිවර්තනය සඳහා රේඛීය ස්ථරය

    162        self.output = nn.Linear(n_heads * d_k, n_channels)

    තිත්නිෂ්පාදන අවධානය සඳහා පරිමාණය

    164        self.scale = d_k ** -0.5

    166        self.n_heads = n_heads
    167        self.d_k = d_k
    • x හැඩය ඇත [batch_size, in_channels, height, width]
    • t හැඩය ඇත [batch_size, time_channels]
    169    def forward(self, x: torch.Tensor, t: Optional[torch.Tensor] = None):

    t භාවිතා නොකෙරේ, නමුත් එය තර්කවල තබා ඇත්තේ අවධානය යොමු කිරීමේ ස්ථර ශ්රිතයේ අත්සන සමඟ ගැලපෙන බැවිනි ResidualBlock .

    176        _ = t

    හැඩයලබා ගන්න

    178        batch_size, n_channels, height, width = x.shape

    x හැඩයට වෙනස් කරන්න [batch_size, seq, n_channels]

    180        x = x.view(batch_size, n_channels, -1).permute(0, 2, 1)

    විමසුමලබා ගන්න, යතුර, සහ අගයන් (concatenated) සහ එය හැඩගස්වා [batch_size, seq, n_heads, 3 * d_k]

    182        qkv = self.projection(x).view(batch_size, -1, self.n_heads, 3 * self.d_k)

    විමසුම, යතුර සහ අගයන් බෙදීම්. ඔවුන් එක් එක් හැඩය ඇත [batch_size, seq, n_heads, d_k]

    184        q, k, v = torch.chunk(qkv, 3, dim=-1)

    පරිමාණතිත් නිෂ්පාදනයක් ගණනය කරන්න

    186        attn = torch.einsum('bihd,bjhd->bijh', q, k) * self.scale

    අනුක්රමිකමානය ඔස්සේ සොෆ්ට්මැක්ස්

    188        attn = attn.softmax(dim=1)

    අගයන්අනුව ගුණ කරන්න

    190        res = torch.einsum('bijh,bjhd->bihd', attn, v)

    නැවතහැඩගස්වන්න [batch_size, seq, n_heads * d_k]

    192        res = res.view(batch_size, -1, self.n_heads * self.d_k)

    බවටපරිවර්තනය කරන්න [batch_size, seq, n_channels]

    194        res = self.output(res)

    මඟහැරීමේ සම්බන්ධතාවය එක් කරන්න

    197        res += x

    හැඩයටවෙනස් කරන්න [batch_size, in_channels, height, width]

    200        res = res.permute(0, 2, 1).view(batch_size, n_channels, height, width)

    203        return res

    බ්ලොක්ඩවුන්

    මෙයඒකාබද්ධ ResidualBlock හා AttentionBlock . මෙම එක් එක් යෝජනාව දී U-Net පළමු භාගයේ දී භාවිතා වේ.

    206class DownBlock(Module):
    213    def __init__(self, in_channels: int, out_channels: int, time_channels: int, has_attn: bool):
    214        super().__init__()
    215        self.res = ResidualBlock(in_channels, out_channels, time_channels)
    216        if has_attn:
    217            self.attn = AttentionBlock(out_channels)
    218        else:
    219            self.attn = nn.Identity()
    221    def forward(self, x: torch.Tensor, t: torch.Tensor):
    222        x = self.res(x, t)
    223        x = self.attn(x)
    224        return x

    බ්ලොක්දක්වා

    මෙයඒකාබද්ධ ResidualBlock හා AttentionBlock . සෑම විභේදනයකදීම යූ-නෙට් හි දෙවන භාගයේදී මේවා භාවිතා වේ.

    227class UpBlock(Module):
    234    def __init__(self, in_channels: int, out_channels: int, time_channels: int, has_attn: bool):
    235        super().__init__()

    ආදානයටඇත්තේ යූ-නෙට් හි පළමු භාගයේ සිට එකම විභේදනයේ ප්රතිදානය අප සංයුක්ත කරන in_channels + out_channels බැවිනි

    238        self.res = ResidualBlock(in_channels + out_channels, out_channels, time_channels)
    239        if has_attn:
    240            self.attn = AttentionBlock(out_channels)
    241        else:
    242            self.attn = nn.Identity()
    244    def forward(self, x: torch.Tensor, t: torch.Tensor):
    245        x = self.res(x, t)
    246        x = self.attn(x)
    247        return x

    මැදකොටස

    එයතවත් එකක් ResidualBlock සමඟ ඒකාබද්ධ ResidualBlock වේ. AttentionBlock මෙම කොටස U-Net හි අඩුම විභේදනයෙන් යොදනු ලැබේ.

    250class MiddleBlock(Module):
    258    def __init__(self, n_channels: int, time_channels: int):
    259        super().__init__()
    260        self.res1 = ResidualBlock(n_channels, n_channels, time_channels)
    261        self.attn = AttentionBlock(n_channels)
    262        self.res2 = ResidualBlock(n_channels, n_channels, time_channels)
    264    def forward(self, x: torch.Tensor, t: torch.Tensor):
    265        x = self.res1(x, t)
    266        x = self.attn(x)
    267        x = self.res2(x, t)
    268        return x

    විශේෂාංගසිතියම පරිමාණය කරන්න

    271class Upsample(nn.Module):
    276    def __init__(self, n_channels):
    277        super().__init__()
    278        self.conv = nn.ConvTranspose2d(n_channels, n_channels, (4, 4), (2, 2), (1, 1))
    280    def forward(self, x: torch.Tensor, t: torch.Tensor):

    t භාවිතා නොකෙරේ, නමුත් එය තර්කවල තබා ඇත්තේ අවධානය යොමු කිරීමේ ස්ථර ශ්රිතයේ අත්සන සමඟ ගැලපෙන බැවිනි ResidualBlock .

    283        _ = t
    284        return self.conv(x)

    විශේෂාංගසිතියම පරිමාණය කරන්න

    287class Downsample(nn.Module):
    292    def __init__(self, n_channels):
    293        super().__init__()
    294        self.conv = nn.Conv2d(n_channels, n_channels, (3, 3), (2, 2), (1, 1))
    296    def forward(self, x: torch.Tensor, t: torch.Tensor):

    t භාවිතා නොකෙරේ, නමුත් එය තර්කවල තබා ඇත්තේ අවධානය යොමු කිරීමේ ස්ථර ශ්රිතයේ අත්සන සමඟ ගැලපෙන බැවිනි ResidualBlock .

    299        _ = t
    300        return self.conv(x)

    යූ-නෙට්

    303class UNet(Module):
    • image_channels යනු රූපයේ නාලිකා ගණන. RGB සඳහා.
    • n_channels ආරම්භක විශේෂාංග සිතියමේ නාලිකා ගණන අපි රූපය බවට පරිවර්තනය කරමු
    • ch_mults යනු එක් එක් විභේදනයේ නාලිකා අංක ලැයිස්තුවයි. නාලිකා ගණන වේ ch_mults[i] * n_channels
    • is_attn යනු එක් එක් විභේදනයේ දී අවධානය යොමු කළ යුතුද යන්න පෙන්නුම් කරන බූලියන් ලැයිස්තුවකි
    • n_blocks එක් එක් යෝජනාව UpDownBlocks දී සංඛ්යාව වේ
    308    def __init__(self, image_channels: int = 3, n_channels: int = 64,
    309                 ch_mults: Union[Tuple[int, ...], List[int]] = (1, 2, 2, 4),
    310                 is_attn: Union[Tuple[bool, ...], List[int]] = (False, False, True, True),
    311                 n_blocks: int = 2):
    319        super().__init__()

    යෝජනාගණන

    322        n_resolutions = len(ch_mults)

    විශේෂාංගසිතියමට ව්යාපෘති රූපය

    325        self.image_proj = nn.Conv2d(image_channels, n_channels, kernel_size=(3, 3), padding=(1, 1))

    කාලයකාවැද්දීම ස්ථරය. කාල කාවැද්දීම n_channels * 4 නාලිකා ඇත

    328        self.time_emb = TimeEmbedding(n_channels * 4)

    U-Netපළමු භාගය - අඩු යෝජනාව

    331        down = []

    නාලිකාගණන

    333        out_channels = in_channels = n_channels

    එක්එක් යෝජනාව සඳහා

    335        for i in range(n_resolutions):

    මෙමවිභේදනයේ ප්රතිදාන නාලිකා ගණන

    337            out_channels = in_channels * ch_mults[i]

    එකතුකරන්න n_blocks

    339            for _ in range(n_blocks):
    340                down.append(DownBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
    341                in_channels = out_channels

    පසුගියහැර අනෙකුත් සියලු යෝජනා දී ආදර්ශ පහළ

    343            if i < n_resolutions - 1:
    344                down.append(Downsample(in_channels))

    මොඩියුලකට්ටලය ඒකාබද්ධ කරන්න

    347        self.down = nn.ModuleList(down)

    මැදකොටස

    350        self.middle = MiddleBlock(out_channels, n_channels * 4, )

    යූ-නෙට්හි දෙවන භාගය - වැඩි කිරීමේ විභේදනය

    353        up = []

    නාලිකාගණන

    355        in_channels = out_channels

    එක්එක් යෝජනාව සඳහා

    357        for i in reversed(range(n_resolutions)):

    n_blocks එකම විභේදනයේ

    359            out_channels = in_channels
    360            for _ in range(n_blocks):
    361                up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))

    නාලිකාගණන අඩු කිරීම සඳහා අවසාන කොටස

    363            out_channels = in_channels // ch_mults[i]
    364            up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
    365            in_channels = out_channels

    පසුගියහැර අනෙකුත් සියලු යෝජනා දී ආදර්ශ දක්වා

    367            if i > 0:
    368                up.append(Upsample(in_channels))

    මොඩියුලකට්ටලය ඒකාබද්ධ කරන්න

    371        self.up = nn.ModuleList(up)

    අවසානසාමාන්යකරණය සහ කැටි ගැසුණු ස්ථරය

    374        self.norm = nn.GroupNorm(8, n_channels)
    375        self.act = Swish()
    376        self.final = nn.Conv2d(in_channels, image_channels, kernel_size=(3, 3), padding=(1, 1))
    • x හැඩය ඇත [batch_size, in_channels, height, width]
    • t හැඩය ඇත [batch_size]
    378    def forward(self, x: torch.Tensor, t: torch.Tensor):

    කාල-පියවරකාවැද්දීම් ලබා ගන්න

    385        t = self.time_emb(t)

    රූපප්රක්ෂේපණය ලබා ගන්න

    388        x = self.image_proj(x)

    h මඟ හැරීමේ සම්බන්ධතාවය සඳහා එක් එක් විභේදනයේ ප්රතිදානයන් ගබඩා කරනු ඇත

    391        h = [x]

    යූ-නෙට්හි පළමු භාගය

    393        for m in self.down:
    394            x = m(x, t)
    395            h.append(x)

    මැද(පහළ)

    398        x = self.middle(x, t)

    යූ-නෙට්හි දෙවන භාගය

    401        for m in self.up:
    402            if isinstance(m, Upsample):
    403                x = m(x, t)
    404            else:

    යූ-නෙට්හි පළමු භාගයේ සිට මඟ හැරීමේ සම්බන්ධතාවය ලබාගෙන සංයුක්ත කරන්න

    406                s = h.pop()
    407                x = torch.cat((x, s), dim=1)

    409                x = m(x, t)

    අවසානසාමාන්යකරණය සහ කැටි කිරීම

    412        return self.final(self.act(self.norm(x)))