From a450afd1bd1afe9ecc4c3d5f7bc4ee9d4d1f302f Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Fri, 3 Jun 2022 10:13:03 +0530 Subject: [PATCH] bug fix --- docs/transformers/rope/index.html | 44 +++++----- docs/transformers/rope/value_pe/index.html | 82 +++++++++---------- labml_nn/transformers/rope/__init__.py | 5 +- .../transformers/rope/value_pe/__init__.py | 3 +- 4 files changed, 68 insertions(+), 66 deletions(-) diff --git a/docs/transformers/rope/index.html b/docs/transformers/rope/index.html index 27455c61..930aeba2 100644 --- a/docs/transformers/rope/index.html +++ b/docs/transformers/rope/index.html @@ -235,7 +235,7 @@
156        self.cos_cached = idx_theta2.cos()[:, None, None, :]
-157        self.sin_cached = idx_theta2.cos()[:, None, None, :]
+157 self.sin_cached = idx_theta2.sin()[:, None, None, :]
@@ -320,7 +320,7 @@
-
177        neg_half_x = self._neg_half(x_rope)
+
178        neg_half_x = self._neg_half(x_rope)
@@ -333,7 +333,7 @@
-
189        x_rope = (x_rope * self.cos_cached[:x.shape[0]]) + (neg_half_x * self.sin_cached[:x.shape[0]])
+
190        x_rope = (x_rope * self.cos_cached[:x.shape[0]]) + (neg_half_x * self.sin_cached[:x.shape[0]])
@@ -345,7 +345,7 @@
-
192        return torch.cat((x_rope, x_pass), dim=-1)
+
193        return torch.cat((x_rope, x_pass), dim=-1)
@@ -358,7 +358,7 @@
-
195class RotaryPEMultiHeadAttention(MultiHeadAttention):
+
196class RotaryPEMultiHeadAttention(MultiHeadAttention):
@@ -369,7 +369,7 @@
-
202    def __init__(self, heads: int, d_model: int, rope_percentage: float = 0.5, dropout_prob: float = 0.1):
+
203    def __init__(self, heads: int, d_model: int, rope_percentage: float = 0.5, dropout_prob: float = 0.1):
@@ -382,7 +382,7 @@
-
206        super().__init__(heads, d_model, dropout_prob, bias=False)
+
207        super().__init__(heads, d_model, dropout_prob, bias=False)
@@ -394,9 +394,9 @@
-
209        d_rope = int(self.d_k * rope_percentage)
-210        self.query_rotary_pe = RotaryPositionalEmbeddings(d_rope)
-211        self.key_rotary_pe = RotaryPositionalEmbeddings(d_rope)
+
210        d_rope = int(self.d_k * rope_percentage)
+211        self.query_rotary_pe = RotaryPositionalEmbeddings(d_rope)
+212        self.key_rotary_pe = RotaryPositionalEmbeddings(d_rope)
@@ -408,7 +408,7 @@
-
213    def get_scores(self, query: torch.Tensor, key: torch.Tensor):
+
214    def get_scores(self, query: torch.Tensor, key: torch.Tensor):
@@ -420,7 +420,7 @@
-
219        return torch.einsum('ibhd,jbhd->ijbh', self.query_rotary_pe(query), self.key_rotary_pe(key))
+
220        return torch.einsum('ibhd,jbhd->ijbh', self.query_rotary_pe(query), self.key_rotary_pe(key))
@@ -432,7 +432,7 @@
-
222def _test_rotary():
+
223def _test_rotary():
@@ -443,16 +443,16 @@
-
226    x = torch.tensor([[1, 2, 3, 4], [4, 5, 6, 7], [7, 8, 9, 10]], dtype=torch.float)
-227    x = x[:, None, None, :]
-228    inspect(x)
-229
-230    rotary_pe = RotaryPositionalEmbeddings(3)
-231    inspect(rotary_pe(x))
-232
+            
227    x = torch.tensor([[1, 2, 3, 4], [4, 5, 6, 7], [7, 8, 9, 10]], dtype=torch.float)
+228    x = x[:, None, None, :]
+229    inspect(x)
+230
+231    rotary_pe = RotaryPositionalEmbeddings(3)
+232    inspect(rotary_pe(x))
 233
