diff --git a/docs/neox/evaluation/llm_int8.html b/docs/neox/evaluation/llm_int8.html new file mode 100644 index 00000000..60926b00 --- /dev/null +++ b/docs/neox/evaluation/llm_int8.html @@ -0,0 +1,177 @@ + + +
+ + + + + + + + + + + + + + + + + + + +1import torch
+2from torch import nn
+3
+4from labml import monit
+5from labml_nn.neox.evaluation import run_eval_harness
+6from labml_nn.neox.model import LayerGenerator
+7
+8if __name__ == '__main__':
+9 device = torch.device('cuda:0')
+10 layer_generator = LayerGenerator(is_clone_layers=True,
+11 dtype=torch.float16,
+12 device=torch.device('cpu'),
+13 )Load layers
+ +15 layers = list(layer_generator.load())This reduces CUDA memory fragmentation
+ +18 for layer in monit.iterate('Convert to int8', layers, is_children_silent=True):
+19 layer_generator.post_load_prepare(layer,
+20 device=device,
+21 is_llm_int8=True,
+22 llm_int8_threshold=6.0,
+23 )
+24 layer.to(device)
+25
+26 with monit.section('Sequential'):
+27 model = nn.Sequential(*layers)
+28
+29 print(run_eval_harness(model, 'half_precision', [], device))d_rope
is the number of features for RoPE embeddings base
- is the base for , which defaults to Concatenate so that for row we have
-+
Concatenate so that for row we have
+
RoPE embeddings
-for
+for
device
is the device of the model Returns the layers as a generator
+is_llm_int8
+ specifies whether to use int8 quantization llm_int8_threshold
+ is the threshold used to separate outlier features482 if filter_layers is None:
-483 filter_layers = set(range(n_layers + 3))
-484
-485 self.n_vocab = n_vocab
-486 self.n_hidden = n_hidden
-487 self.n_layers = n_layers
-488 self.n_heads = n_heads
-489 self.filter_layers = filter_layers
-490 self.is_clone_layers = is_clone_layers
-491 self.dtype = dtype
-492 self.device = device
-493
-494 self.pre_created_layers = dict(
-495 transformer_layer=None,
-496 )486 if filter_layers is None:
+487 filter_layers = set(range(n_layers + 3))
+488
+489 self.n_vocab = n_vocab
+490 self.n_hidden = n_hidden
+491 self.n_layers = n_layers
+492 self.n_heads = n_heads
+493 self.filter_layers = filter_layers
+494 self.is_clone_layers = is_clone_layers
+495 self.dtype = dtype
+496 self.device = device
+497 self.is_llm_int8 = is_llm_int8
+498 self.llm_int8_threshold = llm_int8_threshold
+499
+500 self.pre_created_layers = dict(
+501 transformer_layer=None,
+502 )We move the layer to the device and convert it to the correct data type
+layer
+ is the layer to prepare Returns the prepared layer
498 def _prepare_layer(self, layer: NeoXModule):
-499 layer = layer.to(self.device, self.dtype)
-500 return layer504 def _prepare_layer(self, layer: NeoXModule):502 def _create_and_cache_layer(self, name: str, creator: Callable[[], NeoXModule]):
-503 if self.pre_created_layers[name] is None or not self.is_clone_layers:
-504 layer = creator()
-505 else:
-506 layer = copy.deepcopy(self.pre_created_layers[name])
-507
-508 layer: NeoXModule = self._prepare_layer(layer)
-509
-510 if self.pre_created_layers[name] is None:
-511 self.pre_created_layers[name] = layer
-512
-513 return layer513 return layer.to(self.device, self.dtype)### Layer transformations after loading the checkpoint
+This function implements layer transformations after loading the checkpoint.
+Currently, it only applies the int8 quantization.
+layer
+ is the layer to prepare is_llm_int8
+ specifies whether to use int8 quantization device
+ is the device of the model llm_int8_threshold
+ is the threshold used to separate outlier features Returns the prepared layer
515 def _create_transformer_layer(self):
-516 return self._create_and_cache_layer(
-517 'transformer_layer',
-518 lambda: TransformerLayer(self.n_hidden, self.n_heads)
-519 )515 @torch.no_grad()
+516 def post_load_prepare(self, layer: NeoXModule, *,
+517 is_llm_int8: bool = None,
+518 device: torch.device = None,
+519 llm_int8_threshold: float = None,
+520 ):Get default values if not specified
+521 def _create_embedding_layer(self):
-522 return Embedding(self.n_vocab, self.n_hidden)537 if is_llm_int8 is None:
+538 is_llm_int8 = self.is_llm_int8
+539 if device is None:
+540 device = self.device
+541 if llm_int8_threshold is None:
+542 llm_int8_threshold = self.llm_int8_thresholdSkip if not using int8 quantization
+524 def _create_final_norm_layer(self):
-525 return FinalNorm(self.n_hidden)545 if not is_llm_int8:
+546 return layerOnly convert the linear layers in the transformer layers
+527 def _create_readout_layer(self):
-528 return ReadoutLayer(self.n_hidden, self.n_vocab)549 if not isinstance(layer, TransformerLayer):
+550 return layer530 @torch.no_grad()
-531 def get_layers(self) -> Generator[Tuple[NeoXModule, Tuple[str, str]], None, None]:553 from labml_nn.neox.utils.llm_int8 import make_llm_int8_linear533 if 0 in self.filter_layers:
-534 with monit.section('Embedding layer'):
-535 layer = self._prepare_layer(self._create_embedding_layer())
-536 yield layer, ('layer_00-model_00-model_states.pt', 'layer_00-model_01-model_states.pt')556 with monit.section('Convert to int8'):
+557 layer.attention.output = make_llm_int8_linear(layer.attention.output,
+558 device=device,
+559 threshold=llm_int8_threshold)
+560 layer.attention.qkv_lin = make_llm_int8_linear(layer.attention.qkv_lin,
+561 device=device,
+562 threshold=llm_int8_threshold)
+563 layer.ffn.dense_h_h4 = make_llm_int8_linear(layer.ffn.dense_h_h4,
+564 device=device,
+565 threshold=llm_int8_threshold)
+566 layer.ffn.dense_h4_h = make_llm_int8_linear(layer.ffn.dense_h4_h,
+567 device=device,
+568 threshold=llm_int8_threshold)539 for i in range(self.n_layers):570 return layerTransformer layer
+Copying cached layers is faster than initializing new layers because it takes time to initialize parameters.
+name
+ is the name of the layer creator
+ is the function to create the layer Returns the created layer or a copy of the cached layer
541 if i + 1 in self.filter_layers:
-542 with monit.section(f'Transformer Layer {i}'):
-543 yield self._create_transformer_layer(), \
-544 (f'layer_{i + 2 :02d}-model_00-model_states.pt',
-545 f'layer_{i + 2 :02d}-model_01-model_states.pt')572 def _create_and_cache_layer(self, name: str, creator: Callable[[], NeoXModule]):Final normalization layer
- +548 if self.n_layers + 1 in self.filter_layers:
-549 with monit.section('Final norm layer'):
-550 layer = self._prepare_layer(self._create_final_norm_layer())
-551 yield layer, ('layer_47-model_00-model_states.pt', 'layer_47-model_01-model_states.pt')584 if not self.is_clone_layers:
+585 return self._prepare_layer(creator())
+586
+587 if self.pre_created_layers[name] is None:
+588 self.pre_created_layers[name] = self._prepare_layer(creator())
+589
+590 layer = copy.deepcopy(self.pre_created_layers[name])
+591 return layerReadout layer
- +554 if self.n_layers + 2 in self.filter_layers:
-555 with monit.section('Readout layer'):
-556 layer = self._prepare_layer(self._create_readout_layer())
-557 yield layer, ('layer_48-model_00-model_states.pt', 'layer_48-model_01-model_states.pt')593 def _create_transformer_layer(self):
+594 return self._create_and_cache_layer(
+595 'transformer_layer',
+596 lambda: TransformerLayer(self.n_hidden, self.n_heads)
+597 )559 @property
-560 def total_layers(self):
-561 return self.n_layers + 3
-562
-563 @torch.no_grad()
-564 def load(self) -> Generator[NeoXModule, None, None]:
-565 with torch.no_grad():
-566 with monit.section("Layers"):
-567 for i, (layer, files) in enumerate(self.get_layers()):
-568 if files is not None:
-569 layer.load_state(*checkpoint.load_checkpoint_files(files))
-570
-571 monit.progress(min(0.99, (i + 1) / self.total_layers))
-572 yield layer599 def _create_embedding_layer(self):
+600 return Embedding(self.n_vocab, self.n_hidden)602 def _create_final_norm_layer(self):
+603 return FinalNorm(self.n_hidden)605 def _create_readout_layer(self):
+606 return ReadoutLayer(self.n_hidden, self.n_vocab)608 @torch.no_grad()
+609 def get_layers(self) -> Generator[Tuple[NeoXModule, Tuple[str, str]], None, None]:Embedding layer
+ +614 if 0 in self.filter_layers:
+615 with monit.section('Embedding layer'):
+616 layer = self._prepare_layer(self._create_embedding_layer())
+617 yield layer, ('layer_00-model_00-model_states.pt', 'layer_00-model_01-model_states.pt')Transformer layers
+ +620 for i in range(self.n_layers):Transformer layer
+ +622 if i + 1 in self.filter_layers:
+623 with monit.section(f'Transformer Layer {i}'):
+624 yield self._create_transformer_layer(), \
+625 (f'layer_{i + 2 :02d}-model_00-model_states.pt',
+626 f'layer_{i + 2 :02d}-model_01-model_states.pt')Final normalization layer
+ +629 if self.n_layers + 1 in self.filter_layers:
+630 with monit.section('Final norm layer'):
+631 layer = self._prepare_layer(self._create_final_norm_layer())
+632 yield layer, ('layer_47-model_00-model_states.pt', 'layer_47-model_01-model_states.pt')Readout layer
+ +635 if self.n_layers + 2 in self.filter_layers:
+636 with monit.section('Readout layer'):
+637 layer = self._prepare_layer(self._create_readout_layer())
+638 yield layer, ('layer_48-model_00-model_states.pt', 'layer_48-model_01-model_states.pt')
+639
+640 for k in self.pre_created_layers.keys():
+641 self.pre_created_layers[k] = None643 @property
+644 def total_layers(self):648 return self.n_layers + 3650 @torch.no_grad()
+651 def load(self) -> Generator[NeoXModule, None, None]:655 with monit.section("Layers"):
+656 for i, (layer, files) in enumerate(self.get_layers()):
+657 if files is not None:
+658 layer.load_state(*checkpoint.load_checkpoint_files(files))
+659
+660 layer = self.post_load_prepare(layer)
+661
+662 monit.progress(min(0.99, (i + 1) / self.total_layers))
+663 yield layer