ජීපීටී-නියෝක්ස්ආකෘතිය

ජීපීටී-නියෝක්ස්ආකෘතියේ ස්ථර සඳහා කේතය සහ 20B මුරපොල පූරණය කිරීමේ කේතය මෙන්න.

ස්ථර load_state වල ඇති ක්රමය එම ස්ථරයේ මුරපොලවල් පූරණය කරයි. මුරපොලවල් පැටවීමේ සහායකයන් ක්රියාත්මක වේ checkpoint.py

16import copy
17import math
18from typing import Dict, Optional, Set, Callable, Any, Generator, Tuple
19
20import torch
21from torch import nn
22from torch.cuda.amp import autocast
23
24from labml import monit, logger
25from labml.logger import Text
26from labml_nn.neox import checkpoint
27from labml_nn.neox.utils.cache import get_cache
30class NeoXModule(nn.Module):
31    def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
32        pass

කාවැද්දීමස්ථරය

මෙයමුරපොලට පැටවීම සඳහා කේතය සහිත සම්මත කාවැද්දීම් ස්ථරයකි.

35class Embedding(NeoXModule):
  • n_vocab වචන මාලාවේ ප්රමාණය වේ
  • n_hidden මෙම කාවැද්දීම් ප්රමාණය
42    def __init__(self, n_vocab: int = 50_432, n_hidden: int = 6_144):
47        super().__init__()
48
49        self.emb = nn.Embedding(n_vocab, n_hidden)
  • x හැඩයේ ටෝකන් හැඳුනුම් වේ [batch_size, seq_len]
51    def forward(self, x: torch.Tensor):
55        return self.emb(x)

මුරපොලපූරණය කිරීමට කේතය

57    def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
61        with monit.section('Load embedding layer'):
62            checkpoint.merge_params_dim_0(self.emb.weight, 'word_embeddings.weight', p1, p2)

රොටරිස්ථානීය කාවැද්දීම්

ජීපීටී-නියෝක්ස් භ්රමණ ස්ථානීය කාවැද්දීම් (කඹය)භාවිතා කරයි.

අපින්යාය වැඩි සටහන් සමඟ මෙහි කඹය ක්රියාත්මක කිරීම විස්තර කර ඇත.

65class RoPE(nn.Module):
  • d_rope කඹය කාවැද්දීම් සඳහා විශේෂාංග ගණන
  • base සඳහා පදනම වේ , එය පැහැර හරින
75    def __init__(self, d_rope: int, base: float = 10_000.):
80        super().__init__()

විශේෂාංග සඳහා ගබඩා කිරීමට

83        self.theta = None

හැඹිලිය සහ

85        self.cos_cached = None
86        self.sin_cached = None

සඳහාමූලික

89        self.base = base

කඹයසඳහා විශේෂාංග ගණන

91        self.d_rope = d_rope

විශේෂාංගකරකවන්න

93    @staticmethod
94    def rotate_half(x: torch.Tensor):
100        x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
101        return torch.cat((-x2, x1), dim=-1)
  • x හැඩය ඇත [..., seq, n_heads, d_k]
  • offset යනු ආරම්භක ස්ථානයයි x . පෙර තනතුරු වල යතුරු සහ විමසුම් අප හැඹිලි කළ විට මෙය සිදු වේ
103    def forward(self, x: torch.Tensor, offset: int = 0):

සත්යඅනුක්රමයේ දිග ලබා ගන්න

111        seq_len = x.shape[-3] + offset

ආරම්භකරන්න

114        if self.theta is None:

116            theta = 1.0 / (self.base ** (torch.arange(0, self.d_rope, 2).float() / self.d_rope))
117            self.theta = theta.to(x.device).to(x.dtype)

ආරම්භකරන්න සහ හැඹිලිය

120        if (
121                self.cos_cached is None or
122                seq_len > self.cos_cached.shape[1] or
123                self.cos_cached.device != x.device or
124                self.cos_cached.dtype != x.dtype
125        ):

ස්ථානදර්ශක ලබා ගන්න

127            seq_idx = torch.arange(seq_len, device=x.device).type_as(self.theta)

129            idx_theta = torch.einsum("s,d->sd", seq_idx, self.theta)

පේළියසඳහා අපට ඇති පරිදි සංයුක්ත කරන්න

133            idx_theta2 = torch.cat((idx_theta, idx_theta), dim=-1).to(x.device)

