refactor: RAG Refactor (#985)

Co-authored-by: Aralhi <xiaoping0501@gmail.com>
Co-authored-by: csunny <cfqsunny@163.com>
This commit is contained in:
Aries-ckt
2024-01-03 09:45:26 +08:00
committed by GitHub
parent 90775aad50
commit 9ad70a2961
206 changed files with 5766 additions and 2419 deletions

View File

@@ -0,0 +1,47 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Type, TYPE_CHECKING
from dbgpt.component import BaseComponent
from dbgpt.rag.embedding.embeddings import HuggingFaceEmbeddings
if TYPE_CHECKING:
from dbgpt.rag.embedding.embeddings import Embeddings
class EmbeddingFactory(BaseComponent, ABC):
"""Abstract base class for EmbeddingFactory."""
name = "embedding_factory"
@abstractmethod
def create(
self, model_name: str = None, embedding_cls: Type = None
) -> "Embeddings":
"""Create embedding"""
class DefaultEmbeddingFactory(EmbeddingFactory):
def __init__(
self, system_app=None, default_model_name: str = None, **kwargs: Any
) -> None:
super().__init__(system_app=system_app)
self._default_model_name = default_model_name
self.kwargs = kwargs
def init_app(self, system_app):
pass
def create(
self, model_name: str = None, embedding_cls: Type = None
) -> "Embeddings":
if not model_name:
model_name = self._default_model_name
new_kwargs = {k: v for k, v in self.kwargs.items()}
new_kwargs["model_name"] = model_name
if embedding_cls:
return embedding_cls(**new_kwargs)
else:
return HuggingFaceEmbeddings(**new_kwargs)