diff --git a/docs/cfr/kuhn/index.html b/docs/cfr/kuhn/index.html index e2f5f83d..a3d6f350 100644 --- a/docs/cfr/kuhn/index.html +++ b/docs/cfr/kuhn/index.html @@ -794,11 +794,11 @@
-Set models for saving
+Start the experiment
242 experiment.add_model_savers({'info_sets': InfoSetSaver(conf.cfr.info_sets)})
242 with experiment.start():
244 with experiment.start():
244 conf.cfr.iterate()
Start iterating
- -246 conf.cfr.iterate()
250if __name__ == '__main__':
-251 main()
248if __name__ == '__main__':
+249 main()
This can act as an encoder layer or a decoder layer.
-๐ Some implementations, including the paper seem to have differences in where the layer-normalization is done. Here we do a layer normalization before attention and feed-forward networks, and add the original residual vectors. Alternative is to do a layer normalization after adding the residuals. But we found this to be less stable when training. We found a detailed discussion about this in the paper On Layer Normalization in the Transformer Architecture.
+This can act as an encoder layer or a decoder layer. We use pre-norm.
78 def __init__(self, *,
-79 d_model: int,
-80 self_attn: MultiHeadAttention,
-81 src_attn: MultiHeadAttention = None,
-82 feed_forward: FeedForward,
-83 dropout_prob: float):
69 def __init__(self, *,
+70 d_model: int,
+71 self_attn: MultiHeadAttention,
+72 src_attn: MultiHeadAttention = None,
+73 feed_forward: FeedForward,
+74 dropout_prob: float):
91 super().__init__()
-92 self.size = d_model
-93 self.self_attn = self_attn
-94 self.src_attn = src_attn
-95 self.feed_forward = feed_forward
-96 self.dropout = nn.Dropout(dropout_prob)
-97 self.norm_self_attn = nn.LayerNorm([d_model])
-98 if self.src_attn is not None:
-99 self.norm_src_attn = nn.LayerNorm([d_model])
-100 self.norm_ff = nn.LayerNorm([d_model])
82 super().__init__()
+83 self.size = d_model
+84 self.self_attn = self_attn
+85 self.src_attn = src_attn
+86 self.feed_forward = feed_forward
+87 self.dropout = nn.Dropout(dropout_prob)
+88 self.norm_self_attn = nn.LayerNorm([d_model])
+89 if self.src_attn is not None:
+90 self.norm_src_attn = nn.LayerNorm([d_model])
+91 self.norm_ff = nn.LayerNorm([d_model])
102 self.is_save_ff_input = False
93 self.is_save_ff_input = False
104 def forward(self, *,
-105 x: torch.Tensor,
-106 mask: torch.Tensor,
-107 src: torch.Tensor = None,
-108 src_mask: torch.Tensor = None):
95 def forward(self, *,
+96 x: torch.Tensor,
+97 mask: torch.Tensor,
+98 src: torch.Tensor = None,
+99 src_mask: torch.Tensor = None):
110 z = self.norm_self_attn(x)
101 z = self.norm_self_attn(x)
112 self_attn = self.self_attn(query=z, key=z, value=z, mask=mask)
103 self_attn = self.self_attn(query=z, key=z, value=z, mask=mask)
114 x = x + self.dropout(self_attn)
105 x = x + self.dropout(self_attn)
119 if src is not None:
110 if src is not None:
121 z = self.norm_src_attn(x)
112 z = self.norm_src_attn(x)
123 attn_src = self.src_attn(query=z, key=src, value=src, mask=src_mask)
114 attn_src = self.src_attn(query=z, key=src, value=src, mask=src_mask)
125 x = x + self.dropout(attn_src)
116 x = x + self.dropout(attn_src)
128 z = self.norm_ff(x)
119 z = self.norm_ff(x)
130 if self.is_save_ff_input:
-131 self.ff_input = z.clone()
121 if self.is_save_ff_input:
+122 self.ff_input = z.clone()
133 ff = self.feed_forward(z)
124 ff = self.feed_forward(z)
135 x = x + self.dropout(ff)
-136
-137 return x
126 x = x + self.dropout(ff)
+127
+128 return x
140class Encoder(nn.Module):
131class Encoder(nn.Module):
147 def __init__(self, layer: TransformerLayer, n_layers: int):
-148 super().__init__()
138 def __init__(self, layer: TransformerLayer, n_layers: int):
+139 super().__init__()
150 self.layers = clone_module_list(layer, n_layers)
141 self.layers = clone_module_list(layer, n_layers)
152 self.norm = nn.LayerNorm([layer.size])
143 self.norm = nn.LayerNorm([layer.size])
154 def forward(self, x: torch.Tensor, mask: torch.Tensor):
145 def forward(self, x: torch.Tensor, mask: torch.Tensor):
156 for layer in self.layers:
-157 x = layer(x=x, mask=mask)
147 for layer in self.layers:
+148 x = layer(x=x, mask=mask)
159 return self.norm(x)
150 return self.norm(x)
162class Decoder(nn.Module):
153class Decoder(nn.Module):
169 def __init__(self, layer: TransformerLayer, n_layers: int):
-170 super().__init__()
160 def __init__(self, layer: TransformerLayer, n_layers: int):
+161 super().__init__()
172 self.layers = clone_module_list(layer, n_layers)
163 self.layers = clone_module_list(layer, n_layers)
174 self.norm = nn.LayerNorm([layer.size])
165 self.norm = nn.LayerNorm([layer.size])
176 def forward(self, x: torch.Tensor, memory: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):
167 def forward(self, x: torch.Tensor, memory: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):
178 for layer in self.layers:
-179 x = layer(x=x, mask=tgt_mask, src=memory, src_mask=src_mask)
169 for layer in self.layers:
+170 x = layer(x=x, mask=tgt_mask, src=memory, src_mask=src_mask)
181 return self.norm(x)
172 return self.norm(x)
184class Generator(nn.Module):
175class Generator(nn.Module):
194 def __init__(self, n_vocab: int, d_model: int):
-195 super().__init__()
-196 self.projection = nn.Linear(d_model, n_vocab)
185 def __init__(self, n_vocab: int, d_model: int):
+186 super().__init__()
+187 self.projection = nn.Linear(d_model, n_vocab)
198 def forward(self, x):
-199 return self.projection(x)
189 def forward(self, x):
+190 return self.projection(x)
202class EncoderDecoder(nn.Module):
193class EncoderDecoder(nn.Module):
209 def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: nn.Module, tgt_embed: nn.Module, generator: nn.Module):
-210 super().__init__()
-211 self.encoder = encoder
-212 self.decoder = decoder
-213 self.src_embed = src_embed
-214 self.tgt_embed = tgt_embed
-215 self.generator = generator
200 def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: nn.Module, tgt_embed: nn.Module, generator: nn.Module):
+201 super().__init__()
+202 self.encoder = encoder
+203 self.decoder = decoder
+204 self.src_embed = src_embed
+205 self.tgt_embed = tgt_embed
+206 self.generator = generator
219 for p in self.parameters():
-220 if p.dim() > 1:
-221 nn.init.xavier_uniform_(p)
210 for p in self.parameters():
+211 if p.dim() > 1:
+212 nn.init.xavier_uniform_(p)
223 def forward(self, src: torch.Tensor, tgt: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):
214 def forward(self, src: torch.Tensor, tgt: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):
225 enc = self.encode(src, src_mask)
216 enc = self.encode(src, src_mask)
227 return self.decode(enc, src_mask, tgt, tgt_mask)
218 return self.decode(enc, src_mask, tgt, tgt_mask)
229 def encode(self, src: torch.Tensor, src_mask: torch.Tensor):
-230 return self.encoder(self.src_embed(src), src_mask)
220 def encode(self, src: torch.Tensor, src_mask: torch.Tensor):
+221 return self.encoder(self.src_embed(src), src_mask)
232 def decode(self, memory: torch.Tensor, src_mask: torch.Tensor, tgt: torch.Tensor, tgt_mask: torch.Tensor):
-233 return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)
223 def decode(self, memory: torch.Tensor, src_mask: torch.Tensor, tgt: torch.Tensor, tgt_mask: torch.Tensor):
+224 return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)