This commit is contained in:
Varuna Jayasiri
2022-08-19 09:52:13 +05:30
parent 223bb0e7c0
commit 1d92b5dc62

View File

@ -552,7 +552,7 @@ class LayerGenerator:
else: else:
layer = copy.deepcopy(self.pre_created_layers[name]) layer = copy.deepcopy(self.pre_created_layers[name])
layer: NeoXModule = self._prepare_layer(layer) layer: NeoXModule = layer.to(self.device, self.dtype)
if self.pre_created_layers[name] is None: if self.pre_created_layers[name] is None:
self.pre_created_layers[name] = layer self.pre_created_layers[name] = layer
@ -590,7 +590,7 @@ class LayerGenerator:
# Transformer layer # Transformer layer
if i + 1 in self.filter_layers: if i + 1 in self.filter_layers:
with monit.section(f'Transformer Layer {i}'): with monit.section(f'Transformer Layer {i}'):
yield self._create_transformer_layer(), \ yield self._prepare_layer(self._create_transformer_layer()), \
(f'layer_{i + 2 :02d}-model_00-model_states.pt', (f'layer_{i + 2 :02d}-model_00-model_states.pt',
f'layer_{i + 2 :02d}-model_01-model_states.pt') f'layer_{i + 2 :02d}-model_01-model_states.pt')