From e19d95f9c34ab0c0a070ad2c4447ad032798a93e Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Sat, 20 Aug 2022 10:45:31 +0530 Subject: [PATCH] notes --- docs/neox/evaluation/llm_int8.html | 177 +++++++++++++ docs/neox/model.html | 404 +++++++++++++++++++++-------- docs/neox/readme.html | 129 +++++++++ docs/neox/samples/llm_int8.html | 362 ++++++++++++++++++++++++++ docs/neox/utils/llm_int8.html | 247 ++++++++++++++++++ docs/sitemap.xml | 59 +++-- labml_nn/neox/model.py | 20 +- labml_nn/neox/samples/llm_int8.py | 44 ++-- labml_nn/neox/utils/llm_int8.py | 51 +++- 9 files changed, 1336 insertions(+), 157 deletions(-) create mode 100644 docs/neox/evaluation/llm_int8.html create mode 100644 docs/neox/readme.html create mode 100644 docs/neox/samples/llm_int8.html create mode 100644 docs/neox/utils/llm_int8.html diff --git a/docs/neox/evaluation/llm_int8.html b/docs/neox/evaluation/llm_int8.html new file mode 100644 index 00000000..60926b00 --- /dev/null +++ b/docs/neox/evaluation/llm_int8.html @@ -0,0 +1,177 @@ + + + + + + + + + + + + + + + + + + + + + + + llm_int8.py + + + + + + + + + + +
+
+
+ +
+
+
+ + +
+
+
1import torch
+2from torch import nn
+3
+4from labml import monit
+5from labml_nn.neox.evaluation import run_eval_harness
+6from labml_nn.neox.model import LayerGenerator
+7
+8if __name__ == '__main__':
+9    device = torch.device('cuda:0')
+10    layer_generator = LayerGenerator(is_clone_layers=True,
+11                                     dtype=torch.float16,
+12                                     device=torch.device('cpu'),
+13                                     )
+
+
+
+
+ +

Load layers

+ +
+
+
15    layers = list(layer_generator.load())
+
+
+
+
+ +

This reduces CUDA memory fragmentation

+ +
+
+
18    for layer in monit.iterate('Convert to int8', layers, is_children_silent=True):
+19        layer_generator.post_load_prepare(layer,
+20                                          device=device,
+21                                          is_llm_int8=True,
+22                                          llm_int8_threshold=6.0,
+23                                          )
+24        layer.to(device)
+25
+26    with monit.section('Sequential'):
+27        model = nn.Sequential(*layers)
+28
+29    print(run_eval_harness(model, 'half_precision', [], device))
+
+
+ +
+ + + + \ No newline at end of file diff --git a/docs/neox/model.html b/docs/neox/model.html index 8f793cb5..7759d28e 100644 --- a/docs/neox/model.html +++ b/docs/neox/model.html @@ -230,7 +230,7 @@ + is the base for , which defaults to
@@ -253,7 +253,7 @@ -

To store for the features

+

To store for the features

@@ -265,7 +265,7 @@ -

Cache and

+

Cache and

@@ -278,7 +278,7 @@ -

Base for

+

Base for

@@ -357,7 +357,7 @@ -

Initialize

+

Initialize

@@ -369,7 +369,7 @@ -

+

@@ -382,7 +382,7 @@ -

Initialize and cache

+

Initialize and cache

@@ -399,7 +399,7 @@ -

Get position indexes

+

Get position indexes

@@ -411,7 +411,7 @@ -

+

@@ -423,8 +423,8 @@ -

Concatenate so that for row we have

-

+

Concatenate so that for row we have

+

@@ -436,7 +436,7 @@ -

Calculate and in fp32

+

Calculate and in fp32

@@ -501,7 +501,7 @@ #

RoPE embeddings

-

for

+

for

