feat(model): Support claude proxy models (#2155)

This commit is contained in:
Fangyin Cheng
2024-11-26 19:47:28 +08:00
committed by GitHub
parent 9d8673a02f
commit 61509dc5ea
20 changed files with 508 additions and 157 deletions

View File

@@ -34,6 +34,25 @@ class ProxyTokenizer(ABC):
List[int]: token count, -1 if failed
"""
def support_async(self) -> bool:
"""Check if the tokenizer supports asynchronous counting token.
Returns:
bool: True if supports, False otherwise
"""
return False
async def count_token_async(self, model_name: str, prompts: List[str]) -> List[int]:
"""Count token of given prompts asynchronously.
Args:
model_name (str): model name
prompts (List[str]): prompts to count token
Returns:
List[int]: token count, -1 if failed
"""
raise NotImplementedError()
class TiktokenProxyTokenizer(ProxyTokenizer):
def __init__(self):
@@ -92,7 +111,7 @@ class ProxyLLMClient(LLMClient):
self.model_names = model_names
self.context_length = context_length
self.executor = executor or ThreadPoolExecutor()
self.proxy_tokenizer = proxy_tokenizer or TiktokenProxyTokenizer()
self._proxy_tokenizer = proxy_tokenizer
def __getstate__(self):
"""Customize the serialization of the object"""
@@ -105,6 +124,17 @@ class ProxyLLMClient(LLMClient):
self.__dict__.update(state)
self.executor = ThreadPoolExecutor()
@property
def proxy_tokenizer(self) -> ProxyTokenizer:
"""Get proxy tokenizer
Returns:
ProxyTokenizer: proxy tokenizer
"""
if not self._proxy_tokenizer:
self._proxy_tokenizer = TiktokenProxyTokenizer()
return self._proxy_tokenizer
@classmethod
@abstractmethod
def new_client(
@@ -257,6 +287,9 @@ class ProxyLLMClient(LLMClient):
Returns:
int: token count, -1 if failed
"""
if self.proxy_tokenizer.support_async():
cnts = await self.proxy_tokenizer.count_token_async(model, [prompt])
return cnts[0]
counts = await blocking_func_to_async(
self.executor, self.proxy_tokenizer.count_token, model, [prompt]
)