Files
DB-GPT/pilot/graph_engine/graph_factory.py
2023-10-13 17:13:51 +08:00

35 lines
1.0 KiB
Python

from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Type
from pilot.component import BaseComponent, ComponentType
class RAGGraphFactory(BaseComponent, ABC):
name = ComponentType.RAG_GRAPH_DEFAULT.value
@abstractmethod
def create(self, model_name: str = None, embedding_cls: Type = None):
"""Create RAG Graph Engine"""
class DefaultRAGGraphFactory(RAGGraphFactory):
def __init__(
self, system_app=None, default_model_name: str = None, **kwargs: Any
) -> None:
super().__init__(system_app=system_app)
self._default_model_name = default_model_name
self.kwargs = kwargs
from pilot.graph_engine.graph_engine import RAGGraphEngine
self.rag_engine = RAGGraphEngine(model_name="proxyllm")
def init_app(self, system_app):
pass
def create(self, model_name: str = None, rag_cls: Type = None):
if not model_name:
model_name = self._default_model_name
return self.rag_engine