@@ -1566,7 +1566,10 @@ is the data type of the model
  • device is the device of the model
  • -

    Returns the layers as a generator

    +
  • is_llm_int8 + specifies whether to use int8 quantization
  • +
  • llm_int8_threshold + is the threshold used to separate outlier features
  • @@ -1575,7 +1578,10 @@ 459 filter_layers: Optional[Set] = None, 460 is_clone_layers: bool = True, 461 dtype: torch.dtype = torch.float, -462 device: torch.device = torch.device('cpu')):
    +462 device: torch.device = torch.device('cpu'), +463 is_llm_int8: bool = False, +464 llm_int8_threshold: float = 6.0, +465 ):
    @@ -1586,34 +1592,39 @@
    -
    482        if filter_layers is None:
    -483            filter_layers = set(range(n_layers + 3))
    -484
    -485        self.n_vocab = n_vocab
    -486        self.n_hidden = n_hidden
    -487        self.n_layers = n_layers
    -488        self.n_heads = n_heads
    -489        self.filter_layers = filter_layers
    -490        self.is_clone_layers = is_clone_layers
    -491        self.dtype = dtype
    -492        self.device = device
    -493
    -494        self.pre_created_layers = dict(
    -495            transformer_layer=None,
    -496        )
    +
    486        if filter_layers is None:
    +487            filter_layers = set(range(n_layers + 3))
    +488
    +489        self.n_vocab = n_vocab
    +490        self.n_hidden = n_hidden
    +491        self.n_layers = n_layers
    +492        self.n_heads = n_heads
    +493        self.filter_layers = filter_layers
    +494        self.is_clone_layers = is_clone_layers
    +495        self.dtype = dtype
    +496        self.device = device
    +497        self.is_llm_int8 = is_llm_int8
    +498        self.llm_int8_threshold = llm_int8_threshold
    +499
    +500        self.pre_created_layers = dict(
    +501            transformer_layer=None,
    +502        )
    -
    +
    - +

    Prepares the layer for usage

    +

    We move the layer to the device and convert it to the correct data type

    +
    • layer + is the layer to prepare
    • +

      Returns the prepared layer

    +
    -
    498    def _prepare_layer(self, layer: NeoXModule):
    -499        layer = layer.to(self.device, self.dtype)
    -500        return layer
    +
    504    def _prepare_layer(self, layer: NeoXModule):
    @@ -1624,33 +1635,35 @@
    -
    502    def _create_and_cache_layer(self, name: str, creator: Callable[[], NeoXModule]):
    -503        if self.pre_created_layers[name] is None or not self.is_clone_layers:
    -504            layer = creator()
    -505        else:
    -506            layer = copy.deepcopy(self.pre_created_layers[name])
    -507
    -508        layer: NeoXModule = self._prepare_layer(layer)
    -509
    -510        if self.pre_created_layers[name] is None:
    -511            self.pre_created_layers[name] = layer
    -512
    -513        return layer
    +
    513        return layer.to(self.device, self.dtype)
    -
    +
    - +

    ### Layer transformations after loading the checkpoint

    +

    This function implements layer transformations after loading the checkpoint.

    +

    Currently, it only applies the int8 quantization.

    +
    • layer + is the layer to prepare
    • +
    • is_llm_int8 + specifies whether to use int8 quantization
    • +
    • device + is the device of the model
    • +
    • llm_int8_threshold + is the threshold used to separate outlier features
    • +

      Returns the prepared layer

    +
    -
    515    def _create_transformer_layer(self):
    -516        return self._create_and_cache_layer(
    -517            'transformer_layer',
    -518            lambda: TransformerLayer(self.n_hidden, self.n_heads)
    -519        )
    +
    515    @torch.no_grad()
    +516    def post_load_prepare(self, layer: NeoXModule, *,
    +517                          is_llm_int8: bool = None,
    +518                          device: torch.device = None,
    +519                          llm_int8_threshold: float = None,
    +520                          ):
    @@ -1658,11 +1671,16 @@ - +

    Get default values if not specified

    +
    -
    521    def _create_embedding_layer(self):
    -522        return Embedding(self.n_vocab, self.n_hidden)
    +
    537        if is_llm_int8 is None:
    +538            is_llm_int8 = self.is_llm_int8
    +539        if device is None:
    +540            device = self.device
    +541        if llm_int8_threshold is None:
    +542            llm_int8_threshold = self.llm_int8_threshold
    @@ -1670,11 +1688,12 @@ - +

    Skip if not using int8 quantization

    +
    -
    524    def _create_final_norm_layer(self):
    -525        return FinalNorm(self.n_hidden)
    +
    545        if not is_llm_int8:
    +546            return layer
    @@ -1682,11 +1701,12 @@ - +

    Only convert the linear layers in the transformer layers

    +
    -
    527    def _create_readout_layer(self):
    -528        return ReadoutLayer(self.n_hidden, self.n_vocab)
    +
    549        if not isinstance(layer, TransformerLayer):
    +550            return layer
    @@ -1694,11 +1714,12 @@ - +

    Use make_llm_int8_linear + defined in utilities.

    +
    -
    530    @torch.no_grad()
    -531    def get_layers(self) -> Generator[Tuple[NeoXModule, Tuple[str, str]], None, None]:
    +
    553        from labml_nn.neox.utils.llm_int8 import make_llm_int8_linear
    @@ -1706,14 +1727,23 @@ -

    Embedding layer

    +

    Convert the linear layers

    -
    533        if 0 in self.filter_layers:
    -534            with monit.section('Embedding layer'):
    -535                layer = self._prepare_layer(self._create_embedding_layer())
    -536            yield layer, ('layer_00-model_00-model_states.pt', 'layer_00-model_01-model_states.pt')
    +
    556        with monit.section('Convert to int8'):
    +557            layer.attention.output = make_llm_int8_linear(layer.attention.output,
    +558                                                          device=device,
    +559                                                          threshold=llm_int8_threshold)
    +560            layer.attention.qkv_lin = make_llm_int8_linear(layer.attention.qkv_lin,
    +561                                                           device=device,
    +562                                                           threshold=llm_int8_threshold)
    +563            layer.ffn.dense_h_h4 = make_llm_int8_linear(layer.ffn.dense_h_h4,
    +564                                                        device=device,
    +565                                                        threshold=llm_int8_threshold)
    +566            layer.ffn.dense_h4_h = make_llm_int8_linear(layer.ffn.dense_h4_h,
    +567                                                        device=device,
    +568                                                        threshold=llm_int8_threshold)
    @@ -1721,27 +1751,29 @@ -

    Transformer layers

    +

    -
    539        for i in range(self.n_layers):
    +
    570        return layer
    -
    +
    -

    Transformer layer

    +

    Creates and caches a layer

    +

    Copying cached layers is faster than initializing new layers because it takes time to initialize parameters.

    +
    • name + is the name of the layer
    • +
    • creator + is the function to create the layer
    • +

      Returns the created layer or a copy of the cached layer

    -
    541            if i + 1 in self.filter_layers:
    -542                with monit.section(f'Transformer Layer {i}'):
    -543                    yield self._create_transformer_layer(), \
    -544                          (f'layer_{i + 2 :02d}-model_00-model_states.pt',
    -545                           f'layer_{i + 2 :02d}-model_01-model_states.pt')
    +
    572    def _create_and_cache_layer(self, name: str, creator: Callable[[], NeoXModule]):
    @@ -1749,14 +1781,17 @@ -

    Final normalization layer

    - +
    -
    548        if self.n_layers + 1 in self.filter_layers:
    -549            with monit.section('Final norm layer'):
    -550                layer = self._prepare_layer(self._create_final_norm_layer())
    -551            yield layer, ('layer_47-model_00-model_states.pt', 'layer_47-model_01-model_states.pt')
    +
    584        if not self.is_clone_layers:
    +585            return self._prepare_layer(creator())
    +586
    +587        if self.pre_created_layers[name] is None:
    +588            self.pre_created_layers[name] = self._prepare_layer(creator())
    +589
    +590        layer = copy.deepcopy(self.pre_created_layers[name])
    +591        return layer
    @@ -1764,14 +1799,14 @@ -

    Readout layer

    - +
    -
    554        if self.n_layers + 2 in self.filter_layers:
    -555            with monit.section('Readout layer'):
    -556                layer = self._prepare_layer(self._create_readout_layer())
    -557            yield layer, ('layer_48-model_00-model_states.pt', 'layer_48-model_01-model_states.pt')
    +
    593    def _create_transformer_layer(self):
    +594        return self._create_and_cache_layer(
    +595            'transformer_layer',
    +596            lambda: TransformerLayer(self.n_hidden, self.n_heads)
    +597        )
    @@ -1782,20 +1817,177 @@
    -
    559    @property
    -560    def total_layers(self):
    -561        return self.n_layers + 3
    -562
    -563    @torch.no_grad()
    -564    def load(self) -> Generator[NeoXModule, None, None]:
    -565        with torch.no_grad():
    -566            with monit.section("Layers"):
    -567                for i, (layer, files) in enumerate(self.get_layers()):
    -568                    if files is not None:
    -569                        layer.load_state(*checkpoint.load_checkpoint_files(files))
    -570
    -571                    monit.progress(min(0.99, (i + 1) / self.total_layers))
    -572                    yield layer
    +
    599    def _create_embedding_layer(self):
    +600        return Embedding(self.n_vocab, self.n_hidden)
    +
    + +
    +
    + + +
    +
    +
    602    def _create_final_norm_layer(self):
    +603        return FinalNorm(self.n_hidden)
    +
    +
    +
    +
    + + +
    +
    +
    605    def _create_readout_layer(self):
    +606        return ReadoutLayer(self.n_hidden, self.n_vocab)
    +
    +
    +
    +
    + +

    Generator to get layers

    + +
    +
    +
    608    @torch.no_grad()
    +609    def get_layers(self) -> Generator[Tuple[NeoXModule, Tuple[str, str]], None, None]:
    +
    +
    +
    +
    + +

    Embedding layer

    + +
    +
    +
    614        if 0 in self.filter_layers:
    +615            with monit.section('Embedding layer'):
    +616                layer = self._prepare_layer(self._create_embedding_layer())
    +617            yield layer, ('layer_00-model_00-model_states.pt', 'layer_00-model_01-model_states.pt')
    +
    +
    +
    +
    + +

    Transformer layers

    + +
    +
    +
    620        for i in range(self.n_layers):
    +
    +
    +
    +
    + +

    Transformer layer

    + +
    +
    +
    622            if i + 1 in self.filter_layers:
    +623                with monit.section(f'Transformer Layer {i}'):
    +624                    yield self._create_transformer_layer(), \
    +625                          (f'layer_{i + 2 :02d}-model_00-model_states.pt',
    +626                           f'layer_{i + 2 :02d}-model_01-model_states.pt')
    +
    +
    +
    +
    + +

    Final normalization layer

    + +
    +
    +
    629        if self.n_layers + 1 in self.filter_layers:
    +630            with monit.section('Final norm layer'):
    +631                layer = self._prepare_layer(self._create_final_norm_layer())
    +632            yield layer, ('layer_47-model_00-model_states.pt', 'layer_47-model_01-model_states.pt')
    +
    +
    +
    +
    + +

    Readout layer

    + +
    +
    +
    635        if self.n_layers + 2 in self.filter_layers:
    +636            with monit.section('Readout layer'):
    +637                layer = self._prepare_layer(self._create_readout_layer())
    +638            yield layer, ('layer_48-model_00-model_states.pt', 'layer_48-model_01-model_states.pt')
    +639
    +640        for k in self.pre_created_layers.keys():
    +641            self.pre_created_layers[k] = None
    +
    +
    +
    +
    + +

    Returns the total number of layers

    + +
    +
    +
    643    @property
    +644    def total_layers(self):
    +
    +
    +
    +
    + + +
    +
    +
    648        return self.n_layers + 3
    +
    +
    +
    +
    + +

    Generator to load layers

    + +
    +
    +
    650    @torch.no_grad()
    +651    def load(self) -> Generator[NeoXModule, None, None]:
    +
    +
    +
    +
    + + +
    +
    +
    655        with monit.section("Layers"):
    +656            for i, (layer, files) in enumerate(self.get_layers()):
    +657                if files is not None:
    +658                    layer.load_state(*checkpoint.load_checkpoint_files(files))
    +659
    +660                layer = self.post_load_prepare(layer)
    +661
    +662                monit.progress(min(0.99, (i + 1) / self.total_layers))
    +663                yield layer