remove gelu custom impl and use pytorch impl

This commit is contained in:
lakshith
2024-07-27 21:28:07 +05:30
parent cbc38bb26b
commit b3aedf3093

View File

@ -44,9 +44,6 @@ config = {
"vocab_size": 50257
}
import math
from torch import Tensor
# from transformers
class Conv1D(nn.Module):
@ -74,23 +71,12 @@ class Conv1D(nn.Module):
return x
# from transformers
class NewGELUActivation(nn.Module):
"""
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
"""
def forward(self, input: Tensor) -> Tensor:
return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
class HeadFFN(nn.Module): # todo rename
def __init__(self, dim):
super().__init__()
self.c_fc = Conv1D(dim, config['n_embd'])
self.c_proj = Conv1D(config['n_embd'], dim)
self.act = NewGELUActivation()
self.act = nn.functional.gelu
self.dropout = nn.Dropout(config['resid_pdrop'])
def forward(self, hidden_states):