ජීපීටී-නියෝක්ස්ආකෘතියේ ස්ථර සඳහා කේතය සහ 20B මුරපොල පූරණය කිරීමේ කේතය මෙන්න.
ස්ථර load_state
වල ඇති ක්රමය එම ස්ථරයේ මුරපොලවල් පූරණය කරයි. මුරපොලවල් පැටවීමේ සහායකයන් ක්රියාත්මක වේ checkpoint.py
16import copy
17import math
18from typing import Dict, Optional, Set, Callable, Any, Generator, Tuple
19
20import torch
21from torch import nn
22from torch.cuda.amp import autocast
23
24from labml import monit
25from labml_nn.neox import checkpoint
26from labml_nn.neox.utils.cache import get_cache29class NeoXModule(nn.Module):30 def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
31 pass34class Embedding(NeoXModule):n_vocab
වචන මාලාවේ ප්රමාණය වේ n_hidden
මෙම කාවැද්දීම් ප්රමාණය41 def __init__(self, n_vocab: int = 50_432, n_hidden: int = 6_144):46 super().__init__()
47
48 self.emb = nn.Embedding(n_vocab, n_hidden)x
හැඩයේ ටෝකන් හැඳුනුම් වේ [batch_size, seq_len]
50 def forward(self, x: torch.Tensor):54 return self.emb(x)මුරපොලපූරණය කිරීමට කේතය
56 def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):60 with monit.section('Load embedding layer'):
61 checkpoint.merge_params_dim_0(self.emb.weight, 'word_embeddings.weight', p1, p2)ජීපීටී-නියෝක්ස් භ්රමණ ස්ථානීය කාවැද්දීම් (කඹය)භාවිතා කරයි.
අපින්යාය වැඩි සටහන් සමඟ මෙහි කඹය ක්රියාත්මක කිරීම විස්තර කර ඇත.
64class RoPE(nn.Module):d_rope
කඹය කාවැද්දීම් සඳහා විශේෂාංග ගණන base
සඳහා පදනම වේ , එය පැහැර හරින 74 def __init__(self, d_rope: int, base: float = 10_000.):79 super().__init__()විශේෂාංග සඳහා ගබඩා කිරීමට
82 self.theta = Noneහැඹිලිය සහ
84 self.cos_cached = None
85 self.sin_cached = Noneසඳහාමූලික
88 self.base = baseකඹයසඳහා විශේෂාංග ගණන
90 self.d_rope = d_rope92 @staticmethod
93 def rotate_half(x: torch.Tensor):99 x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
100 return torch.cat((-x2, x1), dim=-1)x
හැඩය ඇත [..., seq, n_heads, d_k]
offset
යනු ආරම්භක ස්ථානයයි x
. පෙර තනතුරු වල යතුරු සහ විමසුම් අප හැඹිලි කළ විට මෙය සිදු වේ102 def forward(self, x: torch.Tensor, offset: int = 0):සත්යඅනුක්රමයේ දිග ලබා ගන්න
110 seq_len = x.shape[-3] + offsetආරම්භකරන්න
113 if self.theta is None:115 theta = 1.0 / (self.base ** (torch.arange(0, self.d_rope, 2).float() / self.d_rope))
116 self.theta = theta.to(x.device).to(x.dtype)ආරම්භකරන්න සහ හැඹිලිය
119 if (
120 self.cos_cached is None or
121 seq_len > self.cos_cached.shape[1] or
122 self.cos_cached.device != x.device or
123 self.cos_cached.dtype != x.dtype
124 ):ස්ථානදර්ශක ලබා ගන්න
126 seq_idx = torch.arange(seq_len, device=x.device).type_as(self.theta)128 idx_theta = torch.einsum("s,d->sd", seq_idx, self.theta)132 idx_theta2 = torch.cat((idx_theta, idx_theta), dim=-1).to(x.device)ගණනයකරන්න සහ fp32
135 with autocast(enabled=False):
136 idx_theta2 = idx_theta2.float()හිසමානයක් එක් කරන්න
138 self.cos_cached = idx_theta2.cos()[:, None, :]
139 self.sin_cached = idx_theta2.sin()[:, None, :]ඒවාහැඹිලිය
142 self.cos_cached = self.cos_cached.to(x.dtype)
143 self.sin_cached = self.sin_cached.to(x.dtype)විශේෂාංගබෙදන්න. d_rope
විශේෂාංග සඳහා පමණක් අපි කඹය යොදන්නෙමු
146 x_rope, x_pass = x[..., :self.d_rope], x[..., self.d_rope:]හැඹිලියසිට පාපය සහ කෝස් අගයන් ලබා ගන්න
149 cos, sin = self.cos_cached[offset: seq_len], self.sin_cached[offset: seq_len]161 x_rope = (x_rope * cos) + (self.rotate_half(x_rope) * sin)කඹයකාවැද්දීම් ලබා නොගත් විශේෂාංග සමඟ සංයුක්ත වන්න
164 return torch.cat((x_rope, x_pass), dim=-1)167class AttentionLayer(nn.Module):n_hidden
කාවැද්දීම් වල විශේෂාංග ගණන n_heads
අවධානය යොමු ප්රධානීන් සංඛ්යාව rope_percentage
කඹය කාවැද්දීම් එකතු කිරීම සඳහා විශේෂාංග ප්රතිශතය mask_fill
අවධානය යොමු න්යාසය සඳහා ආවරණ පිරවුම් අගය172 def __init__(self, n_hidden: int = 6_144, n_heads: int = 64, rope_percentage: float = 0.25,
173 mask_fill: float = -10_000.0):180 super().__init__()
181
182 self.n_heads = n_heads
183 self.mask_fill = mask_fillවිමසුම, යතුර සහ වටිනාකම සඳහා රේඛීය ස්ථරය
186 self.qkv_lin = nn.Linear(n_hidden, n_hidden * 3)අවසානරේඛීය ස්ථරය
188 self.output = nn.Linear(n_hidden, n_hidden)හිසකටවිශේෂාංග ගණන
191 d_k = n_hidden // n_headsකඹයකාවැද්දීම මොඩියුලය
193 self.rope = RoPE(int(d_k * rope_percentage))අවධානයපරිමාණ සාධකය
196 self.scale = 1 / math.sqrt(d_k)හේතුආවරණ හැඹිලි කිරීමට
199 self.causal_mask = Noneඅවධානයයොමු කරන්න සොෆ්ට්මැක්ස් මොඩියුලය
202 self.softmax = nn.Softmax(dim=-2)204 def _get_mask(self, attn: torch.Tensor):විමසුමසහ යතුරු දිග
212 nq, nk = attn.shape[1:3]වෙස්මුහුණ සාදන්න
215 if (
216 self.causal_mask is None or
217 self.causal_mask.shape[0] != nq or
218 self.causal_mask.shape[1] != nk or
219 self.causal_mask.device != attn.device
220 ):
221 self.causal_mask = torch.triu(attn.new_ones([nq, nk], dtype=torch.bool), 1 + nk - nq)හැඹිලියසිට ආපසු
224 return self.causal_mask[None, :, :, None]x
හැඩය ඇත [batch_size, seq_len, n_hidden]
226 def forward(self, x: torch.Tensor):විමසුම, යතුර සහ වටිනාකම් කාවැද්දීම් ලබා ගන්න (සියල්ල සංයුක්ත කර ඇත). පසුගිය මානයක් ප්රමාණය n_hidden -> සිට වෙනස් වනු ඇත 3 x n_hidden
232 qkv = self.qkv_lin(x)හැඩයවෙනස් කිරීමෙන් හිස් වලට බෙදන්න [batch_size, seq_len, n_heads, 3 * d_k]
235 qkv = qkv.view(*qkv.shape[:-1], self.n_heads, -1)විමසුමටබෙදන්න, යතුර සහ හැඩය එක් එක් අගය කරන්න [batch_size, seq_len, n_heads, 3 * d_k]
237 q, k, v = torch.split(qkv, qkv.shape[-1] // 3, dim=-1)අපිපෙර ටෝකන වල තත්වයන් හැඹිලි කරන්නේ නම්
240 if get_cache().get('use_cache', False):රාජ්යid ගේ ලබා ගන්න. අපි පෙර රාජ්යයන් ලබා ගැනීමට හා ඉදිරි රාජ්යයන් ගබඩා කිරීම සඳහා භාවිතා
242 prev_state_id, next_state_id = get_cache().get('state_ids')හැඹිලියතිබේ නම්
244 if prev_state_id is not None:අතීතයතුරු සහ අගයන් ලබා ගන්න. මේවාට හැඩය ඇත [batch_size, prev_seq_len, n_heads, d_k]
246 k_past, v_past = get_cache().pop(f'attn_kv_{prev_state_id}')වත්මන්කාවැද්දීම් වල ඕෆ්සෙට්
248 offset = k_past.shape[1]කඹයකාවැද්දීම් එකතු කරන්න
251 q = self.rope(q, offset=offset)
252 k = self.rope(k, offset=offset)අතීතයසංයුක්ත කරන්න
255 k = torch.cat([k_past, k], dim=1)
256 v = torch.cat([v_past, v], dim=1)
257 else:කඹයකාවැද්දීම් එකතු කරන්න
259 q = self.rope(q)
260 k = self.rope(k)වත්මන්තත්වය සුරකින්න
263 get_cache().push(f'attn_kv_{next_state_id}', (k, v))
264 else:හැඹිලියක්නැත - කඹය කාවැද්දීම් එකතු කරන්න
266 q = self.rope(q)
267 k = self.rope(k)අවධානයගණනය කිරීම සඳහා fp16 කිරීමට ස්වයංක්රීය-වාත්තු අක්රීය
270 with autocast(enabled=False):
271 if q.dtype == torch.float16:වත්මන්dtype fp16 නම් fp32 බවට පරිවර්තනය කරන්න
273 attn = torch.einsum('bihk,bjhk->bijh', q.float(), k.float())
274 else:bfloatසඳහා වාත්තු නොකරන්න
276 attn = torch.einsum('bihk,bjhk->bijh', q, k)පරිමාණඅවධානය
279 attn = attn * self.scaleහේතුවෙස්මුහුණ ලබා ගන්න
282 mask = self._get_mask(attn)වෙස්යොදන්න
284 attn.masked_fill_(mask, self.mask_fill)අවධානයසොෆ්ට්මැක්ස්
287 attn = self.softmax(attn)අවධානයබර තැබූ අගයන් ලබා ගන්න
290 output = torch.einsum('bijh,bjhk->bihk', attn.to(v.dtype), v)[batch_size, seq_len, n_heads, d_k] to
Batch_size, seq_len, n_hidden`වෙතින් නැවත සකස් කරන්න
293 output = output.reshape(*x.shape)අවසානරේඛීය ස්ථරය
296 return self.output(output)299class FFNLayer(nn.Module):n_hidden
කාවැද්දීම ප්රමාණය වේ304 def __init__(self, n_hidden: int = 6_144, d_ff: int = 0):308 super().__init__()
309
310 if not d_ff:
311 d_ff = n_hidden * 4පුළුල්රේඛීය ස්ථරය
314 self.dense_h_h4 = nn.Linear(n_hidden, d_ff)GELUසක්රිය කිරීම
316 self.activation = nn.GELU()සංකෝචනයරේඛීය ස්ථරය
318 self.dense_h4_h = nn.Linear(d_ff, n_hidden)x
හැඩය ඇත [batch_size, seq_len, n_hidden]
320 def forward(self, x: torch.Tensor):324 x = self.dense_h_h4(x)
325 x = self.activation(x)
326 x = self.dense_h4_h(x)
327
328 return x331class TransformerLayer(NeoXModule):n_hidden
කාවැද්දීම ප්රමාණය වේ n_heads
හිස් සංඛ්යාව වේපිටතක්රියාත්මක කිරීම අතහැර දැමීම ඇතුළත් නොවේ.
336 def __init__(self, n_hidden: int = 6_144, n_heads: int = 64):343 super().__init__()අවධානයටපෙර ස්ථර සාමාන්යකරණය
346 self.pre_ln_attn = nn.LayerNorm(n_hidden)FFNට පෙර ස්ථර සාමාන්යකරණය
348 self.pre_ln_ffn = nn.LayerNorm(n_hidden)අවධානයස්ථරය
351 self.attention = AttentionLayer(n_hidden, n_heads)FFNස්ථරය
353 self.ffn = FFNLayer(n_hidden)x
හැඩයේ කාවැද්දීම් වේ [batch_size, seq_len, n_hidden]
355 def forward(self, x: torch.Tensor):අවශේෂසම්බන්ධතාවය
361 residual = xනියෝක්ස්සමාන්තරව අවධානය සහ ප්රතිපෝෂණ ජාලය ක්රියාත්මක කරයි
363 attn = self.attention(self.pre_ln_attn(x))
364 ffn = self.ffn(self.pre_ln_ffn(x))ඒවාසහ අවශේෂ සම්බන්ධතාවය එකතු කරන්න
366 return attn + ffn + residualමුරපොලපූරණය කිරීමට කේතය
368 def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):372 with monit.section('Load transformer layer'):අවධානයයොමු ප්රතිදානය පරිණාමනය
374 checkpoint.merge_params_sum(self.attention.output.bias, 'attention.dense.bias', p1, p2)
375 checkpoint.merge_params_dim_1(self.attention.output.weight, 'attention.dense.weight', p1, p2)අවධානයවිමසුම, යතුර සහ අගය පරිවර්තනය කිරීම
378 checkpoint.merge_params_dim_0(self.attention.qkv_lin.bias, 'attention.query_key_value.bias', p1, p2)
379 checkpoint.merge_params_dim_0(self.attention.qkv_lin.weight, 'attention.query_key_value.weight', p1, p2)අවධානයටපෙර ස්ථර සම්මතය
382 checkpoint.merge_params_duplicate(self.pre_ln_attn.bias, 'input_layernorm.bias', p1, p2)
383 checkpoint.merge_params_duplicate(self.pre_ln_attn.weight, 'input_layernorm.weight', p1, p2)FFNදෙවන පරිණාමනය
386 checkpoint.merge_params_dim_0(self.ffn.dense_h_h4.bias, 'mlp.dense_h_to_4h.bias', p1, p2)
387 checkpoint.merge_params_dim_0(self.ffn.dense_h_h4.weight, 'mlp.dense_h_to_4h.weight', p1, p2)FFNපළමු පරිවර්තනය
390 checkpoint.merge_params_sum(self.ffn.dense_h4_h.bias, 'mlp.dense_4h_to_h.bias', p1, p2)
391 checkpoint.merge_params_dim_1(self.ffn.dense_h4_h.weight, 'mlp.dense_4h_to_h.weight', p1, p2)FFNට පෙර ස්ථර සම්මතය
394 checkpoint.merge_params_duplicate(self.pre_ln_ffn.bias, 'post_attention_layernorm.bias', p1, p2)
395 checkpoint.merge_params_duplicate(self.pre_ln_ffn.weight, 'post_attention_layernorm.weight', p1, p2)398class FinalNorm(NeoXModule):n_hidden
කාවැද්දීම ප්රමාණය වේ403 def __init__(self, n_hidden: int = 6_144):407 super().__init__()
408
409 self.ln = nn.LayerNorm(n_hidden)x
හැඩයේ කාවැද්දීම් වේ [batch_size, seq_len, n_hidden]
411 def forward(self, x: torch.Tensor):415 return self.ln(x)මුරපොලපූරණය කිරීමට කේතය
417 def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):421 with monit.section('Load final normalization layer'):
422 checkpoint.merge_params_duplicate(self.ln.bias, 'norm.bias', p1, p2)
423 checkpoint.merge_params_duplicate(self.ln.weight, 'norm.weight', p1, p2)කියවීමේස්ථරය
426class ReadoutLayer(NeoXModule):n_hidden
කාවැද්දීම ප්රමාණය වේ n_vocab
වචන මාලාවේ ප්රමාණය වේ431 def __init__(self, n_hidden: int = 6_144, n_vocab: int = 50_432):436 super().__init__()
437
438 self.linear = nn.Linear(n_hidden, n_vocab, bias=False)x
හැඩයේ කාවැද්දීම් වේ [batch_size, seq_len, n_hidden]
440 def forward(self, x: torch.Tensor):444 return self.linear(x)මුරපොලපූරණය කිරීමට කේතය
446 def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):450 with monit.section('Load final linear layer'):
451 checkpoint.merge_params_dim_0(self.linear.weight, 'final_linear.weight', p1, p2)454class LayerGenerator:
455 pre_created_layers: Dict[Any, Optional[NeoXModule]]ස්ථරජනනය කරනු ලබන්නේ මුරපොලවල් මෙන් එකම අනුපිළිවෙලකි.
ස්ථරයක්නොමැති None
විට එය ලබා දෙයි; අපි ස්ථර දර්ශක නියෝක්ස් ලෙස භාවිතා කරන අතර අපගේ ක්රියාත්මක කිරීමේදී අපට අවශ්ය නොවන පරිවර්තන ස්ථර දෙකක් තිබේ.
n_vocab
යනු වචන මාලාවේ ටෝකන ගණන n_hidden
කාවැද්දීම් වල ඇති ලක්ෂණ ගණන n_layers
ට්රාන්ස්ෆෝමර් ස්ථර ගණන n_heads
අවධානය යොමු ප්රධානීන් සංඛ්යාව වේ filter_layers
භාවිතා කළ යුතු ස්ථර සමූහයයි. කිසිවක් නොමැති නම් සියලුම ස්ථර භාවිතා කරනු ඇත. අඩු ස්ථර සහිත ආකෘතියේ කුඩා අනුවාදයන් පරීක්ෂා කිරීමට මෙය භාවිතා is_clone_layers
ට්රාන්ස්ෆෝමර් ස්ථර ක්ලෝන කළ යුතුද යන්න නියම කරයි (ටිකක් වේගවත්) dtype
ආකෘතියේ දත්ත වර්ගයයි device
ආකෘතියේ උපාංගය වේ is_llm_int8
INT8 ප්රමාණකරණය භාවිතා කළ යුතුද යන්න නියම කරයි llm_int8_threshold
යනු පිටත විශේෂාංග වෙන් කිරීම සඳහා භාවිතා කරන එළිපත්ත457 def __init__(self, *, n_vocab: int = 50_432, n_hidden: int = 6_144,
458 n_layers: int = 44, n_heads: int = 64,
459 filter_layers: Optional[Set] = None,
460 is_clone_layers: bool = True,
461 dtype: torch.dtype = torch.float,
462 device: torch.device = torch.device('cpu'),
463 is_llm_int8: bool = False,
464 llm_int8_threshold: float = 6.0,
465 ):486 if filter_layers is None:
487 filter_layers = set(range(n_layers + 3))
488
489 self.n_vocab = n_vocab
490 self.n_hidden = n_hidden
491 self.n_layers = n_layers
492 self.n_heads = n_heads
493 self.filter_layers = filter_layers
494 self.is_clone_layers = is_clone_layers
495 self.dtype = dtype
496 self.device = device
497 self.is_llm_int8 = is_llm_int8
498 self.llm_int8_threshold = llm_int8_threshold
499
500 self.pre_created_layers = dict(
501 transformer_layer=None,
502 )අපිස්තරය උපාංගය වෙත ගෙන ගොස් නිවැරදි දත්ත වර්ගයට පරිවර්තනය කරමු
layer
සකස් කළ යුතු ස්ථරයයි සකස්කළ ස්තරයආපසු ලබා දෙයි
504 def _prepare_layer(self, layer: NeoXModule):513 return layer.to(self.device, self.dtype)මෙමශ්රිතය පිරික්සුම් ස්ථානය පැටවීමෙන් පසු ස්ථර පරිවර්තනයන් ක්රියාත්මක කරයි.
දැනටඑය අදාළ වන්නේ INT8 ප්රමාණකරණය පමණි.
layer
සකස් කළ යුතු ස්ථරයයි is_llm_int8
INT8 ප්රමාණකරණය භාවිතා කළ යුතුද යන්න නියම කරයි device
ආකෘතියේ උපාංගය වේ llm_int8_threshold
යනු පිටත විශේෂාංග වෙන් කිරීම සඳහා භාවිතා කරන එළිපත්ත සකස්කළ ස්තරයආපසු ලබා දෙයි
515 @torch.no_grad()
516 def post_load_prepare(self, layer: NeoXModule, *,
517 is_llm_int8: bool = None,
518 device: torch.device = None,
519 llm_int8_threshold: float = None,
520 ):නියමකර නොමැති නම් පෙරනිමි අගයන් ලබා ගන්න
538 if is_llm_int8 is None:
539 is_llm_int8 = self.is_llm_int8
540 if device is None:
541 device = self.device
542 if llm_int8_threshold is None:
543 llm_int8_threshold = self.llm_int8_thresholdINT8ප්රමාණකරණය භාවිතා නොකරන්නේ නම් මඟ හරින්න
546 if not is_llm_int8:
547 return layerට්රාන්ස්ෆෝමර්ස්ථර වල රේඛීය ස්ථර පමණක් පරිවර්තනය කරන්න
550 if not isinstance(layer, TransformerLayer):
551 return layer554 from labml_nn.neox.utils.llm_int8 import make_llm_int8_linearරේඛීයස්ථර පරිවර්තනය කරන්න
557 with monit.section('Convert to int8'):
558 layer.attention.output = make_llm_int8_linear(layer.attention.output,
559 device=device,
560 threshold=llm_int8_threshold)
561 layer.attention.qkv_lin = make_llm_int8_linear(layer.attention.qkv_lin,
562 device=device,
563 threshold=llm_int8_threshold)
564 layer.ffn.dense_h_h4 = make_llm_int8_linear(layer.ffn.dense_h_h4,
565 device=device,
566 threshold=llm_int8_threshold)
567 layer.ffn.dense_h4_h = make_llm_int8_linear(layer.ffn.dense_h4_h,
568 device=device,
569 threshold=llm_int8_threshold)571 return layerපරාමිතීන්ආරම්භ කිරීමට කාලය ගතවන නිසා හැඹිලි ස්ථර පිටපත් කිරීම නව ස්ථර ආරම්භ කිරීමට වඩා වේගවත් වේ.
name
යනු ස්තරයේ නමයි creator
ස්තරය නිර්මාණය කිරීමේ කාර්යයයි සාදනලද ස්තරය හෝ කැච් ස්ථරයේ පිටපතක්ආපසු ලබා දෙයි
573 def _create_and_cache_layer(self, name: str, creator: Callable[[], NeoXModule]):585 if not self.is_clone_layers:
586 return self._prepare_layer(creator())
587
588 if self.pre_created_layers[name] is None:
589 self.pre_created_layers[name] = self._prepare_layer(creator())
590
591 layer = copy.deepcopy(self.pre_created_layers[name])
592 return layer594 def _create_transformer_layer(self):
595 return self._create_and_cache_layer(
596 'transformer_layer',
597 lambda: TransformerLayer(self.n_hidden, self.n_heads)
598 )600 def _create_embedding_layer(self):
601 return Embedding(self.n_vocab, self.n_hidden)603 def _create_final_norm_layer(self):
604 return FinalNorm(self.n_hidden)606 def _create_readout_layer(self):
607 return ReadoutLayer(self.n_hidden, self.n_vocab)609 @torch.no_grad()
610 def get_layers(self) -> Generator[Tuple[NeoXModule, Tuple[str, str]], None, None]:කාවැද්දීමස්ථරය
615 if 0 in self.filter_layers:
616 with monit.section('Embedding layer'):
617 layer = self._prepare_layer(self._create_embedding_layer())
618 yield layer, ('layer_00-model_00-model_states.pt', 'layer_00-model_01-model_states.pt')ට්රාන්ස්ෆෝමර්ස්ථර
621 for i in range(self.n_layers):ට්රාන්ස්ෆෝමර්ස්ථරය
623 if i + 1 in self.filter_layers:
624 with monit.section(f'Transformer Layer {i}'):
625 yield self._create_transformer_layer(), \
626 (f'layer_{i + 2 :02d}-model_00-model_states.pt',
627 f'layer_{i + 2 :02d}-model_01-model_states.pt')අවසානසාමාන්යකරණ ස්තරය
630 if self.n_layers + 1 in self.filter_layers:
631 with monit.section('Final norm layer'):
632 layer = self._prepare_layer(self._create_final_norm_layer())
633 yield layer, ('layer_47-model_00-model_states.pt', 'layer_47-model_01-model_states.pt')කියවීමේස්ථරය
636 if self.n_layers + 2 in self.filter_layers:
637 with monit.section('Readout layer'):
638 layer = self._prepare_layer(self._create_readout_layer())
639 yield layer, ('layer_48-model_00-model_states.pt', 'layer_48-model_01-model_states.pt')
640
641 for k in self.pre_created_layers.keys():
642 self.pre_created_layers[k] = None644 @property
645 def total_layers(self):649 return self.n_layers + 3651 @torch.no_grad()
652 def load(self) -> Generator[NeoXModule, None, None]:656 with monit.section("Layers"):
657 for i, (layer, files) in enumerate(self.get_layers()):
658 if files is not None:
659 layer.load_state(*checkpoint.load_checkpoint_files(files))
660
661 layer = self.post_load_prepare(layer)
662
663 monit.progress(min(0.99, (i + 1) / self.total_layers))
664 yield layer