mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-15 05:59:59 +00:00
feat(rag): Support RAG SDK (#1322)
This commit is contained in:
@@ -1,6 +1,10 @@
|
||||
"""Module for embedding related classes and functions."""
|
||||
|
||||
from .embedding_factory import DefaultEmbeddingFactory, EmbeddingFactory # noqa: F401
|
||||
from .embedding_factory import ( # noqa: F401
|
||||
DefaultEmbeddingFactory,
|
||||
EmbeddingFactory,
|
||||
WrappedEmbeddingFactory,
|
||||
)
|
||||
from .embeddings import ( # noqa: F401
|
||||
Embeddings,
|
||||
HuggingFaceBgeEmbeddings,
|
||||
@@ -21,4 +25,5 @@ __ALL__ = [
|
||||
"OpenAPIEmbeddings",
|
||||
"DefaultEmbeddingFactory",
|
||||
"EmbeddingFactory",
|
||||
"WrappedEmbeddingFactory",
|
||||
]
|
||||
|
32
dbgpt/rag/embedding/_wrapped.py
Normal file
32
dbgpt/rag/embedding/_wrapped.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""Wraps the third-party language model embeddings to the common interface."""
|
||||
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
from dbgpt.core import Embeddings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.embeddings.base import Embeddings as LangChainEmbeddings
|
||||
|
||||
|
||||
class WrappedEmbeddings(Embeddings):
|
||||
"""Wraps the third-party language model embeddings to the common interface."""
|
||||
|
||||
def __init__(self, embeddings: "LangChainEmbeddings") -> None:
|
||||
"""Create a new WrappedEmbeddings."""
|
||||
self._embeddings = embeddings
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed search docs."""
|
||||
return self._embeddings.embed_documents(texts)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Embed query text."""
|
||||
return self._embeddings.embed_query(text)
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Asynchronous Embed search docs."""
|
||||
return await self._embeddings.aembed_documents(texts)
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
"""Asynchronous Embed query text."""
|
||||
return await self._embeddings.aembed_query(text)
|
@@ -1,15 +1,14 @@
|
||||
"""EmbeddingFactory class and DefaultEmbeddingFactory class."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Optional, Type
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
from dbgpt.component import BaseComponent, SystemApp
|
||||
from dbgpt.rag.embedding.embeddings import HuggingFaceEmbeddings
|
||||
from dbgpt.core import Embeddings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dbgpt.rag.embedding.embeddings import Embeddings
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmbeddingFactory(BaseComponent, ABC):
|
||||
@@ -20,7 +19,7 @@ class EmbeddingFactory(BaseComponent, ABC):
|
||||
@abstractmethod
|
||||
def create(
|
||||
self, model_name: Optional[str] = None, embedding_cls: Optional[Type] = None
|
||||
) -> "Embeddings":
|
||||
) -> Embeddings:
|
||||
"""Create an embedding instance.
|
||||
|
||||
Args:
|
||||
@@ -39,12 +38,19 @@ class DefaultEmbeddingFactory(EmbeddingFactory):
|
||||
self,
|
||||
system_app: Optional[SystemApp] = None,
|
||||
default_model_name: Optional[str] = None,
|
||||
default_model_path: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Create a new DefaultEmbeddingFactory."""
|
||||
super().__init__(system_app=system_app)
|
||||
if not default_model_path:
|
||||
default_model_path = default_model_name
|
||||
if not default_model_name:
|
||||
default_model_name = default_model_path
|
||||
self._default_model_name = default_model_name
|
||||
self.kwargs = kwargs
|
||||
self._default_model_path = default_model_path
|
||||
self._kwargs = kwargs
|
||||
self._model = self._load_model()
|
||||
|
||||
def init_app(self, system_app):
|
||||
"""Init the app."""
|
||||
@@ -52,20 +58,166 @@ class DefaultEmbeddingFactory(EmbeddingFactory):
|
||||
|
||||
def create(
|
||||
self, model_name: Optional[str] = None, embedding_cls: Optional[Type] = None
|
||||
) -> "Embeddings":
|
||||
) -> Embeddings:
|
||||
"""Create an embedding instance.
|
||||
|
||||
Args:
|
||||
model_name (str): The model name.
|
||||
embedding_cls (Type): The embedding class.
|
||||
"""
|
||||
if not model_name:
|
||||
model_name = self._default_model_name
|
||||
|
||||
new_kwargs = {k: v for k, v in self.kwargs.items()}
|
||||
new_kwargs["model_name"] = model_name
|
||||
|
||||
if embedding_cls:
|
||||
return embedding_cls(**new_kwargs)
|
||||
else:
|
||||
return HuggingFaceEmbeddings(**new_kwargs)
|
||||
raise NotImplementedError
|
||||
return self._model
|
||||
|
||||
def _load_model(self) -> Embeddings:
|
||||
from dbgpt.model.adapter.embeddings_loader import (
|
||||
EmbeddingLoader,
|
||||
_parse_embedding_params,
|
||||
)
|
||||
from dbgpt.model.parameter import (
|
||||
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
|
||||
BaseEmbeddingModelParameters,
|
||||
EmbeddingModelParameters,
|
||||
)
|
||||
|
||||
param_cls = EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
|
||||
self._default_model_name, EmbeddingModelParameters
|
||||
)
|
||||
model_params: BaseEmbeddingModelParameters = _parse_embedding_params(
|
||||
model_name=self._default_model_name,
|
||||
model_path=self._default_model_path,
|
||||
param_cls=param_cls,
|
||||
**self._kwargs,
|
||||
)
|
||||
logger.info(model_params)
|
||||
loader = EmbeddingLoader()
|
||||
# Ignore model_name args
|
||||
model_name = self._default_model_name or model_params.model_name
|
||||
if not model_name:
|
||||
raise ValueError("model_name must be provided.")
|
||||
return loader.load(model_name, model_params)
|
||||
|
||||
@classmethod
|
||||
def openai(
|
||||
cls,
|
||||
api_url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
model_name: str = "text-embedding-3-small",
|
||||
timeout: int = 60,
|
||||
**kwargs: Any,
|
||||
) -> Embeddings:
|
||||
"""Create an OpenAI embeddings.
|
||||
|
||||
If api_url and api_key are not provided, we will try to get them from
|
||||
environment variables.
|
||||
|
||||
Args:
|
||||
api_url (Optional[str], optional): The api url. Defaults to None.
|
||||
api_key (Optional[str], optional): The api key. Defaults to None.
|
||||
model_name (str, optional): The model name.
|
||||
Defaults to "text-embedding-3-small".
|
||||
timeout (int, optional): The timeout. Defaults to 60.
|
||||
|
||||
Returns:
|
||||
Embeddings: The embeddings instance.
|
||||
"""
|
||||
api_url = (
|
||||
api_url
|
||||
or os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1") + "/embeddings"
|
||||
)
|
||||
api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError("api_key must be provided.")
|
||||
return cls.remote(
|
||||
api_url=api_url,
|
||||
api_key=api_key,
|
||||
model_name=model_name,
|
||||
timeout=timeout,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def default(
|
||||
cls, model_name: str, model_path: Optional[str] = None, **kwargs: Any
|
||||
) -> Embeddings:
|
||||
"""Create a default embeddings.
|
||||
|
||||
It will try to load the model from the model name or model path.
|
||||
|
||||
Args:
|
||||
model_name (str): The model name.
|
||||
model_path (Optional[str], optional): The model path. Defaults to None.
|
||||
if not provided, it will use the model name as the model path to load
|
||||
the model.
|
||||
|
||||
Returns:
|
||||
Embeddings: The embeddings instance.
|
||||
"""
|
||||
return cls(
|
||||
default_model_name=model_name, default_model_path=model_path, **kwargs
|
||||
).create()
|
||||
|
||||
@classmethod
|
||||
def remote(
|
||||
cls,
|
||||
api_url: str = "http://localhost:8100/api/v1/embeddings",
|
||||
api_key: Optional[str] = None,
|
||||
model_name: str = "text2vec",
|
||||
timeout: int = 60,
|
||||
**kwargs: Any,
|
||||
) -> Embeddings:
|
||||
"""Create a remote embeddings.
|
||||
|
||||
Create a remote embeddings which API compatible with the OpenAI's API. So if
|
||||
your model is compatible with OpenAI's API, you can use this method to create
|
||||
a remote embeddings.
|
||||
|
||||
Args:
|
||||
api_url (str, optional): The api url. Defaults to
|
||||
"http://localhost:8100/api/v1/embeddings".
|
||||
api_key (Optional[str], optional): The api key. Defaults to None.
|
||||
model_name (str, optional): The model name. Defaults to "text2vec".
|
||||
timeout (int, optional): The timeout. Defaults to 60.
|
||||
"""
|
||||
from .embeddings import OpenAPIEmbeddings
|
||||
|
||||
return OpenAPIEmbeddings(
|
||||
api_url=api_url,
|
||||
api_key=api_key,
|
||||
model_name=model_name,
|
||||
timeout=timeout,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class WrappedEmbeddingFactory(EmbeddingFactory):
|
||||
"""The default embedding factory."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
system_app: Optional[SystemApp] = None,
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Create a new DefaultEmbeddingFactory."""
|
||||
super().__init__(system_app=system_app)
|
||||
if not embeddings:
|
||||
raise ValueError("embeddings must be provided.")
|
||||
self._model = embeddings
|
||||
|
||||
def init_app(self, system_app):
|
||||
"""Init the app."""
|
||||
pass
|
||||
|
||||
def create(
|
||||
self, model_name: Optional[str] = None, embedding_cls: Optional[Type] = None
|
||||
) -> Embeddings:
|
||||
"""Create an embedding instance.
|
||||
|
||||
Args:
|
||||
model_name (str): The model name.
|
||||
embedding_cls (Type): The embedding class.
|
||||
"""
|
||||
if embedding_cls:
|
||||
raise NotImplementedError
|
||||
return self._model
|
||||
|
Reference in New Issue
Block a user