diff --git a/docs/neox/evaluation/half_precision.html b/docs/neox/evaluation/half_precision.html index a9ae1129..73dd3609 100644 --- a/docs/neox/evaluation/half_precision.html +++ b/docs/neox/evaluation/half_precision.html @@ -3,24 +3,24 @@ - + - - + + - + - + - - + + - half_precision.py + Evaluate GPT-NeoX using LLM.int8() quantization on test suite @@ -71,32 +71,97 @@
-
+
+

Evaluate GPT-NeoX using LLM.int8() quantization on test suite

+

This code evaluate GPT-NeoX using, on a suite of tasks.

+ +
+
+
13import torch
+14from torch import nn
+15
+16from labml_nn.neox.evaluation import run_eval_harness
+17from labml_nn.neox.model import LayerGenerator
+
+
+
+
+
-
1import torch
-2from torch import nn
-3
-4from labml import monit
-5from labml_nn.neox.evaluation import run_eval_harness
-6from labml_nn.neox.model import LayerGenerator
-7
-8if __name__ == '__main__':
-9    device = torch.device('cuda:0')
-10    layers = list(LayerGenerator(is_clone_layers=True,
-11                                 filter_layers=None,
-12                                 dtype=torch.float16,
-13                                 device=device
-14                                 ).load())
-15
-16    with monit.section('Sequential'):
-17        model = nn.Sequential(*layers)
-18
-19    print(run_eval_harness(model, 'half_precision', ['lambada'], device))
+
20def main():
+
+
+
+
+ +

Device

+ +
+
+
22    device = torch.device('cuda:0')
+
+
+
+
+ +

Load layers

+ +
+
+
24    layers = list(LayerGenerator(is_clone_layers=True,
+25                                 filter_layers=None,
+26                                 dtype=torch.float16,
+27                                 device=device
+28                                 ).load())
+
+
+
+
+ +

Create nn.Sequential + model

+ +
+
+
31    model = nn.Sequential(*layers)
+
+
+
+
+ +

Run evaluation harness

+ +
+
+
34    print(run_eval_harness(model, 'half_precision', ['lambada'], device))
+
+
+
+
+ +

+ +
+
+
38if __name__ == '__main__':
+39    main()
-
+
- +

Evaluate GPT-NeoX using LLM.int8() quantization on test suite

+

This code evaluate GPT-NeoX using LLM.int8() quantization, on a suite of tasks.

+
-
1import torch
-2from torch import nn
-3
-4from labml import monit
-5from labml_nn.neox.evaluation import run_eval_harness
-6from labml_nn.neox.model import LayerGenerator
-7
-8if __name__ == '__main__':
-9    device = torch.device('cuda:0')
-10    layer_generator = LayerGenerator(is_clone_layers=True,
-11                                     dtype=torch.float16,
-12                                     device=torch.device('cpu'),
-13                                     )
+
14import torch
+15from torch import nn
+16
+17from labml import monit
+18from labml_nn.neox.evaluation import run_eval_harness
+19from labml_nn.neox.model import LayerGenerator
@@ -98,11 +93,10 @@ -

Load layers

- +
-
15    layers = list(layer_generator.load())
+
22def main():
@@ -110,22 +104,94 @@ +

Device

+ +
+
+
24    device = torch.device('cuda:0')
+
+ +
+
+ +

Load layers in float16 into CPU. We convert the layers to int8 later, because doing that on the fly after loading layers to GPU causes CUDA memory fragmentation (about 3GB memory can get lost due to fragmentation).

+ +
+
+
29    layer_generator = LayerGenerator(is_clone_layers=True,
+30                                     dtype=torch.float16,
+31                                     device=torch.device('cpu'),
+32                                     )
+
+
+
+
+ +

Load layers

+ +
+
+
34    layers = list(layer_generator.load())
+
+
+
+
+

This reduces CUDA memory fragmentation

