mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 18:58:43 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			261 lines
		
	
	
		
			7.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			261 lines
		
	
	
		
			7.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """
 | |
| ---
 | |
| title: Arithmetic Dataset
 | |
| summary: >
 | |
|   This creates arithmetic problems.
 | |
| ---
 | |
| 
 | |
| *This is based on code by [Georges Harik (@gharik)](https://twitter.com/gharik).*
 | |
| """
 | |
| 
 | |
| import random
 | |
| import string
 | |
| from typing import List
 | |
| 
 | |
| import torch
 | |
| from labml.logger import Text
 | |
| from torch.utils.data import DataLoader, Dataset
 | |
| 
 | |
| from labml import monit, logger, tracker
 | |
| from labml.configs import option
 | |
| from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs, transpose_batch
 | |
| 
 | |
| 
 | |
| class ArithmeticDataset(Dataset):
 | |
|     """
 | |
|     ## Arithmetic Dataset
 | |
| 
 | |
|     This creates arithmetic addition problems and solutions with workings.
 | |
|     We've only implemented addition so far.
 | |
| 
 | |
|     It's based on a character level tokenization.
 | |
|     """
 | |
| 
 | |
|     def __init__(self, seq_len: int, max_digits: int, n_sequences: int):
 | |
|         """
 | |
|         :param seq_len: is the sequence length of generated math problems.
 | |
|             We fill as many problems as possible upto this length
 | |
|         :max_digits: is the maximum number of digits in the operand integers
 | |
|         :n_sequences: is the number of sequences per epoch
 | |
|         """
 | |
|         self.n_sequences = n_sequences
 | |
|         self.max_digits = max_digits
 | |
|         self.seq_len = seq_len
 | |
|         # Token id to string
 | |
|         self.itos = list(string.digits + 'xe =\n?+;')
 | |
|         # Character to token id
 | |
|         self.stoi = {c: i for i, c in enumerate(self.itos)}
 | |
| 
 | |
|     @staticmethod
 | |
|     def make_int(n_digits: int):
 | |
|         """
 | |
|         Generates an integer with `n_digit` number of digits
 | |
|         """
 | |
|         res = 0
 | |
|         for i in range(n_digits):
 | |
|             d = random.randrange(1, 11) if i == 0 else random.randrange(0, 11)
 | |
|             res = res * 10 + d
 | |
| 
 | |
|         return res
 | |
| 
 | |
|     @staticmethod
 | |
|     def get_add_explanation(x: int, y: int):
 | |
|         """
 | |
|         Generates the workings for `x + y`.
 | |
|         For example for `11+29` it generates
 | |
|         `1e0+9e0+0e0=10e0 1e0+2e0+1e0=4e0`.
 | |
|         """
 | |
| 
 | |
|         carry = 0
 | |
|         e = 0
 | |
|         explanation = []
 | |
|         while x > 0 or y > 0 or carry > 0:
 | |
|             rx, ry = x % 10, y % 10
 | |
|             total = rx + ry + carry
 | |
|             explanation.append(f"{rx}e{e}+{ry}e{e}+{carry}e{e}=={total}e{e}")
 | |
|             x, y, carry = x // 10, y // 10, total // 10
 | |
|             e += 1
 | |
| 
 | |
|         return ' '.join(explanation)
 | |
| 
 | |
|     # Make a problem with a pre_explanation or not
 | |
|     def make_add_problem(self):
 | |
|         """
 | |
|         Creates an arithmetic addition problem with workings and answer.
 | |
|         """
 | |
|         x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
 | |
|         y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
 | |
| 
 | |
|         explanation = self.get_add_explanation(x, y)
 | |
|         return f"x={x}+{y}; {explanation} x=={x + y}\n"
 | |
| 
 | |
|     def get_qa(self):
 | |
|         """
 | |
|         Get arithmetic problem and answer. This is used for evaluation.
 | |
|         """
 | |
|         x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
 | |
|         y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
 | |
| 
 | |
|         return f'x={x}+{y};', f'{x + y}'
 | |
| 
 | |
|     def get_packed_math_input(self):
 | |
|         """
 | |
|         Generate multiple problems and pack them into a sequence.
 | |
|         """
 | |
|         s_enc = []
 | |
|         while len(s_enc) <= self.seq_len:
 | |
|             s_part = self.make_add_problem()
 | |
|             s_part_enc = self.encode('?' + s_part)
 | |
|             s_enc = s_enc + s_part_enc
 | |
|         return s_enc
 | |
| 
 | |
|     def encode(self, s: str):
 | |
|         """
 | |
|         Encode a given string
 | |
|         """
 | |
|         return [self.stoi[c] for c in s]
 | |
| 
 | |
|     def decode(self, arr: List[int]):
 | |
|         """
 | |
|         Decode a list of token ids
 | |
|         """
 | |
|         return ''.join([self.itos[c] for c in arr])
 | |
| 
 | |
|     def __getitem__(self, idx: int):
 | |
|         """
 | |
|         Get a input and target pair for auto-regressive modelling
 | |
|         """
 | |
|         s = torch.tensor(self.get_packed_math_input())
 | |
|         return s[:self.seq_len], s[1:self.seq_len + 1]
 | |
| 
 | |
|     def __len__(self):
 | |
|         """
 | |
|         Number of sequences per epoch
 | |
|         """
 | |
