mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-07 03:14:42 +00:00
114 lines
4.0 KiB
Python
114 lines
4.0 KiB
Python
import logging
|
|
from typing import Dict, List, Type, Union
|
|
|
|
from dbgpt.core import Embeddings, ModelMetadata, RerankEmbeddings
|
|
from dbgpt.model.adapter.embeddings_loader import (
|
|
EmbeddingLoader,
|
|
_parse_embedding_params,
|
|
)
|
|
from dbgpt.model.adapter.loader import _get_model_real_path
|
|
from dbgpt.model.cluster.worker_base import ModelWorker
|
|
from dbgpt.model.parameter import (
|
|
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
|
|
BaseEmbeddingModelParameters,
|
|
EmbeddingModelParameters,
|
|
WorkerType,
|
|
)
|
|
from dbgpt.util.model_utils import _clear_model_cache
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class EmbeddingsModelWorker(ModelWorker):
|
|
def __init__(self, rerank_model: bool = False) -> None:
|
|
self._embeddings_impl: Union[Embeddings, RerankEmbeddings, None] = None
|
|
self._model_params = None
|
|
self.model_name = None
|
|
self.model_path = None
|
|
self._rerank_model = rerank_model
|
|
self._loader = EmbeddingLoader()
|
|
|
|
def load_worker(self, model_name: str, model_path: str, **kwargs) -> None:
|
|
if model_path.endswith("/"):
|
|
model_path = model_path[:-1]
|
|
model_path = _get_model_real_path(model_name, model_path)
|
|
|
|
self.model_name = model_name
|
|
self.model_path = model_path
|
|
|
|
def worker_type(self) -> WorkerType:
|
|
return WorkerType.TEXT2VEC
|
|
|
|
def model_param_class(self) -> Type:
|
|
return EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
|
|
self.model_name, EmbeddingModelParameters
|
|
)
|
|
|
|
def parse_parameters(
|
|
self, command_args: List[str] = None
|
|
) -> BaseEmbeddingModelParameters:
|
|
param_cls = self.model_param_class()
|
|
return _parse_embedding_params(
|
|
model_name=self.model_name,
|
|
model_path=self.model_path,
|
|
command_args=command_args,
|
|
param_cls=param_cls,
|
|
)
|
|
|
|
def start(
|
|
self,
|
|
model_params: EmbeddingModelParameters = None,
|
|
command_args: List[str] = None,
|
|
) -> None:
|
|
"""Start model worker"""
|
|
if not model_params:
|
|
model_params = self.parse_parameters(command_args)
|
|
if self._rerank_model:
|
|
model_params.rerank = True # type: ignore
|
|
self._model_params = model_params
|
|
if model_params.is_rerank_model():
|
|
logger.info(f"Load rerank embeddings model: {self.model_name}")
|
|
self._embeddings_impl = self._loader.load_rerank_model(
|
|
self.model_name, model_params
|
|
)
|
|
else:
|
|
logger.info(f"Load embeddings model: {self.model_name}")
|
|
self._embeddings_impl = self._loader.load(self.model_name, model_params)
|
|
|
|
def __del__(self):
|
|
self.stop()
|
|
|
|
def stop(self) -> None:
|
|
if not self._embeddings_impl:
|
|
return
|
|
del self._embeddings_impl
|
|
self._embeddings_impl = None
|
|
_clear_model_cache(self._model_params.device)
|
|
|
|
def generate_stream(self, params: Dict):
|
|
"""Generate stream result, chat scene"""
|
|
raise NotImplementedError("Not supported generate_stream for embeddings model")
|
|
|
|
def generate(self, params: Dict):
|
|
"""Generate non stream result"""
|
|
raise NotImplementedError("Not supported generate for embeddings model")
|
|
|
|
def count_token(self, prompt: str) -> int:
|
|
raise NotImplementedError("Not supported count_token for embeddings model")
|
|
|
|
def get_model_metadata(self, params: Dict) -> ModelMetadata:
|
|
raise NotImplementedError(
|
|
"Not supported get_model_metadata for embeddings model"
|
|
)
|
|
|
|
def embeddings(self, params: Dict) -> List[List[float]]:
|
|
model = params.get("model")
|
|
logger.info(f"Receive embeddings request, model: {model}")
|
|
textx: List[str] = params["input"]
|
|
if isinstance(self._embeddings_impl, RerankEmbeddings):
|
|
query = params["query"]
|
|
scores: List[float] = self._embeddings_impl.predict(query, textx)
|
|
return [scores]
|
|
else:
|
|
return self._embeddings_impl.embed_documents(textx)
|