DB-GPT/dbgpt/model/adapter/embeddings_loader.py
2024-08-29 16:37:31 +08:00

142 lines
5.9 KiB
Python

from __future__ import annotations
import logging
from typing import List, Optional, Type, cast
from dbgpt.configs.model_config import get_device
from dbgpt.core import Embeddings, RerankEmbeddings
from dbgpt.model.parameter import (
BaseEmbeddingModelParameters,
EmbeddingModelParameters,
ProxyEmbeddingParameters,
)
from dbgpt.util.parameter_utils import EnvArgumentParser, _get_dict_from_obj
from dbgpt.util.system_utils import get_system_info
from dbgpt.util.tracer import SpanType, SpanTypeRunName, root_tracer
logger = logging.getLogger(__name__)
class EmbeddingLoader:
def __init__(self) -> None:
pass
def load(self, model_name: str, param: BaseEmbeddingModelParameters) -> Embeddings:
metadata = {
"model_name": model_name,
"run_service": SpanTypeRunName.EMBEDDING_MODEL.value,
"params": _get_dict_from_obj(param),
"sys_infos": _get_dict_from_obj(get_system_info()),
}
with root_tracer.start_span(
"EmbeddingLoader.load", span_type=SpanType.RUN, metadata=metadata
):
# add more models
if model_name in ["proxy_openai", "proxy_azure"]:
from langchain.embeddings import OpenAIEmbeddings
from dbgpt.rag.embedding._wrapped import WrappedEmbeddings
return WrappedEmbeddings(OpenAIEmbeddings(**param.build_kwargs()))
elif model_name in ["proxy_http_openapi"]:
from dbgpt.rag.embedding import OpenAPIEmbeddings
proxy_param = cast(ProxyEmbeddingParameters, param)
openapi_param = {}
if proxy_param.proxy_server_url:
openapi_param["api_url"] = proxy_param.proxy_server_url
if proxy_param.proxy_api_key:
openapi_param["api_key"] = proxy_param.proxy_api_key
if proxy_param.proxy_backend:
openapi_param["model_name"] = proxy_param.proxy_backend
return OpenAPIEmbeddings(**openapi_param)
elif model_name in ["proxy_tongyi"]:
from dbgpt.rag.embedding import TongYiEmbeddings
proxy_param = cast(ProxyEmbeddingParameters, param)
tongyi_param = {"api_key": proxy_param.proxy_api_key}
if proxy_param.proxy_backend:
tongyi_param["model_name"] = proxy_param.proxy_backend
return TongYiEmbeddings(**tongyi_param)
elif model_name in ["proxy_qianfan"]:
from dbgpt.rag.embedding import QianFanEmbeddings
proxy_param = cast(ProxyEmbeddingParameters, param)
qianfan_param = {"api_key": proxy_param.proxy_api_key}
if proxy_param.proxy_backend:
qianfan_param["model_name"] = proxy_param.proxy_backend
qianfan_param["api_secret"] = proxy_param.proxy_api_secret
return QianFanEmbeddings(**qianfan_param)
elif model_name in ["proxy_ollama"]:
from dbgpt.rag.embedding import OllamaEmbeddings
proxy_param = cast(ProxyEmbeddingParameters, param)
ollama_param = {}
if proxy_param.proxy_server_url:
ollama_param["api_url"] = proxy_param.proxy_server_url
if proxy_param.proxy_backend:
ollama_param["model_name"] = proxy_param.proxy_backend
return OllamaEmbeddings(**ollama_param)
else:
from dbgpt.rag.embedding import HuggingFaceEmbeddings
kwargs = param.build_kwargs(model_name=param.model_path)
return HuggingFaceEmbeddings(**kwargs)
def load_rerank_model(
self, model_name: str, param: BaseEmbeddingModelParameters
) -> RerankEmbeddings:
metadata = {
"model_name": model_name,
"run_service": SpanTypeRunName.EMBEDDING_MODEL.value,
"params": _get_dict_from_obj(param),
"sys_infos": _get_dict_from_obj(get_system_info()),
}
with root_tracer.start_span(
"EmbeddingLoader.load_rerank_model",
span_type=SpanType.RUN,
metadata=metadata,
):
if model_name in ["rerank_proxy_http_openapi"]:
from dbgpt.rag.embedding.rerank import OpenAPIRerankEmbeddings
proxy_param = cast(ProxyEmbeddingParameters, param)
openapi_param = {}
if proxy_param.proxy_server_url:
openapi_param["api_url"] = proxy_param.proxy_server_url
if proxy_param.proxy_api_key:
openapi_param["api_key"] = proxy_param.proxy_api_key
if proxy_param.proxy_backend:
openapi_param["model_name"] = proxy_param.proxy_backend
return OpenAPIRerankEmbeddings(**openapi_param)
else:
from dbgpt.rag.embedding.rerank import CrossEncoderRerankEmbeddings
kwargs = param.build_kwargs(model_name=param.model_path)
return CrossEncoderRerankEmbeddings(**kwargs)
def _parse_embedding_params(
model_name: Optional[str] = None,
model_path: Optional[str] = None,
command_args: List[str] = None,
param_cls: Optional[Type] = EmbeddingModelParameters,
**kwargs,
):
model_args = EnvArgumentParser()
env_prefix = EnvArgumentParser.get_env_prefix(model_name)
model_params: BaseEmbeddingModelParameters = model_args.parse_args_into_dataclass(
param_cls,
env_prefixes=[env_prefix],
command_args=command_args,
model_name=model_name,
model_path=model_path,
**kwargs,
)
if not model_params.device:
model_params.device = get_device()
logger.info(
f"[EmbeddingsModelWorker] Parameters of device is None, use {model_params.device}"
)
return model_params