mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-26 13:27:46 +00:00
142 lines
5.9 KiB
Python
142 lines
5.9 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
from typing import List, Optional, Type, cast
|
|
|
|
from dbgpt.configs.model_config import get_device
|
|
from dbgpt.core import Embeddings, RerankEmbeddings
|
|
from dbgpt.model.parameter import (
|
|
BaseEmbeddingModelParameters,
|
|
EmbeddingModelParameters,
|
|
ProxyEmbeddingParameters,
|
|
)
|
|
from dbgpt.util.parameter_utils import EnvArgumentParser, _get_dict_from_obj
|
|
from dbgpt.util.system_utils import get_system_info
|
|
from dbgpt.util.tracer import SpanType, SpanTypeRunName, root_tracer
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class EmbeddingLoader:
|
|
def __init__(self) -> None:
|
|
pass
|
|
|
|
def load(self, model_name: str, param: BaseEmbeddingModelParameters) -> Embeddings:
|
|
metadata = {
|
|
"model_name": model_name,
|
|
"run_service": SpanTypeRunName.EMBEDDING_MODEL.value,
|
|
"params": _get_dict_from_obj(param),
|
|
"sys_infos": _get_dict_from_obj(get_system_info()),
|
|
}
|
|
with root_tracer.start_span(
|
|
"EmbeddingLoader.load", span_type=SpanType.RUN, metadata=metadata
|
|
):
|
|
# add more models
|
|
if model_name in ["proxy_openai", "proxy_azure"]:
|
|
from langchain.embeddings import OpenAIEmbeddings
|
|
|
|
from dbgpt.rag.embedding._wrapped import WrappedEmbeddings
|
|
|
|
return WrappedEmbeddings(OpenAIEmbeddings(**param.build_kwargs()))
|
|
elif model_name in ["proxy_http_openapi"]:
|
|
from dbgpt.rag.embedding import OpenAPIEmbeddings
|
|
|
|
proxy_param = cast(ProxyEmbeddingParameters, param)
|
|
openapi_param = {}
|
|
if proxy_param.proxy_server_url:
|
|
openapi_param["api_url"] = proxy_param.proxy_server_url
|
|
if proxy_param.proxy_api_key:
|
|
openapi_param["api_key"] = proxy_param.proxy_api_key
|
|
if proxy_param.proxy_backend:
|
|
openapi_param["model_name"] = proxy_param.proxy_backend
|
|
return OpenAPIEmbeddings(**openapi_param)
|
|
elif model_name in ["proxy_tongyi"]:
|
|
from dbgpt.rag.embedding import TongYiEmbeddings
|
|
|
|
proxy_param = cast(ProxyEmbeddingParameters, param)
|
|
tongyi_param = {"api_key": proxy_param.proxy_api_key}
|
|
if proxy_param.proxy_backend:
|
|
tongyi_param["model_name"] = proxy_param.proxy_backend
|
|
return TongYiEmbeddings(**tongyi_param)
|
|
elif model_name in ["proxy_qianfan"]:
|
|
from dbgpt.rag.embedding import QianFanEmbeddings
|
|
|
|
proxy_param = cast(ProxyEmbeddingParameters, param)
|
|
qianfan_param = {"api_key": proxy_param.proxy_api_key}
|
|
if proxy_param.proxy_backend:
|
|
qianfan_param["model_name"] = proxy_param.proxy_backend
|
|
qianfan_param["api_secret"] = proxy_param.proxy_api_secret
|
|
return QianFanEmbeddings(**qianfan_param)
|
|
elif model_name in ["proxy_ollama"]:
|
|
from dbgpt.rag.embedding import OllamaEmbeddings
|
|
|
|
proxy_param = cast(ProxyEmbeddingParameters, param)
|
|
ollama_param = {}
|
|
if proxy_param.proxy_server_url:
|
|
ollama_param["api_url"] = proxy_param.proxy_server_url
|
|
if proxy_param.proxy_backend:
|
|
ollama_param["model_name"] = proxy_param.proxy_backend
|
|
return OllamaEmbeddings(**ollama_param)
|
|
else:
|
|
from dbgpt.rag.embedding import HuggingFaceEmbeddings
|
|
|
|
kwargs = param.build_kwargs(model_name=param.model_path)
|
|
return HuggingFaceEmbeddings(**kwargs)
|
|
|
|
def load_rerank_model(
|
|
self, model_name: str, param: BaseEmbeddingModelParameters
|
|
) -> RerankEmbeddings:
|
|
metadata = {
|
|
"model_name": model_name,
|
|
"run_service": SpanTypeRunName.EMBEDDING_MODEL.value,
|
|
"params": _get_dict_from_obj(param),
|
|
"sys_infos": _get_dict_from_obj(get_system_info()),
|
|
}
|
|
with root_tracer.start_span(
|
|
"EmbeddingLoader.load_rerank_model",
|
|
span_type=SpanType.RUN,
|
|
metadata=metadata,
|
|
):
|
|
if model_name in ["rerank_proxy_http_openapi"]:
|
|
from dbgpt.rag.embedding.rerank import OpenAPIRerankEmbeddings
|
|
|
|
proxy_param = cast(ProxyEmbeddingParameters, param)
|
|
openapi_param = {}
|
|
if proxy_param.proxy_server_url:
|
|
openapi_param["api_url"] = proxy_param.proxy_server_url
|
|
if proxy_param.proxy_api_key:
|
|
openapi_param["api_key"] = proxy_param.proxy_api_key
|
|
if proxy_param.proxy_backend:
|
|
openapi_param["model_name"] = proxy_param.proxy_backend
|
|
return OpenAPIRerankEmbeddings(**openapi_param)
|
|
else:
|
|
from dbgpt.rag.embedding.rerank import CrossEncoderRerankEmbeddings
|
|
|
|
kwargs = param.build_kwargs(model_name=param.model_path)
|
|
return CrossEncoderRerankEmbeddings(**kwargs)
|
|
|
|
|
|
def _parse_embedding_params(
|
|
model_name: Optional[str] = None,
|
|
model_path: Optional[str] = None,
|
|
command_args: List[str] = None,
|
|
param_cls: Optional[Type] = EmbeddingModelParameters,
|
|
**kwargs,
|
|
):
|
|
model_args = EnvArgumentParser()
|
|
env_prefix = EnvArgumentParser.get_env_prefix(model_name)
|
|
model_params: BaseEmbeddingModelParameters = model_args.parse_args_into_dataclass(
|
|
param_cls,
|
|
env_prefixes=[env_prefix],
|
|
command_args=command_args,
|
|
model_name=model_name,
|
|
model_path=model_path,
|
|
**kwargs,
|
|
)
|
|
if not model_params.device:
|
|
model_params.device = get_device()
|
|
logger.info(
|
|
f"[EmbeddingsModelWorker] Parameters of device is None, use {model_params.device}"
|
|
)
|
|
return model_params
|