📚 compressive transformer experiment

This commit is contained in:
Varuna Jayasiri
2021-02-19 08:53:26 +05:30
parent a1b1550245
commit 661009953c
2 changed files with 505 additions and 275 deletions

File diff suppressed because it is too large Load Diff

View File

@ -47,13 +47,14 @@ class AutoregressiveModel(Module):
self.mask_mem = None self.mask_mem = None
def forward(self, x: torch.Tensor, mem: CompressedMemory): def forward(self, x: torch.Tensor, mem: CompressedMemory):
# Length of the memory # Get memory and compressed memory
if mem is not None: if mem is not None:
mem, c_mem = mem.mem, mem.c_mem mem, c_mem = mem.mem, mem.c_mem
else: else:
mem = [] mem = []
c_mem = [] c_mem = []
# Length of the memory (for masks)
m_len = len(mem[0]) if mem else 0 m_len = len(mem[0]) if mem else 0
if c_mem: if c_mem:
m_len += len(c_mem[0]) m_len += len(c_mem[0])
@ -69,7 +70,7 @@ class AutoregressiveModel(Module):
# Concatenate the masks if there is memory # Concatenate the masks if there is memory
if m_len: if m_len:
mask = torch.cat((self.mask_mem[:len(x), :m_len], self.mask_x[:len(x), :len(x)]), dim=1) mask = torch.cat((self.mask_mem[:len(x), :m_len], self.mask_x[:len(x), :len(x)]), dim=1)
# Use the subsequent mask otherwise # Use only the subsequent mask otherwise
else: else:
mask = self.mask_x[:len(x), :len(x)] mask = self.mask_x[:len(x), :len(x)]
@ -87,7 +88,7 @@ class Configs(NLPAutoRegressionConfigs):
""" """
## Configurations ## Configurations
The default configs can and will be over-ridden when we start the experiment The default configs can and will be over-ridden when we start the experiment.
""" """
model: AutoregressiveModel model: AutoregressiveModel
@ -108,8 +109,8 @@ class Configs(NLPAutoRegressionConfigs):
memory = SimpleStateModule() memory = SimpleStateModule()
# Attention Reconstruction Loss # Attention Reconstruction Loss
attention_reconstruction_loss: AttentionReconstructionLoss attention_reconstruction_loss: AttentionReconstructionLoss
# Compression ratio # Compression rate
compression_ratio: int = 4 compression_rate: int = 4
# Compressed memory length # Compressed memory length
c_mem_len: int = 128 c_mem_len: int = 128
@ -117,6 +118,7 @@ class Configs(NLPAutoRegressionConfigs):
# Set tracker configurations # Set tracker configurations
tracker.set_scalar("accuracy.*", True) tracker.set_scalar("accuracy.*", True)
tracker.set_scalar("loss.*", True) tracker.set_scalar("loss.*", True)
# Do not print the attention reconstruction loss in the terminal
tracker.set_scalar("ar_loss.*", False) tracker.set_scalar("ar_loss.*", False)
# Add a hook to log module outputs # Add a hook to log module outputs
hook_model_outputs(self.mode, self.model, 'model') hook_model_outputs(self.mode, self.model, 'model')
@ -124,55 +126,73 @@ class Configs(NLPAutoRegressionConfigs):
self.state_modules = [self.accuracy, self.memory] self.state_modules = [self.accuracy, self.memory]
@torch.no_grad() @torch.no_grad()
def merge_memory(self, mem: CompressedMemory, new_mem: List[torch.Tensor]) \ def merge_compress_memory(self, mem: CompressedMemory, new_mem: List[torch.Tensor]) \
-> Tuple[CompressedMemory, List[torch.Tensor]]: -> Tuple[CompressedMemory, List[torch.Tensor]]:
""" """
Concatenate memories and remove old memories to keep a maximum of Concatenate new memories and compress the oldest memories.
`mem_len` memories.
""" """
# If it's configured not to use memory # If it's configured not to use memory
if self.mem_len == 0: if self.mem_len == 0 and self.c_mem_len == 0:
return CompressedMemory([], []), [] return CompressedMemory([], []), []
# Get memory and compressed memory
if mem is not None: if mem is not None:
mem, c_mem = mem.mem, mem.c_mem mem, c_mem = mem.mem, mem.c_mem
else: else:
mem, c_mem = [], [] mem, c_mem = [], []
# Concatenate with old memory
# Concatenate new memories with old memory
if mem: if mem:
mem = [torch.cat((m, x), dim=0) for m, x in zip(mem, new_mem)] mem = [torch.cat((m, x), dim=0) for m, x in zip(mem, new_mem)]
else: else:
mem = new_mem mem = new_mem
# Compress the oldest memories if there are more memories than `mem_len`
if len(mem[0]) > self.mem_len: if len(mem[0]) > self.mem_len:
n_c_mem = (len(mem[0]) - self.mem_len + self.compression_ratio - 1) // self.compression_ratio # Calculate the number of compressed memories to make $n_{cm} = \bigg\lceil\frac{n'_m - N_m}{c}\bigg\rceil$,
old_mem = [] # where $n'_m$ is the number of memories we have
trunc_mem = [] # and $N_m$ is the maximum number of memories we maintain (`mem_len`).
n_c_mem = (len(mem[0]) - self.mem_len + self.compression_rate - 1) // self.compression_rate
# Number of memories to compress $c n_{cm}$
n_old = n_c_mem * self.compression_rate
# A list to keep memories that need to be compressed for each layer.
mem_to_compress = []
# A list to keep the memories that do not get compressed for each layer.
uncompressed_mem = []
# Iterate through memories of each layer.
for m in mem: for m in mem:
n_old = n_c_mem * self.compression_ratio # Split the memories at $c n_{cm}$
cm, m = torch.split(m, [n_old, len(m) - n_old]) cm, m = torch.split(m, [n_old, len(m) - n_old])
old_mem.append(cm) # Collect memories to compress
trunc_mem.append(m) mem_to_compress.append(cm)
mem = trunc_mem # Collect remaining memories
uncompressed_mem.append(m)
# Update the memories
mem = uncompressed_mem
# Compress the memories
new_c_mem = [] new_c_mem = []
for i, layer in enumerate(self.model.transformer.layers): for i, layer in enumerate(self.model.transformer.layers):
new_c_mem.append(layer.compress(old_mem[i])) new_c_mem.append(layer.compress(mem_to_compress[i]))
# Concatenate newly compressed memories with old compressed memories
if c_mem: if c_mem:
c_mem = [torch.cat((m, nm), dim=0) for m, nm in zip(c_mem, new_c_mem)] c_mem = [torch.cat((m, nm), dim=0) for m, nm in zip(c_mem, new_c_mem)]
# If there are no old compressed memories
else: else:
c_mem = new_c_mem c_mem = new_c_mem
# Truncate old memories # Truncate old memories
if len(c_mem[0]) > self.c_mem_len: if len(c_mem[0]) > self.c_mem_len:
c_mem = [m[-self.c_mem_len:] for m in c_mem] c_mem = [m[-self.c_mem_len:] for m in c_mem]
# No memories are compressed if the number of memories is less than `mem_len`
else: else:
old_mem = [] mem_to_compress = []
# # Return memories and the memories that were compressed.
return CompressedMemory(mem, c_mem), old_mem # Memories that were compressed is needed for the reconstruction loss computation.
return CompressedMemory(mem, c_mem), mem_to_compress
def step(self, batch: any, batch_idx: BatchIndex): def step(self, batch: any, batch_idx: BatchIndex):
""" """
@ -192,8 +212,8 @@ class Configs(NLPAutoRegressionConfigs):
mem = self.memory.get() mem = self.memory.get()
# Run the model # Run the model
output, new_mem = self.model(data, mem) output, new_mem = self.model(data, mem)
# Merge memory # Merge and compress memory
mem, old_mem = self.merge_memory(mem, new_mem) mem, mem_to_compress = self.merge_compress_memory(mem, new_mem)
# Update memories # Update memories
self.memory.set(mem) self.memory.set(mem)
@ -201,9 +221,13 @@ class Configs(NLPAutoRegressionConfigs):
loss = self.loss_func(output, target) loss = self.loss_func(output, target)
tracker.add("loss.", loss) tracker.add("loss.", loss)
if old_mem: # Calculate attention reconstruction loss if memories were compressed in this step
ar_loss = self.attention_reconstruction_loss(new_mem, old_mem) if mem_to_compress:
# Get attention reconstruction loss
ar_loss = self.attention_reconstruction_loss(new_mem, mem_to_compress)
# Track attention reconstruction loss
tracker.add("ar_loss.", ar_loss) tracker.add("ar_loss.", ar_loss)
# Add attention reconstruction loss to loss
loss = loss + ar_loss loss = loss + ar_loss
# Calculate and log accuracy # Calculate and log accuracy
@ -254,8 +278,8 @@ class Configs(NLPAutoRegressionConfigs):
prompt = prompt[-1:] prompt = prompt[-1:]
# Add the prediction for logging # Add the prediction for logging
log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)] log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)]
# Update memory # Update and compress memory
mem, _ = self.merge_memory(mem, new_mem) mem, _ = self.merge_compress_memory(mem, new_mem)
# Print the sampled output # Print the sampled output
logger.log(log) logger.log(log)
@ -273,14 +297,14 @@ def autoregressive_model(c: Configs):
self_attn=RelativeMultiHeadAttention(c.heads, c.d_model, c.dropout), self_attn=RelativeMultiHeadAttention(c.heads, c.d_model, c.dropout),
feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout), feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout),
dropout_prob=c.dropout, dropout_prob=c.dropout,
compress=Conv1dCompression(c.compression_ratio, c.d_model)), c.n_layers)) compress=Conv1dCompression(c.compression_rate, c.d_model)), c.n_layers))
return m.to(c.device) return m.to(c.device)
@option(Configs.attention_reconstruction_loss) @option(Configs.attention_reconstruction_loss)
def attention_reconstruction_loss(c: Configs): def attention_reconstruction_loss(c: Configs):
""" """
### Initialize the auto-regressive model ### Initialize the attention reconstruction loss
""" """
return AttentionReconstructionLoss(c.model.transformer.layers) return AttentionReconstructionLoss(c.model.transformer.layers)