-
18    for layer in monit.iterate('Convert to int8', layers, is_children_silent=True):
-19        layer_generator.post_load_prepare(layer,
-20                                          device=device,
-21                                          is_llm_int8=True,
-22                                          llm_int8_threshold=6.0,
-23                                          )
-24        layer.to(device)
-25
-26    with monit.section('Sequential'):
-27        model = nn.Sequential(*layers)
-28
-29    print(run_eval_harness(model, 'half_precision', [], device))
+
37    for layer in monit.iterate('Convert to int8', layers, is_children_silent=True):
+38        layer_generator.post_load_prepare(layer,
+39                                          device=device,
+40                                          is_llm_int8=True,
+41                                          llm_int8_threshold=6.0,
+42                                          )
+43        layer.to(device)
+
+
+
+
+ +

Create nn.Sequential + model

+ +
+
+
46    model = nn.Sequential(*layers)
+
+
+
+
+ +

Run evaluation harness

+ +
+
+
49    print(run_eval_harness(model, 'half_precision', [], device))
+
+
+
+
+ +

+ +
+
+
53if __name__ == '__main__':
+54    main()
-
537        if is_llm_int8 is None:
-538            is_llm_int8 = self.is_llm_int8
-539        if device is None:
-540            device = self.device
-541        if llm_int8_threshold is None:
-542            llm_int8_threshold = self.llm_int8_threshold
+
538        if is_llm_int8 is None:
+539            is_llm_int8 = self.is_llm_int8
+540        if device is None:
+541            device = self.device
+542        if llm_int8_threshold is None:
+543            llm_int8_threshold = self.llm_int8_threshold
@@ -1692,8 +1693,8 @@
-
545        if not is_llm_int8:
-546            return layer
+
546        if not is_llm_int8:
+547            return layer
@@ -1705,8 +1706,8 @@
-
549        if not isinstance(layer, TransformerLayer):
-550            return layer
+
550        if not isinstance(layer, TransformerLayer):
+551            return layer
@@ -1719,7 +1720,7 @@
-
553        from labml_nn.neox.utils.llm_int8 import make_llm_int8_linear
+
554        from labml_nn.neox.utils.llm_int8 import make_llm_int8_linear
@@ -1731,19 +1732,19 @@
-
556        with monit.section('Convert to int8'):
-557            layer.attention.output = make_llm_int8_linear(layer.attention.output,
-558                                                          device=device,
-559                                                          threshold=llm_int8_threshold)
-560            layer.attention.qkv_lin = make_llm_int8_linear(layer.attention.qkv_lin,
-561                                                           device=device,
-562                                                           threshold=llm_int8_threshold)
-563            layer.ffn.dense_h_h4 = make_llm_int8_linear(layer.ffn.dense_h_h4,
-564                                                        device=device,
-565                                                        threshold=llm_int8_threshold)
-566            layer.ffn.dense_h4_h = make_llm_int8_linear(layer.ffn.dense_h4_h,
-567                                                        device=device,
-568                                                        threshold=llm_int8_threshold)
+
557        with monit.section('Convert to int8'):
+558            layer.attention.output = make_llm_int8_linear(layer.attention.output,
+559                                                          device=device,
+560                                                          threshold=llm_int8_threshold)
+561            layer.attention.qkv_lin = make_llm_int8_linear(layer.attention.qkv_lin,
+562                                                           device=device,
+563                                                           threshold=llm_int8_threshold)
+564            layer.ffn.dense_h_h4 = make_llm_int8_linear(layer.ffn.dense_h_h4,
+565                                                        device=device,
+566                                                        threshold=llm_int8_threshold)
+567            layer.ffn.dense_h4_h = make_llm_int8_linear(layer.ffn.dense_h4_h,
+568                                                        device=device,
+569                                                        threshold=llm_int8_threshold)
@@ -1755,7 +1756,7 @@
-
570        return layer
+
571        return layer
@@ -1773,7 +1774,7 @@
-
572    def _create_and_cache_layer(self, name: str, creator: Callable[[], NeoXModule]):
+
573    def _create_and_cache_layer(self, name: str, creator: Callable[[], NeoXModule]):
@@ -1784,14 +1785,14 @@
-
584        if not self.is_clone_layers:
-585            return self._prepare_layer(creator())
-586
-587        if self.pre_created_layers[name] is None:
-588            self.pre_created_layers[name] = self._prepare_layer(creator())
-589
-590        layer = copy.deepcopy(self.pre_created_layers[name])
-591        return layer
+
585        if not self.is_clone_layers:
+586            return self._prepare_layer(creator())
+587
+588        if self.pre_created_layers[name] is None:
+589            self.pre_created_layers[name] = self._prepare_layer(creator())
+590
+591        layer = copy.deepcopy(self.pre_created_layers[name])
+592        return layer
@@ -1802,11 +1803,11 @@
-
593    def _create_transformer_layer(self):
-594        return self._create_and_cache_layer(
-595            'transformer_layer',
-596            lambda: TransformerLayer(self.n_hidden, self.n_heads)
-597        )
+
594    def _create_transformer_layer(self):
+595        return self._create_and_cache_layer(
+596            'transformer_layer',
+597            lambda: TransformerLayer(self.n_hidden, self.n_heads)
+598        )
@@ -1817,8 +1818,8 @@
-
599    def _create_embedding_layer(self):
-600        return Embedding(self.n_vocab, self.n_hidden)
+
600    def _create_embedding_layer(self):
+601        return Embedding(self.n_vocab, self.n_hidden)
@@ -1829,8 +1830,8 @@
-
602    def _create_final_norm_layer(self):
-603        return FinalNorm(self.n_hidden)
+
603    def _create_final_norm_layer(self):
+604        return FinalNorm(self.n_hidden)
@@ -1841,8 +1842,8 @@
-
605    def _create_readout_layer(self):
-606        return ReadoutLayer(self.n_hidden, self.n_vocab)
+
606    def _create_readout_layer(self):
+607        return ReadoutLayer(self.n_hidden, self.n_vocab)
@@ -1854,8 +1855,8 @@
-
608    @torch.no_grad()
-609    def get_layers(self) -> Generator[Tuple[NeoXModule, Tuple[str, str]], None, None]:
+
609    @torch.no_grad()
+610    def get_layers(self) -> Generator[Tuple[NeoXModule, Tuple[str, str]], None, None]:
@@ -1867,10 +1868,10 @@
-
614        if 0 in self.filter_layers:
-615            with monit.section('Embedding layer'):
-616                layer = self._prepare_layer(self._create_embedding_layer())
-617            yield layer, ('layer_00-model_00-model_states.pt', 'layer_00-model_01-model_states.pt')
+
615        if 0 in self.filter_layers:
+616            with monit.section('Embedding layer'):
+617                layer = self._prepare_layer(self._create_embedding_layer())
+618            yield layer, ('layer_00-model_00-model_states.pt', 'layer_00-model_01-model_states.pt')
@@ -1882,7 +1883,7 @@
-
620        for i in range(self.n_layers):
+
621        for i in range(self.n_layers):
@@ -1894,11 +1895,11 @@
-
622            if i + 1 in self.filter_layers:
-623                with monit.section(f'Transformer Layer {i}'):
-624                    yield self._create_transformer_layer(), \
-625                          (f'layer_{i + 2 :02d}-model_00-model_states.pt',
-626                           f'layer_{i + 2 :02d}-model_01-model_states.pt')
+
623            if i + 1 in self.filter_layers:
+624                with monit.section(f'Transformer Layer {i}'):
+625                    yield self._create_transformer_layer(), \
+626                          (f'layer_{i + 2 :02d}-model_00-model_states.pt',
+627                           f'layer_{i + 2 :02d}-model_01-model_states.pt')
@@ -1910,10 +1911,10 @@
-
629        if self.n_layers + 1 in self.filter_layers:
-630            with monit.section('Final norm layer'):
-631                layer = self._prepare_layer(self._create_final_norm_layer())
-632            yield layer, ('layer_47-model_00-model_states.pt', 'layer_47-model_01-model_states.pt')
+
630        if self.n_layers + 1 in self.filter_layers:
+631            with monit.section('Final norm layer'):
+632                layer = self._prepare_layer(self._create_final_norm_layer())
+633            yield layer, ('layer_47-model_00-model_states.pt', 'layer_47-model_01-model_states.pt')
@@ -1925,13 +1926,13 @@
-
635        if self.n_layers + 2 in self.filter_layers:
-636            with monit.section('Readout layer'):
-637                layer = self._prepare_layer(self._create_readout_layer())
-638            yield layer, ('layer_48-model_00-model_states.pt', 'layer_48-model_01-model_states.pt')
-639
-640        for k in self.pre_created_layers.keys():
-641            self.pre_created_layers[k] = None
+
636        if self.n_layers + 2 in self.filter_layers:
+637            with monit.section('Readout layer'):
+638                layer = self._prepare_layer(self._create_readout_layer())
+639            yield layer, ('layer_48-model_00-model_states.pt', 'layer_48-model_01-model_states.pt')
+640
+641        for k in self.pre_created_layers.keys():
+642            self.pre_created_layers[k] = None
@@ -1943,8 +1944,8 @@
-
643    @property
-644    def total_layers(self):
+
644    @property
+645    def total_layers(self):
@@ -1955,7 +1956,7 @@
-
648        return self.n_layers + 3
+
649        return self.n_layers + 3
@@ -1967,8 +1968,8 @@
-
650    @torch.no_grad()
-651    def load(self) -> Generator[NeoXModule, None, None]:
+
651    @torch.no_grad()
+652    def load(self) -> Generator[NeoXModule, None, None]:
@@ -1979,15 +1980,15 @@
-
655        with monit.section("Layers"):
-656            for i, (layer, files) in enumerate(self.get_layers()):
-657                if files is not None:
-658                    layer.load_state(*checkpoint.load_checkpoint_files(files))
-659
-660                layer = self.post_load_prepare(layer)
-661
-662                monit.progress(min(0.99, (i + 1) / self.total_layers))
-663                yield layer
+
656        with monit.section("Layers"):
+657            for i, (layer, files) in enumerate(self.get_layers()):
+658                if files is not None:
+659                    layer.load_state(*checkpoint.load_checkpoint_files(files))
+660
+661                layer = self.post_load_prepare(layer)
+662
+663                monit.progress(min(0.99, (i + 1) / self.total_layers))
+664                yield layer

