This shows how to generate text from GPT-NeoX using LLM.int8() quantization.
This needs a GPU with 24GB memory.
15import torch
16from torch import nn
17
18from labml import monit
19from labml_nn.neox.model import LayerGenerator
20from labml_nn.neox.samples.generate import PROMPT, infer
21from labml_nn.neox.utils import get_tokens, print_tokens
22from labml_nn.neox.utils.cache import get_cache25def generate():31    cache = get_cache()
32    cache.set('use_cache', True)Device
35    device = torch.device('cuda:0')Load layers in float16 into CPU. We convert the layers to int8 later, because doing that on the fly after loading layers to GPU causes CUDA memory fragmentation (about 3GB memory can get lost due to fragmentation).
40    layer_generator = LayerGenerator(is_clone_layers=True,
41                                     dtype=torch.float16,
42                                     device=torch.device('cpu'),
43                                     is_llm_int8=False,
44                                     )
45    layers = list(layer_generator.load())This reduces CUDA memory fragmentation
48    for layer in monit.iterate('Convert to int8', layers, is_children_silent=True):
49        layer_generator.post_load_prepare(layer,
50                                          device=device,
51                                          is_llm_int8=True,
52                                          llm_int8_threshold=6.0,
53                                          )
54        layer.to(device)Create nn.Sequential
 model 
57    model = nn.Sequential(*layers)Clear cache and print memory summary for debugging
60    torch.cuda.empty_cache()
61    print(torch.cuda.memory_summary())Get token ids
64    ids = get_tokens(PROMPT)Run the model. We use the infer
 function defined in generate.py
 
68    cache.set('state_ids', (None, 1))
69    with monit.section('Infer'):
70        next_token = infer(model, ids, device)[-1]Append the predicted token
73    ids += [next_token]Predict 100 tokens
76    for i in range(1, 100):Set the state to use cached activations
78        cache.set('state_ids', (i, i + 1))Get next token. Note that we only feed the last token to the model because we cache the key/value pairs of previous tokens.
81        with monit.section('Infer'):
82            next_token = infer(model, [next_token], device)[-1]Append the predicted token
84        ids += [next_token]86        print_tokens(ids, [ids])90if __name__ == '__main__':
91    generate()