Files
2022-08-11 15:44:13 +05:30

119 lines
2.6 KiB
Python

"""
---
title: Cache for Intermediate Activations
summary: >
Cache for intermediate activations for faster inference.
---
# Cache for Intermediate Activations
During inference the model outputs token by token.
We use this simple cache to store key's and value's attention layers,
so that we don't have to recompute them for previous tokens.
"""
from typing import Any
class Cache:
"""
## Cache
This maintains a key-value cache and queues push values and pop them in the same order.
The queues are useful since we have multiple attention layers.
"""
def __init__(self):
self._cache = {}
def clear_all(self):
"""
### Clear cache
"""
self._cache = {}
def push(self, name: str, value: Any):
"""
### Push a value to a queue
:param name: is the name of the queue
:param value: is the value to be pushed
"""
# Create an empty queue if it's not present
if name not in self._cache:
self._cache[name] = []
# Push to the queue
self._cache[name].append(value)
def q_size(self, name):
"""
### Return the size of the queue
:param name: is the name of the queue
:return: size of the queue if exists else None
"""
if name not in self._cache:
return None
if type(self._cache[name]) != list:
return None
return len(self._cache[name])
def pop(self, name: str):
"""
### Pop from a queue
:param name: is the name of the queue
:return: the value
"""
return self._cache[name].pop(0)
def set(self, key: str, value: Any):
"""
### Cache a value
:param key: is the name of the value to be cached
:param value: is the value
"""
self._cache[key] = value
def get(self, key: str, default: Any = None):
"""
### Retrieve a value from cache
:param key: is the name used when caching
:param default: is the default value if the cache is empty
:return: the cached value
"""
return self._cache.get(key, default)
def clear(self, key: str):
"""
### Clear a cache value
:param key: is the name used when caching
"""
del self._cache[key]
# Singleton for cache
_INSTANCE = None
def get_cache() -> Cache:
"""
### Get the cache instance
:return: the cache instance
"""
global _INSTANCE
if _INSTANCE is None:
_INSTANCE = Cache()
return _INSTANCE