Generate Text with GPT-NeoX using LLM.int8() quantization

This shows how to generate text from GPT-NeoX using LLM.int8() quantization.

-

This needs a GPU with more than 45GB memory.

+

This needs a GPU with 24GB memory.

-
15from typing import List
-16
-17import torch
-18from torch import nn
-19
-20from labml import monit
-21from labml_nn.neox.model import LayerGenerator
-22from labml_nn.neox.samples.generate import PROMPT, infer
-23from labml_nn.neox.utils import get_tokens, print_tokens
-24from labml_nn.neox.utils.cache import get_cache
+
15import torch
+16from torch import nn
+17
+18from labml import monit
+19from labml_nn.neox.model import LayerGenerator
+20from labml_nn.neox.samples.generate import PROMPT, infer
+21from labml_nn.neox.utils import get_tokens, print_tokens
+22from labml_nn.neox.utils.cache import get_cache
@@ -102,7 +100,7 @@
-
27def generate():
+
25def generate():
@@ -114,8 +112,8 @@
-
33    cache = get_cache()
-34    cache.set('use_cache', True)
+
31    cache = get_cache()
+32    cache.set('use_cache', True)
@@ -127,7 +125,7 @@
-
37    device = torch.device('cuda:0')
+
35    device = torch.device('cuda:0')
@@ -139,9 +137,12 @@
-
42    layer_generator = LayerGenerator(is_clone_layers=True,
-43                                     dtype=torch.float16,
-44                                     device=torch.device('cpu'),
+
40    layer_generator = LayerGenerator(is_clone_layers=True,
+41                                     dtype=torch.float16,
+42                                     device=torch.device('cpu'),
+43                                     is_llm_int8=False,
+44                                     )
+45    layers = list(layer_generator.load())
@@ -149,12 +150,17 @@ -

is_llm_int8=True,

+

This reduces CUDA memory fragmentation

-
46                                     )
-47    layers = list(layer_generator.load())
+
48    for layer in monit.iterate('Convert to int8', layers, is_children_silent=True):
+49        layer_generator.post_load_prepare(layer,
+50                                          device=device,
+51                                          is_llm_int8=True,
+52                                          llm_int8_threshold=6.0,
+53                                          )
+54        layer.to(device)
@@ -162,17 +168,12 @@ -

