mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-18 11:52:00 +08:00
refractor
This commit is contained in:
@ -7,9 +7,9 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
from labml.configs import BaseConfigs, option, calculate
|
from labml.configs import BaseConfigs, option, calculate
|
||||||
from labml_helpers.module import Module
|
from labml_helpers.module import Module
|
||||||
from transformers.mha import MultiHeadAttention
|
from labml_nn.utils import clone_module_list
|
||||||
from transformers.positional_encoding import PositionalEncoding, get_positional_encoding
|
from .mha import MultiHeadAttention
|
||||||
from transformers.utils import clone_module_list
|
from .positional_encoding import PositionalEncoding, get_positional_encoding
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingsWithPositionalEncoding(Module):
|
class EmbeddingsWithPositionalEncoding(Module):
|
||||||
@ -197,7 +197,7 @@ calculate(TransformerConfigs.decoder_mem_attn, 'mha', _mha)
|
|||||||
|
|
||||||
### Relative MHA
|
### Relative MHA
|
||||||
def _relative_mha(c: TransformerConfigs):
|
def _relative_mha(c: TransformerConfigs):
|
||||||
from transformers.relative_mha import RelativeMultiHeadAttention
|
from .relative_mha import RelativeMultiHeadAttention
|
||||||
return RelativeMultiHeadAttention(c.n_heads, c.d_model)
|
return RelativeMultiHeadAttention(c.n_heads, c.d_model)
|
||||||
|
|
||||||
|
|
||||||
|
@ -6,9 +6,8 @@ https://arxiv.org/abs/1901.02860
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from labml_helpers.module import Module
|
|
||||||
from labml.logger import inspect
|
from labml.logger import inspect
|
||||||
from transformers.mha import MultiHeadAttention
|
from .mha import MultiHeadAttention
|
||||||
|
|
||||||
|
|
||||||
def relative_shift(x: torch.Tensor):
|
def relative_shift(x: torch.Tensor):
|
||||||
|
Reference in New Issue
Block a user