mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-30 22:24:21 +00:00
[tutorial] edited hands-on practices (#1899)
* Add handson to ColossalAI. * Change names of handsons and edit sequence parallel example. * Edit wrong folder name * resolve conflict * delete readme
This commit is contained in:
64
examples/tutorial/opt/inference/cache.py
Normal file
64
examples/tutorial/opt/inference/cache.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from collections import OrderedDict
|
||||
from threading import Lock
|
||||
from contextlib import contextmanager
|
||||
from typing import List, Any, Hashable, Dict
|
||||
|
||||
|
||||
class MissCacheError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ListCache:
|
||||
def __init__(self, cache_size: int, list_size: int, fixed_keys: List[Hashable] = []) -> None:
|
||||
"""Cache a list of values. The fixed keys won't be removed. For other keys, LRU is applied.
|
||||
When the value list is not full, a cache miss occurs. Otherwise, a cache hit occurs. Redundant values will be removed.
|
||||
|
||||
Args:
|
||||
cache_size (int): Max size for LRU cache.
|
||||
list_size (int): Value list size.
|
||||
fixed_keys (List[Hashable], optional): The keys which won't be removed. Defaults to [].
|
||||
"""
|
||||
self.cache_size = cache_size
|
||||
self.list_size = list_size
|
||||
self.cache: OrderedDict[Hashable, List[Any]] = OrderedDict()
|
||||
self.fixed_cache: Dict[Hashable, List[Any]] = {}
|
||||
for key in fixed_keys:
|
||||
self.fixed_cache[key] = []
|
||||
self._lock = Lock()
|
||||
|
||||
def get(self, key: Hashable) -> List[Any]:
|
||||
with self.lock():
|
||||
if key in self.fixed_cache:
|
||||
l = self.fixed_cache[key]
|
||||
if len(l) >= self.list_size:
|
||||
return l
|
||||
elif key in self.cache:
|
||||
self.cache.move_to_end(key)
|
||||
l = self.cache[key]
|
||||
if len(l) >= self.list_size:
|
||||
return l
|
||||
raise MissCacheError()
|
||||
|
||||
def add(self, key: Hashable, value: Any) -> None:
|
||||
with self.lock():
|
||||
if key in self.fixed_cache:
|
||||
l = self.fixed_cache[key]
|
||||
if len(l) < self.list_size and value not in l:
|
||||
l.append(value)
|
||||
elif key in self.cache:
|
||||
self.cache.move_to_end(key)
|
||||
l = self.cache[key]
|
||||
if len(l) < self.list_size and value not in l:
|
||||
l.append(value)
|
||||
else:
|
||||
if len(self.cache) >= self.cache_size:
|
||||
self.cache.popitem(last=False)
|
||||
self.cache[key] = [value]
|
||||
|
||||
@contextmanager
|
||||
def lock(self):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
yield
|
||||
finally:
|
||||
self._lock.release()
|
Reference in New Issue
Block a user