|         return self.n_sequences
 | |
| 
 | |
| 
 | |
| class ArithmeticAutoregression(NLPAutoRegressionConfigs):
 | |
|     """
 | |
|     ## Arithmetic Task Experiment Configurations
 | |
|     """
 | |
|     # Maximum number of digits per operand integer
 | |
|     max_digits: int = 4
 | |
|     # Number of training sequences per epoch
 | |
|     train_sequences_per_epoch: int = 2 ** 12
 | |
|     # Training data loader
 | |
|     train_loader: DataLoader = 'arithmetic_train_loader'
 | |
|     # Number of problems in evaluation
 | |
|     n_tests: int = 64
 | |
|     # No need of a validation dataset
 | |
|     validator = None
 | |
|     # Number of times to run evaluations per epoch
 | |
|     inner_iterations = 4
 | |
|     # Number of tokens in the vocabulary
 | |
|     n_tokens = len(ArithmeticDataset(1, 1, 1).itos)
 | |
| 
 | |
|     @torch.no_grad()
 | |
|     def sample(self):
 | |
|         """
 | |
|         ### Evaluation
 | |
| 
 | |
|         We use the sampling function to evaluate the model on a set of problems
 | |
|         """
 | |
| 
 | |
|         # Skip in the first epoch
 | |
|         if self.training_loop.idx < 1:
 | |
|             return
 | |
| 
 | |
|         # Create a dataset to generate problems
 | |
|         dataset = ArithmeticDataset(self.seq_len, self.max_digits, 1)
 | |
|         # Get a set of problems and answers
 | |
|         qa = [dataset.get_qa() for _ in range(self.n_tests)]
 | |
|         # Collect the problems only
 | |
|         questions = [p[0] for p in qa]
 | |
| 
 | |
|         # Create a tensor with only the initial token
 | |
|         data = torch.tensor([[dataset.stoi[p[0]] for p in questions]])
 | |
|         # Move to device
 | |
|         data = data.to(self.device)
 | |
| 
 | |
|         # Number of sequences that have completed
 | |
|         finished = torch.zeros((len(questions),)).bool().to(self.device)
 | |
|         # Token id of the new line character - this marks end of the answer
 | |
|         new_line = dataset.stoi['\n']
 | |
| 
 | |
|         # Sampled results
 | |
|         results = [p[0] for p in questions]
 | |
| 
 | |
|         # Sample upto sequence length
 | |
|         for i in monit.iterate('Sample', self.seq_len - 1):
 | |
|             # If all the sequences have completed we skip this
 | |
|             if finished.sum() == len(finished):
 | |
|                 continue
 | |
| 
 | |
|             # Get the model output
 | |
|             output, *_ = self.model(data)
 | |
|             # Get the model prediction (greedy)
 | |
|             output = output[-1].argmax(dim=-1)
 | |
| 
 | |
|             # Find which sequences have finished
 | |
|             finished = finished | (output == new_line)
 | |
|             # Skip if all have finished
 | |
|             if finished.sum() == len(finished):
 | |
|                 continue
 | |
| 
 | |
|             # Override with the question
 | |
|             for j, p in enumerate(questions):
 | |
|                 if len(p) > i + 1:
 | |
|                     output[j] = dataset.stoi[p[i + 1]]
 | |
| 
 | |
|             # Add the next token to the input
 | |
|             data = torch.cat([data, output[None, :]], dim=0)
 | |
| 
 | |
|             # Get the sampled results
 | |
|             for j, c in enumerate(output):
 | |
|                 results[j] += dataset.itos[c]
 | |
| 
 | |
|         # Discard everything after the answer in the results
 | |
|         results = [r.split('\n')[0] for r in results]
 | |
| 
 | |
|         # Log a sample
 | |
|         res_sample = results[0].split(';')
 | |
|         logger.log([(res_sample[0], Text.key), (';', Text.subtle), (';'.join(res_sample[1:]), Text.none)])
 | |
| 
 | |
|         # Get the answers
 | |
|         results = [r.split('x==')[-1] for r in results]
 | |
| 
 | |
|         # Count the number of correct answers
 | |
|         correct = 0
 | |
|         for r, _qa in zip(results, qa):
 | |
|             if r == _qa[1]:
 | |
|                 correct += 1
 | |
| 
 | |
|         # Log the score
 | |
|         tracker.save('score', correct / len(results))
 | |
| 
 | |
| 
 | |
| @option(ArithmeticAutoregression.train_loader)
 | |
| def arithmetic_train_loader(c: ArithmeticAutoregression):
 | |
|     """
 | |
|     Training data loader
 | |
|     """
 | |
|     return DataLoader(ArithmeticDataset(c.seq_len, c.max_digits, c.train_sequences_per_epoch),
 | |
|                       batch_size=c.batch_size,
 | |
|                       collate_fn=transpose_batch,
 | |
|                       num_workers=4)
 | |
| 
 | |
| 
 | |
| def _test():
 | |
|     """
 | |
|     Code to test generated problems
 | |
|     """
 | |
|     dataset = ArithmeticDataset(256, 8, 10)
 | |
| 
 | |
|     print(dataset.decode(dataset.get_packed_math_input()))
 | |
| 
 | |
| 
 | |
| #
 | |
| if __name__ == '__main__':
 | |
|     _test()
 | 