ගණනයකරන්න සහ fp32

136            with autocast(enabled=False):
137                idx_theta2 = idx_theta2.float()

හිසමානයක් එක් කරන්න

139                self.cos_cached = idx_theta2.cos()[:, None, :]
140                self.sin_cached = idx_theta2.sin()[:, None, :]

ඒවාහැඹිලිය

143            self.cos_cached = self.cos_cached.to(x.dtype)
144            self.sin_cached = self.sin_cached.to(x.dtype)

විශේෂාංගබෙදන්න. d_rope විශේෂාංග සඳහා පමණක් අපි කඹය යොදන්නෙමු

147        x_rope, x_pass = x[..., :self.d_rope], x[..., self.d_rope:]

හැඹිලියසිට පාපය සහ කෝස් අගයන් ලබා ගන්න

150        cos, sin = self.cos_cached[offset: seq_len], self.sin_cached[offset: seq_len]

කඹයකාවැද්දීම්

සඳහා

162        x_rope = (x_rope * cos) + (self.rotate_half(x_rope) * sin)

කඹයකාවැද්දීම් ලබා නොගත් විශේෂාංග සමඟ සංයුක්ත වන්න

165        return torch.cat((x_rope, x_pass), dim=-1)

අවධානයස්ථරය

168class AttentionLayer(nn.Module):
  • n_hidden කාවැද්දීම් වල විශේෂාංග ගණන
  • n_heads අවධානය යොමු ප්රධානීන් සංඛ්යාව
  • rope_percentage කඹය කාවැද්දීම් එකතු කිරීම සඳහා විශේෂාංග ප්රතිශතය
  • mask_fill අවධානය යොමු න්යාසය සඳහා ආවරණ පිරවුම් අගය
  • is_flash_attention FlashAttention භාවිතා කළ යුතුද යන්න නියම කරයි
173    def __init__(self, n_hidden: int = 6_144, n_heads: int = 64, rope_percentage: float = 0.25,
174                 mask_fill: float = -10_000.0, *, is_flash_attention: bool = False):
183        super().__init__()
184
185        self.n_heads = n_heads
186        self.mask_fill = mask_fill

විමසුම, යතුර සහ වටිනාකම සඳහා රේඛීය ස්ථරය

189        self.qkv_lin = nn.Linear(n_hidden, n_hidden * 3)

අවසානරේඛීය ස්ථරය

191        self.output = nn.Linear(n_hidden, n_hidden)

හිසකටවිශේෂාංග ගණන

194        d_k = n_hidden // n_heads

කඹයකාවැද්දීම මොඩියුලය

196        self.rope = RoPE(int(d_k * rope_percentage))

අවධානයපරිමාණ සාධකය

199        self.scale = 1 / math.sqrt(d_k)

හේතුආවරණ හැඹිලි කිරීමට

202        self.causal_mask = None

අවධානයයොමු කරන්න සොෆ්ට්මැක්ස් මොඩියුලය

205        self.softmax = nn.Softmax(dim=-2)
208        if is_flash_attention:
209            try:
210                from flash_attn.flash_attention import FlashAttention
211                self.flash_attention = FlashAttention()
212            except ImportError:
213                logger.log('Install flash attention github.com/HazyResearch/flash-attention. '
214                           'Falling back to normal attention', Text.warning)
215                self.flash_attention = None
216        else:
217            self.flash_attention = None
219    def _get_mask(self, attn: torch.Tensor):

විමසුමසහ යතුරු දිග

227        nq, nk = attn.shape[1:3]

වෙස්මුහුණ සාදන්න

230        if (
231                self.causal_mask is None or
232                self.causal_mask.shape[0] != nq or
233                self.causal_mask.shape[1] != nk or
234                self.causal_mask.device != attn.device
235        ):
236            self.causal_mask = torch.triu(attn.new_ones([nq, nk], dtype=torch.bool), 1 + nk - nq)

හැඹිලියසිට ආපසු

239        return self.causal_mask[None, :, :, None]
  • x හැඩය ඇත [batch_size, seq_len, n_hidden]
241    def forward(self, x: torch.Tensor):

විමසුම, යතුර සහ වටිනාකම් කාවැද්දීම් ලබා ගන්න (සියල්ල සංයුක්ත කර ඇත). පසුගිය මානයක් ප්රමාණය n_hidden -> සිට වෙනස් වනු ඇත 3 x n_hidden

