diff --git a/transformers/__init__.py b/transformers/__init__.py index 65b23399..2d1e2468 100644 --- a/transformers/__init__.py +++ b/transformers/__init__.py @@ -6,7 +6,7 @@ import torch.nn as nn import torch.nn.functional as F from labml.configs import BaseConfigs, option, calculate -from labml.helpers.pytorch.module import Module +from labml_helpers.module import Module from transformers.mha import MultiHeadAttention from transformers.positional_encoding import PositionalEncoding, get_positional_encoding from transformers.utils import clone_module_list diff --git a/transformers/label_smoothing_loss.py b/transformers/label_smoothing_loss.py index 2dff9f3d..04cbd7f1 100644 --- a/transformers/label_smoothing_loss.py +++ b/transformers/label_smoothing_loss.py @@ -3,7 +3,7 @@ import numpy as np import torch import torch.nn as nn -from labml.helpers.pytorch.module import Module +from labml_helpers.module import Module class LabelSmoothingLoss(Module): diff --git a/transformers/mha.py b/transformers/mha.py index 442e9dd8..c97fd76f 100644 --- a/transformers/mha.py +++ b/transformers/mha.py @@ -5,7 +5,7 @@ import torch from torch import nn as nn from torch.nn import functional as F -from labml.helpers.pytorch.module import Module +from labml_helpers.module import Module class PrepareForMultiHeadAttention(Module): diff --git a/transformers/positional_encoding.py b/transformers/positional_encoding.py index 6421e8d5..731fc008 100644 --- a/transformers/positional_encoding.py +++ b/transformers/positional_encoding.py @@ -5,7 +5,7 @@ import numpy as np import torch import torch.nn as nn -from labml.helpers.pytorch.module import Module +from labml_helpers.module import Module class PositionalEncoding(Module): diff --git a/transformers/relative_mha.py b/transformers/relative_mha.py index 41a14122..17f245ec 100644 --- a/transformers/relative_mha.py +++ b/transformers/relative_mha.py @@ -6,7 +6,7 @@ https://arxiv.org/abs/1901.02860 import torch from torch import nn -from labml.helpers.pytorch.module import Module +from labml_helpers.module import Module from labml.logger import inspect from transformers.mha import MultiHeadAttention diff --git a/transformers/utils.py b/transformers/utils.py index 4f59b341..20d63bb7 100644 --- a/transformers/utils.py +++ b/transformers/utils.py @@ -2,7 +2,7 @@ import copy from torch import nn -from labml.helpers.pytorch.module import Module +from labml_helpers.module import Module def clone_module_list(module: Module, n: int):