This commit is contained in:
Varuna Jayasiri
2022-07-02 14:31:16 +05:30
parent ab4264cbda
commit b6bef1d2fe
8 changed files with 215 additions and 223 deletions

View File

@ -15,20 +15,20 @@ on an NLP auto-regression task (with Tiny Shakespeare dataset).
"""
import torch
from torch import nn
from labml import experiment
from labml.configs import option
from labml_helpers.module import Module
from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
from labml_nn.transformers import TransformerConfigs, Encoder
from labml_nn.transformers.utils import subsequent_mask
class AutoregressiveTransformer(Module):
class AutoregressiveTransformer(nn.Module):
"""
## Auto-Regressive model
"""
def __init__(self, encoder: Encoder, src_embed: Module, generator: Module):
def __init__(self, encoder: Encoder, src_embed: nn.Module, generator: nn.Module):
"""
* `encoder` is the transformer [Encoder](../models.html#Encoder)
* `src_embed` is the token

View File

@ -26,13 +26,12 @@ import math
from typing import Optional, List
import torch
from torch import nn as nn
from torch import nn
from labml import tracker
from labml_helpers.module import Module
class PrepareForMultiHeadAttention(Module):
class PrepareForMultiHeadAttention(nn.Module):
"""
<a id="PrepareMHA"></a>
@ -68,7 +67,7 @@ class PrepareForMultiHeadAttention(Module):
return x
class MultiHeadAttention(Module):
class MultiHeadAttention(nn.Module):
r"""
<a id="MHA"></a>

View File

@ -15,7 +15,6 @@ import math
import torch
import torch.nn as nn
from labml_helpers.module import Module
from labml_nn.utils import clone_module_list
from .feed_forward import FeedForward
@ -23,7 +22,7 @@ from .mha import MultiHeadAttention
from .positional_encoding import get_positional_encoding
class EmbeddingsWithPositionalEncoding(Module):
class EmbeddingsWithPositionalEncoding(nn.Module):
"""
<a id="EmbeddingsWithPositionalEncoding"></a>
@ -41,7 +40,7 @@ class EmbeddingsWithPositionalEncoding(Module):
return self.linear(x) * math.sqrt(self.d_model) + pe
class EmbeddingsWithLearnedPositionalEncoding(Module):
class EmbeddingsWithLearnedPositionalEncoding(nn.Module):
"""
<a id="EmbeddingsWithLearnedPositionalEncoding"></a>
@ -59,7 +58,7 @@ class EmbeddingsWithLearnedPositionalEncoding(Module):
return self.linear(x) * math.sqrt(self.d_model) + pe
class TransformerLayer(Module):
class TransformerLayer(nn.Module):
"""
<a id="TransformerLayer"></a>
@ -139,7 +138,7 @@ class TransformerLayer(Module):
return x
class Encoder(Module):
class Encoder(nn.Module):
"""
<a id="Encoder"></a>
@ -161,7 +160,7 @@ class Encoder(Module):
return self.norm(x)
class Decoder(Module):
class Decoder(nn.Module):
"""
<a id="Decoder"></a>
@ -183,7 +182,7 @@ class Decoder(Module):
return self.norm(x)
class Generator(Module):
class Generator(nn.Module):
"""
<a id="Generator"></a>
@ -201,14 +200,14 @@ class Generator(Module):
return self.projection(x)
class EncoderDecoder(Module):
class EncoderDecoder(nn.Module):
"""
<a id="EncoderDecoder"></a>
## Combined Encoder-Decoder
"""
def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: Module, tgt_embed: Module, generator: Module):
def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: nn.Module, tgt_embed: nn.Module, generator: nn.Module):
super().__init__()
self.encoder = encoder
self.decoder = decoder

View File

@ -26,10 +26,8 @@ import numpy as np
import torch
import torch.nn as nn
from labml_helpers.module import Module
class PositionalEncoding(Module):
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, dropout_prob: float, max_len: int = 5000):
super().__init__()
self.dropout = nn.Dropout(dropout_prob)