GPT-නියෝක්ස්සමඟ පෙළ ජනනය කරන්න

තනිGPU එකක් සමඟ GPT-neox වෙතින් පෙළ ජනනය කරන්නේ කෙසේද යන්න මෙයින් පෙන්වයි.

මේසඳහා 45GB ට වඩා වැඩි මතකයක් සහිත GPU එකක් අවශ්ය වේ.

15

ආනයන

16from typing import List
17
18import torch
19from torch import nn
20
21from labml import monit
22from labml_nn.neox.model import LayerGenerator
23from labml_nn.neox.utils import get_tokens, print_tokens
24from labml_nn.neox.utils.cache import get_cache

පැටවියයුතු ස්ථර ලැයිස්තුව. මෙය පරීක්ෂා කිරීම සඳහා භාවිතා වේ. ට්රාන්ස්ෆෝමර් ස්ථර වලට පළමු පටවනු ලබන {0, 1} පරිදි ඔබට වැනි ස්ථර උප කුලකයක් පැවරිය හැකිය.

29LAYERS = None

සම්පූර්ණකිරීමට විමසන්න

32PROMPT = 'Einstein was born in the German Empire, but moved to Switzerland in 1895, forsaking his German'

ඊළඟටෝකනය පුරෝකථනය කරන්න

  • model ආකෘතිය වේ
  • ids ආදාන ටෝකන හැඳුනුම් වේ
  • device ආකෘතියේ උපාංගය වේ
  • 35def infer(model: nn.Module, ids: List[int], device: torch.device):
    44    with torch.no_grad():

    ටෝකනලබා ගන්න

    46        x = torch.tensor(ids)[None, :].to(device)

    එවාල්ආකෘතිය

    48        x = model(x)

    පුරෝකථනයකළ ටෝකනය ආපසු යන්න

    51    return x[0].max(dim=-1)[1].tolist()

    පෙළජනනය කරන්න

    54def generate():

    වේගවත්උත්පාදනය සඳහා අතරමැදි යතුර/වටිනාකම් යුගල හැඹිලි කිරීමට හැඹිලිය පිහිටුවන්න

    60    cache = get_cache()
    61    cache.set('use_cache', True)

    උපාංගය

    64    device = torch.device('cuda:0')

    ස්ථරපූරණය කරන්න

    67    layers = list(LayerGenerator(is_clone_layers=True,
    68                                 filter_layers=LAYERS,
    69                                 dtype=torch.float16,
    70                                 device=device,
    71                                 ).load())
    72
    73    model = nn.Sequential(*layers)

    ටෝකන්හැඳුනුම්පත් ලබා ගන්න

    76    ids = get_tokens(PROMPT)

    ආකෘතියධාවනය කරන්න

    79    cache.set('state_ids', (None, 1))
    80    with monit.section('Infer'):
    81        next_token = infer(model, ids, device)[-1]

    පුරෝකථනයකළ ටෝකනය එක් කරන්න

    84    ids += [next_token]

    ටෝකන100 ක් පුරෝකථනය කරන්න

    87    for i in range(1, 100):

    හැඹිලිසක්රිය කිරීම් භාවිතා කිරීමට රාජ්යය සකසන්න

    89        cache.set('state_ids', (i, i + 1))

    ඊළඟටෝකනය ලබා ගන්න. පෙර ටෝකන වල යතුර/අගය යුගල හැඹිලි කරන නිසා අපි ආකෘතියට අවසාන ටෝකනය පමණක් පෝෂණය කරන බව සලකන්න.

    92        with monit.section('Infer'):
    93            next_token = infer(model, [next_token], device)[-1]

    පුරෝකථනයකළ ටෝකනය එක් කරන්න

    95        ids += [next_token]

    මුද්රණය

    97        print_tokens(ids, [ids])

    101if __name__ == '__main__':
    102    generate()