mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-10-29 09:38:56 +08:00
notes
This commit is contained in:
@ -518,7 +518,22 @@ class LayerGenerator:
|
||||
device: torch.device = None,
|
||||
llm_int8_threshold: float = None,
|
||||
):
|
||||
# If we are using int8 quantization, we need to convert the layer to int8
|
||||
"""
|
||||
<a id="post_load_prepare"></a>
|
||||
### Layer transformations after loading the checkpoint
|
||||
|
||||
This function implements layer transformations after loading the checkpoint.
|
||||
|
||||
Currently, it only applies the int8 quantization.
|
||||
|
||||
:param layer: is the layer to prepare
|
||||
:param is_llm_int8: specifies whether to use int8 quantization
|
||||
:param device: is the device of the model
|
||||
:param llm_int8_threshold: is the threshold $\alpha$ used to separate outlier features
|
||||
:return: the prepared layer
|
||||
"""
|
||||
|
||||
# Get default values if not specified
|
||||
if is_llm_int8 is None:
|
||||
is_llm_int8 = self.is_llm_int8
|
||||
if device is None:
|
||||
@ -526,6 +541,7 @@ class LayerGenerator:
|
||||
if llm_int8_threshold is None:
|
||||
llm_int8_threshold = self.llm_int8_threshold
|
||||
|
||||
# Skip if not using int8 quantization
|
||||
if not is_llm_int8:
|
||||
return layer
|
||||
|
||||
@ -536,7 +552,7 @@ class LayerGenerator:
|
||||
# Use `make_llm_int8_linear` defined in [utilities](./utils/llm_int8.html).
|
||||
from labml_nn.neox.utils.llm_int8 import make_llm_int8_linear
|
||||
|
||||
#
|
||||
# Convert the linear layers
|
||||
with monit.section('Convert to int8'):
|
||||
layer.attention.output = make_llm_int8_linear(layer.attention.output,
|
||||
device=device,
|
||||
|
||||
@ -1,3 +1,17 @@
|
||||
"""
|
||||
---
|
||||
title: Generate Text with GPT-NeoX using LLM.int8() quantization
|
||||
summary: >
|
||||
Generate Text with GPT-NeoX using LLM.int8() quantization
|
||||
---
|
||||
|
||||
# Generate Text with GPT-NeoX using LLM.int8() quantization
|
||||
|
||||
This shows how to generate text from GPT-NeoX using [LLM.int8() quantization](../utils/llm_int8.html).
|
||||
|
||||
This needs a GPU with more than 45GB memory.
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
@ -5,31 +19,10 @@ from torch import nn
|
||||
|
||||
from labml import monit
|
||||
from labml_nn.neox.model import LayerGenerator
|
||||
from labml_nn.neox.samples.generate import PROMPT, infer
|
||||
from labml_nn.neox.utils import get_tokens, print_tokens
|
||||
from labml_nn.neox.utils.cache import get_cache
|
||||
|
||||
# Prompt to complete
|
||||
PROMPT = 'Einstein was born in the German Empire, but moved to Switzerland in 1895, forsaking his German'
|
||||
|
||||
|
||||
def infer(model: nn.Module, ids: List[int], device: torch.device):
|
||||
"""
|
||||
### Predict the next token
|
||||
|
||||
:param layers: is the list of layers
|
||||
:param ids: are the input token ids
|
||||
:param device: is the device of the model
|
||||
"""
|
||||
|
||||
with torch.no_grad():
|
||||
# Get the tokens
|
||||
x = torch.tensor(ids)[None, :].to(device)
|
||||
# Eval model
|
||||
x = model(x)
|
||||
|
||||
# Return predicted token
|
||||
return x[0].max(dim=-1)[1].tolist()
|
||||
|
||||
|
||||
def generate():
|
||||
"""
|
||||
@ -43,12 +36,14 @@ def generate():
|
||||
# Device
|
||||
device = torch.device('cuda:0')
|
||||
|
||||
# Load layers in float16 into CPU. We convert the layers to int8 later, because doing that
|
||||
# on the fly after loading layers to GPU causes CUDA memory fragmentation
|
||||
# (about 3GB memory can get lost due to fragmentation).
|
||||
layer_generator = LayerGenerator(is_clone_layers=True,
|
||||
dtype=torch.float16,
|
||||
device=torch.device('cpu'),
|
||||
# is_llm_int8=True,
|
||||
)
|
||||
# Load layers
|
||||
layers = list(layer_generator.load())
|
||||
|
||||
# This reduces CUDA memory fragmentation
|
||||
@ -60,10 +55,11 @@ def generate():
|
||||
)
|
||||
layer.to(device)
|
||||
|
||||
# Create `nn.Sequential` model
|
||||
model = nn.Sequential(*layers)
|
||||
|
||||
# Clear cache and print memory summary for debugging
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
print(torch.cuda.memory_summary())
|
||||
|
||||
# Get token ids
|
||||
|
||||
@ -1,8 +1,36 @@
|
||||
"""
|
||||
* [Generate](../samples/llm_int8.html)
|
||||
* [Evaluation](../evaluation/llm_int8.html)
|
||||
---
|
||||
title: LLM.int8() on GPT-NeoX
|
||||
summary: >
|
||||
Transform nn.Linear layers to 8-bit integer layers.
|
||||
---
|
||||
|
||||
# LLM.int() on GPT-NeoX
|
||||
|
||||
This implements a utility function to transform a `nn.Linear` layer to LLM.int8() linear layer.
|
||||
|
||||
[LLM.int8() paper](https://papers.labml.ai/paper/eb2bcaee1d0011edaa66a71c10a887e7)
|
||||
shows you can use int8 quantization while handling outliers to
|
||||
reduce memory footprint without performance degradation in large language models.
|
||||
They convert weights and inputs to scaled 8-bit integers and does matrix multiplication
|
||||
producing int32 results which is then converted back to float16 and rescaled.
|
||||
They show that in large langauge models, some features can give extreme values (outliers)
|
||||
that dominate the model's output.
|
||||
These features get clamped in 8-bit integer space which causes the model performance to degrade.
|
||||
As a solution they pick these outliers (greater than a specified threshold)
|
||||
and compute their multiplications separately in float16 space.
|
||||
Since the percentage of outliers is around 0.01% this doesn't increase memory usage,
|
||||
and prevents the model from degrading performance.
|
||||
|
||||
The code to transform GPT-NoeX layers is defined in [model.py](../model.html#post_load_prepare).
|
||||
|
||||
Here are example uses of GPT-NeoX with int8 quantization.
|
||||
|
||||
* [Generate Text](../samples/llm_int8.html)
|
||||
* [Run Evaluation Tests](../evaluation/llm_int8.html)
|
||||
"""
|
||||
|
||||
# Import [`bitsandbytes`](https://github.com/timdettmers/bitsandbytes) package
|
||||
try:
|
||||
from bitsandbytes.nn import Linear8bitLt, Int8Params
|
||||
except ImportError:
|
||||
@ -13,7 +41,18 @@ from torch import nn
|
||||
|
||||
|
||||
def make_llm_int8_linear(linear_module: nn.Linear, device: torch.device, threshold: float = 6.0):
|
||||
# Create a Linear8bitLt module
|
||||
"""
|
||||
## Transform a `nn.Linear` layer to LLM.int8() linear layer
|
||||
|
||||
:param linear_module: is the `nn.Linear` layer to transform
|
||||
:param device: is the device of the model
|
||||
:param threshold: is the threshold $\alpha$ to use for outlier detection
|
||||
"""
|
||||
|
||||
#
|
||||
assert isinstance(linear_module, nn.Linear)
|
||||
|
||||
# Create an empty Linear8bitLt module
|
||||
int8_lin = Linear8bitLt(
|
||||
linear_module.in_features,
|
||||
linear_module.out_features,
|
||||
@ -22,15 +61,15 @@ def make_llm_int8_linear(linear_module: nn.Linear, device: torch.device, thresho
|
||||
threshold=threshold,
|
||||
)
|
||||
|
||||
# Set the weights
|
||||
# Quantize the weights
|
||||
int8_lin._parameters['weight'] = Int8Params(linear_module.weight.data.cpu(),
|
||||
requires_grad=False,
|
||||
has_fp16_weights=False).to(device)
|
||||
|
||||
# Set the bias.
|
||||
# We don't have to convert this to Int8 since it doesn't use a lot of memory.
|
||||
# Set the bias in float16 space
|
||||
if linear_module.bias is not None:
|
||||
int8_lin._parameters['bias'] = nn.Parameter(linear_module.bias.data,
|
||||
requires_grad=False)
|
||||
|
||||
#
|
||||
return int8_lin
|
||||
|
||||
Reference in New Issue
Block a user