ජීපීටී-නියෝක්ස්ආකෘතිය

ජීපීටී-නියෝක්ස්ආකෘතියේ ස්ථර සඳහා කේතය සහ 20B මුරපොල පූරණය කිරීමේ කේතය මෙන්න.

ස්ථර load_state වල ඇති ක්රමය එම ස්ථරයේ මුරපොලවල් පූරණය කරයි. මුරපොලවල් පැටවීමේ සහායකයන් ක්රියාත්මක වේ checkpoint.py

16import copy
17import math
18from typing import Dict, Optional, Set, Callable, Any, Generator, Tuple
19
20import torch
21from torch import nn
22from torch.cuda.amp import autocast
23
24from labml import monit
25from labml_nn.neox import checkpoint
26from labml_nn.neox.utils.cache import get_cache
29class NeoXModule(nn.Module):
30    def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
31        pass

කාවැද්දීමස්ථරය

මෙයමුරපොලට පැටවීම සඳහා කේතය සහිත සම්මත කාවැද්දීම් ස්ථරයකි.

34class Embedding(NeoXModule):
  • n_vocab වචන මාලාවේ ප්රමාණය වේ
  • n_hidden මෙම කාවැද්දීම් ප්රමාණය
41    def __init__(self, n_vocab: int = 50_432, n_hidden: int = 6_144):
46        super().__init__()
47
48        self.emb = nn.Embedding(n_vocab, n_hidden)
  • x හැඩයේ ටෝකන් හැඳුනුම් වේ [batch_size, seq_len]
50    def forward(self, x: torch.Tensor):
54        return self.emb(x)

මුරපොලපූරණය කිරීමට කේතය

56    def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
60        with monit.section('Load embedding layer'):
61            checkpoint.merge_params_dim_0(self.emb.weight, 'word_embeddings.weight', p1, p2)

රොටරිස්ථානීය කාවැද්දීම්

ජීපීටී-නියෝක්ස් භ්රමණ ස්ථානීය කාවැද්දීම් (කඹය)භාවිතා කරයි.

අපින්යාය වැඩි සටහන් සමඟ මෙහි කඹය ක්රියාත්මක කිරීම විස්තර කර ඇත.

64class RoPE(nn.Module):
  • d_rope කඹය කාවැද්දීම් සඳහා විශේෂාංග ගණන
  • base සඳහා පදනම වේ , එය පැහැර හරින
74    def __init__(self, d_rope: int, base: float = 10_000.):
79        super().__init__()

විශේෂාංග සඳහා ගබඩා කිරීමට

82        self.theta = None

හැඹිලිය සහ

84        self.cos_cached = None
85        self.sin_cached = None

සඳහාමූලික

88        self.base = base

කඹයසඳහා විශේෂාංග ගණන

90        self.d_rope = d_rope

විශේෂාංගකරකවන්න

92    @staticmethod
93    def rotate_half(x: torch.Tensor):
99        x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
100        return torch.cat((-x2, x1), dim=-1)
  • x හැඩය ඇත [..., seq, n_heads, d_k]
  • offset යනු ආරම්භක ස්ථානයයි x . පෙර තනතුරු වල යතුරු සහ විමසුම් අප හැඹිලි කළ විට මෙය සිදු වේ
102    def forward(self, x: torch.Tensor, offset: int = 0):

සත්යඅනුක්රමයේ දිග ලබා ගන්න

110        seq_len = x.shape[-3] + offset

ආරම්භකරන්න

113        if self.theta is None:

115            theta = 1.0 / (self.base ** (torch.arange(0, self.d_rope, 2).float() / self.d_rope))
116            self.theta = theta.to(x.device).to(x.dtype)

ආරම්භකරන්න සහ හැඹිලිය

119        if (
120                self.cos_cached is None or
121                seq_len > self.cos_cached.shape[1] or
122                self.cos_cached.device != x.device or
123                self.cos_cached.dtype != x.dtype
124        ):

ස්ථානදර්ශක ලබා ගන්න

