විසරණ සම්භාවිතා ආකෘති නිරූපණය කිරීම සඳහා යූ-නෙට් ආකෘතිය (ඩීඩීපීඑම්)

ශබ්දයපුරෝකථනය කිරීම සඳහා මෙය යූ-නෙට් පදනම් කරගත් ආකෘතියකි .

යූ-නෙට්යනු ආදර්ශ රූප සටහනේ යූ හැඩයෙන් එය නම ලබා ගනී. විශේෂාංග සිතියම් විභේදනය ක්රමයෙන් අඩු කිරීමෙන් (අඩක්) සහ විභේදනය වැඩි කිරීමෙන් එය ලබා දී ඇති රූපයක් සකසනු ලැබේ. සෑම විභේදනයකදීම පාස්-හරහා සම්බන්ධතාවයක් ඇත.

U-Net diagram from paper

මෙමක්රියාත්මක කිරීම මුල් යූ-නෙට් (අවශේෂ කුට්ටි, බහු-හිස අවධානය) සඳහා වෙනස් කිරීම් රාශියක් අඩංගු වන අතර කාල-පියවර කාවැද්දීම් එකතු කරයි .

24import math
25from typing import Optional, Tuple, Union, List
26
27import torch
28from torch import nn
29
30from labml_helpers.module import Module

ස්විස්ෂ්ක්රියාකාරී ශ්රිතය

33class Swish(Module):
40    def forward(self, x):
41        return x * torch.sigmoid(x)

කාවැද්දීම්සඳහා

44class TimeEmbedding(nn.Module):
  • n_channels කාවැද්දීමේ මානයන් ගණන වේ
49    def __init__(self, n_channels: int):
53        super().__init__()
54        self.n_channels = n_channels

පළමුරේඛීය ස්ථරය

