refractor

This commit is contained in:
Varuna Jayasiri
2020-09-01 08:12:22 +05:30
parent df22468706
commit 048c5bbd10
3 changed files with 5 additions and 6 deletions

View File

@ -7,9 +7,9 @@ import torch.nn.functional as F
from labml.configs import BaseConfigs, option, calculate from labml.configs import BaseConfigs, option, calculate
from labml_helpers.module import Module from labml_helpers.module import Module
from transformers.mha import MultiHeadAttention from labml_nn.utils import clone_module_list
from transformers.positional_encoding import PositionalEncoding, get_positional_encoding from .mha import MultiHeadAttention
from transformers.utils import clone_module_list from .positional_encoding import PositionalEncoding, get_positional_encoding
class EmbeddingsWithPositionalEncoding(Module): class EmbeddingsWithPositionalEncoding(Module):
@ -197,7 +197,7 @@ calculate(TransformerConfigs.decoder_mem_attn, 'mha', _mha)
### Relative MHA ### Relative MHA
def _relative_mha(c: TransformerConfigs): def _relative_mha(c: TransformerConfigs):
from transformers.relative_mha import RelativeMultiHeadAttention from .relative_mha import RelativeMultiHeadAttention
return RelativeMultiHeadAttention(c.n_heads, c.d_model) return RelativeMultiHeadAttention(c.n_heads, c.d_model)

View File

@ -6,9 +6,8 @@ https://arxiv.org/abs/1901.02860
import torch import torch
from torch import nn from torch import nn
from labml_helpers.module import Module
from labml.logger import inspect from labml.logger import inspect
from transformers.mha import MultiHeadAttention from .mha import MultiHeadAttention
def relative_shift(x: torch.Tensor): def relative_shift(x: torch.Tensor):