mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-26 08:41:23 +08:00
📚 compressive transformer experiment
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@ -47,13 +47,14 @@ class AutoregressiveModel(Module):
|
|||||||
self.mask_mem = None
|
self.mask_mem = None
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, mem: CompressedMemory):
|
def forward(self, x: torch.Tensor, mem: CompressedMemory):
|
||||||
# Length of the memory
|
# Get memory and compressed memory
|
||||||
if mem is not None:
|
if mem is not None:
|
||||||
mem, c_mem = mem.mem, mem.c_mem
|
mem, c_mem = mem.mem, mem.c_mem
|
||||||
else:
|
else:
|
||||||
mem = []
|
mem = []
|
||||||
c_mem = []
|
c_mem = []
|
||||||
|
|
||||||
|
# Length of the memory (for masks)
|
||||||
m_len = len(mem[0]) if mem else 0
|
m_len = len(mem[0]) if mem else 0
|
||||||
if c_mem:
|
if c_mem:
|
||||||
m_len += len(c_mem[0])
|
m_len += len(c_mem[0])
|
||||||
@ -69,7 +70,7 @@ class AutoregressiveModel(Module):
|
|||||||
# Concatenate the masks if there is memory
|
# Concatenate the masks if there is memory
|
||||||
if m_len:
|
if m_len:
|
||||||
mask = torch.cat((self.mask_mem[:len(x), :m_len], self.mask_x[:len(x), :len(x)]), dim=1)
|
mask = torch.cat((self.mask_mem[:len(x), :m_len], self.mask_x[:len(x), :len(x)]), dim=1)
|
||||||
# Use the subsequent mask otherwise
|
# Use only the subsequent mask otherwise
|
||||||
else:
|
else:
|
||||||
mask = self.mask_x[:len(x), :len(x)]
|
mask = self.mask_x[:len(x), :len(x)]
|
||||||
|
|
||||||
@ -87,7 +88,7 @@ class Configs(NLPAutoRegressionConfigs):
|
|||||||
"""
|
"""
|
||||||
## Configurations
|
## Configurations
|
||||||
|
|
||||||
The default configs can and will be over-ridden when we start the experiment
|
The default configs can and will be over-ridden when we start the experiment.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model: AutoregressiveModel
|
model: AutoregressiveModel
|
||||||
@ -108,8 +109,8 @@ class Configs(NLPAutoRegressionConfigs):
|
|||||||
memory = SimpleStateModule()
|
memory = SimpleStateModule()
|
||||||
# Attention Reconstruction Loss
|
# Attention Reconstruction Loss
|
||||||
attention_reconstruction_loss: AttentionReconstructionLoss
|
attention_reconstruction_loss: AttentionReconstructionLoss
|
||||||
# Compression ratio
|
# Compression rate
|
||||||
compression_ratio: int = 4
|
compression_rate: int = 4
|
||||||
# Compressed memory length
|
# Compressed memory length
|
||||||
c_mem_len: int = 128
|
c_mem_len: int = 128
|
||||||
|
|
||||||
@ -117,6 +118,7 @@ class Configs(NLPAutoRegressionConfigs):
|
|||||||
# Set tracker configurations
|
# Set tracker configurations
|
||||||
tracker.set_scalar("accuracy.*", True)
|
tracker.set_scalar("accuracy.*", True)
|
||||||
tracker.set_scalar("loss.*", True)
|
tracker.set_scalar("loss.*", True)
|
||||||
|
# Do not print the attention reconstruction loss in the terminal
|
||||||
tracker.set_scalar("ar_loss.*", False)
|
tracker.set_scalar("ar_loss.*", False)
|
||||||
# Add a hook to log module outputs
|
# Add a hook to log module outputs
|
||||||
hook_model_outputs(self.mode, self.model, 'model')
|
hook_model_outputs(self.mode, self.model, 'model')
|
||||||
@ -124,55 +126,73 @@ class Configs(NLPAutoRegressionConfigs):
|
|||||||
self.state_modules = [self.accuracy, self.memory]
|
self.state_modules = [self.accuracy, self.memory]
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def merge_memory(self, mem: CompressedMemory, new_mem: List[torch.Tensor]) \
|
def merge_compress_memory(self, mem: CompressedMemory, new_mem: List[torch.Tensor]) \
|
||||||
-> Tuple[CompressedMemory, List[torch.Tensor]]:
|
-> Tuple[CompressedMemory, List[torch.Tensor]]:
|
||||||
"""
|
"""
|
||||||
Concatenate memories and remove old memories to keep a maximum of
|
Concatenate new memories and compress the oldest memories.
|
||||||
`mem_len` memories.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# If it's configured not to use memory
|
# If it's configured not to use memory
|
||||||
if self.mem_len == 0:
|
if self.mem_len == 0 and self.c_mem_len == 0:
|
||||||
return CompressedMemory([], []), []
|
return CompressedMemory([], []), []
|
||||||
|
|
||||||
|
# Get memory and compressed memory
|
||||||
if mem is not None:
|
if mem is not None:
|
||||||
mem, c_mem = mem.mem, mem.c_mem
|
mem, c_mem = mem.mem, mem.c_mem
|
||||||
else:
|
else:
|
||||||
mem, c_mem = [], []
|
mem, c_mem = [], []
|
||||||
# Concatenate with old memory
|
|
||||||
|
# Concatenate new memories with old memory
|
||||||
if mem:
|
if mem:
|
||||||
mem = [torch.cat((m, x), dim=0) for m, x in zip(mem, new_mem)]
|
mem = [torch.cat((m, x), dim=0) for m, x in zip(mem, new_mem)]
|
||||||
else:
|
else:
|
||||||
mem = new_mem
|
mem = new_mem
|
||||||
|
|
||||||
|
# Compress the oldest memories if there are more memories than `mem_len`
|
||||||
if len(mem[0]) > self.mem_len:
|
if len(mem[0]) > self.mem_len:
|
||||||
n_c_mem = (len(mem[0]) - self.mem_len + self.compression_ratio - 1) // self.compression_ratio
|
# Calculate the number of compressed memories to make $n_{cm} = \bigg\lceil\frac{n'_m - N_m}{c}\bigg\rceil$,
|
||||||
old_mem = []
|
# where $n'_m$ is the number of memories we have
|
||||||
trunc_mem = []
|
# and $N_m$ is the maximum number of memories we maintain (`mem_len`).
|
||||||
|
n_c_mem = (len(mem[0]) - self.mem_len + self.compression_rate - 1) // self.compression_rate
|
||||||
|
# Number of memories to compress $c n_{cm}$
|
||||||
|
n_old = n_c_mem * self.compression_rate
|
||||||
|
# A list to keep memories that need to be compressed for each layer.
|
||||||
|
mem_to_compress = []
|
||||||
|
# A list to keep the memories that do not get compressed for each layer.
|
||||||
|
uncompressed_mem = []
|
||||||
|
# Iterate through memories of each layer.
|
||||||
for m in mem:
|
for m in mem:
|
||||||
n_old = n_c_mem * self.compression_ratio
|
# Split the memories at $c n_{cm}$
|
||||||
cm, m = torch.split(m, [n_old, len(m) - n_old])
|
cm, m = torch.split(m, [n_old, len(m) - n_old])
|
||||||
old_mem.append(cm)
|
# Collect memories to compress
|
||||||
trunc_mem.append(m)
|
mem_to_compress.append(cm)
|
||||||
mem = trunc_mem
|
# Collect remaining memories
|
||||||
|
uncompressed_mem.append(m)
|
||||||
|
# Update the memories
|
||||||
|
mem = uncompressed_mem
|
||||||
|
|
||||||
|
# Compress the memories
|
||||||
new_c_mem = []
|
new_c_mem = []
|
||||||
for i, layer in enumerate(self.model.transformer.layers):
|
for i, layer in enumerate(self.model.transformer.layers):
|
||||||
new_c_mem.append(layer.compress(old_mem[i]))
|
new_c_mem.append(layer.compress(mem_to_compress[i]))
|
||||||
|
|
||||||
|
# Concatenate newly compressed memories with old compressed memories
|
||||||
if c_mem:
|
if c_mem:
|
||||||
c_mem = [torch.cat((m, nm), dim=0) for m, nm in zip(c_mem, new_c_mem)]
|
c_mem = [torch.cat((m, nm), dim=0) for m, nm in zip(c_mem, new_c_mem)]
|
||||||
|
# If there are no old compressed memories
|
||||||
else:
|
else:
|
||||||
c_mem = new_c_mem
|
c_mem = new_c_mem
|
||||||
|
|
||||||
# Truncate old memories
|
# Truncate old memories
|
||||||
if len(c_mem[0]) > self.c_mem_len:
|
if len(c_mem[0]) > self.c_mem_len:
|
||||||
c_mem = [m[-self.c_mem_len:] for m in c_mem]
|
c_mem = [m[-self.c_mem_len:] for m in c_mem]
|
||||||
|
# No memories are compressed if the number of memories is less than `mem_len`
|
||||||
else:
|
else:
|
||||||
old_mem = []
|
mem_to_compress = []
|
||||||
|
|
||||||
#
|
# Return memories and the memories that were compressed.
|
||||||
return CompressedMemory(mem, c_mem), old_mem
|
# Memories that were compressed is needed for the reconstruction loss computation.
|
||||||
|
return CompressedMemory(mem, c_mem), mem_to_compress
|
||||||
|
|
||||||
def step(self, batch: any, batch_idx: BatchIndex):
|
def step(self, batch: any, batch_idx: BatchIndex):
|
||||||
"""
|
"""
|
||||||
@ -192,8 +212,8 @@ class Configs(NLPAutoRegressionConfigs):
|
|||||||
mem = self.memory.get()
|
mem = self.memory.get()
|
||||||
# Run the model
|
# Run the model
|
||||||
output, new_mem = self.model(data, mem)
|
output, new_mem = self.model(data, mem)
|
||||||
# Merge memory
|
# Merge and compress memory
|
||||||
mem, old_mem = self.merge_memory(mem, new_mem)
|
mem, mem_to_compress = self.merge_compress_memory(mem, new_mem)
|
||||||
# Update memories
|
# Update memories
|
||||||
self.memory.set(mem)
|
self.memory.set(mem)
|
||||||
|
|
||||||
@ -201,9 +221,13 @@ class Configs(NLPAutoRegressionConfigs):
|
|||||||
loss = self.loss_func(output, target)
|
loss = self.loss_func(output, target)
|
||||||
tracker.add("loss.", loss)
|
tracker.add("loss.", loss)
|
||||||
|
|
||||||
if old_mem:
|
# Calculate attention reconstruction loss if memories were compressed in this step
|
||||||
ar_loss = self.attention_reconstruction_loss(new_mem, old_mem)
|
if mem_to_compress:
|
||||||
|
# Get attention reconstruction loss
|
||||||
|
ar_loss = self.attention_reconstruction_loss(new_mem, mem_to_compress)
|
||||||
|
# Track attention reconstruction loss
|
||||||
tracker.add("ar_loss.", ar_loss)
|
tracker.add("ar_loss.", ar_loss)
|
||||||
|
# Add attention reconstruction loss to loss
|
||||||
loss = loss + ar_loss
|
loss = loss + ar_loss
|
||||||
|
|
||||||
# Calculate and log accuracy
|
# Calculate and log accuracy
|
||||||
@ -254,8 +278,8 @@ class Configs(NLPAutoRegressionConfigs):
|
|||||||
prompt = prompt[-1:]
|
prompt = prompt[-1:]
|
||||||
# Add the prediction for logging
|
# Add the prediction for logging
|
||||||
log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)]
|
log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)]
|
||||||
# Update memory
|
# Update and compress memory
|
||||||
mem, _ = self.merge_memory(mem, new_mem)
|
mem, _ = self.merge_compress_memory(mem, new_mem)
|
||||||
|
|
||||||
# Print the sampled output
|
# Print the sampled output
|
||||||
logger.log(log)
|
logger.log(log)
|
||||||
@ -273,14 +297,14 @@ def autoregressive_model(c: Configs):
|
|||||||
self_attn=RelativeMultiHeadAttention(c.heads, c.d_model, c.dropout),
|
self_attn=RelativeMultiHeadAttention(c.heads, c.d_model, c.dropout),
|
||||||
feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout),
|
feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout),
|
||||||
dropout_prob=c.dropout,
|
dropout_prob=c.dropout,
|
||||||
compress=Conv1dCompression(c.compression_ratio, c.d_model)), c.n_layers))
|
compress=Conv1dCompression(c.compression_rate, c.d_model)), c.n_layers))
|
||||||
return m.to(c.device)
|
return m.to(c.device)
|
||||||
|
|
||||||
|
|
||||||
@option(Configs.attention_reconstruction_loss)
|
@option(Configs.attention_reconstruction_loss)
|
||||||
def attention_reconstruction_loss(c: Configs):
|
def attention_reconstruction_loss(c: Configs):
|
||||||
"""
|
"""
|
||||||
### Initialize the auto-regressive model
|
### Initialize the attention reconstruction loss
|
||||||
"""
|
"""
|
||||||
return AttentionReconstructionLoss(c.model.transformer.layers)
|
return AttentionReconstructionLoss(c.model.transformer.layers)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user