mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-30 22:10:10 +00:00
46 lines
1.2 KiB
Python
46 lines
1.2 KiB
Python
from abc import ABC, abstractmethod
|
|
from typing import Any, Type, TYPE_CHECKING
|
|
|
|
from pilot.componet import BaseComponet
|
|
|
|
if TYPE_CHECKING:
|
|
from langchain.embeddings.base import Embeddings
|
|
|
|
|
|
class EmbeddingFactory(BaseComponet, ABC):
|
|
name = "embedding_factory"
|
|
|
|
@abstractmethod
|
|
def create(
|
|
self, model_name: str = None, embedding_cls: Type = None
|
|
) -> "Embeddings":
|
|
"""Create embedding"""
|
|
|
|
|
|
class DefaultEmbeddingFactory(EmbeddingFactory):
|
|
def __init__(
|
|
self, system_app=None, default_model_name: str = None, **kwargs: Any
|
|
) -> None:
|
|
super().__init__(system_app=system_app)
|
|
self._default_model_name = default_model_name
|
|
self.kwargs = kwargs
|
|
|
|
def init_app(self, system_app):
|
|
pass
|
|
|
|
def create(
|
|
self, model_name: str = None, embedding_cls: Type = None
|
|
) -> "Embeddings":
|
|
if not model_name:
|
|
model_name = self._default_model_name
|
|
|
|
new_kwargs = {k: v for k, v in self.kwargs.items()}
|
|
new_kwargs["model_name"] = model_name
|
|
|
|
if embedding_cls:
|
|
return embedding_cls(**new_kwargs)
|
|
else:
|
|
from langchain.embeddings import HuggingFaceEmbeddings
|
|
|
|
return HuggingFaceEmbeddings(**new_kwargs)
|