feat(rag): Support RAG SDK (#1322)

This commit is contained in:
Fangyin Cheng
2024-03-22 15:36:57 +08:00
committed by GitHub
parent e65732d6e4
commit 8a17099dd2
69 changed files with 1332 additions and 558 deletions

View File

@@ -1,6 +1,10 @@
"""Module for embedding related classes and functions."""
from .embedding_factory import DefaultEmbeddingFactory, EmbeddingFactory # noqa: F401
from .embedding_factory import ( # noqa: F401
DefaultEmbeddingFactory,
EmbeddingFactory,
WrappedEmbeddingFactory,
)
from .embeddings import ( # noqa: F401
Embeddings,
HuggingFaceBgeEmbeddings,
@@ -21,4 +25,5 @@ __ALL__ = [
"OpenAPIEmbeddings",
"DefaultEmbeddingFactory",
"EmbeddingFactory",
"WrappedEmbeddingFactory",
]

View File

@@ -0,0 +1,32 @@
"""Wraps the third-party language model embeddings to the common interface."""
from typing import TYPE_CHECKING, List
from dbgpt.core import Embeddings
if TYPE_CHECKING:
from langchain.embeddings.base import Embeddings as LangChainEmbeddings
class WrappedEmbeddings(Embeddings):
"""Wraps the third-party language model embeddings to the common interface."""
def __init__(self, embeddings: "LangChainEmbeddings") -> None:
"""Create a new WrappedEmbeddings."""
self._embeddings = embeddings
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed search docs."""
return self._embeddings.embed_documents(texts)
def embed_query(self, text: str) -> List[float]:
"""Embed query text."""
return self._embeddings.embed_query(text)
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Asynchronous Embed search docs."""
return await self._embeddings.aembed_documents(texts)
async def aembed_query(self, text: str) -> List[float]:
"""Asynchronous Embed query text."""
return await self._embeddings.aembed_query(text)

View File

@@ -1,15 +1,14 @@
"""EmbeddingFactory class and DefaultEmbeddingFactory class."""
from __future__ import annotations
import logging
import os
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Optional, Type
from typing import Any, Optional, Type
from dbgpt.component import BaseComponent, SystemApp
from dbgpt.rag.embedding.embeddings import HuggingFaceEmbeddings
from dbgpt.core import Embeddings
if TYPE_CHECKING:
from dbgpt.rag.embedding.embeddings import Embeddings
logger = logging.getLogger(__name__)
class EmbeddingFactory(BaseComponent, ABC):
@@ -20,7 +19,7 @@ class EmbeddingFactory(BaseComponent, ABC):
@abstractmethod
def create(
self, model_name: Optional[str] = None, embedding_cls: Optional[Type] = None
) -> "Embeddings":
) -> Embeddings:
"""Create an embedding instance.
Args:
@@ -39,12 +38,19 @@ class DefaultEmbeddingFactory(EmbeddingFactory):
self,
system_app: Optional[SystemApp] = None,
default_model_name: Optional[str] = None,
default_model_path: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Create a new DefaultEmbeddingFactory."""
super().__init__(system_app=system_app)
if not default_model_path:
default_model_path = default_model_name
if not default_model_name:
default_model_name = default_model_path
self._default_model_name = default_model_name
self.kwargs = kwargs
self._default_model_path = default_model_path
self._kwargs = kwargs
self._model = self._load_model()
def init_app(self, system_app):
"""Init the app."""
@@ -52,20 +58,166 @@ class DefaultEmbeddingFactory(EmbeddingFactory):
def create(
self, model_name: Optional[str] = None, embedding_cls: Optional[Type] = None
) -> "Embeddings":
) -> Embeddings:
"""Create an embedding instance.
Args:
model_name (str): The model name.
embedding_cls (Type): The embedding class.
"""
if not model_name:
model_name = self._default_model_name
new_kwargs = {k: v for k, v in self.kwargs.items()}
new_kwargs["model_name"] = model_name
if embedding_cls:
return embedding_cls(**new_kwargs)
else:
return HuggingFaceEmbeddings(**new_kwargs)
raise NotImplementedError
return self._model
def _load_model(self) -> Embeddings:
from dbgpt.model.adapter.embeddings_loader import (
EmbeddingLoader,
_parse_embedding_params,
)
from dbgpt.model.parameter import (
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
BaseEmbeddingModelParameters,
EmbeddingModelParameters,
)
param_cls = EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
self._default_model_name, EmbeddingModelParameters
)
model_params: BaseEmbeddingModelParameters = _parse_embedding_params(
model_name=self._default_model_name,
model_path=self._default_model_path,
param_cls=param_cls,
**self._kwargs,
)
logger.info(model_params)
loader = EmbeddingLoader()
# Ignore model_name args
model_name = self._default_model_name or model_params.model_name
if not model_name:
raise ValueError("model_name must be provided.")
return loader.load(model_name, model_params)
@classmethod
def openai(
cls,
api_url: Optional[str] = None,
api_key: Optional[str] = None,
model_name: str = "text-embedding-3-small",
timeout: int = 60,
**kwargs: Any,
) -> Embeddings:
"""Create an OpenAI embeddings.
If api_url and api_key are not provided, we will try to get them from
environment variables.
Args:
api_url (Optional[str], optional): The api url. Defaults to None.
api_key (Optional[str], optional): The api key. Defaults to None.
model_name (str, optional): The model name.
Defaults to "text-embedding-3-small".
timeout (int, optional): The timeout. Defaults to 60.
Returns:
Embeddings: The embeddings instance.
"""
api_url = (
api_url
or os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1") + "/embeddings"
)
api_key = api_key or os.getenv("OPENAI_API_KEY")
if not api_key:
raise ValueError("api_key must be provided.")
return cls.remote(
api_url=api_url,
api_key=api_key,
model_name=model_name,
timeout=timeout,
**kwargs,
)
@classmethod
def default(
cls, model_name: str, model_path: Optional[str] = None, **kwargs: Any
) -> Embeddings:
"""Create a default embeddings.
It will try to load the model from the model name or model path.
Args:
model_name (str): The model name.
model_path (Optional[str], optional): The model path. Defaults to None.
if not provided, it will use the model name as the model path to load
the model.
Returns:
Embeddings: The embeddings instance.
"""
return cls(
default_model_name=model_name, default_model_path=model_path, **kwargs
).create()
@classmethod
def remote(
cls,
api_url: str = "http://localhost:8100/api/v1/embeddings",
api_key: Optional[str] = None,
model_name: str = "text2vec",
timeout: int = 60,
**kwargs: Any,
) -> Embeddings:
"""Create a remote embeddings.
Create a remote embeddings which API compatible with the OpenAI's API. So if
your model is compatible with OpenAI's API, you can use this method to create
a remote embeddings.
Args:
api_url (str, optional): The api url. Defaults to
"http://localhost:8100/api/v1/embeddings".
api_key (Optional[str], optional): The api key. Defaults to None.
model_name (str, optional): The model name. Defaults to "text2vec".
timeout (int, optional): The timeout. Defaults to 60.
"""
from .embeddings import OpenAPIEmbeddings
return OpenAPIEmbeddings(
api_url=api_url,
api_key=api_key,
model_name=model_name,
timeout=timeout,
**kwargs,
)
class WrappedEmbeddingFactory(EmbeddingFactory):
"""The default embedding factory."""
def __init__(
self,
system_app: Optional[SystemApp] = None,
embeddings: Optional[Embeddings] = None,
**kwargs: Any,
) -> None:
"""Create a new DefaultEmbeddingFactory."""
super().__init__(system_app=system_app)
if not embeddings:
raise ValueError("embeddings must be provided.")
self._model = embeddings
def init_app(self, system_app):
"""Init the app."""
pass
def create(
self, model_name: Optional[str] = None, embedding_cls: Optional[Type] = None
) -> Embeddings:
"""Create an embedding instance.
Args:
model_name (str): The model name.
embedding_cls (Type): The embedding class.
"""
if embedding_cls:
raise NotImplementedError
return self._model