126            seq_idx = torch.arange(seq_len, device=x.device).type_as(self.theta)

128            idx_theta = torch.einsum("s,d->sd", seq_idx, self.theta)

පේළියසඳහා අපට ඇති පරිදි සංයුක්ත කරන්න

132            idx_theta2 = torch.cat((idx_theta, idx_theta), dim=-1).to(x.device)

ගණනයකරන්න සහ fp32

135            with autocast(enabled=False):
136                idx_theta2 = idx_theta2.float()

හිසමානයක් එක් කරන්න

138                self.cos_cached = idx_theta2.cos()[:, None, :]
139                self.sin_cached = idx_theta2.sin()[:, None, :]

ඒවාහැඹිලිය

142            self.cos_cached = self.cos_cached.to(x.dtype)
143            self.sin_cached = self.sin_cached.to(x.dtype)

විශේෂාංගබෙදන්න. d_rope විශේෂාංග සඳහා පමණක් අපි කඹය යොදන්නෙමු

146        x_rope, x_pass = x[..., :self.d_rope], x[..., self.d_rope:]

හැඹිලියසිට පාපය සහ කෝස් අගයන් ලබා ගන්න

149        cos, sin = self.cos_cached[offset: seq_len], self.sin_cached[offset: seq_len]

කඹයකාවැද්දීම්

සඳහා

161        x_rope = (x_rope * cos) + (self.rotate_half(x_rope) * sin)

කඹයකාවැද්දීම් ලබා නොගත් විශේෂාංග සමඟ සංයුක්ත වන්න

164        return torch.cat((x_rope, x_pass), dim=-1)

අවධානයස්ථරය

167class AttentionLayer(nn.Module):
  • n_hidden කාවැද්දීම් වල විශේෂාංග ගණන
  • n_heads අවධානය යොමු ප්රධානීන් සංඛ්යාව
  • rope_percentage කඹය කාවැද්දීම් එකතු කිරීම සඳහා විශේෂාංග ප්රතිශතය
  • mask_fill අවධානය යොමු න්යාසය සඳහා ආවරණ පිරවුම් අගය
172    def __init__(self, n_hidden: int = 6_144, n_heads: int = 64, rope_percentage: float = 0.25,
173                 mask_fill: float = -10_000.0):
180        super().__init__()
181
182        self.n_heads = n_heads
183        self.mask_fill = mask_fill

විමසුම, යතුර සහ වටිනාකම සඳහා රේඛීය ස්ථරය

186        self.qkv_lin = nn.Linear(n_hidden, n_hidden * 3)

අවසානරේඛීය ස්ථරය

188        self.output = nn.Linear(n_hidden, n_hidden)

හිසකටවිශේෂාංග ගණන

191        d_k = n_hidden // n_heads

කඹයකාවැද්දීම මොඩියුලය

193        self.rope = RoPE(int(d_k * rope_percentage))

අවධානයපරිමාණ සාධකය

196        self.scale = 1 / math.sqrt(d_k)

හේතුආවරණ හැඹිලි කිරීමට

199        self.causal_mask = None

අවධානයයොමු කරන්න සොෆ්ට්මැක්ස් මොඩියුලය

202        self.softmax = nn.Softmax(dim=-2)
204    def _get_mask(self, attn: torch.Tensor):

විමසුමසහ යතුරු දිග

212        nq, nk = attn.shape[1:3]

වෙස්මුහුණ සාදන්න

215        if (
216                self.causal_mask is None or
217                self.causal_mask.shape[0] != nq or
218                self.causal_mask.shape[1] != nk or
219                self.causal_mask.device != attn.device
220        ):
221            self.causal_mask = torch.triu(attn.new_ones([nq, nk], dtype=torch.bool), 1 + nk - nq)

හැඹිලියසිට ආපසු

224        return self.causal_mask[None, :, :, None]
  • x හැඩය ඇත [batch_size, seq_len, n_hidden]
226    def forward(self, x: torch.Tensor):

