DB-GPT/dbgpt/model/proxy/llms/proxy_model.py
2024-01-14 21:01:37 +08:00

44 lines
1.3 KiB
Python

from __future__ import annotations
import logging
from typing import TYPE_CHECKING, List, Optional, Union
from dbgpt.model.parameter import ProxyModelParameters
from dbgpt.model.proxy.base import ProxyLLMClient
from dbgpt.model.utils.token_utils import ProxyTokenizerWrapper
if TYPE_CHECKING:
from dbgpt.core.interface.message import BaseMessage, ModelMessage
logger = logging.getLogger(__name__)
class ProxyModel:
def __init__(
self,
model_params: ProxyModelParameters,
proxy_llm_client: Optional[ProxyLLMClient] = None,
) -> None:
self._model_params = model_params
self._tokenizer = ProxyTokenizerWrapper()
self.proxy_llm_client = proxy_llm_client
def get_params(self) -> ProxyModelParameters:
return self._model_params
def count_token(
self,
messages: Union[str, BaseMessage, ModelMessage, List[ModelMessage]],
model_name: Optional[int] = None,
) -> int:
"""Count token of given messages
Args:
messages (Union[str, BaseMessage, ModelMessage, List[ModelMessage]]): messages to count token
model_name (Optional[int], optional): model name. Defaults to None.
Returns:
int: token count, -1 if failed
"""
return self._tokenizer.count_token(messages, model_name)