feat:knowledge rag graph

This commit is contained in:
aries_ckt 2023-10-10 20:25:51 +05:00
parent 813e2260a6
commit fc656e1c61
7 changed files with 58 additions and 6 deletions

View File

@ -41,7 +41,11 @@ class KnowledgeType(Enum):
def get_knowledge_embedding(
knowledge_type, knowledge_source, vector_store_config, source_reader, text_splitter
knowledge_type,
knowledge_source,
vector_store_config=None,
source_reader=None,
text_splitter=None,
):
match knowledge_type:
case KnowledgeType.DOCUMENT.value:

View File

@ -31,11 +31,11 @@ class SourceEmbedding(ABC):
):
"""Initialize with Loader url, model_name, vector_store_config"""
self.file_path = file_path
self.vector_store_config = vector_store_config
self.vector_store_config = vector_store_config or {}
self.source_reader = source_reader or None
self.text_splitter = text_splitter or None
self.embedding_args = embedding_args
self.embeddings = vector_store_config["embeddings"]
self.embeddings = self.vector_store_config.get("embeddings", None)
@abstractmethod
@register

View File

@ -27,7 +27,7 @@ class URLEmbedding(SourceEmbedding):
file_path, vector_store_config, source_reader=None, text_splitter=None
)
self.file_path = file_path
self.vector_store_config = vector_store_config
self.vector_store_config = vector_store_config or None
self.source_reader = source_reader or None
self.text_splitter = text_splitter or None

View File

@ -75,6 +75,16 @@ class ChatScene(Enum):
"Dialogue through natural language and private documents and knowledge bases.",
["Knowledge Space Select"],
)
ExtractTriplet = Scene(
"extract_triplet",
"Extract Triplet",
"Extract Triplet",
["Extract Select"],
True,
)
ExtractEntity = Scene(
"extract_entity", "Extract Entity", "Extract Entity", ["Extract Select"], True
)
@staticmethod
def of_mode(mode):

View File

@ -205,6 +205,40 @@ class BaseChat(ABC):
self.memory.append(self.current_message)
return self.current_ai_response()
async def get_llm_response(self):
payload = self.__call_base()
logger.info(f"Request: \n{payload}")
ai_response_text = ""
try:
from pilot.model.cluster import WorkerManagerFactory
worker_manager = CFG.SYSTEM_APP.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
model_output = await worker_manager.generate(payload)
### output parse
ai_response_text = (
self.prompt_template.output_parser.parse_model_nostream_resp(
model_output, self.prompt_template.sep
)
)
### model result deal
self.current_message.add_ai_message(ai_response_text)
prompt_define_response = (
self.prompt_template.output_parser.parse_prompt_response(
ai_response_text
)
)
except Exception as e:
print(traceback.format_exc())
logger.error("model response parse failed" + str(e))
self.current_message.add_view_message(
f"""model response parse failed{str(e)}\n {ai_response_text} """
)
return prompt_define_response
def _blocking_stream_call(self):
logger.warn(
"_blocking_stream_call is only temporarily used in webserver and will be deleted soon, please use stream_call to replace it for higher performance"

View File

@ -13,6 +13,8 @@ class ChatFactory(metaclass=Singleton):
from pilot.scene.chat_dashboard.chat import ChatDashboard
from pilot.scene.chat_knowledge.v1.chat import ChatKnowledge
from pilot.scene.chat_knowledge.inner_db_summary.chat import InnerChatDBSummary
from pilot.scene.chat_knowledge.extract_triplet.chat import ExtractTriplet
from pilot.scene.chat_knowledge.extract_entity.chat import ExtractEntity
from pilot.scene.chat_data.chat_excel.excel_analyze.chat import ChatExcel
chat_classes = BaseChat.__subclasses__()

View File

@ -168,10 +168,12 @@ def get_or_create_event_loop() -> asyncio.BaseEventLoop:
assert loop is not None
return loop
except RuntimeError as e:
if not "no running event loop" in str(e):
if not "no running event loop" in str(e) and not "no current event loop" in str(
e
):
raise e
logging.warning("Cant not get running event loop, create new event loop now")
return asyncio.get_event_loop_policy().get_event_loop()
return asyncio.get_event_loop_policy().new_event_loop()
def logging_str_to_uvicorn_level(log_level_str):