mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-11-04 06:16:05 +08:00 
			
		
		
		
	typo fixes
This commit is contained in:
		@ -51,7 +51,7 @@ class MNISTCapsuleNetworkModel(Module):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        # This is the decoder mentioned in the paper.
 | 
					        # This is the decoder mentioned in the paper.
 | 
				
			||||||
        # It takes the outputs of the $10$ digit capsules, each with $16$ features to reproduce the
 | 
					        # It takes the outputs of the $10$ digit capsules, each with $16$ features to reproduce the
 | 
				
			||||||
        # image. It goes through linear layers of sizes $512% and $1024$ with $ReLU$ activations.
 | 
					        # image. It goes through linear layers of sizes $512$ and $1024$ with $ReLU$ activations.
 | 
				
			||||||
        self.decoder = nn.Sequential(
 | 
					        self.decoder = nn.Sequential(
 | 
				
			||||||
            nn.Linear(16 * 10, 512),
 | 
					            nn.Linear(16 * 10, 512),
 | 
				
			||||||
            nn.ReLU(),
 | 
					            nn.ReLU(),
 | 
				
			||||||
@ -70,7 +70,7 @@ class MNISTCapsuleNetworkModel(Module):
 | 
				
			|||||||
        x = F.relu(self.conv1(data))
 | 
					        x = F.relu(self.conv1(data))
 | 
				
			||||||
        # Pass through the second convolution layer.
 | 
					        # Pass through the second convolution layer.
 | 
				
			||||||
        # Output of this has shape `[batch_size, 32 * 8, 6, 6]`.
 | 
					        # Output of this has shape `[batch_size, 32 * 8, 6, 6]`.
 | 
				
			||||||
        # *Note that this layer has a stride length of $2$.
 | 
					        # *Note that this layer has a stride length of $2$*.
 | 
				
			||||||
        x = self.conv2(x)
 | 
					        x = self.conv2(x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Resize and permutate to get the capsules
 | 
					        # Resize and permutate to get the capsules
 | 
				
			||||||
 | 
				
			|||||||
@ -85,6 +85,7 @@ class History(_History):
 | 
				
			|||||||
    This defines when a game ends, calculates the utility and sample chance events (dealing cards).
 | 
					    This defines when a game ends, calculates the utility and sample chance events (dealing cards).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    The history is stored in a string:
 | 
					    The history is stored in a string:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    * First two characters are the cards dealt to player 1 and player 2
 | 
					    * First two characters are the cards dealt to player 1 and player 2
 | 
				
			||||||
    * The third character is the action by the first player
 | 
					    * The third character is the action by the first player
 | 
				
			||||||
    * Fourth character is the action by the second player
 | 
					    * Fourth character is the action by the second player
 | 
				
			||||||
 | 
				
			|||||||
@ -28,7 +28,7 @@ Also, the MLP-mixer uses MLPs of two layers for each mixing and ConvMixer uses a
 | 
				
			|||||||
The paper recommends removing the residual connection across the channel mixing (point-wise convolution)
 | 
					The paper recommends removing the residual connection across the channel mixing (point-wise convolution)
 | 
				
			||||||
and having only a residual connection over the spatial mixing (depth-wise convolution).
 | 
					and having only a residual connection over the spatial mixing (depth-wise convolution).
 | 
				
			||||||
They also use [Batch normalization](../normalization/batch_norm/index.html) instead
 | 
					They also use [Batch normalization](../normalization/batch_norm/index.html) instead
 | 
				
			||||||
of [Layer normalization)(../normalization/layer_norm/index.html).
 | 
					of [Layer normalization](../normalization/layer_norm/index.html).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Here's [an experiment](experiment.html) that trains ConvMixer on CIFAR-10.
 | 
					Here's [an experiment](experiment.html) that trains ConvMixer on CIFAR-10.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -18,7 +18,7 @@ Also, the MLP-mixer uses MLPs of two layers for each mixing and ConvMixer uses a
 | 
				
			|||||||
The paper recommends removing the residual connection across the channel mixing (point-wise convolution)
 | 
					The paper recommends removing the residual connection across the channel mixing (point-wise convolution)
 | 
				
			||||||
and having only a residual connection over the spatial mixing (depth-wise convolution).
 | 
					and having only a residual connection over the spatial mixing (depth-wise convolution).
 | 
				
			||||||
They also use [Batch normalization](https://nn.labml.ai/normalization/batch_norm/index.html) instead
 | 
					They also use [Batch normalization](https://nn.labml.ai/normalization/batch_norm/index.html) instead
 | 
				
			||||||
of [Layer normalization)(../normalization/layer_norm/index.html).
 | 
					of [Layer normalization](../normalization/layer_norm/index.html).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Here's [an experiment](https://nn.labml.ai/conv_mixer/experiment.html) that trains ConvMixer on CIFAR-10.
 | 
					Here's [an experiment](https://nn.labml.ai/conv_mixer/experiment.html) that trains ConvMixer on CIFAR-10.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -245,7 +245,7 @@ def ag_news(c: NLPClassificationConfigs):
 | 
				
			|||||||
    ### AG News dataset
 | 
					    ### AG News dataset
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    This loads the AG News dataset and the set the values for
 | 
					    This loads the AG News dataset and the set the values for
 | 
				
			||||||
     `n_classes', `vocab`, `train_loader`, and `valid_loader`.
 | 
					     `n_classes`, `vocab`, `train_loader`, and `valid_loader`.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Get training and validation datasets
 | 
					    # Get training and validation datasets
 | 
				
			||||||
@ -279,5 +279,5 @@ def ag_news(c: NLPClassificationConfigs):
 | 
				
			|||||||
    valid_loader = DataLoader(valid, batch_size=c.batch_size, shuffle=True,
 | 
					    valid_loader = DataLoader(valid, batch_size=c.batch_size, shuffle=True,
 | 
				
			||||||
                              collate_fn=CollateFunc(tokenizer, vocab, c.seq_len, len(vocab), len(vocab) + 1))
 | 
					                              collate_fn=CollateFunc(tokenizer, vocab, c.seq_len, len(vocab), len(vocab) + 1))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Return `n_classes', `vocab`, `train_loader`, and `valid_loader`
 | 
					    # Return `n_classes`, `vocab`, `train_loader`, and `valid_loader`
 | 
				
			||||||
    return 4, vocab, train_loader, valid_loader
 | 
					    return 4, vocab, train_loader, valid_loader
 | 
				
			||||||
 | 
				
			|||||||
@ -145,7 +145,7 @@ class Discriminator(Module):
 | 
				
			|||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
        channels, height, width = input_shape
 | 
					        channels, height, width = input_shape
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Output of the discriminator is also a map of probabilities*
 | 
					        # Output of the discriminator is also a map of probabilities,
 | 
				
			||||||
        # whether each region of the image is real or generated
 | 
					        # whether each region of the image is real or generated
 | 
				
			||||||
        self.output_shape = (1, height // 2 ** 4, width // 2 ** 4)
 | 
					        self.output_shape = (1, height // 2 ** 4, width // 2 ** 4)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -528,8 +528,8 @@ class Configs(BaseConfigs):
 | 
				
			|||||||
        \Bigg]
 | 
					        \Bigg]
 | 
				
			||||||
        \end{align}
 | 
					        \end{align}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        We use `generator_xy` for $G$ and `generator_yx$ for $F$.
 | 
					        We use `generator_xy` for $G$ and `generator_yx` for $F$.
 | 
				
			||||||
        We use `discriminator_x$ for $D_X$ and `discriminator_y` for $D_Y$.
 | 
					        We use `discriminator_x` for $D_X$ and `discriminator_y` for $D_Y$.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Replay buffers to keep generated samples
 | 
					        # Replay buffers to keep generated samples
 | 
				
			||||||
 | 
				
			|||||||
@ -83,7 +83,7 @@ where the factors of variations are more linear (disentangled).
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
#### AdaIN
 | 
					#### AdaIN
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Then $w$ is transformed into two vectors (***styles***) per layer,
 | 
					Then $w$ is transformed into two vectors (**styles**) per layer,
 | 
				
			||||||
 $i$, $y_i = (y_{s,i}, y_{b,i}) = f_{A_i}(w)$ and used for scaling and shifting (biasing)
 | 
					 $i$, $y_i = (y_{s,i}, y_{b,i}) = f_{A_i}(w)$ and used for scaling and shifting (biasing)
 | 
				
			||||||
 in each layer with $\text{AdaIN}$ operator (normalize and scale):
 | 
					 in each layer with $\text{AdaIN}$ operator (normalize and scale):
 | 
				
			||||||
$$\text{AdaIN}(x_i, y_i) = y_{s, i} \frac{x_i - \mu(x_i)}{\sigma(x_i)} + y_{b,i}$$
 | 
					$$\text{AdaIN}(x_i, y_i) = y_{s, i} \frac{x_i - \mu(x_i)}{\sigma(x_i)} + y_{b,i}$$
 | 
				
			||||||
@ -202,7 +202,7 @@ class Generator(nn.Module):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    *<small>$A$ denotes a linear layer.
 | 
					    *<small>$A$ denotes a linear layer.
 | 
				
			||||||
    $B$ denotes a broadcast and scaling operation (noise is a single channel).
 | 
					    $B$ denotes a broadcast and scaling operation (noise is a single channel).
 | 
				
			||||||
    [*toRGB*](#to_rgb) also has a style modulation which is not shown in the diagram to keep it simple.</small>*
 | 
					    [`toRGB`](#to_rgb) also has a style modulation which is not shown in the diagram to keep it simple.</small>*
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    The generator starts with a learned constant.
 | 
					    The generator starts with a learned constant.
 | 
				
			||||||
    Then it has a series of blocks. The feature map resolution is doubled at each block
 | 
					    Then it has a series of blocks. The feature map resolution is doubled at each block
 | 
				
			||||||
@ -243,7 +243,7 @@ class Generator(nn.Module):
 | 
				
			|||||||
    def forward(self, w: torch.Tensor, input_noise: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]]):
 | 
					    def forward(self, w: torch.Tensor, input_noise: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]]):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        * `w` is $w$. In order to mix-styles (use different $w$ for different layers), we provide a separate
 | 
					        * `w` is $w$. In order to mix-styles (use different $w$ for different layers), we provide a separate
 | 
				
			||||||
        $w$ for each [generator block](#generator_block). It has shape `[n_blocks, batch_size, d_latent]1.
 | 
					        $w$ for each [generator block](#generator_block). It has shape `[n_blocks, batch_size, d_latent]`.
 | 
				
			||||||
        * `input_noise` is the noise for each block.
 | 
					        * `input_noise` is the noise for each block.
 | 
				
			||||||
        It's a list of pairs of noise sensors because each block (except the initial) has two noise inputs
 | 
					        It's a list of pairs of noise sensors because each block (except the initial) has two noise inputs
 | 
				
			||||||
        after each convolution layer (see the diagram).
 | 
					        after each convolution layer (see the diagram).
 | 
				
			||||||
@ -282,7 +282,7 @@ class GeneratorBlock(nn.Module):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    *<small>$A$ denotes a linear layer.
 | 
					    *<small>$A$ denotes a linear layer.
 | 
				
			||||||
    $B$ denotes a broadcast and scaling operation (noise is a single channel).
 | 
					    $B$ denotes a broadcast and scaling operation (noise is a single channel).
 | 
				
			||||||
    [*toRGB*](#to_rgb) also has a style modulation which is not shown in the diagram to keep it simple.</small>*
 | 
					    [`toRGB`](#to_rgb) also has a style modulation which is not shown in the diagram to keep it simple.</small>*
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    The generator block consists of two [style blocks](#style_block) ($3 \times 3$ convolutions with style modulation)
 | 
					    The generator block consists of two [style blocks](#style_block) ($3 \times 3$ convolutions with style modulation)
 | 
				
			||||||
    and an RGB output.
 | 
					    and an RGB output.
 | 
				
			||||||
@ -731,7 +731,7 @@ class EqualizedLinear(nn.Module):
 | 
				
			|||||||
    <a id="equalized_linear"></a>
 | 
					    <a id="equalized_linear"></a>
 | 
				
			||||||
    ## Learning-rate Equalized Linear Layer
 | 
					    ## Learning-rate Equalized Linear Layer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    This uses [learning-rate equalized weights]($equalized_weights) for a linear layer.
 | 
					    This uses [learning-rate equalized weights](#equalized_weights) for a linear layer.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self, in_features: int, out_features: int, bias: float = 0.):
 | 
					    def __init__(self, in_features: int, out_features: int, bias: float = 0.):
 | 
				
			||||||
@ -742,7 +742,7 @@ class EqualizedLinear(nn.Module):
 | 
				
			|||||||
        """
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
        # [Learning-rate equalized weights]($equalized_weights)
 | 
					        # [Learning-rate equalized weights](#equalized_weights)
 | 
				
			||||||
        self.weight = EqualizedWeight([out_features, in_features])
 | 
					        self.weight = EqualizedWeight([out_features, in_features])
 | 
				
			||||||
        # Bias
 | 
					        # Bias
 | 
				
			||||||
        self.bias = nn.Parameter(torch.ones(out_features) * bias)
 | 
					        self.bias = nn.Parameter(torch.ones(out_features) * bias)
 | 
				
			||||||
@ -757,7 +757,7 @@ class EqualizedConv2d(nn.Module):
 | 
				
			|||||||
    <a id="equalized_conv2d"></a>
 | 
					    <a id="equalized_conv2d"></a>
 | 
				
			||||||
    ## Learning-rate Equalized 2D Convolution Layer
 | 
					    ## Learning-rate Equalized 2D Convolution Layer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    This uses [learning-rate equalized weights]($equalized_weights) for a convolution layer.
 | 
					    This uses [learning-rate equalized weights](#equalized_weights) for a convolution layer.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self, in_features: int, out_features: int,
 | 
					    def __init__(self, in_features: int, out_features: int,
 | 
				
			||||||
@ -771,7 +771,7 @@ class EqualizedConv2d(nn.Module):
 | 
				
			|||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
        # Padding size
 | 
					        # Padding size
 | 
				
			||||||
        self.padding = padding
 | 
					        self.padding = padding
 | 
				
			||||||
        # [Learning-rate equalized weights]($equalized_weights)
 | 
					        # [Learning-rate equalized weights](#equalized_weights)
 | 
				
			||||||
        self.weight = EqualizedWeight([out_features, in_features, kernel_size, kernel_size])
 | 
					        self.weight = EqualizedWeight([out_features, in_features, kernel_size, kernel_size])
 | 
				
			||||||
        # Bias
 | 
					        # Bias
 | 
				
			||||||
        self.bias = nn.Parameter(torch.ones(out_features))
 | 
					        self.bias = nn.Parameter(torch.ones(out_features))
 | 
				
			||||||
 | 
				
			|||||||
@ -36,14 +36,17 @@ Each group can have it's own hyper-parameters like learning rates.
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
In most common cases there will be only one group.
 | 
					In most common cases there will be only one group.
 | 
				
			||||||
This is when you initialize your optimizer with,
 | 
					This is when you initialize your optimizer with,
 | 
				
			||||||
 | 
					
 | 
				
			||||||
```python
 | 
					```python
 | 
				
			||||||
Optimizer(model.parameters())
 | 
					Optimizer(model.parameters())
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
You can define multiple parameter groups when initializing the optimizer:
 | 
					You can define multiple parameter groups when initializing the optimizer:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
```python
 | 
					```python
 | 
				
			||||||
Optimizer([{'params': model1.parameters()}, {'params': model2.parameters(), 'lr': 2}])
 | 
					Optimizer([{'params': model1.parameters()}, {'params': model2.parameters(), 'lr': 2}])
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Here we pass a list of groups. Each group is a dictionary with it's parameters under the key 'params'.
 | 
					Here we pass a list of groups. Each group is a dictionary with it's parameters under the key 'params'.
 | 
				
			||||||
You specify any hyper-parameters as well. If the hyper parameters are not defined they will default
 | 
					You specify any hyper-parameters as well. If the hyper parameters are not defined they will default
 | 
				
			||||||
to the optimizer level defaults.
 | 
					to the optimizer level defaults.
 | 
				
			||||||
@ -74,7 +77,7 @@ class GenericAdaptiveOptimizer(Optimizer):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        * `params` is the collection of parameters or set of parameter groups.
 | 
					        * `params` is the collection of parameters or set of parameter groups.
 | 
				
			||||||
        * `defaults` a dictionary of default hyper-parameters
 | 
					        * `defaults` a dictionary of default hyper-parameters
 | 
				
			||||||
        * 'lr` is the learning rate, $\alpha$
 | 
					        * `lr` is the learning rate, $\alpha$
 | 
				
			||||||
        * `betas` is the tuple $(\beta_1, \beta_2)$
 | 
					        * `betas` is the tuple $(\beta_1, \beta_2)$
 | 
				
			||||||
        * `eps` is $\epsilon$
 | 
					        * `eps` is $\epsilon$
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
@ -174,7 +177,8 @@ class WeightDecay:
 | 
				
			|||||||
        decay from the parameter. If added to the  gradient it will go through the normal optimizer update.
 | 
					        decay from the parameter. If added to the  gradient it will go through the normal optimizer update.
 | 
				
			||||||
        * `absolute` this flag indicates whether the weight decay coefficient is absolute. This is applicable
 | 
					        * `absolute` this flag indicates whether the weight decay coefficient is absolute. This is applicable
 | 
				
			||||||
        when the decay is performed directly on the parameter. If this is false the actual decay is
 | 
					        when the decay is performed directly on the parameter. If this is false the actual decay is
 | 
				
			||||||
        `weight_decay` * `learning_rate`.
 | 
					        `weight_decay`
 | 
				
			||||||
 | 
					        * `learning_rate`.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        # Check hyper-parameters
 | 
					        # Check hyper-parameters
 | 
				
			||||||
        if not 0.0 <= weight_decay:
 | 
					        if not 0.0 <= weight_decay:
 | 
				
			||||||
 | 
				
			|||||||
@ -61,11 +61,11 @@ class AdaBelief(RAdam):
 | 
				
			|||||||
        * `betas` is a tuple of ($\beta_1$, $\beta_2$)
 | 
					        * `betas` is a tuple of ($\beta_1$, $\beta_2$)
 | 
				
			||||||
        * `eps` is $\hat{\epsilon}$ or $\epsilon$ based on `optimized_update`
 | 
					        * `eps` is $\hat{\epsilon}$ or $\epsilon$ based on `optimized_update`
 | 
				
			||||||
        * `weight_decay` is an instance of class `WeightDecay` defined in [`__init__.py`](index.html)
 | 
					        * `weight_decay` is an instance of class `WeightDecay` defined in [`__init__.py`](index.html)
 | 
				
			||||||
        * 'optimized_update' is a flag whether to optimize the bias correction of the second moment
 | 
					        * `optimized_update` is a flag whether to optimize the bias correction of the second moment
 | 
				
			||||||
          by doing it after adding $\epsilon$
 | 
					          by doing it after adding $\epsilon$
 | 
				
			||||||
        * `amsgrad` is a flag indicating whether to use AMSGrad or fallback to plain Adam
 | 
					        * `amsgrad` is a flag indicating whether to use AMSGrad or fallback to plain Adam
 | 
				
			||||||
        * `degenerate_to_sgd` whether to use sgd when the rectification term $r_t is intractable
 | 
					        * `degenerate_to_sgd` whether to use sgd when the rectification term $r_t$ is intractable
 | 
				
			||||||
        * 'rectify' is whether to use RAdam update
 | 
					        * `rectify` is whether to use RAdam update
 | 
				
			||||||
        * `defaults` is a dictionary of default for group values.
 | 
					        * `defaults` is a dictionary of default for group values.
 | 
				
			||||||
         This is useful when you want to extend the class `AdaBelief`.
 | 
					         This is useful when you want to extend the class `AdaBelief`.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
 | 
				
			|||||||
@ -155,10 +155,10 @@ class RAdam(AMSGrad):
 | 
				
			|||||||
        * `betas` is a tuple of ($\beta_1$, $\beta_2$)
 | 
					        * `betas` is a tuple of ($\beta_1$, $\beta_2$)
 | 
				
			||||||
        * `eps` is $\hat{\epsilon}$ or $\epsilon$ based on `optimized_update`
 | 
					        * `eps` is $\hat{\epsilon}$ or $\epsilon$ based on `optimized_update`
 | 
				
			||||||
        * `weight_decay` is an instance of class `WeightDecay` defined in [`__init__.py`](index.html)
 | 
					        * `weight_decay` is an instance of class `WeightDecay` defined in [`__init__.py`](index.html)
 | 
				
			||||||
        * 'optimized_update' is a flag whether to optimize the bias correction of the second moment
 | 
					        * `optimized_update` is a flag whether to optimize the bias correction of the second moment
 | 
				
			||||||
          by doing it after adding $\epsilon$
 | 
					          by doing it after adding $\epsilon$
 | 
				
			||||||
        * `amsgrad` is a flag indicating whether to use AMSGrad or fallback to plain Adam
 | 
					        * `amsgrad` is a flag indicating whether to use AMSGrad or fallback to plain Adam
 | 
				
			||||||
        * `degenerate_to_sgd` whether to use sgd when the rectification term $r_t is intractable.
 | 
					        * `degenerate_to_sgd` whether to use sgd when the rectification term $r_t$ is intractable.
 | 
				
			||||||
        * `defaults` is a dictionary of default for group values.
 | 
					        * `defaults` is a dictionary of default for group values.
 | 
				
			||||||
         This is useful when you want to extend the class `RAdam`.
 | 
					         This is useful when you want to extend the class `RAdam`.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
 | 
				
			|||||||
@ -45,10 +45,10 @@ class GAE:
 | 
				
			|||||||
        $\hat{A_t}$
 | 
					        $\hat{A_t}$
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        \begin{align}
 | 
					        \begin{align}
 | 
				
			||||||
        \delta_t &= r_t + \gamma V(s_{t+1}) - V(s_t)$
 | 
					        \delta_t &= r_t + \gamma V(s_{t+1}) - V(s_t)
 | 
				
			||||||
        \\
 | 
					        \\
 | 
				
			||||||
        \hat{A_t} &= \delta_t + \gamma \lambda \delta_{t+1} + ... +
 | 
					        \hat{A_t} &= \delta_t + \gamma \lambda \delta_{t+1} + ... +
 | 
				
			||||||
                             (\gamma \lambda)^{T - t + 1} \delta_{T - 1}$
 | 
					                             (\gamma \lambda)^{T - t + 1} \delta_{T - 1}
 | 
				
			||||||
        \\
 | 
					        \\
 | 
				
			||||||
        &= \delta_t + \gamma \lambda \hat{A_{t+1}}
 | 
					        &= \delta_t + \gamma \lambda \hat{A_{t+1}}
 | 
				
			||||||
        \end{align}
 | 
					        \end{align}
 | 
				
			||||||
 | 
				
			|||||||
@ -114,7 +114,7 @@ class StrokesDataset(Dataset):
 | 
				
			|||||||
            # Mask is on until end of sequence
 | 
					            # Mask is on until end of sequence
 | 
				
			||||||
            self.mask[i, :len_seq + 1] = 1
 | 
					            self.mask[i, :len_seq + 1] = 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Start-of-sequence is $(0, 0, 1, 0, 0)
 | 
					        # Start-of-sequence is $(0, 0, 1, 0, 0)$
 | 
				
			||||||
        self.data[:, 0, 2] = 1
 | 
					        self.data[:, 0, 2] = 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __len__(self):
 | 
					    def __len__(self):
 | 
				
			||||||
 | 
				
			|||||||
@ -42,7 +42,7 @@ AFT Local only apply learned pair-wise position biases locally:
 | 
				
			|||||||
\begin{align}
 | 
					\begin{align}
 | 
				
			||||||
w'_{t,t'} =
 | 
					w'_{t,t'} =
 | 
				
			||||||
\begin{cases}
 | 
					\begin{cases}
 | 
				
			||||||
w_{t,t'},  & \text{for $\lvert t-t' \rvert \lt s$} \\
 | 
					w_{t,t'},  & {\text{for } \lvert t-t' \rvert \lt s} \\
 | 
				
			||||||
0, & \text{otherwise}
 | 
					0, & \text{otherwise}
 | 
				
			||||||
\end{cases}
 | 
					\end{cases}
 | 
				
			||||||
\end{align}
 | 
					\end{align}
 | 
				
			||||||
@ -79,7 +79,7 @@ class AFTLocal(Module):
 | 
				
			|||||||
    \begin{align}
 | 
					    \begin{align}
 | 
				
			||||||
    w'_{t,t'} =
 | 
					    w'_{t,t'} =
 | 
				
			||||||
    \begin{cases}
 | 
					    \begin{cases}
 | 
				
			||||||
    w_{t,t'},  & \text{for $\lvert t-t' \rvert \lt s$} \\
 | 
					    w_{t,t'},  & {\text{for } \lvert t-t' \rvert \lt s} \\
 | 
				
			||||||
    0, & \text{otherwise}
 | 
					    0, & \text{otherwise}
 | 
				
			||||||
    \end{cases}
 | 
					    \end{cases}
 | 
				
			||||||
    \end{align}
 | 
					    \end{align}
 | 
				
			||||||
@ -119,7 +119,7 @@ class AFTLocal(Module):
 | 
				
			|||||||
        \begin{align}
 | 
					        \begin{align}
 | 
				
			||||||
        m_{t,t'} =
 | 
					        m_{t,t'} =
 | 
				
			||||||
        \begin{cases}
 | 
					        \begin{cases}
 | 
				
			||||||
        1,  & \text{for $\lvert t-t' \rvert \lt s$} \\
 | 
					        1,  & {\text{for } \lvert t-t' \rvert \lt s} \\
 | 
				
			||||||
        0, & \text{otherwise}
 | 
					        0, & \text{otherwise}
 | 
				
			||||||
        \end{cases}
 | 
					        \end{cases}
 | 
				
			||||||
        \end{align}
 | 
					        \end{align}
 | 
				
			||||||
@ -170,7 +170,7 @@ class AFTLocal(Module):
 | 
				
			|||||||
        #     \begin{align}
 | 
					        #     \begin{align}
 | 
				
			||||||
        #     w'_{t,t'} =
 | 
					        #     w'_{t,t'} =
 | 
				
			||||||
        #     \begin{cases}
 | 
					        #     \begin{cases}
 | 
				
			||||||
        #     w_{t,t'},  & \text{for $\lvert t-t' \rvert \lt s$} \\
 | 
					        #     w_{t,t'},  & {\text{for }\lvert t-t' \rvert \lt s} \\
 | 
				
			||||||
        #     0, & \text{otherwise}
 | 
					        #     0, & \text{otherwise}
 | 
				
			||||||
        #     \end{cases}
 | 
					        #     \end{cases}
 | 
				
			||||||
        #     \end{align}
 | 
					        #     \end{align}
 | 
				
			||||||
 | 
				
			|||||||
@ -40,7 +40,7 @@ We have implemented the latter here since it gives better results.
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
This implementation uses pre-layer normalization
 | 
					This implementation uses pre-layer normalization
 | 
				
			||||||
while the paper uses post-layer normalization.
 | 
					while the paper uses post-layer normalization.
 | 
				
			||||||
Pre-layer norm does the layer norm before FFN[../feedforward.html) and
 | 
					Pre-layer norm does the layer norm before [FFN](../feedforward.html) and
 | 
				
			||||||
self-attention, and the pass-through in the residual connection is not normalized.
 | 
					self-attention, and the pass-through in the residual connection is not normalized.
 | 
				
			||||||
This is supposed to be more stable in standard transformer setups.
 | 
					This is supposed to be more stable in standard transformer setups.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -246,7 +246,7 @@ class AttentionReconstructionLoss:
 | 
				
			|||||||
        This is a reimplementation of ['PrepareForMultiHeadAttention'](../mha.html#PrepareMHA)
 | 
					        This is a reimplementation of ['PrepareForMultiHeadAttention'](../mha.html#PrepareMHA)
 | 
				
			||||||
        where the projections are done with the parameters detached from gradient computation.
 | 
					        where the projections are done with the parameters detached from gradient computation.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        * `pmha* is the ['PrepareForMultiHeadAttention'](../mha.html#PrepareMHA) module
 | 
					        * `pmha` is the ['PrepareForMultiHeadAttention'](../mha.html#PrepareMHA) module
 | 
				
			||||||
        * `x` is tensor with the token embeddings
 | 
					        * `x` is tensor with the token embeddings
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -32,7 +32,7 @@ We have implemented the latter here since it gives better results.
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
This implementation uses pre-layer normalization
 | 
					This implementation uses pre-layer normalization
 | 
				
			||||||
while the paper uses post-layer normalization.
 | 
					while the paper uses post-layer normalization.
 | 
				
			||||||
Pre-layer norm does the layer norm before FFN[../feedforward.html) and
 | 
					Pre-layer norm does the layer norm before [FFN](../feedforward.html) and
 | 
				
			||||||
self-attention, and the pass-through in the residual connection is not normalized.
 | 
					self-attention, and the pass-through in the residual connection is not normalized.
 | 
				
			||||||
This is supposed to be more stable in standard transformer setups.
 | 
					This is supposed to be more stable in standard transformer setups.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -201,7 +201,7 @@ class Trainer:
 | 
				
			|||||||
        # Cross-entropy loss
 | 
					        # Cross-entropy loss
 | 
				
			||||||
        self.loss_func = nn.CrossEntropyLoss()
 | 
					        self.loss_func = nn.CrossEntropyLoss()
 | 
				
			||||||
        # Number of training epochs;
 | 
					        # Number of training epochs;
 | 
				
			||||||
        # *note that our dataset definition repeats the data `seq_len` times in a single epoch
 | 
					        # *note that our dataset definition repeats the data `seq_len` times in a single epoch*
 | 
				
			||||||
        self.epochs = configs.epochs
 | 
					        self.epochs = configs.epochs
 | 
				
			||||||
        # Gradient clipping norm
 | 
					        # Gradient clipping norm
 | 
				
			||||||
        self.grad_norm_clip = configs.grad_norm_clip
 | 
					        self.grad_norm_clip = configs.grad_norm_clip
 | 
				
			||||||
 | 
				
			|||||||
@ -46,7 +46,7 @@ def knn(queries: torch.Tensor, index: faiss.IndexFlatL2, keys_store: np.ndarray,
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    # Normalize $f(c_i)$
 | 
					    # Normalize $f(c_i)$
 | 
				
			||||||
    keys_found_n = keys_found / torch.sqrt((keys_found ** 2).sum(-1, keepdims=True) + 1e-10)
 | 
					    keys_found_n = keys_found / torch.sqrt((keys_found ** 2).sum(-1, keepdims=True) + 1e-10)
 | 
				
			||||||
    # Normalize $f($\color{yellowgreen}{c_t})$
 | 
					    # Normalize $f(\color{yellowgreen}{c_t})$
 | 
				
			||||||
    queries_n = queries / torch.sqrt((queries ** 2).sum(-1, keepdims=True) + 1e-10)
 | 
					    queries_n = queries / torch.sqrt((queries ** 2).sum(-1, keepdims=True) + 1e-10)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Get the dot-product, or cosine similarity
 | 
					    # Get the dot-product, or cosine similarity
 | 
				
			||||||
 | 
				
			|||||||
@ -81,7 +81,7 @@ class MLM:
 | 
				
			|||||||
                 masking_prob: float = 0.15, randomize_prob: float = 0.1, no_change_prob: float = 0.1,
 | 
					                 masking_prob: float = 0.15, randomize_prob: float = 0.1, no_change_prob: float = 0.1,
 | 
				
			||||||
                 ):
 | 
					                 ):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        * `padding_token` is the padding token `[PAD].
 | 
					        * `padding_token` is the padding token `[PAD]`.
 | 
				
			||||||
          We will use this to mark the labels that shouldn't be used for loss calculation.
 | 
					          We will use this to mark the labels that shouldn't be used for loss calculation.
 | 
				
			||||||
        * `mask_token` is the masking token `[MASK]`.
 | 
					        * `mask_token` is the masking token `[MASK]`.
 | 
				
			||||||
        * `no_mask_tokens` is a list of tokens that should not be masked.
 | 
					        * `no_mask_tokens` is a list of tokens that should not be masked.
 | 
				
			||||||
 | 
				
			|||||||
@ -155,11 +155,13 @@ class SwitchFeedForward(Module):
 | 
				
			|||||||
        final_output = final_output.view(seq_len, batch_size, d_model)
 | 
					        final_output = final_output.view(seq_len, batch_size, d_model)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Return
 | 
					        # Return
 | 
				
			||||||
 | 
					        #
 | 
				
			||||||
        # * the final output
 | 
					        # * the final output
 | 
				
			||||||
        # * number of tokens routed to each expert
 | 
					        # * number of tokens routed to each expert
 | 
				
			||||||
        # * sum of probabilities for each expert
 | 
					        # * sum of probabilities for each expert
 | 
				
			||||||
        # * number of tokens dropped.
 | 
					        # * number of tokens dropped.
 | 
				
			||||||
        # * routing probabilities of the selected experts
 | 
					        # * routing probabilities of the selected experts
 | 
				
			||||||
 | 
					        #
 | 
				
			||||||
        # These are used for the load balancing loss and logging
 | 
					        # These are used for the load balancing loss and logging
 | 
				
			||||||
        return final_output, counts, route_prob.sum(0), len(dropped), route_prob_max
 | 
					        return final_output, counts, route_prob.sum(0), len(dropped), route_prob_max
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user