mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-01 08:11:45 +00:00
289 lines
9.1 KiB
Python
289 lines
9.1 KiB
Python
"""EmbeddingFactory class and DefaultEmbeddingFactory class."""
|
|
|
|
import logging
|
|
import os
|
|
from abc import ABC, abstractmethod
|
|
from typing import Any, List, Optional, Type
|
|
|
|
from dbgpt.component import BaseComponent, SystemApp
|
|
from dbgpt.core import Embeddings, RerankEmbeddings
|
|
from dbgpt.core.awel import DAGVar
|
|
from dbgpt.core.awel.flow import ResourceCategory, register_resource
|
|
from dbgpt.util.i18n_utils import _
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class EmbeddingFactory(BaseComponent, ABC):
|
|
"""Abstract base class for EmbeddingFactory."""
|
|
|
|
name = "embedding_factory"
|
|
|
|
@abstractmethod
|
|
def create(
|
|
self, model_name: Optional[str] = None, embedding_cls: Optional[Type] = None
|
|
) -> Embeddings:
|
|
"""Create an embedding instance.
|
|
|
|
Args:
|
|
model_name (str): The model name.
|
|
embedding_cls (Type): The embedding class.
|
|
|
|
Returns:
|
|
Embeddings: The embedding instance.
|
|
"""
|
|
|
|
|
|
class RerankEmbeddingFactory(BaseComponent, ABC):
|
|
"""Class for RerankEmbeddingFactory."""
|
|
|
|
name = "rerank_embedding_factory"
|
|
|
|
@abstractmethod
|
|
def create(
|
|
self, model_name: Optional[str] = None, embedding_cls: Optional[Type] = None
|
|
) -> RerankEmbeddings:
|
|
"""Create an embedding instance.
|
|
|
|
Args:
|
|
model_name (str): The model name.
|
|
embedding_cls (Type): The embedding class.
|
|
|
|
Returns:
|
|
RerankEmbeddings: The embedding instance.
|
|
"""
|
|
|
|
|
|
class DefaultEmbeddingFactory(EmbeddingFactory):
|
|
"""The default embedding factory."""
|
|
|
|
def __init__(
|
|
self,
|
|
system_app: Optional[SystemApp] = None,
|
|
default_model_name: Optional[str] = None,
|
|
default_model_path: Optional[str] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Create a new DefaultEmbeddingFactory."""
|
|
super().__init__(system_app=system_app)
|
|
if not default_model_path:
|
|
default_model_path = default_model_name
|
|
if not default_model_name:
|
|
default_model_name = default_model_path
|
|
self._default_model_name = default_model_name
|
|
self._default_model_path = default_model_path
|
|
self._kwargs = kwargs
|
|
self._model = self._load_model()
|
|
|
|
def init_app(self, system_app):
|
|
"""Init the app."""
|
|
pass
|
|
|
|
def create(
|
|
self, model_name: Optional[str] = None, embedding_cls: Optional[Type] = None
|
|
) -> Embeddings:
|
|
"""Create an embedding instance.
|
|
|
|
Args:
|
|
model_name (str): The model name.
|
|
embedding_cls (Type): The embedding class.
|
|
"""
|
|
if embedding_cls:
|
|
raise NotImplementedError
|
|
return self._model
|
|
|
|
def _load_model(self) -> Embeddings:
|
|
from dbgpt.model.adapter.embeddings_loader import (
|
|
EmbeddingLoader,
|
|
_parse_embedding_params,
|
|
)
|
|
from dbgpt.model.parameter import (
|
|
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
|
|
BaseEmbeddingModelParameters,
|
|
EmbeddingModelParameters,
|
|
)
|
|
|
|
param_cls = EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
|
|
self._default_model_name, EmbeddingModelParameters
|
|
)
|
|
model_params: BaseEmbeddingModelParameters = _parse_embedding_params(
|
|
model_name=self._default_model_name,
|
|
model_path=self._default_model_path,
|
|
param_cls=param_cls,
|
|
**self._kwargs,
|
|
)
|
|
logger.info(model_params)
|
|
loader = EmbeddingLoader()
|
|
# Ignore model_name args
|
|
model_name = self._default_model_name or model_params.model_name
|
|
if not model_name:
|
|
raise ValueError("model_name must be provided.")
|
|
return loader.load(model_name, model_params)
|
|
|
|
@classmethod
|
|
def openai(
|
|
cls,
|
|
api_url: Optional[str] = None,
|
|
api_key: Optional[str] = None,
|
|
model_name: str = "text-embedding-3-small",
|
|
timeout: int = 60,
|
|
**kwargs: Any,
|
|
) -> Embeddings:
|
|
"""Create an OpenAI embeddings.
|
|
|
|
If api_url and api_key are not provided, we will try to get them from
|
|
environment variables.
|
|
|
|
Args:
|
|
api_url (Optional[str], optional): The api url. Defaults to None.
|
|
api_key (Optional[str], optional): The api key. Defaults to None.
|
|
model_name (str, optional): The model name.
|
|
Defaults to "text-embedding-3-small".
|
|
timeout (int, optional): The timeout. Defaults to 60.
|
|
|
|
Returns:
|
|
Embeddings: The embeddings instance.
|
|
"""
|
|
api_url = (
|
|
api_url
|
|
or os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1") + "/embeddings"
|
|
)
|
|
api_key = api_key or os.getenv("OPENAI_API_KEY")
|
|
if not api_key:
|
|
raise ValueError("api_key must be provided.")
|
|
return cls.remote(
|
|
api_url=api_url,
|
|
api_key=api_key,
|
|
model_name=model_name,
|
|
timeout=timeout,
|
|
**kwargs,
|
|
)
|
|
|
|
@classmethod
|
|
def default(
|
|
cls, model_name: str, model_path: Optional[str] = None, **kwargs: Any
|
|
) -> Embeddings:
|
|
"""Create a default embeddings.
|
|
|
|
It will try to load the model from the model name or model path.
|
|
|
|
Args:
|
|
model_name (str): The model name.
|
|
model_path (Optional[str], optional): The model path. Defaults to None.
|
|
if not provided, it will use the model name as the model path to load
|
|
the model.
|
|
|
|
Returns:
|
|
Embeddings: The embeddings instance.
|
|
"""
|
|
return cls(
|
|
default_model_name=model_name, default_model_path=model_path, **kwargs
|
|
).create()
|
|
|
|
@classmethod
|
|
def remote(
|
|
cls,
|
|
api_url: str = "http://localhost:8100/api/v1/embeddings",
|
|
api_key: Optional[str] = None,
|
|
model_name: str = "text2vec",
|
|
timeout: int = 60,
|
|
**kwargs: Any,
|
|
) -> Embeddings:
|
|
"""Create a remote embeddings.
|
|
|
|
Create a remote embeddings which API compatible with the OpenAI's API. So if
|
|
your model is compatible with OpenAI's API, you can use this method to create
|
|
a remote embeddings.
|
|
|
|
Args:
|
|
api_url (str, optional): The api url. Defaults to
|
|
"http://localhost:8100/api/v1/embeddings".
|
|
api_key (Optional[str], optional): The api key. Defaults to None.
|
|
model_name (str, optional): The model name. Defaults to "text2vec".
|
|
timeout (int, optional): The timeout. Defaults to 60.
|
|
"""
|
|
from .embeddings import OpenAPIEmbeddings
|
|
|
|
return OpenAPIEmbeddings(
|
|
api_url=api_url,
|
|
api_key=api_key,
|
|
model_name=model_name,
|
|
timeout=timeout,
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
class WrappedEmbeddingFactory(EmbeddingFactory):
|
|
"""The default embedding factory."""
|
|
|
|
def __init__(
|
|
self,
|
|
system_app: Optional[SystemApp] = None,
|
|
embeddings: Optional[Embeddings] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Create a new DefaultEmbeddingFactory."""
|
|
super().__init__(system_app=system_app)
|
|
if not embeddings:
|
|
raise ValueError("embeddings must be provided.")
|
|
self._model = embeddings
|
|
|
|
def init_app(self, system_app):
|
|
"""Init the app."""
|
|
pass
|
|
|
|
def create(
|
|
self, model_name: Optional[str] = None, embedding_cls: Optional[Type] = None
|
|
) -> Embeddings:
|
|
"""Create an embedding instance.
|
|
|
|
Args:
|
|
model_name (str): The model name.
|
|
embedding_cls (Type): The embedding class.
|
|
"""
|
|
if embedding_cls:
|
|
raise NotImplementedError
|
|
return self._model
|
|
|
|
|
|
@register_resource(
|
|
label=_("Default Embeddings"),
|
|
name="default_embeddings",
|
|
category=ResourceCategory.EMBEDDINGS,
|
|
description=_(
|
|
"Default embeddings(using default embedding model of current system)"
|
|
),
|
|
)
|
|
class DefaultEmbeddings(Embeddings):
|
|
"""The default embeddings."""
|
|
|
|
def __init__(self, embedding_factory: Optional[EmbeddingFactory] = None) -> None:
|
|
"""Create a new DefaultEmbeddings."""
|
|
self._embedding_factory = embedding_factory
|
|
|
|
@property
|
|
def embeddings(self) -> Embeddings:
|
|
"""Get the embeddings."""
|
|
if not self._embedding_factory:
|
|
system_app = DAGVar.get_current_system_app()
|
|
if not system_app:
|
|
raise ValueError("System app is not initialized")
|
|
self._embedding_factory = EmbeddingFactory.get_instance(system_app)
|
|
return self._embedding_factory.create()
|
|
|
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
"""Embed search docs."""
|
|
return self.embeddings.embed_documents(texts)
|
|
|
|
def embed_query(self, text: str) -> List[float]:
|
|
"""Embed query text."""
|
|
return self.embeddings.embed_query(text)
|
|
|
|
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
"""Asynchronous Embed search docs."""
|
|
return await self.embeddings.aembed_documents(texts)
|
|
|
|
async def aembed_query(self, text: str) -> List[float]:
|
|
"""Asynchronous Embed query text."""
|
|
return await self.embeddings.aembed_query(text)
|