diff --git a/docs/neox/evaluation/half_precision.html b/docs/neox/evaluation/half_precision.html index a9ae1129..73dd3609 100644 --- a/docs/neox/evaluation/half_precision.html +++ b/docs/neox/evaluation/half_precision.html @@ -3,24 +3,24 @@
- + - - + + - + - + - - + + -This code evaluate GPT-NeoX using, on a suite of tasks.
+ +13import torch
+14from torch import nn
+15
+16from labml_nn.neox.evaluation import run_eval_harness
+17from labml_nn.neox.model import LayerGenerator
1import torch
-2from torch import nn
-3
-4from labml import monit
-5from labml_nn.neox.evaluation import run_eval_harness
-6from labml_nn.neox.model import LayerGenerator
-7
-8if __name__ == '__main__':
-9 device = torch.device('cuda:0')
-10 layers = list(LayerGenerator(is_clone_layers=True,
-11 filter_layers=None,
-12 dtype=torch.float16,
-13 device=device
-14 ).load())
-15
-16 with monit.section('Sequential'):
-17 model = nn.Sequential(*layers)
-18
-19 print(run_eval_harness(model, 'half_precision', ['lambada'], device))
20def main():
Device
+ +22 device = torch.device('cuda:0')
Load layers
+ +24 layers = list(LayerGenerator(is_clone_layers=True,
+25 filter_layers=None,
+26 dtype=torch.float16,
+27 device=device
+28 ).load())
Create nn.Sequential
+ model
31 model = nn.Sequential(*layers)
34 print(run_eval_harness(model, 'half_precision', ['lambada'], device))
+ +
38if __name__ == '__main__':
+39 main()
This code evaluate GPT-NeoX using LLM.int8() quantization, on a suite of tasks.
+1import torch
-2from torch import nn
-3
-4from labml import monit
-5from labml_nn.neox.evaluation import run_eval_harness
-6from labml_nn.neox.model import LayerGenerator
-7
-8if __name__ == '__main__':
-9 device = torch.device('cuda:0')
-10 layer_generator = LayerGenerator(is_clone_layers=True,
-11 dtype=torch.float16,
-12 device=torch.device('cpu'),
-13 )
14import torch
+15from torch import nn
+16
+17from labml import monit
+18from labml_nn.neox.evaluation import run_eval_harness
+19from labml_nn.neox.model import LayerGenerator
Load layers
- +15 layers = list(layer_generator.load())
22def main():
Device
+ +24 device = torch.device('cuda:0')
Load layers in float16 into CPU. We convert the layers to int8 later, because doing that on the fly after loading layers to GPU causes CUDA memory fragmentation (about 3GB memory can get lost due to fragmentation).
+ +29 layer_generator = LayerGenerator(is_clone_layers=True,
+30 dtype=torch.float16,
+31 device=torch.device('cpu'),
+32 )
Load layers
+ +34 layers = list(layer_generator.load())
This reduces CUDA memory fragmentation
18 for layer in monit.iterate('Convert to int8', layers, is_children_silent=True):
-19 layer_generator.post_load_prepare(layer,
-20 device=device,
-21 is_llm_int8=True,
-22 llm_int8_threshold=6.0,
-23 )
-24 layer.to(device)
-25
-26 with monit.section('Sequential'):
-27 model = nn.Sequential(*layers)
-28
-29 print(run_eval_harness(model, 'half_precision', [], device))
37 for layer in monit.iterate('Convert to int8', layers, is_children_silent=True):
+38 layer_generator.post_load_prepare(layer,
+39 device=device,
+40 is_llm_int8=True,
+41 llm_int8_threshold=6.0,
+42 )
+43 layer.to(device)
Create nn.Sequential
+ model
46 model = nn.Sequential(*layers)
49 print(run_eval_harness(model, 'half_precision', [], device))
+ +
53if __name__ == '__main__':
+54 main()
537 if is_llm_int8 is None:
-538 is_llm_int8 = self.is_llm_int8
-539 if device is None:
-540 device = self.device
-541 if llm_int8_threshold is None:
-542 llm_int8_threshold = self.llm_int8_threshold
538 if is_llm_int8 is None:
+539 is_llm_int8 = self.is_llm_int8
+540 if device is None:
+541 device = self.device
+542 if llm_int8_threshold is None:
+543 llm_int8_threshold = self.llm_int8_threshold
545 if not is_llm_int8:
-546 return layer
546 if not is_llm_int8:
+547 return layer
549 if not isinstance(layer, TransformerLayer):
-550 return layer
550 if not isinstance(layer, TransformerLayer):
+551 return layer
553 from labml_nn.neox.utils.llm_int8 import make_llm_int8_linear
554 from labml_nn.neox.utils.llm_int8 import make_llm_int8_linear
556 with monit.section('Convert to int8'):
-557 layer.attention.output = make_llm_int8_linear(layer.attention.output,
-558 device=device,
-559 threshold=llm_int8_threshold)
-560 layer.attention.qkv_lin = make_llm_int8_linear(layer.attention.qkv_lin,
-561 device=device,
-562 threshold=llm_int8_threshold)
-563 layer.ffn.dense_h_h4 = make_llm_int8_linear(layer.ffn.dense_h_h4,
-564 device=device,
-565 threshold=llm_int8_threshold)
-566 layer.ffn.dense_h4_h = make_llm_int8_linear(layer.ffn.dense_h4_h,
-567 device=device,
-568 threshold=llm_int8_threshold)
557 with monit.section('Convert to int8'):
+558 layer.attention.output = make_llm_int8_linear(layer.attention.output,
+559 device=device,
+560 threshold=llm_int8_threshold)
+561 layer.attention.qkv_lin = make_llm_int8_linear(layer.attention.qkv_lin,
+562 device=device,
+563 threshold=llm_int8_threshold)
+564 layer.ffn.dense_h_h4 = make_llm_int8_linear(layer.ffn.dense_h_h4,
+565 device=device,
+566 threshold=llm_int8_threshold)
+567 layer.ffn.dense_h4_h = make_llm_int8_linear(layer.ffn.dense_h4_h,
+568 device=device,
+569 threshold=llm_int8_threshold)
570 return layer
571 return layer
572 def _create_and_cache_layer(self, name: str, creator: Callable[[], NeoXModule]):
573 def _create_and_cache_layer(self, name: str, creator: Callable[[], NeoXModule]):
584 if not self.is_clone_layers:
-585 return self._prepare_layer(creator())
-586
-587 if self.pre_created_layers[name] is None:
-588 self.pre_created_layers[name] = self._prepare_layer(creator())
-589
-590 layer = copy.deepcopy(self.pre_created_layers[name])
-591 return layer
585 if not self.is_clone_layers:
+586 return self._prepare_layer(creator())
+587
+588 if self.pre_created_layers[name] is None:
+589 self.pre_created_layers[name] = self._prepare_layer(creator())
+590
+591 layer = copy.deepcopy(self.pre_created_layers[name])
+592 return layer
593 def _create_transformer_layer(self):
-594 return self._create_and_cache_layer(
-595 'transformer_layer',
-596 lambda: TransformerLayer(self.n_hidden, self.n_heads)
-597 )
594 def _create_transformer_layer(self):
+595 return self._create_and_cache_layer(
+596 'transformer_layer',
+597 lambda: TransformerLayer(self.n_hidden, self.n_heads)
+598 )
599 def _create_embedding_layer(self):
-600 return Embedding(self.n_vocab, self.n_hidden)
600 def _create_embedding_layer(self):
+601 return Embedding(self.n_vocab, self.n_hidden)
602 def _create_final_norm_layer(self):
-603 return FinalNorm(self.n_hidden)
603 def _create_final_norm_layer(self):
+604 return FinalNorm(self.n_hidden)
605 def _create_readout_layer(self):
-606 return ReadoutLayer(self.n_hidden, self.n_vocab)
606 def _create_readout_layer(self):
+607 return ReadoutLayer(self.n_hidden, self.n_vocab)
608 @torch.no_grad()
-609 def get_layers(self) -> Generator[Tuple[NeoXModule, Tuple[str, str]], None, None]:
609 @torch.no_grad()
+610 def get_layers(self) -> Generator[Tuple[NeoXModule, Tuple[str, str]], None, None]:
614 if 0 in self.filter_layers:
-615 with monit.section('Embedding layer'):
-616 layer = self._prepare_layer(self._create_embedding_layer())
-617 yield layer, ('layer_00-model_00-model_states.pt', 'layer_00-model_01-model_states.pt')
615 if 0 in self.filter_layers:
+616 with monit.section('Embedding layer'):
+617 layer = self._prepare_layer(self._create_embedding_layer())
+618 yield layer, ('layer_00-model_00-model_states.pt', 'layer_00-model_01-model_states.pt')
620 for i in range(self.n_layers):
621 for i in range(self.n_layers):
622 if i + 1 in self.filter_layers:
-623 with monit.section(f'Transformer Layer {i}'):
-624 yield self._create_transformer_layer(), \
-625 (f'layer_{i + 2 :02d}-model_00-model_states.pt',
-626 f'layer_{i + 2 :02d}-model_01-model_states.pt')
623 if i + 1 in self.filter_layers:
+624 with monit.section(f'Transformer Layer {i}'):
+625 yield self._create_transformer_layer(), \
+626 (f'layer_{i + 2 :02d}-model_00-model_states.pt',
+627 f'layer_{i + 2 :02d}-model_01-model_states.pt')
629 if self.n_layers + 1 in self.filter_layers:
-630 with monit.section('Final norm layer'):
-631 layer = self._prepare_layer(self._create_final_norm_layer())
-632 yield layer, ('layer_47-model_00-model_states.pt', 'layer_47-model_01-model_states.pt')
630 if self.n_layers + 1 in self.filter_layers:
+631 with monit.section('Final norm layer'):
+632 layer = self._prepare_layer(self._create_final_norm_layer())
+633 yield layer, ('layer_47-model_00-model_states.pt', 'layer_47-model_01-model_states.pt')
635 if self.n_layers + 2 in self.filter_layers:
-636 with monit.section('Readout layer'):
-637 layer = self._prepare_layer(self._create_readout_layer())
-638 yield layer, ('layer_48-model_00-model_states.pt', 'layer_48-model_01-model_states.pt')
-639
-640 for k in self.pre_created_layers.keys():
-641 self.pre_created_layers[k] = None
636 if self.n_layers + 2 in self.filter_layers:
+637 with monit.section('Readout layer'):
+638 layer = self._prepare_layer(self._create_readout_layer())
+639 yield layer, ('layer_48-model_00-model_states.pt', 'layer_48-model_01-model_states.pt')
+640
+641 for k in self.pre_created_layers.keys():
+642 self.pre_created_layers[k] = None
643 @property
-644 def total_layers(self):
644 @property
+645 def total_layers(self):
648 return self.n_layers + 3
649 return self.n_layers + 3
650 @torch.no_grad()
-651 def load(self) -> Generator[NeoXModule, None, None]:
651 @torch.no_grad()
+652 def load(self) -> Generator[NeoXModule, None, None]:
655 with monit.section("Layers"):
-656 for i, (layer, files) in enumerate(self.get_layers()):
-657 if files is not None:
-658 layer.load_state(*checkpoint.load_checkpoint_files(files))
-659
-660 layer = self.post_load_prepare(layer)
-661
-662 monit.progress(min(0.99, (i + 1) / self.total_layers))
-663 yield layer
656 with monit.section("Layers"):
+657 for i, (layer, files) in enumerate(self.get_layers()):
+658 if files is not None:
+659 layer.load_state(*checkpoint.load_checkpoint_files(files))
+660
+661 layer = self.post_load_prepare(layer)
+662
+663 monit.progress(min(0.99, (i + 1) / self.total_layers))
+664 yield layer
This shows how to generate text from GPT-NeoX using LLM.int8() quantization.
-This needs a GPU with more than 45GB memory.
+This needs a GPU with 24GB memory.
15from typing import List
-16
-17import torch
-18from torch import nn
-19
-20from labml import monit
-21from labml_nn.neox.model import LayerGenerator
-22from labml_nn.neox.samples.generate import PROMPT, infer
-23from labml_nn.neox.utils import get_tokens, print_tokens
-24from labml_nn.neox.utils.cache import get_cache
15import torch
+16from torch import nn
+17
+18from labml import monit
+19from labml_nn.neox.model import LayerGenerator
+20from labml_nn.neox.samples.generate import PROMPT, infer
+21from labml_nn.neox.utils import get_tokens, print_tokens
+22from labml_nn.neox.utils.cache import get_cache
27def generate():
25def generate():
33 cache = get_cache()
-34 cache.set('use_cache', True)
31 cache = get_cache()
+32 cache.set('use_cache', True)
37 device = torch.device('cuda:0')
35 device = torch.device('cuda:0')
42 layer_generator = LayerGenerator(is_clone_layers=True,
-43 dtype=torch.float16,
-44 device=torch.device('cpu'),
40 layer_generator = LayerGenerator(is_clone_layers=True,
+41 dtype=torch.float16,
+42 device=torch.device('cpu'),
+43 is_llm_int8=False,
+44 )
+45 layers = list(layer_generator.load())
46 )
-47 layers = list(layer_generator.load())
48 for layer in monit.iterate('Convert to int8', layers, is_children_silent=True):
+49 layer_generator.post_load_prepare(layer,
+50 device=device,
+51 is_llm_int8=True,
+52 llm_int8_threshold=6.0,
+53 )
+54 layer.to(device)
50 for layer in monit.iterate('Convert to int8', layers, is_children_silent=True):
-51 layer_generator.post_load_prepare(layer,
-52 device=device,
-53 is_llm_int8=True,
-54 llm_int8_threshold=6.0,
-55 )
-56 layer.to(device)
57 model = nn.Sequential(*layers)
Create nn.Sequential
- model
Clear cache and print memory summary for debugging
59 model = nn.Sequential(*layers)
60 torch.cuda.empty_cache()
+61 print(torch.cuda.memory_summary())
62 torch.cuda.empty_cache()
-63 print(torch.cuda.memory_summary())
64 ids = get_tokens(PROMPT)
Get token ids
+Run the model. We use the infer
+ function defined in generate.py
+
66 ids = get_tokens(PROMPT)
68 cache.set('state_ids', (None, 1))
+69 with monit.section('Infer'):
+70 next_token = infer(model, ids, device)[-1]
69 cache.set('state_ids', (None, 1))
-70 with monit.section('Infer'):
-71 next_token = infer(model, ids, device)[-1]
73 ids += [next_token]
74 ids += [next_token]
76 for i in range(1, 100):
77 for i in range(1, 100):
78 cache.set('state_ids', (i, i + 1))
Set the state to use cached activations
+Get next token. Note that we only feed the last token to the model because we cache the key/value pairs of previous tokens.
79 cache.set('state_ids', (i, i + 1))
81 with monit.section('Infer'):
+82 next_token = infer(model, [next_token], device)[-1]
Get next token. Note that we only feed the last token to the model because we cache the key/value pairs of previous tokens.
+Append the predicted token
82 with monit.section('Infer'):
-83 next_token = infer(model, [next_token], device)[-1]
84 ids += [next_token]
85 ids += [next_token]
86 print_tokens(ids, [ids])
87 print_tokens(ids, [ids])
91if __name__ == '__main__':
-92 generate()
90if __name__ == '__main__':
+91 generate()