mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 10:48:49 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			119 lines
		
	
	
		
			2.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			119 lines
		
	
	
		
			2.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """
 | |
| ---
 | |
| title: Cache for Intermediate Activations
 | |
| summary: >
 | |
|     Cache for intermediate activations for faster inference.
 | |
| ---
 | |
| 
 | |
| # Cache for Intermediate Activations
 | |
| 
 | |
| During inference the model outputs token by token.
 | |
| We use this simple cache to store key's and value's attention layers,
 | |
| so that we don't have to recompute them for previous tokens.
 | |
| """
 | |
| 
 | |
| from typing import Any
 | |
| 
 | |
| 
 | |
| class Cache:
 | |
|     """
 | |
|     ## Cache
 | |
| 
 | |
|     This maintains a key-value cache and queues push values and pop them in the same order.
 | |
|     The queues are useful since we have multiple attention layers.
 | |
|     """
 | |
| 
 | |
|     def __init__(self):
 | |
|         self._cache = {}
 | |
| 
 | |
|     def clear_all(self):
 | |
|         """
 | |
|         ### Clear cache
 | |
|         """
 | |
|         self._cache = {}
 | |
| 
 | |
|     def push(self, name: str, value: Any):
 | |
|         """
 | |
|         ### Push a value to a queue
 | |
| 
 | |
|         :param name: is the name of the queue
 | |
|         :param value: is the value to be pushed
 | |
|         """
 | |
| 
 | |
|         # Create an empty queue if it's not present
 | |
|         if name not in self._cache:
 | |
|             self._cache[name] = []
 | |
| 
 | |
|         # Push to the queue
 | |
|         self._cache[name].append(value)
 | |
| 
 | |
|     def q_size(self, name):
 | |
|         """
 | |
|         ### Return the size of the queue
 | |
| 
 | |
|         :param name: is the name of the queue
 | |
|         :return: size of the queue if exists else None
 | |
|         """
 | |
| 
 | |
|         if name not in self._cache:
 | |
|             return None
 | |
| 
 | |
|         if type(self._cache[name]) != list:
 | |
|             return None
 | |
| 
 | |
|         return len(self._cache[name])
 | |
| 
 | |
|     def pop(self, name: str):
 | |
|         """
 | |
|         ### Pop from a queue
 | |
| 
 | |
|         :param name: is the name of the queue
 | |
|         :return: the value
 | |
|         """
 | |
|         return self._cache[name].pop(0)
 | |
| 
 | |
|     def set(self, key: str, value: Any):
 | |
|         """
 | |
|         ### Cache a value
 | |
| 
 | |
|         :param key: is the name of the value to be cached
 | |
|         :param value: is the value
 | |
|         """
 | |
|         self._cache[key] = value
 | |
| 
 | |
|     def get(self, key: str, default: Any = None):
 | |
|         """
 | |
|         ### Retrieve a value from cache
 | |
| 
 | |
|         :param key: is the name used when caching
 | |
|         :param default: is the default value if the cache is empty
 | |
|         :return: the cached value
 | |
|         """
 | |
|         return self._cache.get(key, default)
 | |
| 
 | |
|     def clear(self, key: str):
 | |
|         """
 | |
|         ### Clear a cache value
 | |
| 
 | |
|         :param key: is the name used when caching
 | |
|         """
 | |
|         del self._cache[key]
 | |
| 
 | |
| 
 | |
| # Singleton for cache
 | |
| _INSTANCE = None
 | |
| 
 | |
| 
 | |
| def get_cache() -> Cache:
 | |
|     """
 | |
|     ### Get the cache instance
 | |
| 
 | |
|     :return: the cache instance
 | |
|     """
 | |
|     global _INSTANCE
 | |
| 
 | |
|     if _INSTANCE is None:
 | |
|         _INSTANCE = Cache()
 | |
| 
 | |
|     return _INSTANCE
 | 