This reduces CUDA memory fragmentation

+

Create nn.Sequential + model

-
50    for layer in monit.iterate('Convert to int8', layers, is_children_silent=True):
-51        layer_generator.post_load_prepare(layer,
-52                                          device=device,
-53                                          is_llm_int8=True,
-54                                          llm_int8_threshold=6.0,
-55                                          )
-56        layer.to(device)
+
57    model = nn.Sequential(*layers)
@@ -180,12 +181,12 @@ -

Create nn.Sequential - model

+

Clear cache and print memory summary for debugging

-
59    model = nn.Sequential(*layers)
+
60    torch.cuda.empty_cache()
+61    print(torch.cuda.memory_summary())
@@ -193,12 +194,11 @@ -

Clear cache and print memory summary for debugging

+

Get token ids

-
62    torch.cuda.empty_cache()
-63    print(torch.cuda.memory_summary())
+
64    ids = get_tokens(PROMPT)
@@ -206,11 +206,15 @@ -

Get token ids

+

Run the model. We use the infer + function defined in generate.py +

-
66    ids = get_tokens(PROMPT)
+
68    cache.set('state_ids', (None, 1))
+69    with monit.section('Infer'):
+70        next_token = infer(model, ids, device)[-1]
@@ -218,13 +222,11 @@ -

