mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-11-02 04:37:46 +08:00
typo fixes
This commit is contained in:
@ -51,7 +51,7 @@ class MNISTCapsuleNetworkModel(Module):
|
|||||||
|
|
||||||
# This is the decoder mentioned in the paper.
|
# This is the decoder mentioned in the paper.
|
||||||
# It takes the outputs of the $10$ digit capsules, each with $16$ features to reproduce the
|
# It takes the outputs of the $10$ digit capsules, each with $16$ features to reproduce the
|
||||||
# image. It goes through linear layers of sizes $512% and $1024$ with $ReLU$ activations.
|
# image. It goes through linear layers of sizes $512$ and $1024$ with $ReLU$ activations.
|
||||||
self.decoder = nn.Sequential(
|
self.decoder = nn.Sequential(
|
||||||
nn.Linear(16 * 10, 512),
|
nn.Linear(16 * 10, 512),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
@ -70,7 +70,7 @@ class MNISTCapsuleNetworkModel(Module):
|
|||||||
x = F.relu(self.conv1(data))
|
x = F.relu(self.conv1(data))
|
||||||
# Pass through the second convolution layer.
|
# Pass through the second convolution layer.
|
||||||
# Output of this has shape `[batch_size, 32 * 8, 6, 6]`.
|
# Output of this has shape `[batch_size, 32 * 8, 6, 6]`.
|
||||||
# *Note that this layer has a stride length of $2$.
|
# *Note that this layer has a stride length of $2$*.
|
||||||
x = self.conv2(x)
|
x = self.conv2(x)
|
||||||
|
|
||||||
# Resize and permutate to get the capsules
|
# Resize and permutate to get the capsules
|
||||||
|
|||||||
@ -85,6 +85,7 @@ class History(_History):
|
|||||||
This defines when a game ends, calculates the utility and sample chance events (dealing cards).
|
This defines when a game ends, calculates the utility and sample chance events (dealing cards).
|
||||||
|
|
||||||
The history is stored in a string:
|
The history is stored in a string:
|
||||||
|
|
||||||
* First two characters are the cards dealt to player 1 and player 2
|
* First two characters are the cards dealt to player 1 and player 2
|
||||||
* The third character is the action by the first player
|
* The third character is the action by the first player
|
||||||
* Fourth character is the action by the second player
|
* Fourth character is the action by the second player
|
||||||
|
|||||||
@ -28,7 +28,7 @@ Also, the MLP-mixer uses MLPs of two layers for each mixing and ConvMixer uses a
|
|||||||
The paper recommends removing the residual connection across the channel mixing (point-wise convolution)
|
The paper recommends removing the residual connection across the channel mixing (point-wise convolution)
|
||||||
and having only a residual connection over the spatial mixing (depth-wise convolution).
|
and having only a residual connection over the spatial mixing (depth-wise convolution).
|
||||||
They also use [Batch normalization](../normalization/batch_norm/index.html) instead
|
They also use [Batch normalization](../normalization/batch_norm/index.html) instead
|
||||||
of [Layer normalization)(../normalization/layer_norm/index.html).
|
of [Layer normalization](../normalization/layer_norm/index.html).
|
||||||
|
|
||||||
Here's [an experiment](experiment.html) that trains ConvMixer on CIFAR-10.
|
Here's [an experiment](experiment.html) that trains ConvMixer on CIFAR-10.
|
||||||
|
|
||||||
|
|||||||
@ -18,7 +18,7 @@ Also, the MLP-mixer uses MLPs of two layers for each mixing and ConvMixer uses a
|
|||||||
The paper recommends removing the residual connection across the channel mixing (point-wise convolution)
|
The paper recommends removing the residual connection across the channel mixing (point-wise convolution)
|
||||||
and having only a residual connection over the spatial mixing (depth-wise convolution).
|
and having only a residual connection over the spatial mixing (depth-wise convolution).
|
||||||
They also use [Batch normalization](https://nn.labml.ai/normalization/batch_norm/index.html) instead
|
They also use [Batch normalization](https://nn.labml.ai/normalization/batch_norm/index.html) instead
|
||||||
of [Layer normalization)(../normalization/layer_norm/index.html).
|
of [Layer normalization](../normalization/layer_norm/index.html).
|
||||||
|
|
||||||
Here's [an experiment](https://nn.labml.ai/conv_mixer/experiment.html) that trains ConvMixer on CIFAR-10.
|
Here's [an experiment](https://nn.labml.ai/conv_mixer/experiment.html) that trains ConvMixer on CIFAR-10.
|
||||||
|
|
||||||
|
|||||||
@ -245,7 +245,7 @@ def ag_news(c: NLPClassificationConfigs):
|
|||||||
### AG News dataset
|
### AG News dataset
|
||||||
|
|
||||||
This loads the AG News dataset and the set the values for
|
This loads the AG News dataset and the set the values for
|
||||||
`n_classes', `vocab`, `train_loader`, and `valid_loader`.
|
`n_classes`, `vocab`, `train_loader`, and `valid_loader`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Get training and validation datasets
|
# Get training and validation datasets
|
||||||
@ -279,5 +279,5 @@ def ag_news(c: NLPClassificationConfigs):
|
|||||||
valid_loader = DataLoader(valid, batch_size=c.batch_size, shuffle=True,
|
valid_loader = DataLoader(valid, batch_size=c.batch_size, shuffle=True,
|
||||||
collate_fn=CollateFunc(tokenizer, vocab, c.seq_len, len(vocab), len(vocab) + 1))
|
collate_fn=CollateFunc(tokenizer, vocab, c.seq_len, len(vocab), len(vocab) + 1))
|
||||||
|
|
||||||
# Return `n_classes', `vocab`, `train_loader`, and `valid_loader`
|
# Return `n_classes`, `vocab`, `train_loader`, and `valid_loader`
|
||||||
return 4, vocab, train_loader, valid_loader
|
return 4, vocab, train_loader, valid_loader
|
||||||
|
|||||||
@ -145,7 +145,7 @@ class Discriminator(Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
channels, height, width = input_shape
|
channels, height, width = input_shape
|
||||||
|
|
||||||
# Output of the discriminator is also a map of probabilities*
|
# Output of the discriminator is also a map of probabilities,
|
||||||
# whether each region of the image is real or generated
|
# whether each region of the image is real or generated
|
||||||
self.output_shape = (1, height // 2 ** 4, width // 2 ** 4)
|
self.output_shape = (1, height // 2 ** 4, width // 2 ** 4)
|
||||||
|
|
||||||
@ -528,8 +528,8 @@ class Configs(BaseConfigs):
|
|||||||
\Bigg]
|
\Bigg]
|
||||||
\end{align}
|
\end{align}
|
||||||
|
|
||||||
We use `generator_xy` for $G$ and `generator_yx$ for $F$.
|
We use `generator_xy` for $G$ and `generator_yx` for $F$.
|
||||||
We use `discriminator_x$ for $D_X$ and `discriminator_y` for $D_Y$.
|
We use `discriminator_x` for $D_X$ and `discriminator_y` for $D_Y$.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Replay buffers to keep generated samples
|
# Replay buffers to keep generated samples
|
||||||
|
|||||||
@ -83,7 +83,7 @@ where the factors of variations are more linear (disentangled).
|
|||||||
|
|
||||||
#### AdaIN
|
#### AdaIN
|
||||||
|
|
||||||
Then $w$ is transformed into two vectors (***styles***) per layer,
|
Then $w$ is transformed into two vectors (**styles**) per layer,
|
||||||
$i$, $y_i = (y_{s,i}, y_{b,i}) = f_{A_i}(w)$ and used for scaling and shifting (biasing)
|
$i$, $y_i = (y_{s,i}, y_{b,i}) = f_{A_i}(w)$ and used for scaling and shifting (biasing)
|
||||||
in each layer with $\text{AdaIN}$ operator (normalize and scale):
|
in each layer with $\text{AdaIN}$ operator (normalize and scale):
|
||||||
$$\text{AdaIN}(x_i, y_i) = y_{s, i} \frac{x_i - \mu(x_i)}{\sigma(x_i)} + y_{b,i}$$
|
$$\text{AdaIN}(x_i, y_i) = y_{s, i} \frac{x_i - \mu(x_i)}{\sigma(x_i)} + y_{b,i}$$
|
||||||
@ -202,7 +202,7 @@ class Generator(nn.Module):
|
|||||||
|
|
||||||
*<small>$A$ denotes a linear layer.
|
*<small>$A$ denotes a linear layer.
|
||||||
$B$ denotes a broadcast and scaling operation (noise is a single channel).
|
$B$ denotes a broadcast and scaling operation (noise is a single channel).
|
||||||
[*toRGB*](#to_rgb) also has a style modulation which is not shown in the diagram to keep it simple.</small>*
|
[`toRGB`](#to_rgb) also has a style modulation which is not shown in the diagram to keep it simple.</small>*
|
||||||
|
|
||||||
The generator starts with a learned constant.
|
The generator starts with a learned constant.
|
||||||
Then it has a series of blocks. The feature map resolution is doubled at each block
|
Then it has a series of blocks. The feature map resolution is doubled at each block
|
||||||
@ -243,7 +243,7 @@ class Generator(nn.Module):
|
|||||||
def forward(self, w: torch.Tensor, input_noise: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]]):
|
def forward(self, w: torch.Tensor, input_noise: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]]):
|
||||||
"""
|
"""
|
||||||
* `w` is $w$. In order to mix-styles (use different $w$ for different layers), we provide a separate
|
* `w` is $w$. In order to mix-styles (use different $w$ for different layers), we provide a separate
|
||||||
$w$ for each [generator block](#generator_block). It has shape `[n_blocks, batch_size, d_latent]1.
|
$w$ for each [generator block](#generator_block). It has shape `[n_blocks, batch_size, d_latent]`.
|
||||||
* `input_noise` is the noise for each block.
|
* `input_noise` is the noise for each block.
|
||||||
It's a list of pairs of noise sensors because each block (except the initial) has two noise inputs
|
It's a list of pairs of noise sensors because each block (except the initial) has two noise inputs
|
||||||
after each convolution layer (see the diagram).
|
after each convolution layer (see the diagram).
|
||||||
@ -282,7 +282,7 @@ class GeneratorBlock(nn.Module):
|
|||||||
|
|
||||||
*<small>$A$ denotes a linear layer.
|
*<small>$A$ denotes a linear layer.
|
||||||
$B$ denotes a broadcast and scaling operation (noise is a single channel).
|
$B$ denotes a broadcast and scaling operation (noise is a single channel).
|
||||||
[*toRGB*](#to_rgb) also has a style modulation which is not shown in the diagram to keep it simple.</small>*
|
[`toRGB`](#to_rgb) also has a style modulation which is not shown in the diagram to keep it simple.</small>*
|
||||||
|
|
||||||
The generator block consists of two [style blocks](#style_block) ($3 \times 3$ convolutions with style modulation)
|
The generator block consists of two [style blocks](#style_block) ($3 \times 3$ convolutions with style modulation)
|
||||||
and an RGB output.
|
and an RGB output.
|
||||||
@ -731,7 +731,7 @@ class EqualizedLinear(nn.Module):
|
|||||||
<a id="equalized_linear"></a>
|
<a id="equalized_linear"></a>
|
||||||
## Learning-rate Equalized Linear Layer
|
## Learning-rate Equalized Linear Layer
|
||||||
|
|
||||||
This uses [learning-rate equalized weights]($equalized_weights) for a linear layer.
|
This uses [learning-rate equalized weights](#equalized_weights) for a linear layer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_features: int, out_features: int, bias: float = 0.):
|
def __init__(self, in_features: int, out_features: int, bias: float = 0.):
|
||||||
@ -742,7 +742,7 @@ class EqualizedLinear(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# [Learning-rate equalized weights]($equalized_weights)
|
# [Learning-rate equalized weights](#equalized_weights)
|
||||||
self.weight = EqualizedWeight([out_features, in_features])
|
self.weight = EqualizedWeight([out_features, in_features])
|
||||||
# Bias
|
# Bias
|
||||||
self.bias = nn.Parameter(torch.ones(out_features) * bias)
|
self.bias = nn.Parameter(torch.ones(out_features) * bias)
|
||||||
@ -757,7 +757,7 @@ class EqualizedConv2d(nn.Module):
|
|||||||
<a id="equalized_conv2d"></a>
|
<a id="equalized_conv2d"></a>
|
||||||
## Learning-rate Equalized 2D Convolution Layer
|
## Learning-rate Equalized 2D Convolution Layer
|
||||||
|
|
||||||
This uses [learning-rate equalized weights]($equalized_weights) for a convolution layer.
|
This uses [learning-rate equalized weights](#equalized_weights) for a convolution layer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_features: int, out_features: int,
|
def __init__(self, in_features: int, out_features: int,
|
||||||
@ -771,7 +771,7 @@ class EqualizedConv2d(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
# Padding size
|
# Padding size
|
||||||
self.padding = padding
|
self.padding = padding
|
||||||
# [Learning-rate equalized weights]($equalized_weights)
|
# [Learning-rate equalized weights](#equalized_weights)
|
||||||
self.weight = EqualizedWeight([out_features, in_features, kernel_size, kernel_size])
|
self.weight = EqualizedWeight([out_features, in_features, kernel_size, kernel_size])
|
||||||
# Bias
|
# Bias
|
||||||
self.bias = nn.Parameter(torch.ones(out_features))
|
self.bias = nn.Parameter(torch.ones(out_features))
|
||||||
|
|||||||
@ -36,14 +36,17 @@ Each group can have it's own hyper-parameters like learning rates.
|
|||||||
|
|
||||||
In most common cases there will be only one group.
|
In most common cases there will be only one group.
|
||||||
This is when you initialize your optimizer with,
|
This is when you initialize your optimizer with,
|
||||||
|
|
||||||
```python
|
```python
|
||||||
Optimizer(model.parameters())
|
Optimizer(model.parameters())
|
||||||
```
|
```
|
||||||
|
|
||||||
You can define multiple parameter groups when initializing the optimizer:
|
You can define multiple parameter groups when initializing the optimizer:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
Optimizer([{'params': model1.parameters()}, {'params': model2.parameters(), 'lr': 2}])
|
Optimizer([{'params': model1.parameters()}, {'params': model2.parameters(), 'lr': 2}])
|
||||||
```
|
```
|
||||||
|
|
||||||
Here we pass a list of groups. Each group is a dictionary with it's parameters under the key 'params'.
|
Here we pass a list of groups. Each group is a dictionary with it's parameters under the key 'params'.
|
||||||
You specify any hyper-parameters as well. If the hyper parameters are not defined they will default
|
You specify any hyper-parameters as well. If the hyper parameters are not defined they will default
|
||||||
to the optimizer level defaults.
|
to the optimizer level defaults.
|
||||||
@ -74,7 +77,7 @@ class GenericAdaptiveOptimizer(Optimizer):
|
|||||||
|
|
||||||
* `params` is the collection of parameters or set of parameter groups.
|
* `params` is the collection of parameters or set of parameter groups.
|
||||||
* `defaults` a dictionary of default hyper-parameters
|
* `defaults` a dictionary of default hyper-parameters
|
||||||
* 'lr` is the learning rate, $\alpha$
|
* `lr` is the learning rate, $\alpha$
|
||||||
* `betas` is the tuple $(\beta_1, \beta_2)$
|
* `betas` is the tuple $(\beta_1, \beta_2)$
|
||||||
* `eps` is $\epsilon$
|
* `eps` is $\epsilon$
|
||||||
"""
|
"""
|
||||||
@ -174,7 +177,8 @@ class WeightDecay:
|
|||||||
decay from the parameter. If added to the gradient it will go through the normal optimizer update.
|
decay from the parameter. If added to the gradient it will go through the normal optimizer update.
|
||||||
* `absolute` this flag indicates whether the weight decay coefficient is absolute. This is applicable
|
* `absolute` this flag indicates whether the weight decay coefficient is absolute. This is applicable
|
||||||
when the decay is performed directly on the parameter. If this is false the actual decay is
|
when the decay is performed directly on the parameter. If this is false the actual decay is
|
||||||
`weight_decay` * `learning_rate`.
|
`weight_decay`
|
||||||
|
* `learning_rate`.
|
||||||
"""
|
"""
|
||||||
# Check hyper-parameters
|
# Check hyper-parameters
|
||||||
if not 0.0 <= weight_decay:
|
if not 0.0 <= weight_decay:
|
||||||
|
|||||||
@ -61,11 +61,11 @@ class AdaBelief(RAdam):
|
|||||||
* `betas` is a tuple of ($\beta_1$, $\beta_2$)
|
* `betas` is a tuple of ($\beta_1$, $\beta_2$)
|
||||||
* `eps` is $\hat{\epsilon}$ or $\epsilon$ based on `optimized_update`
|
* `eps` is $\hat{\epsilon}$ or $\epsilon$ based on `optimized_update`
|
||||||
* `weight_decay` is an instance of class `WeightDecay` defined in [`__init__.py`](index.html)
|
* `weight_decay` is an instance of class `WeightDecay` defined in [`__init__.py`](index.html)
|
||||||
* 'optimized_update' is a flag whether to optimize the bias correction of the second moment
|
* `optimized_update` is a flag whether to optimize the bias correction of the second moment
|
||||||
by doing it after adding $\epsilon$
|
by doing it after adding $\epsilon$
|
||||||
* `amsgrad` is a flag indicating whether to use AMSGrad or fallback to plain Adam
|
* `amsgrad` is a flag indicating whether to use AMSGrad or fallback to plain Adam
|
||||||
* `degenerate_to_sgd` whether to use sgd when the rectification term $r_t is intractable
|
* `degenerate_to_sgd` whether to use sgd when the rectification term $r_t$ is intractable
|
||||||
* 'rectify' is whether to use RAdam update
|
* `rectify` is whether to use RAdam update
|
||||||
* `defaults` is a dictionary of default for group values.
|
* `defaults` is a dictionary of default for group values.
|
||||||
This is useful when you want to extend the class `AdaBelief`.
|
This is useful when you want to extend the class `AdaBelief`.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -155,10 +155,10 @@ class RAdam(AMSGrad):
|
|||||||
* `betas` is a tuple of ($\beta_1$, $\beta_2$)
|
* `betas` is a tuple of ($\beta_1$, $\beta_2$)
|
||||||
* `eps` is $\hat{\epsilon}$ or $\epsilon$ based on `optimized_update`
|
* `eps` is $\hat{\epsilon}$ or $\epsilon$ based on `optimized_update`
|
||||||
* `weight_decay` is an instance of class `WeightDecay` defined in [`__init__.py`](index.html)
|
* `weight_decay` is an instance of class `WeightDecay` defined in [`__init__.py`](index.html)
|
||||||
* 'optimized_update' is a flag whether to optimize the bias correction of the second moment
|
* `optimized_update` is a flag whether to optimize the bias correction of the second moment
|
||||||
by doing it after adding $\epsilon$
|
by doing it after adding $\epsilon$
|
||||||
* `amsgrad` is a flag indicating whether to use AMSGrad or fallback to plain Adam
|
* `amsgrad` is a flag indicating whether to use AMSGrad or fallback to plain Adam
|
||||||
* `degenerate_to_sgd` whether to use sgd when the rectification term $r_t is intractable.
|
* `degenerate_to_sgd` whether to use sgd when the rectification term $r_t$ is intractable.
|
||||||
* `defaults` is a dictionary of default for group values.
|
* `defaults` is a dictionary of default for group values.
|
||||||
This is useful when you want to extend the class `RAdam`.
|
This is useful when you want to extend the class `RAdam`.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -45,10 +45,10 @@ class GAE:
|
|||||||
$\hat{A_t}$
|
$\hat{A_t}$
|
||||||
|
|
||||||
\begin{align}
|
\begin{align}
|
||||||
\delta_t &= r_t + \gamma V(s_{t+1}) - V(s_t)$
|
\delta_t &= r_t + \gamma V(s_{t+1}) - V(s_t)
|
||||||
\\
|
\\
|
||||||
\hat{A_t} &= \delta_t + \gamma \lambda \delta_{t+1} + ... +
|
\hat{A_t} &= \delta_t + \gamma \lambda \delta_{t+1} + ... +
|
||||||
(\gamma \lambda)^{T - t + 1} \delta_{T - 1}$
|
(\gamma \lambda)^{T - t + 1} \delta_{T - 1}
|
||||||
\\
|
\\
|
||||||
&= \delta_t + \gamma \lambda \hat{A_{t+1}}
|
&= \delta_t + \gamma \lambda \hat{A_{t+1}}
|
||||||
\end{align}
|
\end{align}
|
||||||
|
|||||||
@ -114,7 +114,7 @@ class StrokesDataset(Dataset):
|
|||||||
# Mask is on until end of sequence
|
# Mask is on until end of sequence
|
||||||
self.mask[i, :len_seq + 1] = 1
|
self.mask[i, :len_seq + 1] = 1
|
||||||
|
|
||||||
# Start-of-sequence is $(0, 0, 1, 0, 0)
|
# Start-of-sequence is $(0, 0, 1, 0, 0)$
|
||||||
self.data[:, 0, 2] = 1
|
self.data[:, 0, 2] = 1
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
|||||||
@ -42,7 +42,7 @@ AFT Local only apply learned pair-wise position biases locally:
|
|||||||
\begin{align}
|
\begin{align}
|
||||||
w'_{t,t'} =
|
w'_{t,t'} =
|
||||||
\begin{cases}
|
\begin{cases}
|
||||||
w_{t,t'}, & \text{for $\lvert t-t' \rvert \lt s$} \\
|
w_{t,t'}, & {\text{for } \lvert t-t' \rvert \lt s} \\
|
||||||
0, & \text{otherwise}
|
0, & \text{otherwise}
|
||||||
\end{cases}
|
\end{cases}
|
||||||
\end{align}
|
\end{align}
|
||||||
@ -79,7 +79,7 @@ class AFTLocal(Module):
|
|||||||
\begin{align}
|
\begin{align}
|
||||||
w'_{t,t'} =
|
w'_{t,t'} =
|
||||||
\begin{cases}
|
\begin{cases}
|
||||||
w_{t,t'}, & \text{for $\lvert t-t' \rvert \lt s$} \\
|
w_{t,t'}, & {\text{for } \lvert t-t' \rvert \lt s} \\
|
||||||
0, & \text{otherwise}
|
0, & \text{otherwise}
|
||||||
\end{cases}
|
\end{cases}
|
||||||
\end{align}
|
\end{align}
|
||||||
@ -119,7 +119,7 @@ class AFTLocal(Module):
|
|||||||
\begin{align}
|
\begin{align}
|
||||||
m_{t,t'} =
|
m_{t,t'} =
|
||||||
\begin{cases}
|
\begin{cases}
|
||||||
1, & \text{for $\lvert t-t' \rvert \lt s$} \\
|
1, & {\text{for } \lvert t-t' \rvert \lt s} \\
|
||||||
0, & \text{otherwise}
|
0, & \text{otherwise}
|
||||||
\end{cases}
|
\end{cases}
|
||||||
\end{align}
|
\end{align}
|
||||||
@ -170,7 +170,7 @@ class AFTLocal(Module):
|
|||||||
# \begin{align}
|
# \begin{align}
|
||||||
# w'_{t,t'} =
|
# w'_{t,t'} =
|
||||||
# \begin{cases}
|
# \begin{cases}
|
||||||
# w_{t,t'}, & \text{for $\lvert t-t' \rvert \lt s$} \\
|
# w_{t,t'}, & {\text{for }\lvert t-t' \rvert \lt s} \\
|
||||||
# 0, & \text{otherwise}
|
# 0, & \text{otherwise}
|
||||||
# \end{cases}
|
# \end{cases}
|
||||||
# \end{align}
|
# \end{align}
|
||||||
|
|||||||
@ -40,7 +40,7 @@ We have implemented the latter here since it gives better results.
|
|||||||
|
|
||||||
This implementation uses pre-layer normalization
|
This implementation uses pre-layer normalization
|
||||||
while the paper uses post-layer normalization.
|
while the paper uses post-layer normalization.
|
||||||
Pre-layer norm does the layer norm before FFN[../feedforward.html) and
|
Pre-layer norm does the layer norm before [FFN](../feedforward.html) and
|
||||||
self-attention, and the pass-through in the residual connection is not normalized.
|
self-attention, and the pass-through in the residual connection is not normalized.
|
||||||
This is supposed to be more stable in standard transformer setups.
|
This is supposed to be more stable in standard transformer setups.
|
||||||
|
|
||||||
@ -246,7 +246,7 @@ class AttentionReconstructionLoss:
|
|||||||
This is a reimplementation of ['PrepareForMultiHeadAttention'](../mha.html#PrepareMHA)
|
This is a reimplementation of ['PrepareForMultiHeadAttention'](../mha.html#PrepareMHA)
|
||||||
where the projections are done with the parameters detached from gradient computation.
|
where the projections are done with the parameters detached from gradient computation.
|
||||||
|
|
||||||
* `pmha* is the ['PrepareForMultiHeadAttention'](../mha.html#PrepareMHA) module
|
* `pmha` is the ['PrepareForMultiHeadAttention'](../mha.html#PrepareMHA) module
|
||||||
* `x` is tensor with the token embeddings
|
* `x` is tensor with the token embeddings
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|||||||
@ -32,7 +32,7 @@ We have implemented the latter here since it gives better results.
|
|||||||
|
|
||||||
This implementation uses pre-layer normalization
|
This implementation uses pre-layer normalization
|
||||||
while the paper uses post-layer normalization.
|
while the paper uses post-layer normalization.
|
||||||
Pre-layer norm does the layer norm before FFN[../feedforward.html) and
|
Pre-layer norm does the layer norm before [FFN](../feedforward.html) and
|
||||||
self-attention, and the pass-through in the residual connection is not normalized.
|
self-attention, and the pass-through in the residual connection is not normalized.
|
||||||
This is supposed to be more stable in standard transformer setups.
|
This is supposed to be more stable in standard transformer setups.
|
||||||
|
|
||||||
|
|||||||
@ -201,7 +201,7 @@ class Trainer:
|
|||||||
# Cross-entropy loss
|
# Cross-entropy loss
|
||||||
self.loss_func = nn.CrossEntropyLoss()
|
self.loss_func = nn.CrossEntropyLoss()
|
||||||
# Number of training epochs;
|
# Number of training epochs;
|
||||||
# *note that our dataset definition repeats the data `seq_len` times in a single epoch
|
# *note that our dataset definition repeats the data `seq_len` times in a single epoch*
|
||||||
self.epochs = configs.epochs
|
self.epochs = configs.epochs
|
||||||
# Gradient clipping norm
|
# Gradient clipping norm
|
||||||
self.grad_norm_clip = configs.grad_norm_clip
|
self.grad_norm_clip = configs.grad_norm_clip
|
||||||
|
|||||||
@ -46,7 +46,7 @@ def knn(queries: torch.Tensor, index: faiss.IndexFlatL2, keys_store: np.ndarray,
|
|||||||
|
|
||||||
# Normalize $f(c_i)$
|
# Normalize $f(c_i)$
|
||||||
keys_found_n = keys_found / torch.sqrt((keys_found ** 2).sum(-1, keepdims=True) + 1e-10)
|
keys_found_n = keys_found / torch.sqrt((keys_found ** 2).sum(-1, keepdims=True) + 1e-10)
|
||||||
# Normalize $f($\color{yellowgreen}{c_t})$
|
# Normalize $f(\color{yellowgreen}{c_t})$
|
||||||
queries_n = queries / torch.sqrt((queries ** 2).sum(-1, keepdims=True) + 1e-10)
|
queries_n = queries / torch.sqrt((queries ** 2).sum(-1, keepdims=True) + 1e-10)
|
||||||
|
|
||||||
# Get the dot-product, or cosine similarity
|
# Get the dot-product, or cosine similarity
|
||||||
|
|||||||
@ -81,7 +81,7 @@ class MLM:
|
|||||||
masking_prob: float = 0.15, randomize_prob: float = 0.1, no_change_prob: float = 0.1,
|
masking_prob: float = 0.15, randomize_prob: float = 0.1, no_change_prob: float = 0.1,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
* `padding_token` is the padding token `[PAD].
|
* `padding_token` is the padding token `[PAD]`.
|
||||||
We will use this to mark the labels that shouldn't be used for loss calculation.
|
We will use this to mark the labels that shouldn't be used for loss calculation.
|
||||||
* `mask_token` is the masking token `[MASK]`.
|
* `mask_token` is the masking token `[MASK]`.
|
||||||
* `no_mask_tokens` is a list of tokens that should not be masked.
|
* `no_mask_tokens` is a list of tokens that should not be masked.
|
||||||
|
|||||||
@ -155,11 +155,13 @@ class SwitchFeedForward(Module):
|
|||||||
final_output = final_output.view(seq_len, batch_size, d_model)
|
final_output = final_output.view(seq_len, batch_size, d_model)
|
||||||
|
|
||||||
# Return
|
# Return
|
||||||
|
#
|
||||||
# * the final output
|
# * the final output
|
||||||
# * number of tokens routed to each expert
|
# * number of tokens routed to each expert
|
||||||
# * sum of probabilities for each expert
|
# * sum of probabilities for each expert
|
||||||
# * number of tokens dropped.
|
# * number of tokens dropped.
|
||||||
# * routing probabilities of the selected experts
|
# * routing probabilities of the selected experts
|
||||||
|
#
|
||||||
# These are used for the load balancing loss and logging
|
# These are used for the load balancing loss and logging
|
||||||
return final_output, counts, route_prob.sum(0), len(dropped), route_prob_max
|
return final_output, counts, route_prob.sum(0), len(dropped), route_prob_max
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user