mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-26 08:41:23 +08:00
📚 compressive transformer experiment
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@ -47,13 +47,14 @@ class AutoregressiveModel(Module):
|
||||
self.mask_mem = None
|
||||
|
||||
def forward(self, x: torch.Tensor, mem: CompressedMemory):
|
||||
# Length of the memory
|
||||
# Get memory and compressed memory
|
||||
if mem is not None:
|
||||
mem, c_mem = mem.mem, mem.c_mem
|
||||
else:
|
||||
mem = []
|
||||
c_mem = []
|
||||
|
||||
# Length of the memory (for masks)
|
||||
m_len = len(mem[0]) if mem else 0
|
||||
if c_mem:
|
||||
m_len += len(c_mem[0])
|
||||
@ -69,7 +70,7 @@ class AutoregressiveModel(Module):
|
||||
# Concatenate the masks if there is memory
|
||||
if m_len:
|
||||
mask = torch.cat((self.mask_mem[:len(x), :m_len], self.mask_x[:len(x), :len(x)]), dim=1)
|
||||
# Use the subsequent mask otherwise
|
||||
# Use only the subsequent mask otherwise
|
||||
else:
|
||||
mask = self.mask_x[:len(x), :len(x)]
|
||||
|
||||
@ -87,7 +88,7 @@ class Configs(NLPAutoRegressionConfigs):
|
||||
"""
|
||||
## Configurations
|
||||
|
||||
The default configs can and will be over-ridden when we start the experiment
|
||||
The default configs can and will be over-ridden when we start the experiment.
|
||||
"""
|
||||
|
||||
model: AutoregressiveModel
|
||||
@ -108,8 +109,8 @@ class Configs(NLPAutoRegressionConfigs):
|
||||
memory = SimpleStateModule()
|
||||
# Attention Reconstruction Loss
|
||||
attention_reconstruction_loss: AttentionReconstructionLoss
|
||||
# Compression ratio
|
||||
compression_ratio: int = 4
|
||||
# Compression rate
|
||||
compression_rate: int = 4
|
||||
# Compressed memory length
|
||||
c_mem_len: int = 128
|
||||
|
||||
@ -117,6 +118,7 @@ class Configs(NLPAutoRegressionConfigs):
|
||||
# Set tracker configurations
|
||||
tracker.set_scalar("accuracy.*", True)
|
||||
tracker.set_scalar("loss.*", True)
|
||||
# Do not print the attention reconstruction loss in the terminal
|
||||
tracker.set_scalar("ar_loss.*", False)
|
||||
# Add a hook to log module outputs
|
||||
hook_model_outputs(self.mode, self.model, 'model')
|
||||
@ -124,55 +126,73 @@ class Configs(NLPAutoRegressionConfigs):
|
||||
self.state_modules = [self.accuracy, self.memory]
|
||||
|
||||
@torch.no_grad()
|
||||
def merge_memory(self, mem: CompressedMemory, new_mem: List[torch.Tensor]) \
|
||||
def merge_compress_memory(self, mem: CompressedMemory, new_mem: List[torch.Tensor]) \
|
||||
-> Tuple[CompressedMemory, List[torch.Tensor]]:
|
||||
"""
|
||||
Concatenate memories and remove old memories to keep a maximum of
|
||||
`mem_len` memories.
|
||||
Concatenate new memories and compress the oldest memories.
|
||||
"""
|
||||
|
||||
# If it's configured not to use memory
|
||||
if self.mem_len == 0:
|
||||
if self.mem_len == 0 and self.c_mem_len == 0:
|
||||
return CompressedMemory([], []), []
|
||||
|
||||
# Get memory and compressed memory
|
||||
if mem is not None:
|
||||
mem, c_mem = mem.mem, mem.c_mem
|
||||
else:
|
||||
mem, c_mem = [], []
|
||||
# Concatenate with old memory
|
||||
|
||||
# Concatenate new memories with old memory
|
||||
if mem:
|
||||
mem = [torch.cat((m, x), dim=0) for m, x in zip(mem, new_mem)]
|
||||
else:
|
||||
mem = new_mem
|
||||
|
||||
# Compress the oldest memories if there are more memories than `mem_len`
|
||||
if len(mem[0]) > self.mem_len:
|
||||
n_c_mem = (len(mem[0]) - self.mem_len + self.compression_ratio - 1) // self.compression_ratio
|
||||
old_mem = []
|
||||
trunc_mem = []
|
||||
# Calculate the number of compressed memories to make $n_{cm} = \bigg\lceil\frac{n'_m - N_m}{c}\bigg\rceil$,
|
||||
# where $n'_m$ is the number of memories we have
|
||||
# and $N_m$ is the maximum number of memories we maintain (`mem_len`).
|
||||
n_c_mem = (len(mem[0]) - self.mem_len + self.compression_rate - 1) // self.compression_rate
|
||||
# Number of memories to compress $c n_{cm}$
|
||||
n_old = n_c_mem * self.compression_rate
|
||||
# A list to keep memories that need to be compressed for each layer.
|
||||
mem_to_compress = []
|
||||
# A list to keep the memories that do not get compressed for each layer.
|
||||
uncompressed_mem = []
|
||||
# Iterate through memories of each layer.
|
||||
for m in mem:
|
||||
n_old = n_c_mem * self.compression_ratio
|
||||
# Split the memories at $c n_{cm}$
|
||||
cm, m = torch.split(m, [n_old, len(m) - n_old])
|
||||
old_mem.append(cm)
|
||||
trunc_mem.append(m)
|
||||
mem = trunc_mem
|
||||
# Collect memories to compress
|
||||
mem_to_compress.append(cm)
|
||||
# Collect remaining memories
|
||||
uncompressed_mem.append(m)
|
||||
# Update the memories
|
||||
mem = uncompressed_mem
|
||||
|
||||
# Compress the memories
|
||||
new_c_mem = []
|
||||
for i, layer in enumerate(self.model.transformer.layers):
|
||||
new_c_mem.append(layer.compress(old_mem[i]))
|
||||
new_c_mem.append(layer.compress(mem_to_compress[i]))
|
||||
|
||||
# Concatenate newly compressed memories with old compressed memories
|
||||
if c_mem:
|
||||
c_mem = [torch.cat((m, nm), dim=0) for m, nm in zip(c_mem, new_c_mem)]
|
||||
# If there are no old compressed memories
|
||||
else:
|
||||
c_mem = new_c_mem
|
||||
|
||||
# Truncate old memories
|
||||
if len(c_mem[0]) > self.c_mem_len:
|
||||
c_mem = [m[-self.c_mem_len:] for m in c_mem]
|
||||
# No memories are compressed if the number of memories is less than `mem_len`
|
||||
else:
|
||||
old_mem = []
|
||||
mem_to_compress = []
|
||||
|
||||
#
|
||||
return CompressedMemory(mem, c_mem), old_mem
|
||||
# Return memories and the memories that were compressed.
|
||||
# Memories that were compressed is needed for the reconstruction loss computation.
|
||||
return CompressedMemory(mem, c_mem), mem_to_compress
|
||||
|
||||
def step(self, batch: any, batch_idx: BatchIndex):
|
||||
"""
|
||||
@ -192,8 +212,8 @@ class Configs(NLPAutoRegressionConfigs):
|
||||
mem = self.memory.get()
|
||||
# Run the model
|
||||
output, new_mem = self.model(data, mem)
|
||||
# Merge memory
|
||||
mem, old_mem = self.merge_memory(mem, new_mem)
|
||||
# Merge and compress memory
|
||||
mem, mem_to_compress = self.merge_compress_memory(mem, new_mem)
|
||||
# Update memories
|
||||
self.memory.set(mem)
|
||||
|
||||
@ -201,9 +221,13 @@ class Configs(NLPAutoRegressionConfigs):
|
||||
loss = self.loss_func(output, target)
|
||||
tracker.add("loss.", loss)
|
||||
|
||||
if old_mem:
|
||||
ar_loss = self.attention_reconstruction_loss(new_mem, old_mem)
|
||||
# Calculate attention reconstruction loss if memories were compressed in this step
|
||||
if mem_to_compress:
|
||||
# Get attention reconstruction loss
|
||||
ar_loss = self.attention_reconstruction_loss(new_mem, mem_to_compress)
|
||||
# Track attention reconstruction loss
|
||||
tracker.add("ar_loss.", ar_loss)
|
||||
# Add attention reconstruction loss to loss
|
||||
loss = loss + ar_loss
|
||||
|
||||
# Calculate and log accuracy
|
||||
@ -254,8 +278,8 @@ class Configs(NLPAutoRegressionConfigs):
|
||||
prompt = prompt[-1:]
|
||||
# Add the prediction for logging
|
||||
log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)]
|
||||
# Update memory
|
||||
mem, _ = self.merge_memory(mem, new_mem)
|
||||
# Update and compress memory
|
||||
mem, _ = self.merge_compress_memory(mem, new_mem)
|
||||
|
||||
# Print the sampled output
|
||||
logger.log(log)
|
||||
@ -273,14 +297,14 @@ def autoregressive_model(c: Configs):
|
||||
self_attn=RelativeMultiHeadAttention(c.heads, c.d_model, c.dropout),
|
||||
feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout),
|
||||
dropout_prob=c.dropout,
|
||||
compress=Conv1dCompression(c.compression_ratio, c.d_model)), c.n_layers))
|
||||
compress=Conv1dCompression(c.compression_rate, c.d_model)), c.n_layers))
|
||||
return m.to(c.device)
|
||||
|
||||
|
||||
@option(Configs.attention_reconstruction_loss)
|
||||
def attention_reconstruction_loss(c: Configs):
|
||||
"""
|
||||
### Initialize the auto-regressive model
|
||||
### Initialize the attention reconstruction loss
|
||||
"""
|
||||
return AttentionReconstructionLoss(c.model.transformer.layers)
|
||||
|
||||
|
Reference in New Issue
Block a user