mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-10-30 02:08:50 +08:00
103 lines
2.6 KiB
Python
103 lines
2.6 KiB
Python
"""
|
|
---
|
|
title: Generate Text with GPT-NeoX
|
|
summary: >
|
|
Generate Text with GPT-NeoX
|
|
---
|
|
|
|
# Generate Text with GPT-NeoX
|
|
|
|
This shows how to generate text from GPT-NeoX with a single GPU.
|
|
|
|
This needs a GPU with more than 45GB memory.
|
|
"""
|
|
|
|
# Imports
|
|
from typing import List
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from labml import monit
|
|
from labml_nn.neox.model import LayerGenerator
|
|
from labml_nn.neox.utils import get_tokens, print_tokens
|
|
from labml_nn.neox.utils.cache import get_cache
|
|
|
|
# List of layers to load. This is used for testing.
|
|
# You can assign a subset of layers like `{0, 1}` so that it only loads
|
|
# the first to transformer layers.
|
|
LAYERS = None
|
|
|
|
# Prompt to complete
|
|
PROMPT = 'Einstein was born in the German Empire, but moved to Switzerland in 1895, forsaking his German'
|
|
|
|
|
|
def infer(model: nn.Module, ids: List[int], device: torch.device):
|
|
"""
|
|
### Predict the next token
|
|
|
|
:param model: is the model
|
|
:param ids: are the input token ids
|
|
:param device: is the device of the model
|
|
"""
|
|
|
|
with torch.no_grad():
|
|
# Get the tokens
|
|
x = torch.tensor(ids)[None, :].to(device)
|
|
# Eval model
|
|
x = model(x)
|
|
|
|
# Return predicted token
|
|
return x[0].max(dim=-1)[1].tolist()
|
|
|
|
|
|
def generate():
|
|
"""
|
|
## Generate text
|
|
"""
|
|
|
|
# Setup [cache](../utils/cache.html) to cache intermediate key/value pairs for faster generation
|
|
cache = get_cache()
|
|
cache.set('use_cache', True)
|
|
|
|
# Device
|
|
device = torch.device('cuda:0')
|
|
|
|
# Load layers
|
|
layers = list(LayerGenerator(is_clone_layers=True,
|
|
filter_layers=LAYERS,
|
|
dtype=torch.float16,
|
|
device=device,
|
|
).load())
|
|
|
|
model = nn.Sequential(*layers)
|
|
|
|
# Get token ids
|
|
ids = get_tokens(PROMPT)
|
|
|
|
# Run the model
|
|
cache.set('state_ids', (None, 1))
|
|
with monit.section('Infer'):
|
|
next_token = infer(model, ids, device)[-1]
|
|
|
|
# Append the predicted token
|
|
ids += [next_token]
|
|
|
|
# Predict 100 tokens
|
|
for i in range(1, 100):
|
|
# Set the state to use cached activations
|
|
cache.set('state_ids', (i, i + 1))
|
|
# Get next token. Note that we only feed the last token to the model because
|
|
# we cache the key/value pairs of previous tokens.
|
|
with monit.section('Infer'):
|
|
next_token = infer(model, [next_token], device)[-1]
|
|
# Append the predicted token
|
|
ids += [next_token]
|
|
# Print
|
|
print_tokens(ids, [ids])
|
|
|
|
|
|
#
|
|
if __name__ == '__main__':
|
|
generate()
|