diff --git a/labml_nn/neox/model.py b/labml_nn/neox/model.py index f8760c55..dd2556d6 100644 --- a/labml_nn/neox/model.py +++ b/labml_nn/neox/model.py @@ -552,7 +552,7 @@ class LayerGenerator: else: layer = copy.deepcopy(self.pre_created_layers[name]) - layer: NeoXModule = self._prepare_layer(layer) + layer: NeoXModule = layer.to(self.device, self.dtype) if self.pre_created_layers[name] is None: self.pre_created_layers[name] = layer @@ -590,7 +590,7 @@ class LayerGenerator: # Transformer layer if i + 1 in self.filter_layers: with monit.section(f'Transformer Layer {i}'): - yield self._create_transformer_layer(), \ + yield self._prepare_layer(self._create_transformer_layer()), \ (f'layer_{i + 2 :02d}-model_00-model_states.pt', f'layer_{i + 2 :02d}-model_01-model_states.pt')