56        self.lin1 = nn.Linear(self.n_channels // 4, self.n_channels)

සක්‍රීයකිරීම

58        self.act = Swish()

දෙවනරේඛීය ස්ථරය

60        self.lin2 = nn.Linear(self.n_channels, self.n_channels)
62    def forward(self, t: torch.Tensor):

ට්රාන්ස්ෆෝමරයට සමානසයිනොසොයිඩල් ස්ථාන කාවැද්දීම් සාදන්න

කොහේද half_dim

72        half_dim = self.n_channels // 8
73        emb = math.log(10_000) / (half_dim - 1)
74        emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
75        emb = t[:, None] * emb[None, :]
76        emb = torch.cat((emb.sin(), emb.cos()), dim=1)

එම්එල්පීසමඟ පරිවර්තනය කරන්න

79        emb = self.act(self.lin1(emb))
80        emb = self.lin2(emb)

83        return emb

අවශේෂකොටස

අවශේෂකොටසකදී කණ්ඩායම් සාමාන්යකරණය සමග convolution ස්ථර දෙකක් ඇත. සෑම විභේදනයක්ම අවශේෂ කොටස් දෙකකින් සකසනු ලැබේ.

86class ResidualBlock(Module):
  • in_channels ආදාන නාලිකා ගණන
  • out_channels ආදාන නාලිකා ගණන
  • time_channels කාලය පියවර () කාවැද්දීම් සංඛ්යාව නාලිකා වේ
  • n_groups කණ්ඩායම් සාමාන්යකරණය සඳහා කණ්ඩායම් සංඛ්යාව වේ
  • dropout හැලහැප්මේ අනුපාතය වේ
  • 94    def __init__(self, in_channels: int, out_channels: int, time_channels: int,
    95                 n_groups: int = 32, dropout: float = 0.1):
    103        super().__init__()

    කණ්ඩායම්සාමාන්යකරණය සහ පළමු කැටි ගැසුණු ස්ථරය

    105        self.norm1 = nn.GroupNorm(n_groups, in_channels)
    106        self.act1 = Swish()
    107        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))

    කණ්ඩායම්සාමාන්යකරණය සහ දෙවන කැටි ගැසුණු ස්තරය

    110        self.norm2 = nn.GroupNorm(n_groups, out_channels)
    111        self.act2 = Swish()
    112        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))

    ආදානනාලිකා ගණන ප්රතිදාන නාලිකා ගණනට සමාන නොවේ නම් කෙටිමං සම්බන්ධතාවය ප්රක්ෂේපණය කළ යුතුය

    116        if in_channels != out_channels:
    117            self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1))
    118        else:
    119            self.shortcut = nn.Identity()

    කාලකාවැද්දීම් සඳහා රේඛීය ස්ථරය

    122        self.time_emb = nn.Linear(time_channels, out_channels)
    123        self.time_act = Swish()
    124
    125        self.dropout = nn.Dropout(dropout)
    • x හැඩය ඇත [batch_size, in_channels, height, width]
    • t හැඩය ඇත [batch_size, time_channels]
    127    def forward(self, x: torch.Tensor, t: torch.Tensor):

    පළමුකැටි ගැසුණු ස්ථරය

    133        h = self.conv1(self.act1(self.norm1(x)))

    කාලකාවැද්දීම් එකතු කරන්න

    135        h += self.time_emb(self.time_act(t))[:, :, None, None]

    දෙවනකැටි ගැසුණු ස්ථරය

    137        h = self.conv2(self.dropout(self.act2(self.norm2(h))))

    කෙටිමංසම්බන්ධතාවය එකතු කර ආපසු යන්න

    140        return h + self.shortcut(x)

    අවධානයවාරණ

    මෙය ට්රාන්ස්ෆෝමර් බහු-හිස අවධානයටසමාන වේ.

    143class AttentionBlock(Module):
    • n_channels යනු ආදානයේ නාලිකා ගණන
    • n_heads බහු-හිස අවධානය යොමු ප්රධානීන් සංඛ්යාව වේ
    • d_k එක් එක් හිසෙහි මානයන් ගණන
  • n_groups කණ්ඩායම් සාමාන්යකරණය සඳහා කණ්ඩායම්සංඛ්යාව වේ
  • 150    def __init__(self, n_channels: int, n_heads: int = 1, d_k: int = None, n_groups: int = 32):
    157        super().__init__()

    පෙරනිමි d_k

    160        if d_k is None:
    161            d_k = n_channels

    සාමාන්යකරණයස්ථරය

    163        self.norm = nn.GroupNorm(n_groups, n_channels)

    විමසුම, යතුර සහ අගයන් සඳහා ප්රක්ෂේපණ

    165        self.projection = nn.Linear(n_channels, n_heads * d_k * 3)

    අවසානපරිවර්තනය සඳහා රේඛීය ස්ථරය

    167        self.output = nn.Linear(n_heads * d_k, n_channels)

    තිත්නිෂ්පාදන අවධානය සඳහා පරිමාණය

    169        self.scale = d_k ** -0.5

    171        self.n_heads = n_heads
    172        self.d_k = d_k
    • x හැඩය ඇත [batch_size, in_channels, height, width]
    • t හැඩය ඇත [batch_size, time_channels]
    174    def forward(self, x: torch.Tensor, t: Optional[torch.Tensor] = None):

    t භාවිතා නොකෙරේ, නමුත් එය තර්කවල තබා ඇත්තේ අවධානය යොමු කිරීමේ ස්ථර ශ්රිතයේ අත්සන සමඟ ගැලපෙන බැවිනි ResidualBlock .

    181        _ = t

    හැඩයලබා ගන්න

    183        batch_size, n_channels, height, width = x.shape

    x හැඩයට වෙනස් කරන්න [batch_size, seq, n_channels]

    185        x = x.view(batch_size, n_channels, -1).permute(0, 2, 1)

    විමසුමලබා ගන්න, යතුර, සහ අගයන් (concatenated) සහ එය හැඩගස්වා [batch_size, seq, n_heads, 3 * d_k]

    187        qkv = self.projection(x).view(batch_size, -1, self.n_heads, 3 * self.d_k)

    විමසුම, යතුර සහ අගයන් බෙදීම්. ඔවුන් එක් එක් හැඩය ඇත [batch_size, seq, n_heads, d_k]

    189        q, k, v = torch.chunk(qkv, 3, dim=-1)

    පරිමාණතිත් නිෂ්පාදනයක් ගණනය කරන්න

    191        attn = torch.einsum('bihd,bjhd->bijh', q, k) * self.scale

    අනුක්රමිකමානය ඔස්සේ සොෆ්ට්මැක්ස්

    193        attn = attn.softmax(dim=2)

    අගයන්අනුව ගුණ කරන්න

    195        res = torch.einsum('bijh,bjhd->bihd', attn, v)

    නැවතහැඩගස්වන්න [batch_size, seq, n_heads * d_k]

    197        res = res.view(batch_size, -1, self.n_heads * self.d_k)

    බවටපරිවර්තනය කරන්න [batch_size, seq, n_channels]

    199        res = self.output(res)

    මඟහැරීමේ සම්බන්ධතාවය එක් කරන්න

    202        res += x

    හැඩයටවෙනස් කරන්න [batch_size, in_channels, height, width]

    205        res = res.permute(0, 2, 1).view(batch_size, n_channels, height, width)

    208        return res

    බ්ලොක්ඩවුන්

    මෙයඒකාබද්ධ ResidualBlock හා AttentionBlock . මෙම එක් එක් යෝජනාව දී U-Net පළමු භාගයේ දී භාවිතා වේ.

    211class DownBlock(Module):
    218    def __init__(self, in_channels: int, out_channels: int, time_channels: int, has_attn: bool):
    219        super().__init__()
    220        self.res = ResidualBlock(in_channels, out_channels, time_channels)
    221        if has_attn:
    222            self.attn = AttentionBlock(out_channels)
    223        else:
    224            self.attn = nn.Identity()
    226    def forward(self, x: torch.Tensor, t: torch.Tensor):
    227        x = self.res(x, t)
    228        x = self.attn(x)
    229        return x

    බ්ලොක්දක්වා

    මෙයඒකාබද්ධ ResidualBlock හා AttentionBlock . සෑම විභේදනයකදීම යූ-නෙට් හි දෙවන භාගයේදී මේවා භාවිතා වේ.

    232class UpBlock(Module):
    239    def __init__(self, in_channels: int, out_channels: int, time_channels: int, has_attn: bool):
    240        super().__init__()

    ආදානයටඇත්තේ යූ-නෙට් හි පළමු භාගයේ සිට එකම විභේදනයේ ප්රතිදානය අප සංයුක්ත කරන in_channels + out_channels බැවිනි

    243        self.res = ResidualBlock(in_channels + out_channels, out_channels, time_channels)
    244        if has_attn:
    245            self.attn = AttentionBlock(out_channels)
    246        else:
    247            self.attn = nn.Identity()
    249    def forward(self, x: torch.Tensor, t: torch.Tensor):
    250        x = self.res(x, t)
    251        x = self.attn(x)
    252        return x

    මැදකොටස

    එයතවත් එකක් ResidualBlock සමඟ ඒකාබද්ධ ResidualBlock වේ. AttentionBlock මෙම කොටස U-Net හි අඩුම විභේදනයෙන් යොදනු ලැබේ.

    255class MiddleBlock(Module):
    263    def __init__(self, n_channels: int, time_channels: int):
    264        super().__init__()
    265        self.res1 = ResidualBlock(n_channels, n_channels, time_channels)
    266        self.attn = AttentionBlock(n_channels)
    267        self.res2 = ResidualBlock(n_channels, n_channels, time_channels)
    269    def forward(self, x: torch.Tensor, t: torch.Tensor):
    270        x = self.res1(x, t)
    271        x = self.attn(x)
    272        x = self.res2(x, t)
    273        return x

    විශේෂාංගසිතියම පරිමාණය කරන්න

    276class Upsample(nn.Module):
    281    def __init__(self, n_channels):
    282        super().__init__()
    283        self.conv = nn.ConvTranspose2d(n_channels, n_channels, (4, 4), (2, 2), (1, 1))
    285    def forward(self, x: torch.Tensor, t: torch.Tensor):

    t භාවිතා නොකෙරේ, නමුත් එය තර්කවල තබා ඇත්තේ අවධානය යොමු කිරීමේ ස්ථර ශ්රිතයේ අත්සන සමඟ ගැලපෙන බැවිනි ResidualBlock .

    288        _ = t
    289        return self.conv(x)

    විශේෂාංගසිතියම පරිමාණය කරන්න

    292class Downsample(nn.Module):
    297    def __init__(self, n_channels):
    298        super().__init__()
    299        self.conv = nn.Conv2d(n_channels, n_channels, (3, 3), (2, 2), (1, 1))
    301    def forward(self, x: torch.Tensor, t: torch.Tensor):

    t භාවිතා නොකෙරේ, නමුත් එය තර්කවල තබා ඇත්තේ අවධානය යොමු කිරීමේ ස්ථර ශ්රිතයේ අත්සන සමඟ ගැලපෙන බැවිනි ResidualBlock .

    304        _ = t
    305        return self.conv(x)

    යූ-නෙට්

    308class UNet(Module):
    • image_channels යනු රූපයේ නාලිකා ගණන. RGB සඳහා.
    • n_channels ආරම්භක විශේෂාංග සිතියමේ නාලිකා ගණන අපි රූපය බවට පරිවර්තනය කරමු
    • ch_mults යනු එක් එක් විභේදනයේ නාලිකා අංක ලැයිස්තුවයි. නාලිකා ගණන වේ ch_mults[i] * n_channels
    • is_attn යනු එක් එක් විභේදනයේ දී අවධානය යොමු කළ යුතුද යන්න පෙන්නුම් කරන බූලියන් ලැයිස්තුවකි
    • n_blocks එක් එක් යෝජනාව UpDownBlocks දී සංඛ්යාව වේ
    313    def __init__(self, image_channels: int = 3, n_channels: int = 64,
    314                 ch_mults: Union[Tuple[int, ...], List[int]] = (1, 2, 2, 4),
    315                 is_attn: Union[Tuple[bool, ...], List[int]] = (False, False, True, True),
    316                 n_blocks: int = 2):
    324        super().__init__()

    යෝජනාගණන

    327        n_resolutions = len(ch_mults)

    විශේෂාංගසිතියමට ව්යාපෘති රූපය

    330        self.image_proj = nn.Conv2d(image_channels, n_channels, kernel_size=(3, 3), padding=(1, 1))

    කාලයකාවැද්දීම ස්ථරය. කාල කාවැද්දීම n_channels * 4 නාලිකා ඇත

    333        self.time_emb = TimeEmbedding(n_channels * 4)

    U-Netපළමු භාගය - අඩු යෝජනාව

    336        down = []

    නාලිකාගණන

    338        out_channels = in_channels = n_channels

    එක්එක් යෝජනාව සඳහා

    340        for i in range(n_resolutions):

    මෙමවිභේදනයේ ප්රතිදාන නාලිකා ගණන

    342            out_channels = in_channels * ch_mults[i]

    එකතුකරන්න n_blocks

    344            for _ in range(n_blocks):
    345                down.append(DownBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
    346                in_channels = out_channels

    පසුගියහැර අනෙකුත් සියලු යෝජනා දී ආදර්ශ පහළ

    348            if i < n_resolutions - 1:
    349                down.append(Downsample(in_channels))

    මොඩියුලකට්ටලය ඒකාබද්ධ කරන්න

    352        self.down = nn.ModuleList(down)

    මැදකොටස

    355        self.middle = MiddleBlock(out_channels, n_channels * 4, )

    යූ-නෙට්හි දෙවන භාගය - වැඩි කිරීමේ විභේදනය

    358        up = []

    නාලිකාගණන

    360        in_channels = out_channels

    එක්එක් යෝජනාව සඳහා

    362        for i in reversed(range(n_resolutions)):

    n_blocks එකම විභේදනයේ

    364            out_channels = in_channels
    365            for _ in range(n_blocks):
    366                up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))

    නාලිකාගණන අඩු කිරීම සඳහා අවසාන කොටස

    368            out_channels = in_channels // ch_mults[i]
    369            up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
    370            in_channels = out_channels

    පසුගියහැර අනෙකුත් සියලු යෝජනා දී ආදර්ශ දක්වා

    372            if i > 0:
    373                up.append(Upsample(in_channels))

    මොඩියුලකට්ටලය ඒකාබද්ධ කරන්න

    376        self.up = nn.ModuleList(up)

    අවසානසාමාන්යකරණය සහ කැටි ගැසුණු ස්ථරය

    379        self.norm = nn.GroupNorm(8, n_channels)
    380        self.act = Swish()
    381        self.final = nn.Conv2d(in_channels, image_channels, kernel_size=(3, 3), padding=(1, 1))
    • x හැඩය ඇත [batch_size, in_channels, height, width]
    • t හැඩය ඇත [batch_size]
    383    def forward(self, x: torch.Tensor, t: torch.Tensor):

    කාල-පියවරකාවැද්දීම් ලබා ගන්න

    390        t = self.time_emb(t)

    රූපප්රක්ෂේපණය ලබා ගන්න

    393        x = self.image_proj(x)

    h මඟ හැරීමේ සම්බන්ධතාවය සඳහා එක් එක් විභේදනයේ ප්රතිදානයන් ගබඩා කරනු ඇත

    396        h = [x]

    යූ-නෙට්හි පළමු භාගය

    398        for m in self.down:
    399            x = m(x, t)
    400            h.append(x)

    මැද(පහළ)

    403        x = self.middle(x, t)

    යූ-නෙට්හි දෙවන භාගය

    406        for m in self.up:
    407            if isinstance(m, Upsample):
    408                x = m(x, t)
    409            else:

    යූ-නෙට්හි පළමු භාගයේ සිට මඟ හැරීමේ සම්බන්ධතාවය ලබාගෙන සංයුක්ත කරන්න

    411                s = h.pop()
    412                x = torch.cat((x, s), dim=1)

    414                x = m(x, t)

    අවසානසාමාන්යකරණය සහ කැටි කිරීම

    417        return self.final(self.act(self.norm(x)))