Run the model

+

Append the predicted token

-
69    cache.set('state_ids', (None, 1))
-70    with monit.section('Infer'):
-71        next_token = infer(model, ids, device)[-1]
+
73    ids += [next_token]
@@ -232,11 +234,11 @@ -

Append the predicted token

+

Predict 100 tokens

-
74    ids += [next_token]
+
76    for i in range(1, 100):
@@ -244,11 +246,11 @@ -

Predict 100 tokens

+

Set the state to use cached activations

-
77    for i in range(1, 100):
+
78        cache.set('state_ids', (i, i + 1))
@@ -256,11 +258,12 @@ -

Set the state to use cached activations

+

Get next token. Note that we only feed the last token to the model because we cache the key/value pairs of previous tokens.

-
79        cache.set('state_ids', (i, i + 1))
+
81        with monit.section('Infer'):
+82            next_token = infer(model, [next_token], device)[-1]
@@ -268,12 +271,11 @@ -

Get next token. Note that we only feed the last token to the model because we cache the key/value pairs of previous tokens.

+

Append the predicted token

-
82        with monit.section('Infer'):
-83            next_token = infer(model, [next_token], device)[-1]
+
84        ids += [next_token]
@@ -281,11 +283,11 @@ -

Append the predicted token

+

Print

-
85        ids += [next_token]
+
86        print_tokens(ids, [ids])
@@ -293,24 +295,12 @@ -

Print

- -
-
-
87        print_tokens(ids, [ids])
-
- -
-
-

-
91if __name__ == '__main__':
-92    generate()
+
90if __name__ == '__main__':
+91    generate()