mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-16 19:01:12 +08:00
deep norm fix
This commit is contained in:
@ -110,7 +110,7 @@ class DeepNorm(nn.Module):
|
|||||||
:param gx: is the output of the current sub-layer $\mathop{G}_l (x_l, \theta_l)$
|
:param gx: is the output of the current sub-layer $\mathop{G}_l (x_l, \theta_l)$
|
||||||
"""
|
"""
|
||||||
# $$x_{l + 1} = \mathop{LN}\Big( \alpha x_l + \mathop{G}_l \big(x_l, \theta_l \big)\Big)$$
|
# $$x_{l + 1} = \mathop{LN}\Big( \alpha x_l + \mathop{G}_l \big(x_l, \theta_l \big)\Big)$$
|
||||||
return x + self.alpha * gx
|
return self.layer_norm(x + self.alpha * gx)
|
||||||
|
|
||||||
|
|
||||||
class DeepNormTransformerLayer(nn.Module):
|
class DeepNormTransformerLayer(nn.Module):
|
||||||
|
@ -8,7 +8,6 @@ summary: >
|
|||||||
# [DeepNorm](index.html) Experiment
|
# [DeepNorm](index.html) Experiment
|
||||||
|
|
||||||
[](https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/normalization/deep_norm/experiment.ipynb)
|
[](https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/normalization/deep_norm/experiment.ipynb)
|
||||||
[](https://www.comet.ml/labml/deep-norm/61d817f80ff143c8825fba4aacd431d4?experiment-tab=chart&showOutliers=true&smoothing=0&transformY=smoothing&xAxis=step)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
@ -131,7 +130,7 @@ def main():
|
|||||||
#### Create and run the experiment
|
#### Create and run the experiment
|
||||||
"""
|
"""
|
||||||
# Create experiment
|
# Create experiment
|
||||||
experiment.create(name="deep_norm", writers={'screen', 'web_api', 'comet'})
|
experiment.create(name="deep_norm", writers={'screen', 'web_api'})
|
||||||
# Create configs
|
# Create configs
|
||||||
conf = Configs()
|
conf = Configs()
|
||||||
# Override configurations
|
# Override configurations
|
||||||
|
@ -83,6 +83,14 @@ class LayerNorm(Module):
|
|||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
# Convert `normalized_shape` to `torch.Size`
|
||||||
|
if isinstance(normalized_shape, int):
|
||||||
|
normalized_shape = torch.Size([normalized_shape])
|
||||||
|
elif isinstance(normalized_shape, list):
|
||||||
|
normalized_shape = torch.Size(normalized_shape)
|
||||||
|
assert isinstance(normalized_shape, torch.Size)
|
||||||
|
|
||||||
|
#
|
||||||
self.normalized_shape = normalized_shape
|
self.normalized_shape = normalized_shape
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
self.elementwise_affine = elementwise_affine
|
self.elementwise_affine = elementwise_affine
|
||||||
|
Reference in New Issue
Block a user