247        qkv = self.qkv_lin(x)

හැඩයවෙනස් කිරීමෙන් හිස් වලට බෙදන්න [batch_size, seq_len, n_heads, 3 * d_k]

250        qkv = qkv.view(*qkv.shape[:-1], self.n_heads, -1)

විමසුමටබෙදන්න, යතුර සහ හැඩය එක් එක් අගය කරන්න [batch_size, seq_len, n_heads, 3 * d_k]

252        q, k, v = torch.split(qkv, qkv.shape[-1] // 3, dim=-1)

අපිපෙර ටෝකන වල තත්වයන් හැඹිලි කරන්නේ නම්

255        if get_cache().get('use_cache', False):

රාජ්යid ගේ ලබා ගන්න. අපි පෙර රාජ්යයන් ලබා ගැනීමට හා ඉදිරි රාජ්යයන් ගබඩා කිරීම සඳහා භාවිතා

257            prev_state_id, next_state_id = get_cache().get('state_ids')

හැඹිලියතිබේ නම්

259            if prev_state_id is not None:

අතීතයතුරු සහ අගයන් ලබා ගන්න. මේවාට හැඩය ඇත [batch_size, prev_seq_len, n_heads, d_k]

261                k_past, v_past = get_cache().pop(f'attn_kv_{prev_state_id}')

වත්මන්කාවැද්දීම් වල ඕෆ්සෙට්

263                offset = k_past.shape[1]

කඹයකාවැද්දීම් එකතු කරන්න

266                q = self.rope(q, offset=offset)
267                k = self.rope(k, offset=offset)

අතීතයසංයුක්ත කරන්න

270                k = torch.cat([k_past, k], dim=1)
271                v = torch.cat([v_past, v], dim=1)
272            else:

කඹයකාවැද්දීම් එකතු කරන්න

274                q = self.rope(q)
275                k = self.rope(k)

වත්මන්තත්වය සුරකින්න

278            get_cache().push(f'attn_kv_{next_state_id}', (k, v))
279        else:

හැඹිලියක්නැත - කඹය කාවැද්දීම් එකතු කරන්න

281            q = self.rope(q)
282            k = self.rope(k)

ෆ්ලෑෂ් අවධානය භාවිතා කරන්න

285        if self.flash_attention is not None and q.shape[1] == k.shape[1] and q.shape[-1] <= 128:
286            output = self.compute_flash_attention(q, k, v)

එසේ නොමැති නම්, සාමාන්ය අවධානය භාවිතා කරන්න

288        else:
289            output = self.compute_attention(q, k, v)

[batch_size, seq_len, n_heads, d_k] to Batch_size, seq_len, n_hidden`වෙතින් නැවත සකස් කරන්න

292        output = output.reshape(*x.shape)

අවසානරේඛීය ස්ථරය

295        return self.output(output)
297    def compute_flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):

ඒවා හැඩයට ගොඩගසන්න[batch_size, seq_len, 3, n_heads, d_k]

299        qkv = torch.stack((q, k, v), dim=2)
300        d_k = qkv.shape[-1]
301        if d_k <= 32:
302            pad = 32 - d_k
303        elif d_k <= 64:
304            pad = 64 - d_k
305        elif d_k <= 128:
306            pad = 128 - d_k
307        else:
308            raise ValueError(f'Head size {d_k} too large for flash attention')
309
310        if pad > 0:
311            qkv = torch.cat((qkv, qkv.new_zeros(*qkv.shape[:-1], pad)), dim=-1)
312
313        output, _ = self.flash_attention(qkv, causal=True)

ප්රතිදානය හැඩයෙන් යුක්ත වේ[batch_size, seq_len, n_heads, d_k + padding]

315        output = output[:, :, :, :d_k]
316
317        return output
319    def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):

අවධානයගණනය කිරීම සඳහා fp16 කිරීමට ස්වයංක්රීය-වාත්තු අක්රීය

321        with autocast(enabled=False):
322            if q.dtype == torch.float16:

වත්මන්dtype fp16 නම් fp32 බවට පරිවර්තනය කරන්න

324                attn = torch.einsum('bihk,bjhk->bijh', q.float(), k.float())
325            else:

bfloatසඳහා වාත්තු නොකරන්න

327                attn = torch.einsum('bihk,bjhk->bijh', q, k)

පරිමාණඅවධානය

330            attn = attn * self.scale

හේතුවෙස්මුහුණ ලබා ගන්න

333            mask = self._get_mask(attn)

වෙස්යොදන්න

335            attn.masked_fill_(mask, self.mask_fill)

අවධානයසොෆ්ට්මැක්ස්

338            attn = self.softmax(attn)

අවධානයබර තැබූ අගයන් ලබා ගන්න

341        output = torch.einsum('bijh,bjhk->bihk', attn.to(v.dtype), v)
342
343        return output

ප්රතිපෝෂණජාලය

346class FFNLayer(nn.Module):
  • n_hidden කාවැද්දීම ප්රමාණය වේ
351    def __init__(self, n_hidden: int = 6_144, d_ff: int = 0):
355        super().__init__()
356
357        if not d_ff:
358            d_ff = n_hidden * 4

පුළුල්රේඛීය ස්ථරය

361        self.dense_h_h4 = nn.Linear(n_hidden, d_ff)

GELUසක්රිය කිරීම

363        self.activation = nn.GELU()

සංකෝචනයරේඛීය ස්ථරය

365        self.dense_h4_h = nn.Linear(d_ff, n_hidden)
  • x හැඩය ඇත [batch_size, seq_len, n_hidden]
367    def forward(self, x: torch.Tensor):
371        x = self.dense_h_h4(x)
372        x = self.activation(x)
373        x = self.dense_h4_h(x)
374
375        return x

ට්රාන්ස්ෆෝමර්ස්ථරය

378class TransformerLayer(NeoXModule):
  • n_hidden කාවැද්දීම ප්රමාණය වේ
  • n_heads හිස් සංඛ්යාව වේ
  • is_flash_attention FlashAttention භාවිතා කළ යුතුද යන්න නියම කරයි

පිටත ක්රියාත්මක කිරීම අතහැර දැමීම ඇතුළත් නොවේ.

383    def __init__(self, n_hidden: int = 6_144, n_heads: int = 64, *, is_flash_attention: bool = False):
392        super().__init__()

අවධානයටපෙර ස්ථර සාමාන්යකරණය

395        self.pre_ln_attn = nn.LayerNorm(n_hidden)

FFNට පෙර ස්ථර සාමාන්යකරණය

397        self.pre_ln_ffn = nn.LayerNorm(n_hidden)

අවධානයස්ථරය

400        self.attention = AttentionLayer(n_hidden, n_heads, is_flash_attention=is_flash_attention)

FFNස්ථරය

402        self.ffn = FFNLayer(n_hidden)
  • x හැඩයේ කාවැද්දීම් වේ [batch_size, seq_len, n_hidden]
404    def forward(self, x: torch.Tensor):

අවශේෂසම්බන්ධතාවය

410        residual = x

නියෝක්ස්සමාන්තරව අවධානය සහ ප්රතිපෝෂණ ජාලය ක්රියාත්මක කරයි

412        attn = self.attention(self.pre_ln_attn(x))
413        ffn = self.ffn(self.pre_ln_ffn(x))

ඒවාසහ අවශේෂ සම්බන්ධතාවය එකතු කරන්න

415        return attn + ffn + residual

මුරපොලපූරණය කිරීමට කේතය

417    def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
421        with monit.section('Load transformer layer'):

අවධානයයොමු ප්රතිදානය පරිණාමනය

423            checkpoint.merge_params_sum(self.attention.output.bias, 'attention.dense.bias', p1, p2)
424            checkpoint.merge_params_dim_1(self.attention.output.weight, 'attention.dense.weight', p1, p2)

අවධානයවිමසුම, යතුර සහ අගය පරිවර්තනය කිරීම

427            checkpoint.merge_params_dim_0(self.attention.qkv_lin.bias, 'attention.query_key_value.bias', p1, p2)
428            checkpoint.merge_params_dim_0(self.attention.qkv_lin.weight, 'attention.query_key_value.weight', p1, p2)

අවධානයටපෙර ස්ථර සම්මතය

431            checkpoint.merge_params_duplicate(self.pre_ln_attn.bias, 'input_layernorm.bias', p1, p2)
432            checkpoint.merge_params_duplicate(self.pre_ln_attn.weight, 'input_layernorm.weight', p1, p2)

FFNදෙවන පරිණාමනය

435            checkpoint.merge_params_dim_0(self.ffn.dense_h_h4.bias, 'mlp.dense_h_to_4h.bias', p1, p2)
436            checkpoint.merge_params_dim_0(self.ffn.dense_h_h4.weight, 'mlp.dense_h_to_4h.weight', p1, p2)

FFNපළමු පරිවර්තනය

439            checkpoint.merge_params_sum(self.ffn.dense_h4_h.bias, 'mlp.dense_4h_to_h.bias', p1, p2)
440            checkpoint.merge_params_dim_1(self.ffn.dense_h4_h.weight, 'mlp.dense_4h_to_h.weight', p1, p2)

FFNට පෙර ස්ථර සම්මතය

443            checkpoint.merge_params_duplicate(self.pre_ln_ffn.bias, 'post_attention_layernorm.bias', p1, p2)
444            checkpoint.merge_params_duplicate(self.pre_ln_ffn.weight, 'post_attention_layernorm.weight', p1, p2)

අවසානසාමාන්යකරණ ස්තරය

447class FinalNorm(NeoXModule):
  • n_hidden කාවැද්දීම ප්රමාණය වේ
452    def __init__(self, n_hidden: int = 6_144):
456        super().__init__()
457
458        self.ln = nn.LayerNorm(n_hidden)
  • x හැඩයේ කාවැද්දීම් වේ [batch_size, seq_len, n_hidden]
460    def forward(self, x: torch.Tensor):
464        return self.ln(x)

මුරපොලපූරණය කිරීමට කේතය

466    def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
470        with monit.section('Load final normalization layer'):
471            checkpoint.merge_params_duplicate(self.ln.bias, 'norm.bias', p1, p2)
472            checkpoint.merge_params_duplicate(self.ln.weight, 'norm.weight', p1, p2)

කියවීමේස්ථරය

475class ReadoutLayer(NeoXModule):
  • n_hidden කාවැද්දීම ප්රමාණය වේ
  • n_vocab වචන මාලාවේ ප්රමාණය වේ
480    def __init__(self, n_hidden: int = 6_144, n_vocab: int = 50_432):
485        super().__init__()
486
487        self.linear = nn.Linear(n_hidden, n_vocab, bias=False)
  • x හැඩයේ කාවැද්දීම් වේ [batch_size, seq_len, n_hidden]
489    def forward(self, x: torch.Tensor):
493        return self.linear(x)

මුරපොලපූරණය කිරීමට කේතය

495    def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
499        with monit.section('Load final linear layer'):
500            checkpoint.merge_params_dim_0(self.linear.weight, 'final_linear.weight', p1, p2)
503class LayerGenerator:
504    pre_created_layers: Dict[Any, Optional[NeoXModule]]

ස්ථර නිර්මාණය කිරීමට උත්පාදක යන්ත්රය

ස්ථර ජනනය කරනු ලබන්නේ මුරපොලවල් මෙන් එකම අනුපිළිවෙලකි.

ස්ථරයක් නොමැතිNone විට එය ලබා දෙයි; අපි ස්ථර දර්ශක නියෝක්ස් ලෙස භාවිතා කරන අතර අපගේ ක්රියාත්මක කිරීමේදී අපට අවශ්ය නොවන පරිවර්තන ස්ථර දෙකක් තිබේ.

  • n_vocab යනු වචන මාලාවේ ටෝකන ගණන
  • n_hidden මෙම කාවැද්දීම් දී විශේෂාංග සංඛ්යාව
  • n_layers ට්රාන්ස්ෆෝමර් ස්ථර ගණන වේ
  • n_heads අවධානය යොමු ප්රධානීන් සංඛ්යාව වේ
  • filter_layers භාවිතා කළ යුතු ස්ථර සමූහයයි. කිසිවක් නොමැති නම් සියලුම ස්ථර භාවිතා කරනු ඇත. අඩු ස්ථර සහිත ආකෘතියේ කුඩා අනුවාදයන් පරීක්ෂා කිරීමට මෙය භාවිතා
  • කරයි
  • is_clone_layers ට්රාන්ස්ෆෝමර් ස්ථර ක්ලෝන කළ යුතුද යන්න නියම කරයි (ටිකක් වේගවත්)
  • dtype ආකෘතියේ දත්ත වර්ගයයි
  • device ආකෘතියේ උපාංගය වේ
  • is_llm_int8 INT8 ප්රමාණකරණය භාවිතා කළ යුතුද යන්න නියම කරයි
  • llm_int8_threshold පිටත විශේෂාංග වෙන් කිරීම සඳහා භාවිතා කරන එළිපත්ත වේ
  • is_flash_attention FlashAttention භාවිතා කළ යුතුද යන්න නියම කරයි
506    def __init__(self, *, n_vocab: int = 50_432, n_hidden: int = 6_144,
507                 n_layers: int = 44, n_heads: int = 64,
508                 filter_layers: Optional[Set] = None,
509                 is_clone_layers: bool = True,
510                 dtype: torch.dtype = torch.float,
511                 device: torch.device = torch.device('cpu'),
512                 is_llm_int8: bool = False,
513                 llm_int8_threshold: float = 6.0,
514                 is_flash_attention: bool = False
515                 ):
538        if filter_layers is None:
539            filter_layers = set(range(n_layers + 3))
540
541        self.n_vocab = n_vocab
542        self.n_hidden = n_hidden
543        self.n_layers = n_layers
544        self.n_heads = n_heads
545        self.filter_layers = filter_layers
546        self.is_clone_layers = is_clone_layers
547        self.dtype = dtype
548        self.device = device
549        self.is_llm_int8 = is_llm_int8
550        self.llm_int8_threshold = llm_int8_threshold
551        self.is_flash_attention = is_flash_attention
552
553        self.pre_created_layers = dict(
554            transformer_layer=None,
555        )

භාවිතයසඳහා ස්තරය සකස් කරයි

අපිස්තරය උපාංගය වෙත ගෙන ගොස් නිවැරදි දත්ත වර්ගයට පරිවර්තනය කරමු

  • layer සකස් කළ යුතු ස්ථරයයි
  • සකස්කළ ස්තරයආපසු ලබා දෙයි

557    def _prepare_layer(self, layer: NeoXModule):
566        return layer.to(self.device, self.dtype)

පිරික්සුම්ස්ථානය පැටවීමෙන් පසු ස්ථර පරිවර්තනයන්

මෙමශ්රිතය පිරික්සුම් ස්ථානය පැටවීමෙන් පසු ස්ථර පරිවර්තනයන් ක්රියාත්මක කරයි.

දැනටඑය අදාළ වන්නේ INT8 ප්රමාණකරණය පමණි.

  • layer සකස් කළ යුතු ස්ථරයයි
  • is_llm_int8 INT8 ප්රමාණකරණය භාවිතා කළ යුතුද යන්න නියම කරයි
  • device ආකෘතියේ උපාංගය වේ
  • llm_int8_threshold යනු පිටත විශේෂාංග වෙන් කිරීම සඳහා භාවිතා කරන එළිපත්ත
  • සකස්කළ ස්තරයආපසු ලබා දෙයි

568    @torch.no_grad()
569    def post_load_prepare(self, layer: NeoXModule, *,
570                          is_llm_int8: bool = None,
571                          device: torch.device = None,
572                          llm_int8_threshold: float = None,
573                          ):

නියමකර නොමැති නම් පෙරනිමි අගයන් ලබා ගන්න

591        if is_llm_int8 is None:
592            is_llm_int8 = self.is_llm_int8
593        if device is None:
594            device = self.device
595        if llm_int8_threshold is None:
596            llm_int8_threshold = self.llm_int8_threshold

INT8ප්රමාණකරණය භාවිතා නොකරන්නේ නම් මඟ හරින්න

599        if not is_llm_int8:
600            return layer

ට්රාන්ස්ෆෝමර්ස්ථර වල රේඛීය ස්ථර පමණක් පරිවර්තනය කරන්න

603        if not isinstance(layer, TransformerLayer):
604            return layer

උපයෝගිතාවල make_llm_int8_linear අර්ථ දක්වා ඇති භාවිතා කරන්න.

607        from labml_nn.neox.utils.llm_int8 import make_llm_int8_linear

රේඛීයස්ථර පරිවර්තනය කරන්න

610        with monit.section('Convert to int8'):
611            layer.attention.output = make_llm_int8_linear(layer.attention.output,
612                                                          device=device,
613                                                          threshold=llm_int8_threshold)
614            layer.attention.qkv_lin = make_llm_int8_linear(layer.attention.qkv_lin,
615                                                           device=device,
616                                                           threshold=llm_int8_threshold)
617            layer.ffn.dense_h_h4 = make_llm_int8_linear(layer.ffn.dense_h_h4,
618                                                        device=device,
619                                                        threshold=llm_int8_threshold)
620            layer.ffn.dense_h4_h = make_llm_int8_linear(layer.ffn.dense_h4_h,
621                                                        device=device,
622                                                        threshold=llm_int8_threshold)

624        return layer

ස්තරයක්නිර්මාණය කර හැච් කරයි

පරාමිතීන්ආරම්භ කිරීමට කාලය ගතවන නිසා හැඹිලි ස්ථර පිටපත් කිරීම නව ස්ථර ආරම්භ කිරීමට වඩා වේගවත් වේ.

  • name යනු ස්තරයේ නමයි
  • creator ස්තරය නිර්මාණය කිරීමේ කාර්යයයි
  • සාදනලද ස්තරය හෝ කැච් ස්ථරයේ පිටපතක්ආපසු ලබා දෙයි

626    def _create_and_cache_layer(self, name: str, creator: Callable[[], NeoXModule]):
638        if not self.is_clone_layers:
639            return self._prepare_layer(creator())
640
641        if self.pre_created_layers[name] is None:
642            self.pre_created_layers[name] = self._prepare_layer(creator())
643
644        layer = copy.deepcopy(self.pre_created_layers[name])
645        return layer
647    def _create_transformer_layer(self):
648        return self._create_and_cache_layer(
649            'transformer_layer',
650            lambda: TransformerLayer(self.n_hidden, self.n_heads, is_flash_attention=self.is_flash_attention)
651        )
653    def _create_embedding_layer(self):
654        return Embedding(self.n_vocab, self.n_hidden)
656    def _create_final_norm_layer(self):
657        return FinalNorm(self.n_hidden)
659    def _create_readout_layer(self):
660        return ReadoutLayer(self.n_hidden, self.n_vocab)

ස්ථරලබා ගැනීම සඳහා උත්පාදක යන්ත්රය

662    @torch.no_grad()
663    def get_layers(self) -> Generator[Tuple[NeoXModule, Tuple[str, str]], None, None]:

කාවැද්දීමස්ථරය

668        if 0 in self.filter_layers:
669            with monit.section('Embedding layer'):
670                layer = self._prepare_layer(self._create_embedding_layer())
671            yield layer, ('layer_00-model_00-model_states.pt', 'layer_00-model_01-model_states.pt')

ට්රාන්ස්ෆෝමර්ස්ථර

674        for i in range(self.n_layers):

ට්රාන්ස්ෆෝමර්ස්ථරය

676            if i + 1 in self.filter_layers:
677                with monit.section(f'Transformer Layer {i}'):
678                    yield self._create_transformer_layer(), \
679                          (f'layer_{i + 2 :02d}-model_00-model_states.pt',
680                           f'layer_{i + 2 :02d}-model_01-model_states.pt')

අවසානසාමාන්යකරණ ස්තරය

683        if self.n_layers + 1 in self.filter_layers:
684            with monit.section('Final norm layer'):
685                layer = self._prepare_layer(self._create_final_norm_layer())
686            yield layer, ('layer_47-model_00-model_states.pt', 'layer_47-model_01-model_states.pt')

කියවීමේස්ථරය

689        if self.n_layers + 2 in self.filter_layers:
690            with monit.section('Readout layer'):
691                layer = self._prepare_layer(self._create_readout_layer())
692            yield layer, ('layer_48-model_00-model_states.pt', 'layer_48-model_01-model_states.pt')
693
694        for k in self.pre_created_layers.keys():
695            self.pre_created_layers[k] = None

මුළුස්ථර ගණන නැවත ලබා දෙයි

697    @property
698    def total_layers(self):
702        return self.n_layers + 3

ස්ථරපැටවීමට උත්පාදක යන්ත්රය

704    @torch.no_grad()
705    def load(self) -> Generator[NeoXModule, None, None]:
709        with monit.section("Layers"):
710            for i, (layer, files) in enumerate(self.get_layers()):
711                if files is not None:
712                    layer.load_state(*checkpoint.load_checkpoint_files(files))
713
714                layer = self.post_load_prepare(layer)
715
716                monit.progress(min(0.99, (i + 1) / self.total_layers))
717                yield layer