14import math
15
16import torch
17import torch.nn as nn
18
19from labml_nn.utils import clone_module_list
20from .feed_forward import FeedForward
21from .mha import MultiHeadAttention
22from .positional_encoding import get_positional_encoding25class EmbeddingsWithPositionalEncoding(nn.Module):32 def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000):
33 super().__init__()
34 self.linear = nn.Embedding(n_vocab, d_model)
35 self.d_model = d_model
36 self.register_buffer('positional_encodings', get_positional_encoding(d_model, max_len))38 def forward(self, x: torch.Tensor):
39 pe = self.positional_encodings[:x.shape[0]].requires_grad_(False)
40 return self.linear(x) * math.sqrt(self.d_model) + pe43class EmbeddingsWithLearnedPositionalEncoding(nn.Module):50 def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000):
51 super().__init__()
52 self.linear = nn.Embedding(n_vocab, d_model)
53 self.d_model = d_model
54 self.positional_encodings = nn.Parameter(torch.zeros(max_len, 1, d_model), requires_grad=True)56 def forward(self, x: torch.Tensor):
57 pe = self.positional_encodings[:x.shape[0]]
58 return self.linear(x) * math.sqrt(self.d_model) + peමෙයඑන්කෝඩර් තට්ටුවක් හෝ විකේතක තට්ටුවක් ලෙස ක්රියා කළ හැකිය.
🗒කඩදාසි ඇතුළු සමහර ක්රියාත්මක කිරීම්, ස්ථර-සාමාන්යකරණය සිදු කරන ස්ථානයේ වෙනස්කම් ඇති බව පෙනේ. මෙන්න අපි අවධානය සහ පෝෂක ඉදිරියට ජාල පෙර ස්ථරයක් සාමාන්යකරණය කරන්න, සහ මුල් අවශේෂ දෛශික එකතු කරන්න. විකල්පයක් වන්නේ අපද්රව්ය එකතු කිරීමෙන් පසු ස්ථර සාමාන්යකරණය කිරීමයි. නමුත් පුහුණුවීමේදී මෙය අඩු ස්ථාවර බව අපට පෙනී ගියේය. ට්රාන්ස්ෆෝමර් ගෘහ නිර්මාණ ශිල්පයේ ON Layer සාමාන්යකරණය පිළිබඳපත්රිකාවේ මේ පිළිබඳව සවිස්තරාත්මක සාකච්ඡාවක් අපට හමු විය.
61class TransformerLayer(nn.Module):d_model
ටෝකනය කාවැද්දීමේ ප්රමාණයයි self_attn
ස්වයං අවධානය මොඩියුලය src_attn
යනු ප්රභව අවධානය යොමු කිරීමේ මොඩියුලය (මෙය විකේතකයක් තුළ භාවිතා කරන විට) feed_forward
යනු ආහාර ඉදිරි මොඩියුලයයි dropout_prob
ස්වයං අවධානයෙන් පසු ඉවත් වීමේ සම්භාවිතාව සහ FFN79 def __init__(self, *,
80 d_model: int,
81 self_attn: MultiHeadAttention,
82 src_attn: MultiHeadAttention = None,
83 feed_forward: FeedForward,
84 dropout_prob: float):92 super().__init__()
93 self.size = d_model
94 self.self_attn = self_attn
95 self.src_attn = src_attn
96 self.feed_forward = feed_forward
97 self.dropout = nn.Dropout(dropout_prob)
98 self.norm_self_attn = nn.LayerNorm([d_model])
99 if self.src_attn is not None:
100 self.norm_src_attn = nn.LayerNorm([d_model])
101 self.norm_ff = nn.LayerNorm([d_model])ආහාරඉදිරි ස්ථරයට ආදානය ඉතිරි කර ගත යුතුද යන්න
103 self.is_save_ff_input = False105 def forward(self, *,
106 x: torch.Tensor,
107 mask: torch.Tensor,
108 src: torch.Tensor = None,
109 src_mask: torch.Tensor = None):ස්වයංඅවධානය යොමු කිරීමට පෙර දෛශික සාමාන්යකරණය කරන්න
111 z = self.norm_self_attn(x)ස්වයංඅවධානය හරහා ධාවනය කරන්න, i.e. යතුරු සහ වටිනාකම් ස්වයං සිට
113 self_attn = self.self_attn(query=z, key=z, value=z, mask=mask)ස්වයංඅවධානය ප්රතිඵල එකතු
115 x = x + self.dropout(self_attn)ප්රභවයක්ලබා දෙන්නේ නම්, ප්රභවයට අවධානය යොමු කිරීමෙන් ප්රති results ල ලබා ගන්න. එන්කෝඩර් ප්රතිදානයන් කෙරෙහි අවධානය යොමු කරන විකේතක තට්ටුවක් ඔබට ඇති විට මෙය
වේ120 if src is not None:දෛශිකසාමාන්යකරණය කරන්න
122 z = self.norm_src_attn(x)ප්රභවයටඅවධානය යොමු කරන්න. එනම් යතුරු සහ අගයන් ප්රභවයෙන් වේ
124 attn_src = self.src_attn(query=z, key=src, value=src, mask=src_mask)මූලාශ්රඅවධානය යොමු ප්රතිඵල එක් කරන්න
126 x = x + self.dropout(attn_src)පෝෂණයසඳහා සාමාන්යකරණය කරන්න
129 z = self.norm_ff(x)නිශ්චිතවදක්වා ඇත්නම් ආදානය ආහාර ඉදිරි ස්ථරයට සුරකින්න
131 if self.is_save_ff_input:
132 self.ff_input = z.clone()Feed-forwardජාලය හරහා ගමන් කරන්න
134 ff = self.feed_forward(z)ප්රතිපෝෂණඉදිරි ප්රති results ල නැවත එක් කරන්න
136 x = x + self.dropout(ff)
137
138 return x141class Encoder(nn.Module):148 def __init__(self, layer: TransformerLayer, n_layers: int):
149 super().__init__()ට්රාන්ස්ෆෝමර්ස්ථරයේ පිටපත් සාදන්න
151 self.layers = clone_module_list(layer, n_layers)අවසානසාමාන්යකරණ ස්තරය
153 self.norm = nn.LayerNorm([layer.size])155 def forward(self, x: torch.Tensor, mask: torch.Tensor):එක්එක් ට්රාන්ස්ෆෝමර් ස්ථරය හරහා ධාවනය කරන්න
157 for layer in self.layers:
158 x = layer(x=x, mask=mask)අවසානවශයෙන්, දෛශික සාමාන්යකරණය කරන්න
160 return self.norm(x)163class Decoder(nn.Module):170 def __init__(self, layer: TransformerLayer, n_layers: int):
171 super().__init__()ට්රාන්ස්ෆෝමර්ස්ථරයේ පිටපත් සාදන්න
173 self.layers = clone_module_list(layer, n_layers)අවසානසාමාන්යකරණ ස්තරය
175 self.norm = nn.LayerNorm([layer.size])177 def forward(self, x: torch.Tensor, memory: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):එක්එක් ට්රාන්ස්ෆෝමර් ස්ථරය හරහා ධාවනය කරන්න
179 for layer in self.layers:
180 x = layer(x=x, mask=tgt_mask, src=memory, src_mask=src_mask)අවසානවශයෙන්, දෛශික සාමාන්යකරණය කරන්න
182 return self.norm(x)මෙයටෝකන පුරෝකථනය කරන අතර එම අයගේ සොෆ්ට්මැක්ස් ලබා දෙයි. ඔබ භාවිතා කරන්නේ නම් ඔබට මෙය අවශ්ය නොවේ nn.CrossEntropyLoss
.
185class Generator(nn.Module):195 def __init__(self, n_vocab: int, d_model: int):
196 super().__init__()
197 self.projection = nn.Linear(d_model, n_vocab)199 def forward(self, x):
200 return self.projection(x)203class EncoderDecoder(nn.Module):210 def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: nn.Module, tgt_embed: nn.Module, generator: nn.Module):
211 super().__init__()
212 self.encoder = encoder
213 self.decoder = decoder
214 self.src_embed = src_embed
215 self.tgt_embed = tgt_embed
216 self.generator = generatorමෙයඔවුන්ගේ කේතයෙන් වැදගත් විය. ග්ලෝරෝට්/fan_avg සමඟ පරාමිතීන් ආරම්භ කරන්න.
220 for p in self.parameters():
221 if p.dim() > 1:
222 nn.init.xavier_uniform_(p)224 def forward(self, src: torch.Tensor, tgt: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):එන්කෝඩරයහරහා ප්රභවය ධාවනය කරන්න
226 enc = self.encode(src, src_mask)විකේතකයහරහා කේතීකරණ සහ ඉලක්ක ධාවනය කරන්න
228 return self.decode(enc, src_mask, tgt, tgt_mask)230 def encode(self, src: torch.Tensor, src_mask: torch.Tensor):
231 return self.encoder(self.src_embed(src), src_mask)233 def decode(self, memory: torch.Tensor, src_mask: torch.Tensor, tgt: torch.Tensor, tgt_mask: torch.Tensor):
234 return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)