mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-15 18:27:20 +08:00
lab_helpers
This commit is contained in:
@ -6,7 +6,7 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from labml.configs import BaseConfigs, option, calculate
|
from labml.configs import BaseConfigs, option, calculate
|
||||||
from labml.helpers.pytorch.module import Module
|
from labml_helpers.module import Module
|
||||||
from transformers.mha import MultiHeadAttention
|
from transformers.mha import MultiHeadAttention
|
||||||
from transformers.positional_encoding import PositionalEncoding, get_positional_encoding
|
from transformers.positional_encoding import PositionalEncoding, get_positional_encoding
|
||||||
from transformers.utils import clone_module_list
|
from transformers.utils import clone_module_list
|
||||||
|
@ -3,7 +3,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from labml.helpers.pytorch.module import Module
|
from labml_helpers.module import Module
|
||||||
|
|
||||||
|
|
||||||
class LabelSmoothingLoss(Module):
|
class LabelSmoothingLoss(Module):
|
||||||
|
@ -5,7 +5,7 @@ import torch
|
|||||||
from torch import nn as nn
|
from torch import nn as nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from labml.helpers.pytorch.module import Module
|
from labml_helpers.module import Module
|
||||||
|
|
||||||
|
|
||||||
class PrepareForMultiHeadAttention(Module):
|
class PrepareForMultiHeadAttention(Module):
|
||||||
|
@ -5,7 +5,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from labml.helpers.pytorch.module import Module
|
from labml_helpers.module import Module
|
||||||
|
|
||||||
|
|
||||||
class PositionalEncoding(Module):
|
class PositionalEncoding(Module):
|
||||||
|
@ -6,7 +6,7 @@ https://arxiv.org/abs/1901.02860
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from labml.helpers.pytorch.module import Module
|
from labml_helpers.module import Module
|
||||||
from labml.logger import inspect
|
from labml.logger import inspect
|
||||||
from transformers.mha import MultiHeadAttention
|
from transformers.mha import MultiHeadAttention
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@ import copy
|
|||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from labml.helpers.pytorch.module import Module
|
from labml_helpers.module import Module
|
||||||
|
|
||||||
|
|
||||||
def clone_module_list(module: Module, n: int):
|
def clone_module_list(module: Module, n: int):
|
||||||
|
Reference in New Issue
Block a user