This shows how to generate text from GPT-NeoX with a single GPU.
This needs a GPU with more than 45GB memory.
15Imports
16from typing import List
17
18import torch
19from torch import nn
20
21from labml import monit
22from labml_nn.neox.model import LayerGenerator
23from labml_nn.neox.utils import get_tokens, print_tokens
24from labml_nn.neox.utils.cache import get_cacheList of layers to load. This is used for testing. You can assign a subset of layers like {0, 1}
 so that it only loads the first to transformer layers. 
29LAYERS = NonePrompt to complete
32PROMPT = 'Einstein was born in the German Empire, but moved to Switzerland in 1895, forsaking his German'model
  is the model ids
  are the input token ids device
  is the device of the model35def infer(model: nn.Module, ids: List[int], device: torch.device):44    with torch.no_grad():Get the tokens
46        x = torch.tensor(ids)[None, :].to(device)Eval model
48        x = model(x)Return predicted token
51    return x[0].max(dim=-1)[1].tolist()54def generate():60    cache = get_cache()
61    cache.set('use_cache', True)Device
64    device = torch.device('cuda:0')Load layers
67    layers = list(LayerGenerator(is_clone_layers=True,
68                                 filter_layers=LAYERS,
69                                 dtype=torch.float16,
70                                 device=device,
71                                 ).load())
72
73    model = nn.Sequential(*layers)Get token ids
76    ids = get_tokens(PROMPT)Run the model
79    cache.set('state_ids', (None, 1))
80    with monit.section('Infer'):
81        next_token = infer(model, ids, device)[-1]Append the predicted token
84    ids += [next_token]Predict 100 tokens
87    for i in range(1, 100):Set the state to use cached activations
89        cache.set('state_ids', (i, i + 1))Get next token. Note that we only feed the last token to the model because we cache the key/value pairs of previous tokens.
92        with monit.section('Infer'):
93            next_token = infer(model, [next_token], device)[-1]Append the predicted token
95        ids += [next_token]97        print_tokens(ids, [ids])101if __name__ == '__main__':
102    generate()