ජීපීටී-නියෝක්ස්ආකෘතියේ ස්ථර සඳහා කේතය සහ 20B මුරපොල පූරණය කිරීමේ කේතය මෙන්න.
ස්ථර load_state
වල ඇති ක්රමය එම ස්ථරයේ මුරපොලවල් පූරණය කරයි. මුරපොලවල් පැටවීමේ සහායකයන් ක්රියාත්මක වේ checkpoint.py
16import copy
17import math
18from typing import Dict, Optional, Set, Callable, Any, Generator, Tuple
19
20import torch
21from torch import nn
22from torch.cuda.amp import autocast
23
24from labml import monit, logger
25from labml.logger import Text
26from labml_nn.neox import checkpoint
27from labml_nn.neox.utils.cache import get_cache
30class NeoXModule(nn.Module):
31 def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
32 pass
35class Embedding(NeoXModule):
n_vocab
වචන මාලාවේ ප්රමාණය වේ n_hidden
මෙම කාවැද්දීම් ප්රමාණය42 def __init__(self, n_vocab: int = 50_432, n_hidden: int = 6_144):
47 super().__init__()
48
49 self.emb = nn.Embedding(n_vocab, n_hidden)
x
හැඩයේ ටෝකන් හැඳුනුම් වේ [batch_size, seq_len]
51 def forward(self, x: torch.Tensor):
55 return self.emb(x)
මුරපොලපූරණය කිරීමට කේතය
57 def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
61 with monit.section('Load embedding layer'):
62 checkpoint.merge_params_dim_0(self.emb.weight, 'word_embeddings.weight', p1, p2)
ජීපීටී-නියෝක්ස් භ්රමණ ස්ථානීය කාවැද්දීම් (කඹය)භාවිතා කරයි.
අපින්යාය වැඩි සටහන් සමඟ මෙහි කඹය ක්රියාත්මක කිරීම විස්තර කර ඇත.
65class RoPE(nn.Module):
d_rope
කඹය කාවැද්දීම් සඳහා විශේෂාංග ගණන base
සඳහා පදනම වේ , එය පැහැර හරින 75 def __init__(self, d_rope: int, base: float = 10_000.):
80 super().__init__()
විශේෂාංග සඳහා ගබඩා කිරීමට
83 self.theta = None
හැඹිලිය සහ
85 self.cos_cached = None
86 self.sin_cached = None
සඳහාමූලික
89 self.base = base
කඹයසඳහා විශේෂාංග ගණන
91 self.d_rope = d_rope
93 @staticmethod
94 def rotate_half(x: torch.Tensor):
100 x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
101 return torch.cat((-x2, x1), dim=-1)
x
හැඩය ඇත [..., seq, n_heads, d_k]
offset
යනු ආරම්භක ස්ථානයයි x
. පෙර තනතුරු වල යතුරු සහ විමසුම් අප හැඹිලි කළ විට මෙය සිදු වේ103 def forward(self, x: torch.Tensor, offset: int = 0):
සත්යඅනුක්රමයේ දිග ලබා ගන්න
111 seq_len = x.shape[-3] + offset
ආරම්භකරන්න
114 if self.theta is None:
116 theta = 1.0 / (self.base ** (torch.arange(0, self.d_rope, 2).float() / self.d_rope))
117 self.theta = theta.to(x.device).to(x.dtype)
ආරම්භකරන්න සහ හැඹිලිය
120 if (
121 self.cos_cached is None or
122 seq_len > self.cos_cached.shape[1] or
123 self.cos_cached.device != x.device or
124 self.cos_cached.dtype != x.dtype
125 ):
ස්ථානදර්ශක ලබා ගන්න
127 seq_idx = torch.arange(seq_len, device=x.device).type_as(self.theta)
129 idx_theta = torch.einsum("s,d->sd", seq_idx, self.theta)
133 idx_theta2 = torch.cat((idx_theta, idx_theta), dim=-1).to(x.device)
ගණනයකරන්න සහ fp32
136 with autocast(enabled=False):
137 idx_theta2 = idx_theta2.float()
හිසමානයක් එක් කරන්න
139 self.cos_cached = idx_theta2.cos()[:, None, :]
140 self.sin_cached = idx_theta2.sin()[:, None, :]
ඒවාහැඹිලිය
143 self.cos_cached = self.cos_cached.to(x.dtype)
144 self.sin_cached = self.sin_cached.to(x.dtype)
විශේෂාංගබෙදන්න. d_rope
විශේෂාංග සඳහා පමණක් අපි කඹය යොදන්නෙමු
147 x_rope, x_pass = x[..., :self.d_rope], x[..., self.d_rope:]
හැඹිලියසිට පාපය සහ කෝස් අගයන් ලබා ගන්න
150 cos, sin = self.cos_cached[offset: seq_len], self.sin_cached[offset: seq_len]
162 x_rope = (x_rope * cos) + (self.rotate_half(x_rope) * sin)
කඹයකාවැද්දීම් ලබා නොගත් විශේෂාංග සමඟ සංයුක්ත වන්න
165 return torch.cat((x_rope, x_pass), dim=-1)
168class AttentionLayer(nn.Module):
n_hidden
කාවැද්දීම් වල විශේෂාංග ගණනn_heads
අවධානය යොමු ප්රධානීන් සංඛ්යාවrope_percentage
කඹය කාවැද්දීම් එකතු කිරීම සඳහා විශේෂාංග ප්රතිශතයmask_fill
අවධානය යොමු න්යාසය සඳහා ආවරණ පිරවුම් අගයis_flash_attention
FlashAttention භාවිතා කළ යුතුද යන්න නියම කරයි173 def __init__(self, n_hidden: int = 6_144, n_heads: int = 64, rope_percentage: float = 0.25,
174 mask_fill: float = -10_000.0, *, is_flash_attention: bool = False):
183 super().__init__()
184
185 self.n_heads = n_heads
186 self.mask_fill = mask_fill
විමසුම, යතුර සහ වටිනාකම සඳහා රේඛීය ස්ථරය
189 self.qkv_lin = nn.Linear(n_hidden, n_hidden * 3)
අවසානරේඛීය ස්ථරය
191 self.output = nn.Linear(n_hidden, n_hidden)
හිසකටවිශේෂාංග ගණන
194 d_k = n_hidden // n_heads
කඹයකාවැද්දීම මොඩියුලය
196 self.rope = RoPE(int(d_k * rope_percentage))
අවධානයපරිමාණ සාධකය
199 self.scale = 1 / math.sqrt(d_k)
හේතුආවරණ හැඹිලි කිරීමට
202 self.causal_mask = None
අවධානයයොමු කරන්න සොෆ්ට්මැක්ස් මොඩියුලය
205 self.softmax = nn.Softmax(dim=-2)
208 if is_flash_attention:
209 try:
210 from flash_attn.flash_attention import FlashAttention
211 self.flash_attention = FlashAttention()
212 except ImportError:
213 logger.log('Install flash attention github.com/HazyResearch/flash-attention. '
214 'Falling back to normal attention', Text.warning)
215 self.flash_attention = None
216 else:
217 self.flash_attention = None
219 def _get_mask(self, attn: torch.Tensor):
විමසුමසහ යතුරු දිග
227 nq, nk = attn.shape[1:3]
වෙස්මුහුණ සාදන්න
230 if (
231 self.causal_mask is None or
232 self.causal_mask.shape[0] != nq or
233 self.causal_mask.shape[1] != nk or
234 self.causal_mask.device != attn.device
235 ):
236 self.causal_mask = torch.triu(attn.new_ones([nq, nk], dtype=torch.bool), 1 + nk - nq)
හැඹිලියසිට ආපසු
239 return self.causal_mask[None, :, :, None]
x
හැඩය ඇත [batch_size, seq_len, n_hidden]
241 def forward(self, x: torch.Tensor):
විමසුම, යතුර සහ වටිනාකම් කාවැද්දීම් ලබා ගන්න (සියල්ල සංයුක්ත කර ඇත). පසුගිය මානයක් ප්රමාණය n_hidden -> සිට වෙනස් වනු ඇත 3 x n_hidden
247 qkv = self.qkv_lin(x)
හැඩයවෙනස් කිරීමෙන් හිස් වලට බෙදන්න [batch_size, seq_len, n_heads, 3 * d_k]
250 qkv = qkv.view(*qkv.shape[:-1], self.n_heads, -1)
විමසුමටබෙදන්න, යතුර සහ හැඩය එක් එක් අගය කරන්න [batch_size, seq_len, n_heads, 3 * d_k]
252 q, k, v = torch.split(qkv, qkv.shape[-1] // 3, dim=-1)
අපිපෙර ටෝකන වල තත්වයන් හැඹිලි කරන්නේ නම්
255 if get_cache().get('use_cache', False):
රාජ්යid ගේ ලබා ගන්න. අපි පෙර රාජ්යයන් ලබා ගැනීමට හා ඉදිරි රාජ්යයන් ගබඩා කිරීම සඳහා භාවිතා
257 prev_state_id, next_state_id = get_cache().get('state_ids')
හැඹිලියතිබේ නම්
259 if prev_state_id is not None:
අතීතයතුරු සහ අගයන් ලබා ගන්න. මේවාට හැඩය ඇත [batch_size, prev_seq_len, n_heads, d_k]
261 k_past, v_past = get_cache().pop(f'attn_kv_{prev_state_id}')
වත්මන්කාවැද්දීම් වල ඕෆ්සෙට්
263 offset = k_past.shape[1]
කඹයකාවැද්දීම් එකතු කරන්න
266 q = self.rope(q, offset=offset)
267 k = self.rope(k, offset=offset)
අතීතයසංයුක්ත කරන්න
270 k = torch.cat([k_past, k], dim=1)
271 v = torch.cat([v_past, v], dim=1)
272 else:
කඹයකාවැද්දීම් එකතු කරන්න
274 q = self.rope(q)
275 k = self.rope(k)
වත්මන්තත්වය සුරකින්න
278 get_cache().push(f'attn_kv_{next_state_id}', (k, v))
279 else:
හැඹිලියක්නැත - කඹය කාවැද්දීම් එකතු කරන්න
281 q = self.rope(q)
282 k = self.rope(k)
ෆ්ලෑෂ් අවධානය භාවිතා කරන්න
285 if self.flash_attention is not None and q.shape[1] == k.shape[1] and q.shape[-1] <= 128:
286 output = self.compute_flash_attention(q, k, v)
එසේ නොමැති නම්, සාමාන්ය අවධානය භාවිතා කරන්න
288 else:
289 output = self.compute_attention(q, k, v)
[batch_size, seq_len, n_heads, d_k] to
Batch_size, seq_len, n_hidden`වෙතින් නැවත සකස් කරන්න
292 output = output.reshape(*x.shape)
අවසානරේඛීය ස්ථරය
295 return self.output(output)
297 def compute_flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
ඒවා හැඩයට ගොඩගසන්න[batch_size, seq_len, 3, n_heads, d_k]
299 qkv = torch.stack((q, k, v), dim=2)
300 d_k = qkv.shape[-1]
301 if d_k <= 32:
302 pad = 32 - d_k
303 elif d_k <= 64:
304 pad = 64 - d_k
305 elif d_k <= 128:
306 pad = 128 - d_k
307 else:
308 raise ValueError(f'Head size {d_k} too large for flash attention')
309
310 if pad > 0:
311 qkv = torch.cat((qkv, qkv.new_zeros(*qkv.shape[:-1], pad)), dim=-1)
312
313 output, _ = self.flash_attention(qkv, causal=True)
ප්රතිදානය හැඩයෙන් යුක්ත වේ[batch_size, seq_len, n_heads, d_k + padding]
315 output = output[:, :, :, :d_k]
316
317 return output
319 def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
අවධානයගණනය කිරීම සඳහා fp16 කිරීමට ස්වයංක්රීය-වාත්තු අක්රීය
321 with autocast(enabled=False):
322 if q.dtype == torch.float16:
වත්මන්dtype fp16 නම් fp32 බවට පරිවර්තනය කරන්න
324 attn = torch.einsum('bihk,bjhk->bijh', q.float(), k.float())
325 else:
bfloatසඳහා වාත්තු නොකරන්න
327 attn = torch.einsum('bihk,bjhk->bijh', q, k)
පරිමාණඅවධානය
330 attn = attn * self.scale
හේතුවෙස්මුහුණ ලබා ගන්න
333 mask = self._get_mask(attn)
වෙස්යොදන්න
335 attn.masked_fill_(mask, self.mask_fill)
අවධානයසොෆ්ට්මැක්ස්
338 attn = self.softmax(attn)
අවධානයබර තැබූ අගයන් ලබා ගන්න
341 output = torch.einsum('bijh,bjhk->bihk', attn.to(v.dtype), v)
342
343 return output
346class FFNLayer(nn.Module):
n_hidden
කාවැද්දීම ප්රමාණය වේ351 def __init__(self, n_hidden: int = 6_144, d_ff: int = 0):
355 super().__init__()
356
357 if not d_ff:
358 d_ff = n_hidden * 4
පුළුල්රේඛීය ස්ථරය
361 self.dense_h_h4 = nn.Linear(n_hidden, d_ff)
GELUසක්රිය කිරීම
363 self.activation = nn.GELU()
සංකෝචනයරේඛීය ස්ථරය
365 self.dense_h4_h = nn.Linear(d_ff, n_hidden)
x
හැඩය ඇත [batch_size, seq_len, n_hidden]
367 def forward(self, x: torch.Tensor):
371 x = self.dense_h_h4(x)
372 x = self.activation(x)
373 x = self.dense_h4_h(x)
374
375 return x
378class TransformerLayer(NeoXModule):
n_hidden
කාවැද්දීම ප්රමාණය වේn_heads
හිස් සංඛ්යාව වේis_flash_attention
FlashAttention භාවිතා කළ යුතුද යන්න නියම කරයිපිටත ක්රියාත්මක කිරීම අතහැර දැමීම ඇතුළත් නොවේ.
383 def __init__(self, n_hidden: int = 6_144, n_heads: int = 64, *, is_flash_attention: bool = False):
392 super().__init__()
අවධානයටපෙර ස්ථර සාමාන්යකරණය
395 self.pre_ln_attn = nn.LayerNorm(n_hidden)
FFNට පෙර ස්ථර සාමාන්යකරණය
397 self.pre_ln_ffn = nn.LayerNorm(n_hidden)
අවධානයස්ථරය
400 self.attention = AttentionLayer(n_hidden, n_heads, is_flash_attention=is_flash_attention)
FFNස්ථරය
402 self.ffn = FFNLayer(n_hidden)
x
හැඩයේ කාවැද්දීම් වේ [batch_size, seq_len, n_hidden]
404 def forward(self, x: torch.Tensor):
අවශේෂසම්බන්ධතාවය
410 residual = x
නියෝක්ස්සමාන්තරව අවධානය සහ ප්රතිපෝෂණ ජාලය ක්රියාත්මක කරයි
412 attn = self.attention(self.pre_ln_attn(x))
413 ffn = self.ffn(self.pre_ln_ffn(x))
ඒවාසහ අවශේෂ සම්බන්ධතාවය එකතු කරන්න
415 return attn + ffn + residual
මුරපොලපූරණය කිරීමට කේතය
417 def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
421 with monit.section('Load transformer layer'):
අවධානයයොමු ප්රතිදානය පරිණාමනය
423 checkpoint.merge_params_sum(self.attention.output.bias, 'attention.dense.bias', p1, p2)
424 checkpoint.merge_params_dim_1(self.attention.output.weight, 'attention.dense.weight', p1, p2)
අවධානයවිමසුම, යතුර සහ අගය පරිවර්තනය කිරීම
427 checkpoint.merge_params_dim_0(self.attention.qkv_lin.bias, 'attention.query_key_value.bias', p1, p2)
428 checkpoint.merge_params_dim_0(self.attention.qkv_lin.weight, 'attention.query_key_value.weight', p1, p2)
අවධානයටපෙර ස්ථර සම්මතය
431 checkpoint.merge_params_duplicate(self.pre_ln_attn.bias, 'input_layernorm.bias', p1, p2)
432 checkpoint.merge_params_duplicate(self.pre_ln_attn.weight, 'input_layernorm.weight', p1, p2)
FFNදෙවන පරිණාමනය
435 checkpoint.merge_params_dim_0(self.ffn.dense_h_h4.bias, 'mlp.dense_h_to_4h.bias', p1, p2)
436 checkpoint.merge_params_dim_0(self.ffn.dense_h_h4.weight, 'mlp.dense_h_to_4h.weight', p1, p2)
FFNපළමු පරිවර්තනය
439 checkpoint.merge_params_sum(self.ffn.dense_h4_h.bias, 'mlp.dense_4h_to_h.bias', p1, p2)
440 checkpoint.merge_params_dim_1(self.ffn.dense_h4_h.weight, 'mlp.dense_4h_to_h.weight', p1, p2)
FFNට පෙර ස්ථර සම්මතය
443 checkpoint.merge_params_duplicate(self.pre_ln_ffn.bias, 'post_attention_layernorm.bias', p1, p2)
444 checkpoint.merge_params_duplicate(self.pre_ln_ffn.weight, 'post_attention_layernorm.weight', p1, p2)
447class FinalNorm(NeoXModule):
n_hidden
කාවැද්දීම ප්රමාණය වේ452 def __init__(self, n_hidden: int = 6_144):
456 super().__init__()
457
458 self.ln = nn.LayerNorm(n_hidden)
x
හැඩයේ කාවැද්දීම් වේ [batch_size, seq_len, n_hidden]
460 def forward(self, x: torch.Tensor):
464 return self.ln(x)
මුරපොලපූරණය කිරීමට කේතය
466 def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
470 with monit.section('Load final normalization layer'):
471 checkpoint.merge_params_duplicate(self.ln.bias, 'norm.bias', p1, p2)
472 checkpoint.merge_params_duplicate(self.ln.weight, 'norm.weight', p1, p2)
කියවීමේස්ථරය
475class ReadoutLayer(NeoXModule):
n_hidden
කාවැද්දීම ප්රමාණය වේ n_vocab
වචන මාලාවේ ප්රමාණය වේ480 def __init__(self, n_hidden: int = 6_144, n_vocab: int = 50_432):
485 super().__init__()
486
487 self.linear = nn.Linear(n_hidden, n_vocab, bias=False)
x
හැඩයේ කාවැද්දීම් වේ [batch_size, seq_len, n_hidden]
489 def forward(self, x: torch.Tensor):
493 return self.linear(x)
මුරපොලපූරණය කිරීමට කේතය
495 def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
499 with monit.section('Load final linear layer'):
500 checkpoint.merge_params_dim_0(self.linear.weight, 'final_linear.weight', p1, p2)
503class LayerGenerator:
504 pre_created_layers: Dict[Any, Optional[NeoXModule]]
ස්ථර ජනනය කරනු ලබන්නේ මුරපොලවල් මෙන් එකම අනුපිළිවෙලකි.
ස්ථරයක් නොමැතිNone
විට එය ලබා දෙයි; අපි ස්ථර දර්ශක නියෝක්ස් ලෙස භාවිතා කරන අතර අපගේ ක්රියාත්මක කිරීමේදී අපට අවශ්ය නොවන පරිවර්තන ස්ථර දෙකක් තිබේ.
n_vocab
යනු වචන මාලාවේ ටෝකන ගණනn_hidden
මෙම කාවැද්දීම් දී විශේෂාංග සංඛ්යාවn_layers
ට්රාන්ස්ෆෝමර් ස්ථර ගණන වේn_heads
අවධානය යොමු ප්රධානීන් සංඛ්යාව වේfilter_layers
භාවිතා කළ යුතු ස්ථර සමූහයයි. කිසිවක් නොමැති නම් සියලුම ස්ථර භාවිතා කරනු ඇත. අඩු ස්ථර සහිත ආකෘතියේ කුඩා අනුවාදයන් පරීක්ෂා කිරීමට මෙය භාවිතාis_clone_layers
ට්රාන්ස්ෆෝමර් ස්ථර ක්ලෝන කළ යුතුද යන්න නියම කරයි (ටිකක් වේගවත්)dtype
ආකෘතියේ දත්ත වර්ගයයිdevice
ආකෘතියේ උපාංගය වේis_llm_int8
INT8 ප්රමාණකරණය භාවිතා කළ යුතුද යන්න නියම කරයිllm_int8_threshold
පිටත විශේෂාංග වෙන් කිරීම සඳහා භාවිතා කරන එළිපත්ත වේis_flash_attention
FlashAttention භාවිතා කළ යුතුද යන්න නියම කරයි506 def __init__(self, *, n_vocab: int = 50_432, n_hidden: int = 6_144,
507 n_layers: int = 44, n_heads: int = 64,
508 filter_layers: Optional[Set] = None,
509 is_clone_layers: bool = True,
510 dtype: torch.dtype = torch.float,
511 device: torch.device = torch.device('cpu'),
512 is_llm_int8: bool = False,
513 llm_int8_threshold: float = 6.0,
514 is_flash_attention: bool = False
515 ):
538 if filter_layers is None:
539 filter_layers = set(range(n_layers + 3))
540
541 self.n_vocab = n_vocab
542 self.n_hidden = n_hidden
543 self.n_layers = n_layers
544 self.n_heads = n_heads
545 self.filter_layers = filter_layers
546 self.is_clone_layers = is_clone_layers
547 self.dtype = dtype
548 self.device = device
549 self.is_llm_int8 = is_llm_int8
550 self.llm_int8_threshold = llm_int8_threshold
551 self.is_flash_attention = is_flash_attention
552
553 self.pre_created_layers = dict(
554 transformer_layer=None,
555 )
අපිස්තරය උපාංගය වෙත ගෙන ගොස් නිවැරදි දත්ත වර්ගයට පරිවර්තනය කරමු
layer
සකස් කළ යුතු ස්ථරයයි සකස්කළ ස්තරයආපසු ලබා දෙයි
557 def _prepare_layer(self, layer: NeoXModule):
566 return layer.to(self.device, self.dtype)
මෙමශ්රිතය පිරික්සුම් ස්ථානය පැටවීමෙන් පසු ස්ථර පරිවර්තනයන් ක්රියාත්මක කරයි.
දැනටඑය අදාළ වන්නේ INT8 ප්රමාණකරණය පමණි.
layer
සකස් කළ යුතු ස්ථරයයි is_llm_int8
INT8 ප්රමාණකරණය භාවිතා කළ යුතුද යන්න නියම කරයි device
ආකෘතියේ උපාංගය වේ llm_int8_threshold
යනු පිටත විශේෂාංග වෙන් කිරීම සඳහා භාවිතා කරන එළිපත්ත සකස්කළ ස්තරයආපසු ලබා දෙයි
568 @torch.no_grad()
569 def post_load_prepare(self, layer: NeoXModule, *,
570 is_llm_int8: bool = None,
571 device: torch.device = None,
572 llm_int8_threshold: float = None,
573 ):
නියමකර නොමැති නම් පෙරනිමි අගයන් ලබා ගන්න
591 if is_llm_int8 is None:
592 is_llm_int8 = self.is_llm_int8
593 if device is None:
594 device = self.device
595 if llm_int8_threshold is None:
596 llm_int8_threshold = self.llm_int8_threshold
INT8ප්රමාණකරණය භාවිතා නොකරන්නේ නම් මඟ හරින්න
599 if not is_llm_int8:
600 return layer
ට්රාන්ස්ෆෝමර්ස්ථර වල රේඛීය ස්ථර පමණක් පරිවර්තනය කරන්න
603 if not isinstance(layer, TransformerLayer):
604 return layer
607 from labml_nn.neox.utils.llm_int8 import make_llm_int8_linear
රේඛීයස්ථර පරිවර්තනය කරන්න
610 with monit.section('Convert to int8'):
611 layer.attention.output = make_llm_int8_linear(layer.attention.output,
612 device=device,
613 threshold=llm_int8_threshold)
614 layer.attention.qkv_lin = make_llm_int8_linear(layer.attention.qkv_lin,
615 device=device,
616 threshold=llm_int8_threshold)
617 layer.ffn.dense_h_h4 = make_llm_int8_linear(layer.ffn.dense_h_h4,
618 device=device,
619 threshold=llm_int8_threshold)
620 layer.ffn.dense_h4_h = make_llm_int8_linear(layer.ffn.dense_h4_h,
621 device=device,
622 threshold=llm_int8_threshold)
624 return layer
පරාමිතීන්ආරම්භ කිරීමට කාලය ගතවන නිසා හැඹිලි ස්ථර පිටපත් කිරීම නව ස්ථර ආරම්භ කිරීමට වඩා වේගවත් වේ.
name
යනු ස්තරයේ නමයි creator
ස්තරය නිර්මාණය කිරීමේ කාර්යයයි සාදනලද ස්තරය හෝ කැච් ස්ථරයේ පිටපතක්ආපසු ලබා දෙයි
626 def _create_and_cache_layer(self, name: str, creator: Callable[[], NeoXModule]):
638 if not self.is_clone_layers:
639 return self._prepare_layer(creator())
640
641 if self.pre_created_layers[name] is None:
642 self.pre_created_layers[name] = self._prepare_layer(creator())
643
644 layer = copy.deepcopy(self.pre_created_layers[name])
645 return layer
647 def _create_transformer_layer(self):
648 return self._create_and_cache_layer(
649 'transformer_layer',
650 lambda: TransformerLayer(self.n_hidden, self.n_heads, is_flash_attention=self.is_flash_attention)
651 )
653 def _create_embedding_layer(self):
654 return Embedding(self.n_vocab, self.n_hidden)
656 def _create_final_norm_layer(self):
657 return FinalNorm(self.n_hidden)
659 def _create_readout_layer(self):
660 return ReadoutLayer(self.n_hidden, self.n_vocab)
662 @torch.no_grad()
663 def get_layers(self) -> Generator[Tuple[NeoXModule, Tuple[str, str]], None, None]:
කාවැද්දීමස්ථරය
668 if 0 in self.filter_layers:
669 with monit.section('Embedding layer'):
670 layer = self._prepare_layer(self._create_embedding_layer())
671 yield layer, ('layer_00-model_00-model_states.pt', 'layer_00-model_01-model_states.pt')
ට්රාන්ස්ෆෝමර්ස්ථර
674 for i in range(self.n_layers):
ට්රාන්ස්ෆෝමර්ස්ථරය
676 if i + 1 in self.filter_layers:
677 with monit.section(f'Transformer Layer {i}'):
678 yield self._create_transformer_layer(), \
679 (f'layer_{i + 2 :02d}-model_00-model_states.pt',
680 f'layer_{i + 2 :02d}-model_01-model_states.pt')
අවසානසාමාන්යකරණ ස්තරය
683 if self.n_layers + 1 in self.filter_layers:
684 with monit.section('Final norm layer'):
685 layer = self._prepare_layer(self._create_final_norm_layer())
686 yield layer, ('layer_47-model_00-model_states.pt', 'layer_47-model_01-model_states.pt')
කියවීමේස්ථරය
689 if self.n_layers + 2 in self.filter_layers:
690 with monit.section('Readout layer'):
691 layer = self._prepare_layer(self._create_readout_layer())
692 yield layer, ('layer_48-model_00-model_states.pt', 'layer_48-model_01-model_states.pt')
693
694 for k in self.pre_created_layers.keys():
695 self.pre_created_layers[k] = None
697 @property
698 def total_layers(self):
702 return self.n_layers + 3
704 @torch.no_grad()
705 def load(self) -> Generator[NeoXModule, None, None]:
709 with monit.section("Layers"):
710 for i, (layer, files) in enumerate(self.get_layers()):
711 if files is not None:
712 layer.load_state(*checkpoint.load_checkpoint_files(files))
713
714 layer = self.post_load_prepare(layer)
715
716 monit.progress(min(0.99, (i + 1) / self.total_layers))
717 yield layer