විමසුම, යතුර සහ වටිනාකම් කාවැද්දීම් ලබා ගන්න (සියල්ල සංයුක්ත කර ඇත). පසුගිය මානයක් ප්රමාණය n_hidden -> සිට වෙනස් වනු ඇත 3 x n_hidden

232        qkv = self.qkv_lin(x)

හැඩයවෙනස් කිරීමෙන් හිස් වලට බෙදන්න [batch_size, seq_len, n_heads, 3 * d_k]

235        qkv = qkv.view(*qkv.shape[:-1], self.n_heads, -1)

විමසුමටබෙදන්න, යතුර සහ හැඩය එක් එක් අගය කරන්න [batch_size, seq_len, n_heads, 3 * d_k]

237        q, k, v = torch.split(qkv, qkv.shape[-1] // 3, dim=-1)

අපිපෙර ටෝකන වල තත්වයන් හැඹිලි කරන්නේ නම්

240        if get_cache().get('use_cache', False):

රාජ්යid ගේ ලබා ගන්න. අපි පෙර රාජ්යයන් ලබා ගැනීමට හා ඉදිරි රාජ්යයන් ගබඩා කිරීම සඳහා භාවිතා

242            prev_state_id, next_state_id = get_cache().get('state_ids')

හැඹිලියතිබේ නම්

244            if prev_state_id is not None:

අතීතයතුරු සහ අගයන් ලබා ගන්න. මේවාට හැඩය ඇත [batch_size, prev_seq_len, n_heads, d_k]

246                k_past, v_past = get_cache().pop(f'attn_kv_{prev_state_id}')

වත්මන්කාවැද්දීම් වල ඕෆ්සෙට්

248                offset = k_past.shape[1]

කඹයකාවැද්දීම් එකතු කරන්න

251                q = self.rope(q, offset=offset)
252                k = self.rope(k, offset=offset)

අතීතයසංයුක්ත කරන්න

255                k = torch.cat([k_past, k], dim=1)
256                v = torch.cat([v_past, v], dim=1)
257            else:

කඹයකාවැද්දීම් එකතු කරන්න

259                q = self.rope(q)
260                k = self.rope(k)

වත්මන්තත්වය සුරකින්න

263            get_cache().push(f'attn_kv_{next_state_id}', (k, v))
264        else:

හැඹිලියක්නැත - කඹය කාවැද්දීම් එකතු කරන්න

266            q = self.rope(q)
267            k = self.rope(k)

අවධානයගණනය කිරීම සඳහා fp16 කිරීමට ස්වයංක්රීය-වාත්තු අක්රීය

270        with autocast(enabled=False):
271            if q.dtype == torch.float16:

වත්මන්dtype fp16 නම් fp32 බවට පරිවර්තනය කරන්න

273                attn = torch.einsum('bihk,bjhk->bijh', q.float(), k.float())
274            else:

bfloatසඳහා වාත්තු නොකරන්න

276                attn = torch.einsum('bihk,bjhk->bijh', q, k)

පරිමාණඅවධානය

279            attn = attn * self.scale

හේතුවෙස්මුහුණ ලබා ගන්න

282            mask = self._get_mask(attn)

වෙස්යොදන්න

284            attn.masked_fill_(mask, self.mask_fill)

අවධානයසොෆ්ට්මැක්ස්

287            attn = self.softmax(attn)

අවධානයබර තැබූ අගයන් ලබා ගන්න

290        output = torch.einsum('bijh,bjhk->bihk', attn.to(v.dtype), v)

[batch_size, seq_len, n_heads, d_k] to Batch_size, seq_len, n_hidden`වෙතින් නැවත සකස් කරන්න

293        output = output.reshape(*x.shape)

අවසානරේඛීය ස්ථරය

296        return self.output(output)

ප්රතිපෝෂණජාලය

299class FFNLayer(nn.Module):
  • n_hidden කාවැද්දීම ප්රමාණය වේ
304    def __init__(self, n_hidden: int = 6_144, d_ff: int = 0):
308        super().__init__()
309
310        if not d_ff:
311            d_ff = n_hidden * 4

පුළුල්රේඛීය ස්ථරය

314        self.dense_h_h4 = nn.Linear(n_hidden, d_ff)

GELUසක්රිය කිරීම

316        self.activation = nn.GELU()

සංකෝචනයරේඛීය ස්ථරය

318        self.dense_h4_h = nn.Linear(d_ff, n_hidden)
  • x හැඩය ඇත [batch_size, seq_len, n_hidden]
320    def forward(self, x: torch.Tensor):
324        x = self.dense_h_h4(x)
325        x = self.activation(x)
326        x = self.dense_h4_h(x)
327
328        return x

ට්රාන්ස්ෆෝමර්ස්ථරය

331class TransformerLayer(NeoXModule):
  • n_hidden කාවැද්දීම ප්රමාණය වේ
  • n_heads හිස් සංඛ්යාව වේ

පිටතක්රියාත්මක කිරීම අතහැර දැමීම ඇතුළත් නොවේ.

336    def __init__(self, n_hidden: int = 6_144, n_heads: int = 64):
343        super().__init__()

අවධානයටපෙර ස්ථර සාමාන්යකරණය

346        self.pre_ln_attn = nn.LayerNorm(n_hidden)

FFNට පෙර ස්ථර සාමාන්යකරණය

348        self.pre_ln_ffn = nn.LayerNorm(n_hidden)

අවධානයස්ථරය

351        self.attention = AttentionLayer(n_hidden, n_heads)

FFNස්ථරය

353        self.ffn = FFNLayer(n_hidden)
  • x හැඩයේ කාවැද්දීම් වේ [batch_size, seq_len, n_hidden]
355    def forward(self, x: torch.Tensor):

අවශේෂසම්බන්ධතාවය

361        residual = x

නියෝක්ස්සමාන්තරව අවධානය සහ ප්රතිපෝෂණ ජාලය ක්රියාත්මක කරයි

363        attn = self.attention(self.pre_ln_attn(x))
364        ffn = self.ffn(self.pre_ln_ffn(x))

ඒවාසහ අවශේෂ සම්බන්ධතාවය එකතු කරන්න

366        return attn + ffn + residual

මුරපොලපූරණය කිරීමට කේතය

368    def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
372        with monit.section('Load transformer layer'):

අවධානයයොමු ප්රතිදානය පරිණාමනය

374            checkpoint.merge_params_sum(self.attention.output.bias, 'attention.dense.bias', p1, p2)
375            checkpoint.merge_params_dim_1(self.attention.output.weight, 'attention.dense.weight', p1, p2)

අවධානයවිමසුම, යතුර සහ අගය පරිවර්තනය කිරීම

378            checkpoint.merge_params_dim_0(self.attention.qkv_lin.bias, 'attention.query_key_value.bias', p1, p2)
379            checkpoint.merge_params_dim_0(self.attention.qkv_lin.weight, 'attention.query_key_value.weight', p1, p2)

අවධානයටපෙර ස්ථර සම්මතය

382            checkpoint.merge_params_duplicate(self.pre_ln_attn.bias, 'input_layernorm.bias', p1, p2)
383            checkpoint.merge_params_duplicate(self.pre_ln_attn.weight, 'input_layernorm.weight', p1, p2)

FFNදෙවන පරිණාමනය

386            checkpoint.merge_params_dim_0(self.ffn.dense_h_h4.bias, 'mlp.dense_h_to_4h.bias', p1, p2)
387            checkpoint.merge_params_dim_0(self.ffn.dense_h_h4.weight, 'mlp.dense_h_to_4h.weight', p1, p2)

FFNපළමු පරිවර්තනය

390            checkpoint.merge_params_sum(self.ffn.dense_h4_h.bias, 'mlp.dense_4h_to_h.bias', p1, p2)
391            checkpoint.merge_params_dim_1(self.ffn.dense_h4_h.weight, 'mlp.dense_4h_to_h.weight', p1, p2)

FFNට පෙර ස්ථර සම්මතය

394            checkpoint.merge_params_duplicate(self.pre_ln_ffn.bias, 'post_attention_layernorm.bias', p1, p2)
395            checkpoint.merge_params_duplicate(self.pre_ln_ffn.weight, 'post_attention_layernorm.weight', p1, p2)

අවසානසාමාන්යකරණ ස්තරය

398class FinalNorm(NeoXModule):
  • n_hidden කාවැද්දීම ප්රමාණය වේ
403    def __init__(self, n_hidden: int = 6_144):
407        super().__init__()
408
409        self.ln = nn.LayerNorm(n_hidden)
  • x හැඩයේ කාවැද්දීම් වේ [batch_size, seq_len, n_hidden]
411    def forward(self, x: torch.Tensor):
415        return self.ln(x)

මුරපොලපූරණය කිරීමට කේතය

417    def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
421        with monit.section('Load final normalization layer'):
422            checkpoint.merge_params_duplicate(self.ln.bias, 'norm.bias', p1, p2)
423            checkpoint.merge_params_duplicate(self.ln.weight, 'norm.weight', p1, p2)

කියවීමේස්ථරය

426class ReadoutLayer(NeoXModule):
  • n_hidden කාවැද්දීම ප්රමාණය වේ
  • n_vocab වචන මාලාවේ ප්රමාණය වේ
431    def __init__(self, n_hidden: int = 6_144, n_vocab: int = 50_432):
436        super().__init__()
437
438        self.linear = nn.Linear(n_hidden, n_vocab, bias=False)
  • x හැඩයේ කාවැද්දීම් වේ [batch_size, seq_len, n_hidden]
440    def forward(self, x: torch.Tensor):
444        return self.linear(x)

මුරපොලපූරණය කිරීමට කේතය

446    def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
450        with monit.section('Load final linear layer'):
451            checkpoint.merge_params_dim_0(self.linear.weight, 'final_linear.weight', p1, p2)
454class LayerGenerator:
455    pre_created_layers: Dict[Any, Optional[NeoXModule]]

ස්ථරනිර්මාණය කිරීමට උත්පාදක යන්ත්රය

ස්ථරජනනය කරනු ලබන්නේ මුරපොලවල් මෙන් එකම අනුපිළිවෙලකි.

ස්ථරයක්නොමැති None විට එය ලබා දෙයි; අපි ස්ථර දර්ශක නියෝක්ස් ලෙස භාවිතා කරන අතර අපගේ ක්රියාත්මක කිරීමේදී අපට අවශ්ය නොවන පරිවර්තන ස්ථර දෙකක් තිබේ.

  • n_vocab යනු වචන මාලාවේ ටෝකන ගණන
  • n_hidden කාවැද්දීම් වල ඇති ලක්ෂණ ගණන
  • n_layers ට්රාන්ස්ෆෝමර් ස්ථර ගණන
  • n_heads අවධානය යොමු ප්රධානීන් සංඛ්යාව වේ
  • filter_layers භාවිතා කළ යුතු ස්ථර සමූහයයි. කිසිවක් නොමැති නම් සියලුම ස්ථර භාවිතා කරනු ඇත. අඩු ස්ථර සහිත ආකෘතියේ කුඩා අනුවාදයන් පරීක්ෂා කිරීමට මෙය භාවිතා
  • කරයි
  • is_clone_layers ට්රාන්ස්ෆෝමර් ස්ථර ක්ලෝන කළ යුතුද යන්න නියම කරයි (ටිකක් වේගවත්)
  • dtype ආකෘතියේ දත්ත වර්ගයයි
  • device ආකෘතියේ උපාංගය වේ
  • is_llm_int8 INT8 ප්රමාණකරණය භාවිතා කළ යුතුද යන්න නියම කරයි
  • llm_int8_threshold යනු පිටත විශේෂාංග වෙන් කිරීම සඳහා භාවිතා කරන එළිපත්ත
457    def __init__(self, *, n_vocab: int = 50_432, n_hidden: int = 6_144,
458                 n_layers: int = 44, n_heads: int = 64,
459                 filter_layers: Optional[Set] = None,
460                 is_clone_layers: bool = True,
461                 dtype: torch.dtype = torch.float,
462                 device: torch.device = torch.device('cpu'),
463                 is_llm_int8: bool = False,
464                 llm_int8_threshold: float = 6.0,
465                 ):
486        if filter_layers is None:
487            filter_layers = set(range(n_layers + 3))
488
489        self.n_vocab = n_vocab
490        self.n_hidden = n_hidden
491        self.n_layers = n_layers
492        self.n_heads = n_heads
493        self.filter_layers = filter_layers
494        self.is_clone_layers = is_clone_layers
495        self.dtype = dtype
496        self.device = device
497        self.is_llm_int8 = is_llm_int8
498        self.llm_int8_threshold = llm_int8_threshold
499
500        self.pre_created_layers = dict(
501            transformer_layer=None,
502        )

භාවිතයසඳහා ස්තරය සකස් කරයි

අපිස්තරය උපාංගය වෙත ගෙන ගොස් නිවැරදි දත්ත වර්ගයට පරිවර්තනය කරමු

  • layer සකස් කළ යුතු ස්ථරයයි
  • සකස්කළ ස්තරයආපසු ලබා දෙයි

504    def _prepare_layer(self, layer: NeoXModule):
513        return layer.to(self.device, self.dtype)

පිරික්සුම්ස්ථානය පැටවීමෙන් පසු ස්ථර පරිවර්තනයන්

මෙමශ්රිතය පිරික්සුම් ස්ථානය පැටවීමෙන් පසු ස්ථර පරිවර්තනයන් ක්රියාත්මක කරයි.

දැනටඑය අදාළ වන්නේ INT8 ප්රමාණකරණය පමණි.

  • layer සකස් කළ යුතු ස්ථරයයි
  • is_llm_int8 INT8 ප්රමාණකරණය භාවිතා කළ යුතුද යන්න නියම කරයි
  • device ආකෘතියේ උපාංගය වේ
  • llm_int8_threshold යනු පිටත විශේෂාංග වෙන් කිරීම සඳහා භාවිතා කරන එළිපත්ත
  • සකස්කළ ස්තරයආපසු ලබා දෙයි

515    @torch.no_grad()
516    def post_load_prepare(self, layer: NeoXModule, *,
517                          is_llm_int8: bool = None,
518                          device: torch.device = None,
519                          llm_int8_threshold: float = None,
520                          ):

නියමකර නොමැති නම් පෙරනිමි අගයන් ලබා ගන්න

538        if is_llm_int8 is None:
539            is_llm_int8 = self.is_llm_int8
540        if device is None:
541            device = self.device
542        if llm_int8_threshold is None:
543            llm_int8_threshold = self.llm_int8_threshold

INT8ප්රමාණකරණය භාවිතා නොකරන්නේ නම් මඟ හරින්න

546        if not is_llm_int8:
547            return layer

ට්රාන්ස්ෆෝමර්ස්ථර වල රේඛීය ස්ථර පමණක් පරිවර්තනය කරන්න

550        if not isinstance(layer, TransformerLayer):
551            return layer

උපයෝගිතාවල make_llm_int8_linear අර්ථ දක්වා ඇති භාවිතා කරන්න.

554        from labml_nn.neox.utils.llm_int8 import make_llm_int8_linear

රේඛීයස්ථර පරිවර්තනය කරන්න

557        with monit.section('Convert to int8'):
558            layer.attention.output = make_llm_int8_linear(layer.attention.output,
559                                                          device=device,
560                                                          threshold=llm_int8_threshold)
561            layer.attention.qkv_lin = make_llm_int8_linear(layer.attention.qkv_lin,
562                                                           device=device,
563                                                           threshold=llm_int8_threshold)
564            layer.ffn.dense_h_h4 = make_llm_int8_linear(layer.ffn.dense_h_h4,
565                                                        device=device,
566                                                        threshold=llm_int8_threshold)
567            layer.ffn.dense_h4_h = make_llm_int8_linear(layer.ffn.dense_h4_h,
568                                                        device=device,
569                                                        threshold=llm_int8_threshold)

571        return layer

ස්තරයක්නිර්මාණය කර හැච් කරයි

පරාමිතීන්ආරම්භ කිරීමට කාලය ගතවන නිසා හැඹිලි ස්ථර පිටපත් කිරීම නව ස්ථර ආරම්භ කිරීමට වඩා වේගවත් වේ.

  • name යනු ස්තරයේ නමයි
  • creator ස්තරය නිර්මාණය කිරීමේ කාර්යයයි
  • සාදනලද ස්තරය හෝ කැච් ස්ථරයේ පිටපතක්ආපසු ලබා දෙයි

573    def _create_and_cache_layer(self, name: str, creator: Callable[[], NeoXModule]):
585        if not self.is_clone_layers:
586            return self._prepare_layer(creator())
587
588        if self.pre_created_layers[name] is None:
589            self.pre_created_layers[name] = self._prepare_layer(creator())
590
591        layer = copy.deepcopy(self.pre_created_layers[name])
592        return layer
594    def _create_transformer_layer(self):
595        return self._create_and_cache_layer(
596            'transformer_layer',
597            lambda: TransformerLayer(self.n_hidden, self.n_heads)
598        )
600    def _create_embedding_layer(self):
601        return Embedding(self.n_vocab, self.n_hidden)
603    def _create_final_norm_layer(self):
604        return FinalNorm(self.n_hidden)
606    def _create_readout_layer(self):
607        return ReadoutLayer(self.n_hidden, self.n_vocab)

ස්ථරලබා ගැනීම සඳහා උත්පාදක යන්ත්රය

609    @torch.no_grad()
610    def get_layers(self) -> Generator[Tuple[NeoXModule, Tuple[str, str]], None, None]:

කාවැද්දීමස්ථරය

615        if 0 in self.filter_layers:
616            with monit.section('Embedding layer'):
617                layer = self._prepare_layer(self._create_embedding_layer())
618            yield layer, ('layer_00-model_00-model_states.pt', 'layer_00-model_01-model_states.pt')

ට්රාන්ස්ෆෝමර්ස්ථර

621        for i in range(self.n_layers):

ට්රාන්ස්ෆෝමර්ස්ථරය

623            if i + 1 in self.filter_layers:
624                with monit.section(f'Transformer Layer {i}'):
625                    yield self._create_transformer_layer(), \
626                          (f'layer_{i + 2 :02d}-model_00-model_states.pt',
627                           f'layer_{i + 2 :02d}-model_01-model_states.pt')

අවසානසාමාන්යකරණ ස්තරය

630        if self.n_layers + 1 in self.filter_layers:
631            with monit.section('Final norm layer'):
632                layer = self._prepare_layer(self._create_final_norm_layer())
633            yield layer, ('layer_47-model_00-model_states.pt', 'layer_47-model_01-model_states.pt')

කියවීමේස්ථරය

636        if self.n_layers + 2 in self.filter_layers:
637            with monit.section('Readout layer'):
638                layer = self._prepare_layer(self._create_readout_layer())
639            yield layer, ('layer_48-model_00-model_states.pt', 'layer_48-model_01-model_states.pt')
640
641        for k in self.pre_created_layers.keys():
642            self.pre_created_layers[k] = None

මුළුස්ථර ගණන නැවත ලබා දෙයි

644    @property
645    def total_layers(self):
649        return self.n_layers + 3

ස්ථරපැටවීමට උත්පාදක යන්ත්රය

651    @torch.no_grad()
652    def load(self) -> Generator[NeoXModule, None, None]:
656        with monit.section("Layers"):
657            for i, (layer, files) in enumerate(self.get_layers()):
658                if files is not None:
659                    layer.load_state(*checkpoint.load_checkpoint_files(files))
660
661                layer = self.post_load_prepare(layer)
662
663                monit.progress(min(0.99, (i + 1) / self.total_layers))
664                yield layer