-234if __name__ == '__main__':
-235    _test_rotary()
+234 +235if __name__ == '__main__': +236 _test_rotary()
-
142        neg_half_x = self._neg_half(x_rope)
+
143        neg_half_x = self._neg_half(x_rope)
@@ -173,7 +173,7 @@
-
158        x_rope = (x_rope * self.cos_cached[:x.shape[0]]) - (neg_half_x * self.sin_cached[:x.shape[0]])
+
159        x_rope = (x_rope * self.cos_cached[:x.shape[0]]) - (neg_half_x * self.sin_cached[:x.shape[0]])
@@ -185,7 +185,7 @@
-
161        return torch.cat((x_rope, x_pass), dim=-1)
+
162        return torch.cat((x_rope, x_pass), dim=-1)
@@ -198,7 +198,7 @@
-
164class RotaryValuePEMultiHeadAttention(MultiHeadAttention):
+
165class RotaryValuePEMultiHeadAttention(MultiHeadAttention):
@@ -209,9 +209,9 @@
-
171    def __init__(self, heads: int, d_model: int,
-172                 rope_percentage: float = 0.5, rope_value_percentage: float = 0.5,
-173                 dropout_prob: float = 0.1):
+
172    def __init__(self, heads: int, d_model: int,
+173                 rope_percentage: float = 0.5, rope_value_percentage: float = 0.5,
+174                 dropout_prob: float = 0.1):
@@ -224,7 +224,7 @@
-
177        super().__init__(heads, d_model, dropout_prob, bias=False)
+
178        super().__init__(heads, d_model, dropout_prob, bias=False)
@@ -236,13 +236,13 @@
-
180        d_rope = int(self.d_k * rope_percentage)
-181        d_rope_value = int(self.d_k * rope_value_percentage)
-182
-183        self.query_rotary_pe = RotaryPositionalEmbeddings(d_rope)
-184        self.key_rotary_pe = RotaryPositionalEmbeddings(d_rope)
-185        self.value_rotary_pe = RotaryPositionalEmbeddings(d_rope_value)
-186        self.value_reverse_rotary_pe = ReverseRotaryPositionalEmbeddings(d_rope_value)
+
181        d_rope = int(self.d_k * rope_percentage)
+182        d_rope_value = int(self.d_k * rope_value_percentage)
+183
+184        self.query_rotary_pe = RotaryPositionalEmbeddings(d_rope)
+185        self.key_rotary_pe = RotaryPositionalEmbeddings(d_rope)
+186        self.value_rotary_pe = RotaryPositionalEmbeddings(d_rope_value)
+187        self.value_reverse_rotary_pe = ReverseRotaryPositionalEmbeddings(d_rope_value)
@@ -254,7 +254,7 @@
-
188    def get_scores(self, query: torch.Tensor, key: torch.Tensor):
+
189    def get_scores(self, query: torch.Tensor, key: torch.Tensor):
@@ -266,7 +266,7 @@
-
194        return torch.einsum('ibhd,jbhd->ijbh', self.query_rotary_pe(query), self.key_rotary_pe(key))
+
195        return torch.einsum('ibhd,jbhd->ijbh', self.query_rotary_pe(query), self.key_rotary_pe(key))
@@ -289,11 +289,11 @@
-
196    def forward(self, *,
-197                query: torch.Tensor,
-198                key: torch.Tensor,
-199                value: torch.Tensor,
-200                mask: Optional[torch.Tensor] = None):
+
197    def forward(self, *,
+198                query: torch.Tensor,
+199                key: torch.Tensor,
+200                value: torch.Tensor,
+201                mask: Optional[torch.Tensor] = None):
@@ -309,10 +309,10 @@
-
212        seq_len, batch_size, _ = query.shape
-213
-214        if mask is not None:
-215            mask = self.prepare_mask(mask, query.shape, key.shape)
+
213        seq_len, batch_size, _ = query.shape
+214
+215        if mask is not None:
+216            mask = self.prepare_mask(mask, query.shape, key.shape)
@@ -328,9 +328,9 @@
-
219        query = self.query(query)
-220        key = self.key(key)
-221        value = self.value(value)
+
220        query = self.query(query)
+221        key = self.key(key)
+222        value = self.value(value)
@@ -343,7 +343,7 @@
-
225        scores = self.get_scores(query, key)
+
226        scores = self.get_scores(query, key)
@@ -366,7 +366,7 @@ M834 80h400000v40h-400000z">
228        scores *= self.scale
+
229        scores *= self.scale
@@ -378,8 +378,8 @@ M834 80h400000v40h-400000z">
231        if mask is not None:
-232            scores = scores.masked_fill(mask == 0, float('-inf'))
+
232        if mask is not None:
+233            scores = scores.masked_fill(mask == 0, float('-inf'))
@@ -402,7 +402,7 @@ M834 80h400000v40h-400000z">
236        attn = self.softmax(scores)
+
237        attn = self.softmax(scores)
@@ -414,7 +414,7 @@ M834 80h400000v40h-400000z">
239        attn = self.dropout(attn)
+
240        attn = self.dropout(attn)
@@ -426,7 +426,7 @@ M834 80h400000v40h-400000z">
242        value = self.value_rotary_pe(value)
+
243        value = self.value_rotary_pe(value)
@@ -449,7 +449,7 @@ M834 80h400000v40h-400000z">
246        x = torch.einsum("ijbh,jbhd->ibhd", attn, self.value_rotary_pe(value))
+
247        x = torch.einsum("ijbh,jbhd->ibhd", attn, self.value_rotary_pe(value))
@@ -461,7 +461,7 @@ M834 80h400000v40h-400000z">
249        x = self.value_reverse_rotary_pe(x)
+
250        x = self.value_reverse_rotary_pe(x)
@@ -473,7 +473,7 @@ M834 80h400000v40h-400000z">
252        self.attn = attn.detach()
+
253        self.attn = attn.detach()
@@ -485,7 +485,7 @@ M834 80h400000v40h-400000z">
255        x = x.reshape(seq_len, batch_size, -1)
+
256        x = x.reshape(seq_len, batch_size, -1)
@@ -497,7 +497,7 @@ M834 80h400000v40h-400000z">
258        return self.output(x)
+
259        return self.output(x)