DB-GPT/dbgpt/model/cluster/worker/embedding_worker.py
2024-05-16 14:50:16 +08:00

114 lines
4.0 KiB
Python

import logging
from typing import Dict, List, Type, Union
from dbgpt.core import Embeddings, ModelMetadata, RerankEmbeddings
from dbgpt.model.adapter.embeddings_loader import (
EmbeddingLoader,
_parse_embedding_params,
)
from dbgpt.model.adapter.loader import _get_model_real_path
from dbgpt.model.cluster.worker_base import ModelWorker
from dbgpt.model.parameter import (
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
BaseEmbeddingModelParameters,
EmbeddingModelParameters,
WorkerType,
)
from dbgpt.util.model_utils import _clear_model_cache
logger = logging.getLogger(__name__)
class EmbeddingsModelWorker(ModelWorker):
def __init__(self, rerank_model: bool = False) -> None:
self._embeddings_impl: Union[Embeddings, RerankEmbeddings, None] = None
self._model_params = None
self.model_name = None
self.model_path = None
self._rerank_model = rerank_model
self._loader = EmbeddingLoader()
def load_worker(self, model_name: str, model_path: str, **kwargs) -> None:
if model_path.endswith("/"):
model_path = model_path[:-1]
model_path = _get_model_real_path(model_name, model_path)
self.model_name = model_name
self.model_path = model_path
def worker_type(self) -> WorkerType:
return WorkerType.TEXT2VEC
def model_param_class(self) -> Type:
return EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
self.model_name, EmbeddingModelParameters
)
def parse_parameters(
self, command_args: List[str] = None
) -> BaseEmbeddingModelParameters:
param_cls = self.model_param_class()
return _parse_embedding_params(
model_name=self.model_name,
model_path=self.model_path,
command_args=command_args,
param_cls=param_cls,
)
def start(
self,
model_params: EmbeddingModelParameters = None,
command_args: List[str] = None,
) -> None:
"""Start model worker"""
if not model_params:
model_params = self.parse_parameters(command_args)
if self._rerank_model:
model_params.rerank = True # type: ignore
self._model_params = model_params
if model_params.is_rerank_model():
logger.info(f"Load rerank embeddings model: {self.model_name}")
self._embeddings_impl = self._loader.load_rerank_model(
self.model_name, model_params
)
else:
logger.info(f"Load embeddings model: {self.model_name}")
self._embeddings_impl = self._loader.load(self.model_name, model_params)
def __del__(self):
self.stop()
def stop(self) -> None:
if not self._embeddings_impl:
return
del self._embeddings_impl
self._embeddings_impl = None
_clear_model_cache(self._model_params.device)
def generate_stream(self, params: Dict):
"""Generate stream result, chat scene"""
raise NotImplementedError("Not supported generate_stream for embeddings model")
def generate(self, params: Dict):
"""Generate non stream result"""
raise NotImplementedError("Not supported generate for embeddings model")
def count_token(self, prompt: str) -> int:
raise NotImplementedError("Not supported count_token for embeddings model")
def get_model_metadata(self, params: Dict) -> ModelMetadata:
raise NotImplementedError(
"Not supported get_model_metadata for embeddings model"
)
def embeddings(self, params: Dict) -> List[List[float]]:
model = params.get("model")
logger.info(f"Receive embeddings request, model: {model}")
textx: List[str] = params["input"]
if isinstance(self._embeddings_impl, RerankEmbeddings):
query = params["query"]
scores: List[float] = self._embeddings_impl.predict(query, textx)
return [scores]
else:
return self._embeddings_impl.embed_documents(textx)