From fc656e1c61634c604300e4e489aa1b44b20c3ce0 Mon Sep 17 00:00:00 2001
From: aries_ckt <916701291@qq.com>
Date: Tue, 10 Oct 2023 20:25:51 +0500
Subject: [PATCH 01/57] feat:knowledge rag graph
---
pilot/embedding_engine/knowledge_type.py | 6 +++-
pilot/embedding_engine/source_embedding.py | 4 +--
pilot/embedding_engine/url_embedding.py | 2 +-
pilot/scene/base.py | 10 +++++++
pilot/scene/base_chat.py | 34 ++++++++++++++++++++++
pilot/scene/chat_factory.py | 2 ++
pilot/utils/utils.py | 6 ++--
7 files changed, 58 insertions(+), 6 deletions(-)
diff --git a/pilot/embedding_engine/knowledge_type.py b/pilot/embedding_engine/knowledge_type.py
index 77fb98666..c4c278be4 100644
--- a/pilot/embedding_engine/knowledge_type.py
+++ b/pilot/embedding_engine/knowledge_type.py
@@ -41,7 +41,11 @@ class KnowledgeType(Enum):
def get_knowledge_embedding(
- knowledge_type, knowledge_source, vector_store_config, source_reader, text_splitter
+ knowledge_type,
+ knowledge_source,
+ vector_store_config=None,
+ source_reader=None,
+ text_splitter=None,
):
match knowledge_type:
case KnowledgeType.DOCUMENT.value:
diff --git a/pilot/embedding_engine/source_embedding.py b/pilot/embedding_engine/source_embedding.py
index 24bae97b2..5b1e57ae2 100644
--- a/pilot/embedding_engine/source_embedding.py
+++ b/pilot/embedding_engine/source_embedding.py
@@ -31,11 +31,11 @@ class SourceEmbedding(ABC):
):
"""Initialize with Loader url, model_name, vector_store_config"""
self.file_path = file_path
- self.vector_store_config = vector_store_config
+ self.vector_store_config = vector_store_config or {}
self.source_reader = source_reader or None
self.text_splitter = text_splitter or None
self.embedding_args = embedding_args
- self.embeddings = vector_store_config["embeddings"]
+ self.embeddings = self.vector_store_config.get("embeddings", None)
@abstractmethod
@register
diff --git a/pilot/embedding_engine/url_embedding.py b/pilot/embedding_engine/url_embedding.py
index e00cf84e2..39e7bf1dc 100644
--- a/pilot/embedding_engine/url_embedding.py
+++ b/pilot/embedding_engine/url_embedding.py
@@ -27,7 +27,7 @@ class URLEmbedding(SourceEmbedding):
file_path, vector_store_config, source_reader=None, text_splitter=None
)
self.file_path = file_path
- self.vector_store_config = vector_store_config
+ self.vector_store_config = vector_store_config or None
self.source_reader = source_reader or None
self.text_splitter = text_splitter or None
diff --git a/pilot/scene/base.py b/pilot/scene/base.py
index 162759e3c..489e98d0b 100644
--- a/pilot/scene/base.py
+++ b/pilot/scene/base.py
@@ -75,6 +75,16 @@ class ChatScene(Enum):
"Dialogue through natural language and private documents and knowledge bases.",
["Knowledge Space Select"],
)
+ ExtractTriplet = Scene(
+ "extract_triplet",
+ "Extract Triplet",
+ "Extract Triplet",
+ ["Extract Select"],
+ True,
+ )
+ ExtractEntity = Scene(
+ "extract_entity", "Extract Entity", "Extract Entity", ["Extract Select"], True
+ )
@staticmethod
def of_mode(mode):
diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py
index daab56964..5c16ee286 100644
--- a/pilot/scene/base_chat.py
+++ b/pilot/scene/base_chat.py
@@ -205,6 +205,40 @@ class BaseChat(ABC):
self.memory.append(self.current_message)
return self.current_ai_response()
+ async def get_llm_response(self):
+ payload = self.__call_base()
+ logger.info(f"Request: \n{payload}")
+ ai_response_text = ""
+ try:
+ from pilot.model.cluster import WorkerManagerFactory
+
+ worker_manager = CFG.SYSTEM_APP.get_component(
+ ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
+ ).create()
+
+ model_output = await worker_manager.generate(payload)
+
+ ### output parse
+ ai_response_text = (
+ self.prompt_template.output_parser.parse_model_nostream_resp(
+ model_output, self.prompt_template.sep
+ )
+ )
+ ### model result deal
+ self.current_message.add_ai_message(ai_response_text)
+ prompt_define_response = (
+ self.prompt_template.output_parser.parse_prompt_response(
+ ai_response_text
+ )
+ )
+ except Exception as e:
+ print(traceback.format_exc())
+ logger.error("model response parse failed!" + str(e))
+ self.current_message.add_view_message(
+ f"""model response parse failed!{str(e)}\n {ai_response_text} """
+ )
+ return prompt_define_response
+
def _blocking_stream_call(self):
logger.warn(
"_blocking_stream_call is only temporarily used in webserver and will be deleted soon, please use stream_call to replace it for higher performance"
diff --git a/pilot/scene/chat_factory.py b/pilot/scene/chat_factory.py
index ad2da3b3f..fc47b7468 100644
--- a/pilot/scene/chat_factory.py
+++ b/pilot/scene/chat_factory.py
@@ -13,6 +13,8 @@ class ChatFactory(metaclass=Singleton):
from pilot.scene.chat_dashboard.chat import ChatDashboard
from pilot.scene.chat_knowledge.v1.chat import ChatKnowledge
from pilot.scene.chat_knowledge.inner_db_summary.chat import InnerChatDBSummary
+ from pilot.scene.chat_knowledge.extract_triplet.chat import ExtractTriplet
+ from pilot.scene.chat_knowledge.extract_entity.chat import ExtractEntity
from pilot.scene.chat_data.chat_excel.excel_analyze.chat import ChatExcel
chat_classes = BaseChat.__subclasses__()
diff --git a/pilot/utils/utils.py b/pilot/utils/utils.py
index b72745a33..ebb3534e0 100644
--- a/pilot/utils/utils.py
+++ b/pilot/utils/utils.py
@@ -168,10 +168,12 @@ def get_or_create_event_loop() -> asyncio.BaseEventLoop:
assert loop is not None
return loop
except RuntimeError as e:
- if not "no running event loop" in str(e):
+ if not "no running event loop" in str(e) and not "no current event loop" in str(
+ e
+ ):
raise e
logging.warning("Cant not get running event loop, create new event loop now")
- return asyncio.get_event_loop_policy().get_event_loop()
+ return asyncio.get_event_loop_policy().new_event_loop()
def logging_str_to_uvicorn_level(log_level_str):
From eb2c220d227464006d203306665e4602be597284 Mon Sep 17 00:00:00 2001
From: aries_ckt <916701291@qq.com>
Date: Thu, 12 Oct 2023 11:00:21 +0500
Subject: [PATCH 02/57] feat:rag graph conponent
---
pilot/component.py | 1 +
pilot/server/component_configs.py | 4 ++++
2 files changed, 5 insertions(+)
diff --git a/pilot/component.py b/pilot/component.py
index 3179fa696..74af0e150 100644
--- a/pilot/component.py
+++ b/pilot/component.py
@@ -47,6 +47,7 @@ class ComponentType(str, Enum):
WORKER_MANAGER_FACTORY = "dbgpt_worker_manager_factory"
MODEL_CONTROLLER = "dbgpt_model_controller"
EXECUTOR_DEFAULT = "dbgpt_thread_pool_default"
+ RAG_GRAPH_DEFAULT = "dbgpt_rag_engine_default"
class BaseComponent(LifeCycle, ABC):
diff --git a/pilot/server/component_configs.py b/pilot/server/component_configs.py
index 71ef797d9..f7bcde332 100644
--- a/pilot/server/component_configs.py
+++ b/pilot/server/component_configs.py
@@ -28,6 +28,10 @@ def initialize_components(
system_app.register_instance(controller)
+ # Register global default RAGGraphFactory
+ from pilot.graph_engine.graph_factory import DefaultRAGGraphFactory
+ system_app.register(DefaultRAGGraphFactory)
+
_initialize_embedding_model(
param, system_app, embedding_model_name, embedding_model_path
)
From fa6a9040d5e0670ba491b5574101e5de3e8f046a Mon Sep 17 00:00:00 2001
From: aries_ckt <916701291@qq.com>
Date: Fri, 13 Oct 2023 14:22:46 +0800
Subject: [PATCH 03/57] feat:knowledge rag graph
---
pilot/scene/chat_knowledge/v1/chat.py | 25 ++++++++++++++++
pilot/server/component_configs.py | 1 +
pilot/server/knowledge/api.py | 35 +++++++++++++++++++++++
pilot/server/knowledge/request/request.py | 7 +++++
pilot/server/knowledge/service.py | 22 +++++++++++++-
5 files changed, 89 insertions(+), 1 deletion(-)
diff --git a/pilot/scene/chat_knowledge/v1/chat.py b/pilot/scene/chat_knowledge/v1/chat.py
index 8177a1a5a..ebecddd19 100644
--- a/pilot/scene/chat_knowledge/v1/chat.py
+++ b/pilot/scene/chat_knowledge/v1/chat.py
@@ -1,6 +1,7 @@
import os
from typing import Dict
+from pilot.component import ComponentType
from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene
from pilot.configs.config import Config
@@ -47,6 +48,11 @@ class ChatKnowledge(BaseChat):
"vector_store_name": self.knowledge_space,
"vector_store_type": CFG.VECTOR_STORE_TYPE,
}
+ from pilot.graph_engine.graph_factory import RAGGraphFactory
+
+ self.rag_engine = CFG.SYSTEM_APP.get_component(
+ ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory
+ ).create()
embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory
)
@@ -82,6 +88,25 @@ class ChatKnowledge(BaseChat):
if self.space_context:
self.prompt_template.template_define = self.space_context["prompt"]["scene"]
self.prompt_template.template = self.space_context["prompt"]["template"]
+ # docs = self.rag_engine.search(query=self.current_user_input)
+ # import httpx
+ # with httpx.Client() as client:
+ # request = client.build_request(
+ # "post",
+ # "http://127.0.0.1/api/knowledge/entities/extract",
+ # json="application/json", # using json for data to ensure it sends as application/json
+ # params={"text": self.current_user_input},
+ # headers={},
+ # )
+ #
+ # response = client.send(request)
+ # if response.status_code != 200:
+ # error_msg = f"request /api/knowledge/entities/extract failed, error: {response.text}"
+ # raise Exception(error_msg)
+ # docs = response.json()
+ # import requests
+ # docs = requests.post("http://127.0.0.1:5000/api/knowledge/entities/extract", headers={}, json={"text": self.current_user_input})
+
docs = self.knowledge_embedding_client.similar_search(
self.current_user_input, self.top_k
)
diff --git a/pilot/server/component_configs.py b/pilot/server/component_configs.py
index 7d306ada1..ba5c35ec6 100644
--- a/pilot/server/component_configs.py
+++ b/pilot/server/component_configs.py
@@ -31,6 +31,7 @@ def initialize_components(
# Register global default RAGGraphFactory
from pilot.graph_engine.graph_factory import DefaultRAGGraphFactory
+
system_app.register(DefaultRAGGraphFactory)
_initialize_embedding_model(
diff --git a/pilot/server/knowledge/api.py b/pilot/server/knowledge/api.py
index 57fadb21e..e0f31031e 100644
--- a/pilot/server/knowledge/api.py
+++ b/pilot/server/knowledge/api.py
@@ -24,6 +24,7 @@ from pilot.server.knowledge.request.request import (
ChunkQueryRequest,
DocumentQueryRequest,
SpaceArgumentRequest,
+ EntityExtractRequest,
)
from pilot.server.knowledge.request.request import KnowledgeSpaceRequest
@@ -198,3 +199,37 @@ def similar_query(space_name: str, query_request: KnowledgeQueryRequest):
for d in docs
]
return {"response": res}
+
+
+@router.post("/knowledge/entity/extract")
+async def entity_extract(request: EntityExtractRequest):
+ logger.info(f"Received params: {request}")
+ try:
+ # from pilot.graph_engine.graph_factory import RAGGraphFactory
+ # from pilot.component import ComponentType
+ # rag_engine = CFG.SYSTEM_APP.get_component(
+ # ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory
+ # ).create()
+ # return Result.succ(await rag_engine.search(request.text))
+ from pilot.scene.base import ChatScene
+ from pilot.common.chat_util import llm_chat_response_nostream
+ import uuid
+
+ chat_param = {
+ "chat_session_id": uuid.uuid1(),
+ "current_user_input": request.text,
+ "select_param": "entity",
+ "model_name": request.model_name,
+ }
+
+ # import nest_asyncio
+ # nest_asyncio.apply()
+ # loop = asyncio.get_event_loop()
+ # loop.stop()
+ # loop = utils.get_or_create_event_loop()
+ res = await llm_chat_response_nostream(
+ ChatScene.ExtractEntity.value(), **{"chat_param": chat_param}
+ )
+ return Result.succ(res)
+ except Exception as e:
+ return Result.faild(code="E000X", msg=f"entity extract error {e}")
diff --git a/pilot/server/knowledge/request/request.py b/pilot/server/knowledge/request/request.py
index b83165c19..c6b94ff0d 100644
--- a/pilot/server/knowledge/request/request.py
+++ b/pilot/server/knowledge/request/request.py
@@ -104,3 +104,10 @@ class SpaceArgumentRequest(BaseModel):
"""argument: argument"""
argument: str
+
+
+class EntityExtractRequest(BaseModel):
+ """argument: argument"""
+
+ text: str
+ model_name: str
diff --git a/pilot/server/knowledge/service.py b/pilot/server/knowledge/service.py
index c11fc3b46..f4150fa73 100644
--- a/pilot/server/knowledge/service.py
+++ b/pilot/server/knowledge/service.py
@@ -58,7 +58,11 @@ class SyncStatus(Enum):
# @singleton
class KnowledgeService:
def __init__(self):
- pass
+ from pilot.graph_engine.graph_engine import RAGGraphEngine
+
+ # source = "/Users/chenketing/Desktop/project/llama_index/examples/paul_graham_essay/data/test/test_kg_text.txt"
+
+ # pass
"""create knowledge space"""
@@ -229,6 +233,10 @@ class KnowledgeService:
pre_separator=sync_request.pre_separator,
text_splitter_impl=text_splitter,
)
+ from pilot.graph_engine.graph_engine import RAGGraphEngine
+
+ # source = "/Users/chenketing/Desktop/project/llama_index/examples/paul_graham_essay/data/test/test_kg_text.txt"
+ # engine = RAGGraphEngine(knowledge_source=source, model_name="proxyllm", text_splitter=text_splitter)
embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory
)
@@ -244,6 +252,18 @@ class KnowledgeService:
embedding_factory=embedding_factory,
)
chunk_docs = client.read()
+ from pilot.graph_engine.graph_factory import RAGGraphFactory
+
+ rag_engine = CFG.SYSTEM_APP.get_component(
+ ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory
+ ).create()
+ rag_engine.knowledge_graph(docs=chunk_docs)
+ # docs = engine.search(
+ # "Comparing Curry and James in terms of their positions, playing styles, and achievements in the NBA"
+ # )
+ embedding_factory = CFG.SYSTEM_APP.get_component(
+ "embedding_factory", EmbeddingFactory
+ )
# update document status
doc.status = SyncStatus.RUNNING.name
doc.chunk_size = len(chunk_docs)
From 2f82f98e315d81129987c28196fd66e62a75f56b Mon Sep 17 00:00:00 2001
From: aries_ckt <916701291@qq.com>
Date: Fri, 13 Oct 2023 17:13:51 +0800
Subject: [PATCH 04/57] feat:knowledge rag graph
---
pilot/common/chat_util.py | 20 +
pilot/graph_engine/__init__.py | 0
pilot/graph_engine/graph_engine.py | 137 +++++
pilot/graph_engine/graph_factory.py | 34 ++
pilot/graph_engine/graph_search.py | 193 ++++++
pilot/graph_engine/index_struct.py | 259 ++++++++
pilot/graph_engine/index_type.py | 48 ++
pilot/graph_engine/kv_index.py | 74 +++
pilot/graph_engine/node.py | 569 ++++++++++++++++++
pilot/graph_engine/search.py | 44 ++
.../chat_knowledge/extract_entity/__init__.py | 0
.../chat_knowledge/extract_entity/chat.py | 35 ++
.../extract_entity/out_parser.py | 39 ++
.../chat_knowledge/extract_entity/prompt.py | 52 ++
.../extract_triplet/__init__.py | 0
.../chat_knowledge/extract_triplet/chat.py | 35 ++
.../extract_triplet/out_parser.py | 57 ++
.../chat_knowledge/extract_triplet/prompt.py | 57 ++
pilot/scene/chat_knowledge/v1/chat.py | 20 +-
pilot/server/knowledge/api.py | 11 -
20 files changed, 1654 insertions(+), 30 deletions(-)
create mode 100644 pilot/common/chat_util.py
create mode 100644 pilot/graph_engine/__init__.py
create mode 100644 pilot/graph_engine/graph_engine.py
create mode 100644 pilot/graph_engine/graph_factory.py
create mode 100644 pilot/graph_engine/graph_search.py
create mode 100644 pilot/graph_engine/index_struct.py
create mode 100644 pilot/graph_engine/index_type.py
create mode 100644 pilot/graph_engine/kv_index.py
create mode 100644 pilot/graph_engine/node.py
create mode 100644 pilot/graph_engine/search.py
create mode 100644 pilot/scene/chat_knowledge/extract_entity/__init__.py
create mode 100644 pilot/scene/chat_knowledge/extract_entity/chat.py
create mode 100644 pilot/scene/chat_knowledge/extract_entity/out_parser.py
create mode 100644 pilot/scene/chat_knowledge/extract_entity/prompt.py
create mode 100644 pilot/scene/chat_knowledge/extract_triplet/__init__.py
create mode 100644 pilot/scene/chat_knowledge/extract_triplet/chat.py
create mode 100644 pilot/scene/chat_knowledge/extract_triplet/out_parser.py
create mode 100644 pilot/scene/chat_knowledge/extract_triplet/prompt.py
diff --git a/pilot/common/chat_util.py b/pilot/common/chat_util.py
new file mode 100644
index 000000000..159db99d0
--- /dev/null
+++ b/pilot/common/chat_util.py
@@ -0,0 +1,20 @@
+import asyncio
+
+from starlette.responses import StreamingResponse
+
+from pilot.scene.base_chat import BaseChat
+from pilot.scene.chat_factory import ChatFactory
+
+chat_factory = ChatFactory()
+
+
+async def llm_chat_response_nostream(chat_scene: str, **chat_param):
+ """ llm_chat_response_nostream """
+ chat: BaseChat = chat_factory.get_implementation(chat_scene, **chat_param)
+ res = await chat.get_llm_response()
+ return res
+
+
+async def llm_chat_response(chat_scene: str, **chat_param):
+ chat: BaseChat = chat_factory.get_implementation(chat_scene, **chat_param)
+ return chat.stream_call()
diff --git a/pilot/graph_engine/__init__.py b/pilot/graph_engine/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pilot/graph_engine/graph_engine.py b/pilot/graph_engine/graph_engine.py
new file mode 100644
index 000000000..c20142123
--- /dev/null
+++ b/pilot/graph_engine/graph_engine.py
@@ -0,0 +1,137 @@
+import logging
+from typing import Any, Optional, Callable, Tuple, List
+
+from langchain.schema import Document
+from langchain.text_splitter import RecursiveCharacterTextSplitter
+
+from pilot.embedding_engine import KnowledgeType
+from pilot.embedding_engine.knowledge_type import get_knowledge_embedding
+from pilot.graph_engine.index_struct import KG
+from pilot.graph_engine.node import TextNode
+from pilot.utils import utils
+
+logger = logging.getLogger(__name__)
+
+
+class RAGGraphEngine:
+ """Knowledge RAG Graph Engine.
+ Build a KG by extracting triplets, and leveraging the KG during query-time.
+ Args:
+ knowledge_type (Optional[str]): Default: KnowledgeType.DOCUMENT.value
+ extracting triplets.
+ graph_store (Optional[GraphStore]): The graph store to use.refrence:llama-index
+ include_embeddings (bool): Whether to include embeddings in the index.
+ Defaults to False.
+ max_object_length (int): The maximum length of the object in a triplet.
+ Defaults to 128.
+ extract_triplet_fn (Optional[Callable]): The function to use for
+ extracting triplets. Defaults to None.
+ """
+
+ index_struct_cls = KG
+
+ def __init__(
+ self,
+ knowledge_type: Optional[str] = KnowledgeType.DOCUMENT.value,
+ knowledge_source: Optional[str] = None,
+ text_splitter=None,
+ graph_store=None,
+ index_struct: Optional[KG] = None,
+ model_name: Optional[str] = None,
+ max_triplets_per_chunk: int = 10,
+ include_embeddings: bool = False,
+ max_object_length: int = 128,
+ extract_triplet_fn: Optional[Callable] = None,
+ **kwargs: Any,
+ ) -> None:
+ """Initialize params."""
+ # from llama_index.graph_stores import SimpleGraphStore
+ # from llama_index.graph_stores.types import GraphStore
+
+ # need to set parameters before building index in base class.
+ self.knowledge_source = knowledge_source
+ self.knowledge_type = knowledge_type
+ self.model_name = model_name
+ self.text_splitter = text_splitter
+ self.index_struct = index_struct
+ self.include_embeddings = include_embeddings
+ # self.graph_store = graph_store or SimpleGraphStore()
+ self.graph_store = graph_store
+ self.max_triplets_per_chunk = max_triplets_per_chunk
+ self._max_object_length = max_object_length
+ self._extract_triplet_fn = extract_triplet_fn
+
+ def knowledge_graph(self, docs=None):
+ """knowledge docs into graph store"""
+ if not docs:
+ if self.text_splitter:
+ self.text_splitter = RecursiveCharacterTextSplitter(
+ chunk_size=2000, chunk_overlap=100
+ )
+ knowledge_source = get_knowledge_embedding(
+ knowledge_type=self.knowledge_type,
+ knowledge_source=self.knowledge_source,
+ text_splitter=self.text_splitter,
+ )
+ docs = knowledge_source.read()
+ if self.index_struct is None:
+ self.index_struct = self._build_index_from_docs(docs)
+
+ def _extract_triplets(self, text: str) -> List[Tuple[str, str, str]]:
+ """Extract triplets from text by function or llm"""
+ if self._extract_triplet_fn is not None:
+ return self._extract_triplet_fn(text)
+ else:
+ return self._llm_extract_triplets(text)
+
+ def _llm_extract_triplets(self, text: str) -> List[Tuple[str, str, str]]:
+ """Extract triplets from text by llm"""
+ from pilot.scene.base import ChatScene
+ from pilot.common.chat_util import llm_chat_response_nostream
+ import uuid
+
+ chat_param = {
+ "chat_session_id": uuid.uuid1(),
+ "current_user_input": text,
+ "select_param": "triplet",
+ "model_name": self.model_name,
+ }
+ loop = utils.get_or_create_event_loop()
+ triplets = loop.run_until_complete(
+ llm_chat_response_nostream(
+ ChatScene.ExtractTriplet.value(), **{"chat_param": chat_param}
+ )
+ )
+ return triplets
+ # response = self._service_context.llm_predictor.predict(
+ # self.kg_triple_extract_template,
+ # text=text,
+ # )
+ # print(response, flush=True)
+ # return self._parse_triplet_response(
+ # response, max_length=self._max_object_length
+ # )
+
+ def _build_index_from_docs(self, documents: List[Document]) -> KG:
+ """Build the index from nodes."""
+ index_struct = self.index_struct_cls()
+ for doc in documents:
+ triplets = self._extract_triplets(doc.page_content)
+ if len(triplets) == 0:
+ continue
+ text_node = TextNode(text=doc.page_content, metadata=doc.metadata)
+ logger.info(f"extracted knowledge triplets: {triplets}")
+ for triplet in triplets:
+ subj, _, obj = triplet
+ self.graph_store.upsert_triplet(*triplet)
+ index_struct.add_node([subj, obj], text_node)
+
+
+ return index_struct
+
+ def search(self, query):
+ from pilot.graph_engine.graph_search import RAGGraphSearch
+
+ graph_search = RAGGraphSearch(graph_engine=self)
+ return graph_search.search(query)
+
diff --git a/pilot/graph_engine/graph_factory.py b/pilot/graph_engine/graph_factory.py
new file mode 100644
index 000000000..3a8b99c17
--- /dev/null
+++ b/pilot/graph_engine/graph_factory.py
@@ -0,0 +1,34 @@
+from __future__ import annotations
+from abc import ABC, abstractmethod
+from typing import Any, Type
+
+from pilot.component import BaseComponent, ComponentType
+
+
+class RAGGraphFactory(BaseComponent, ABC):
+ name = ComponentType.RAG_GRAPH_DEFAULT.value
+
+ @abstractmethod
+ def create(self, model_name: str = None, embedding_cls: Type = None):
+ """Create RAG Graph Engine"""
+
+
+class DefaultRAGGraphFactory(RAGGraphFactory):
+ def __init__(
+ self, system_app=None, default_model_name: str = None, **kwargs: Any
+ ) -> None:
+ super().__init__(system_app=system_app)
+ self._default_model_name = default_model_name
+ self.kwargs = kwargs
+ from pilot.graph_engine.graph_engine import RAGGraphEngine
+
+ self.rag_engine = RAGGraphEngine(model_name="proxyllm")
+
+ def init_app(self, system_app):
+ pass
+
+ def create(self, model_name: str = None, rag_cls: Type = None):
+ if not model_name:
+ model_name = self._default_model_name
+
+ return self.rag_engine
diff --git a/pilot/graph_engine/graph_search.py b/pilot/graph_engine/graph_search.py
new file mode 100644
index 000000000..9b06fd234
--- /dev/null
+++ b/pilot/graph_engine/graph_search.py
@@ -0,0 +1,193 @@
+import logging
+import os
+from collections import defaultdict
+from concurrent.futures import ThreadPoolExecutor
+from typing import List, Optional, Dict, Any, Set, Callable
+
+from pilot.graph_engine.node import BaseNode, TextNode, NodeWithScore
+from pilot.graph_engine.search import BaseSearch, SearchMode
+from pilot.utils import utils
+
+logger = logging.getLogger(__name__)
+DEFAULT_NODE_SCORE = 1000.0
+GLOBAL_EXPLORE_NODE_LIMIT = 3
+REL_TEXT_LIMIT = 30
+
+
+class RAGGraphSearch(BaseSearch):
+ """RAG Graph Search.
+
+ args:
+ graph_engine RAGGraphEngine.
+ model_name (str): model name
+ (see :ref:`Prompt-Templates`).
+ text_qa_template (Optional[BasePromptTemplate]): A Question Answering Prompt
+ (see :ref:`Prompt-Templates`).
+ max_keywords_per_query (int): Maximum number of keywords to extract from query.
+ num_chunks_per_query (int): Maximum number of text chunks to query.
+ search_mode (Optional[SearchMode]): Specifies whether to use keyowrds, default SearchMode.KEYWORD
+ embeddings, or both to find relevant triplets. Should be one of "keyword",
+ "embedding", or "hybrid".
+ graph_store_query_depth (int): The depth of the graph store query.
+ extract_subject_entities_fn (Optional[Callback]): extract_subject_entities callback.
+ """
+
+ def __init__(
+ self,
+ graph_engine,
+ model_name: str = None,
+ max_keywords_per_query: int = 10,
+ num_chunks_per_query: int = 10,
+ search_mode: Optional[SearchMode] = SearchMode.KEYWORD,
+ graph_store_query_depth: int = 2,
+ extract_subject_entities_fn: Optional[Callable] = None,
+ **kwargs: Any,
+ ) -> None:
+ """Initialize params."""
+ from pilot.graph_engine.graph_engine import RAGGraphEngine
+
+ self.graph_engine: RAGGraphEngine = graph_engine
+ self.model_name = model_name or self.graph_engine.model_name
+ self._index_struct = self.graph_engine.index_struct
+ self.max_keywords_per_query = max_keywords_per_query
+ self.num_chunks_per_query = num_chunks_per_query
+ self._search_mode = search_mode
+
+ self._graph_store = self.graph_engine.graph_store
+ self.graph_store_query_depth = graph_store_query_depth
+ self._verbose = kwargs.get("verbose", False)
+ refresh_schema = kwargs.get("refresh_schema", False)
+ self.extract_subject_entities_fn = extract_subject_entities_fn
+ self.executor = ThreadPoolExecutor(max_workers=os.cpu_count() * 5)
+ try:
+ self._graph_schema = self._graph_store.get_schema(refresh=refresh_schema)
+ except NotImplementedError:
+ self._graph_schema = ""
+ except Exception as e:
+ logger.warn(f"can not to find graph schema: {e}")
+ self._graph_schema = ""
+
+ def _extract_subject_entities(self, query_str: str) -> Set[str]:
+ """extract subject entities."""
+ if self.extract_subject_entities_fn is not None:
+ return self.extract_subject_entities_fn(query_str)
+ else:
+ return self._extract_entities_by_llm(query_str)
+
+ def _extract_entities_by_llm(self, text: str) -> Set[str]:
+ """extract subject entities from text by llm"""
+ from pilot.scene.base import ChatScene
+ from pilot.common.chat_util import llm_chat_response_nostream
+ import uuid
+
+ chat_param = {
+ "chat_session_id": uuid.uuid1(),
+ "current_user_input": text,
+ "select_param": "entity",
+ "model_name": self.model_name,
+ }
+ loop = utils.get_or_create_event_loop()
+ entities = loop.run_until_complete(
+ llm_chat_response_nostream(
+ ChatScene.ExtractEntity.value(), **{"chat_param": chat_param}
+ )
+ )
+ return entities
+
+ def _search(
+ self,
+ query_str: str,
+ ) -> List[NodeWithScore]:
+ """Get nodes for response."""
+ node_visited = set()
+ keywords = self._extract_subject_entities(query_str)
+ print(f"extract entities: {keywords}\n")
+ rel_texts = []
+ cur_rel_map = {}
+ chunk_indices_count: Dict[str, int] = defaultdict(int)
+ if self._search_mode != SearchMode.EMBEDDING:
+ for keyword in keywords:
+ keyword = keyword.lower()
+ subjs = set((keyword,))
+ node_ids = self._index_struct.search_node_by_keyword(keyword)
+ for node_id in node_ids[:GLOBAL_EXPLORE_NODE_LIMIT]:
+ if node_id in node_visited:
+ continue
+
+ if self._include_text:
+ chunk_indices_count[node_id] += 1
+
+ node_visited.add(node_id)
+
+ rel_map = self._graph_store.get_rel_map(
+ list(subjs), self.graph_store_query_depth
+ )
+ logger.debug(f"rel_map: {rel_map}")
+
+ if not rel_map:
+ continue
+ rel_texts.extend(
+ [
+ str(rel_obj)
+ for rel_objs in rel_map.values()
+ for rel_obj in rel_objs
+ ]
+ )
+ cur_rel_map.update(rel_map)
+
+ sorted_nodes_with_scores = []
+ if not rel_texts:
+ logger.info("> No relationships found, returning nodes found by keywords.")
+ if len(sorted_nodes_with_scores) == 0:
+ logger.info("> No nodes found by keywords, returning empty response.")
+ return [
+ NodeWithScore(node=TextNode(text="No relationships found."), score=1.0)
+ ]
+
+ # add relationships as Node
+ # TODO: make initial text customizable
+ rel_initial_text = (
+ f"The following are knowledge sequence in max depth"
+ f" {self.graph_store_query_depth} "
+ f"in the form of directed graph like:\n"
+ f"`subject -[predicate]->, object, <-[predicate_next_hop]-,"
+ f" object_next_hop ...`"
+ )
+ rel_info = [rel_initial_text] + rel_texts
+ rel_node_info = {
+ "kg_rel_texts": rel_texts,
+ "kg_rel_map": cur_rel_map,
+ }
+ if self._graph_schema != "":
+ rel_node_info["kg_schema"] = {"schema": self._graph_schema}
+ rel_info_text = "\n".join(
+ [
+ str(item)
+ for sublist in rel_info
+ for item in (sublist if isinstance(sublist, list) else [sublist])
+ ]
+ )
+ if self._verbose:
+ print(f"KG context:\n{rel_info_text}\n", color="blue")
+ rel_text_node = TextNode(
+ text=rel_info_text,
+ metadata=rel_node_info,
+ excluded_embed_metadata_keys=["kg_rel_map", "kg_rel_texts"],
+ excluded_llm_metadata_keys=["kg_rel_map", "kg_rel_texts"],
+ )
+ # this node is constructed from rel_texts, give high confidence to avoid cutoff
+ sorted_nodes_with_scores.append(
+ NodeWithScore(node=rel_text_node, score=DEFAULT_NODE_SCORE)
+ )
+
+ return sorted_nodes_with_scores
+
+ def _get_metadata_for_response(
+ self, nodes: List[BaseNode]
+ ) -> Optional[Dict[str, Any]]:
+ """Get metadata for response."""
+ for node in nodes:
+ if node.metadata is None or "kg_rel_map" not in node.metadata:
+ continue
+ return node.metadata
+ raise ValueError("kg_rel_map must be found in at least one Node.")
\ No newline at end of file
diff --git a/pilot/graph_engine/index_struct.py b/pilot/graph_engine/index_struct.py
new file mode 100644
index 000000000..edc47a7ac
--- /dev/null
+++ b/pilot/graph_engine/index_struct.py
@@ -0,0 +1,259 @@
+"""Data structures.
+
+Nodes are decoupled from the indices.
+
+"""
+
+import uuid
+from abc import abstractmethod
+from dataclasses import dataclass, field
+from typing import Dict, List, Optional, Sequence, Set
+
+from dataclasses_json import DataClassJsonMixin
+
+
+from pilot.graph_engine.index_type import IndexStructType
+from pilot.graph_engine.node import TextNode, BaseNode
+
+# TODO: legacy backport of old Node class
+Node = TextNode
+
+
+@dataclass
+class IndexStruct(DataClassJsonMixin):
+ """A base data struct for a LlamaIndex."""
+
+ index_id: str = field(default_factory=lambda: str(uuid.uuid4()))
+ summary: Optional[str] = None
+
+ def get_summary(self) -> str:
+ """Get text summary."""
+ if self.summary is None:
+ raise ValueError("summary field of the index_struct not set.")
+ return self.summary
+
+ @classmethod
+ @abstractmethod
+ def get_type(cls):
+ """Get index struct type."""
+
+
+@dataclass
+class IndexGraph(IndexStruct):
+ """A graph representing the tree-structured index."""
+
+ # mapping from index in tree to Node doc id.
+ all_nodes: Dict[int, str] = field(default_factory=dict)
+ root_nodes: Dict[int, str] = field(default_factory=dict)
+ node_id_to_children_ids: Dict[str, List[str]] = field(default_factory=dict)
+
+ @property
+ def node_id_to_index(self) -> Dict[str, int]:
+ """Map from node id to index."""
+ return {node_id: index for index, node_id in self.all_nodes.items()}
+
+ @property
+ def size(self) -> int:
+ """Get the size of the graph."""
+ return len(self.all_nodes)
+
+ def get_index(self, node: BaseNode) -> int:
+ """Get index of node."""
+ return self.node_id_to_index[node.node_id]
+
+ def insert(
+ self,
+ node: BaseNode,
+ index: Optional[int] = None,
+ children_nodes: Optional[Sequence[BaseNode]] = None,
+ ) -> None:
+ """Insert node."""
+ index = index or self.size
+ node_id = node.node_id
+
+ self.all_nodes[index] = node_id
+
+ if children_nodes is None:
+ children_nodes = []
+ children_ids = [n.node_id for n in children_nodes]
+ self.node_id_to_children_ids[node_id] = children_ids
+
+ def get_children(self, parent_node: Optional[BaseNode]) -> Dict[int, str]:
+ """Get children nodes."""
+ if parent_node is None:
+ return self.root_nodes
+ else:
+ parent_id = parent_node.node_id
+ children_ids = self.node_id_to_children_ids[parent_id]
+ return {
+ self.node_id_to_index[child_id]: child_id for child_id in children_ids
+ }
+
+ def insert_under_parent(
+ self,
+ node: BaseNode,
+ parent_node: Optional[BaseNode],
+ new_index: Optional[int] = None,
+ ) -> None:
+ """Insert under parent node."""
+ new_index = new_index or self.size
+ if parent_node is None:
+ self.root_nodes[new_index] = node.node_id
+ self.node_id_to_children_ids[node.node_id] = []
+ else:
+ if parent_node.node_id not in self.node_id_to_children_ids:
+ self.node_id_to_children_ids[parent_node.node_id] = []
+ self.node_id_to_children_ids[parent_node.node_id].append(node.node_id)
+
+ self.all_nodes[new_index] = node.node_id
+
+ @classmethod
+ def get_type(cls) -> IndexStructType:
+ """Get type."""
+ return IndexStructType.TREE
+
+
+@dataclass
+class KeywordTable(IndexStruct):
+ """A table of keywords mapping keywords to text chunks."""
+
+ table: Dict[str, Set[str]] = field(default_factory=dict)
+
+ def add_node(self, keywords: List[str], node: BaseNode) -> None:
+ """Add text to table."""
+ for keyword in keywords:
+ if keyword not in self.table:
+ self.table[keyword] = set()
+ self.table[keyword].add(node.node_id)
+
+ @property
+ def node_ids(self) -> Set[str]:
+ """Get all node ids."""
+ return set.union(*self.table.values())
+
+ @property
+ def keywords(self) -> Set[str]:
+ """Get all keywords in the table."""
+ return set(self.table.keys())
+
+ @property
+ def size(self) -> int:
+ """Get the size of the table."""
+ return len(self.table)
+
+ @classmethod
+ def get_type(cls) -> IndexStructType:
+ """Get type."""
+ return IndexStructType.KEYWORD_TABLE
+
+
+@dataclass
+class IndexList(IndexStruct):
+ """A list of documents."""
+
+ nodes: List[str] = field(default_factory=list)
+
+ def add_node(self, node: BaseNode) -> None:
+ """Add text to table, return current position in list."""
+ # don't worry about child indices for now, nodes are all in order
+ self.nodes.append(node.node_id)
+
+ @classmethod
+ def get_type(cls) -> IndexStructType:
+ """Get type."""
+ return IndexStructType.LIST
+
+
+@dataclass
+class IndexDict(IndexStruct):
+ """A simple dictionary of documents."""
+
+ # TODO: slightly deprecated, should likely be a list or set now
+ # mapping from vector store id to node doc_id
+ nodes_dict: Dict[str, str] = field(default_factory=dict)
+
+ # TODO: deprecated, not used
+ # mapping from node doc_id to vector store id
+ doc_id_dict: Dict[str, List[str]] = field(default_factory=dict)
+
+ # TODO: deprecated, not used
+ # this should be empty for all other indices
+ embeddings_dict: Dict[str, List[float]] = field(default_factory=dict)
+
+ def add_node(
+ self,
+ node: BaseNode,
+ text_id: Optional[str] = None,
+ ) -> str:
+ """Add text to table, return current position in list."""
+ # # don't worry about child indices for now, nodes are all in order
+ # self.nodes_dict[int_id] = node
+ vector_id = text_id if text_id is not None else node.node_id
+ self.nodes_dict[vector_id] = node.node_id
+
+ return vector_id
+
+ def delete(self, doc_id: str) -> None:
+ """Delete a Node."""
+ del self.nodes_dict[doc_id]
+
+ @classmethod
+ def get_type(cls) -> IndexStructType:
+ """Get type."""
+ return IndexStructType.VECTOR_STORE
+
+
+@dataclass
+class KG(IndexStruct):
+ """A table of keywords mapping keywords to text chunks."""
+
+ # Unidirectional
+
+ # table of keywords to node ids
+ table: Dict[str, Set[str]] = field(default_factory=dict)
+
+ # TODO: legacy attribute, remove in future releases
+ rel_map: Dict[str, List[List[str]]] = field(default_factory=dict)
+
+ # TBD, should support vector store, now we just persist the embedding memory
+ # maybe chainable abstractions for *_stores could be designed
+ embedding_dict: Dict[str, List[float]] = field(default_factory=dict)
+
+ @property
+ def node_ids(self) -> Set[str]:
+ """Get all node ids."""
+ return set.union(*self.table.values())
+
+ def add_to_embedding_dict(self, triplet_str: str, embedding: List[float]) -> None:
+ """Add embedding to dict."""
+ self.embedding_dict[triplet_str] = embedding
+
+ def add_node(self, keywords: List[str], node: BaseNode) -> None:
+ """Add text to table."""
+ node_id = node.node_id
+ for keyword in keywords:
+ keyword = keyword.lower()
+ if keyword not in self.table:
+ self.table[keyword] = set()
+ self.table[keyword].add(node_id)
+
+ def search_node_by_keyword(self, keyword: str) -> List[str]:
+ """Search for nodes by keyword."""
+ if keyword not in self.table:
+ return []
+ return list(self.table[keyword])
+
+ @classmethod
+ def get_type(cls) -> IndexStructType:
+ """Get type."""
+ return IndexStructType.KG
+
+
+@dataclass
+class EmptyIndexStruct(IndexStruct):
+ """Empty index."""
+
+ @classmethod
+ def get_type(cls) -> IndexStructType:
+ """Get type."""
+ return IndexStructType.EMPTY
diff --git a/pilot/graph_engine/index_type.py b/pilot/graph_engine/index_type.py
new file mode 100644
index 000000000..939066be9
--- /dev/null
+++ b/pilot/graph_engine/index_type.py
@@ -0,0 +1,48 @@
+"""IndexStructType class."""
+
+from enum import Enum
+
+
+class IndexStructType(str, Enum):
+ """Index struct type. Identifier for a "type" of index.
+
+ Attributes:
+ TREE ("tree"): Tree index. See :ref:`Ref-Indices-Tree` for tree indices.
+ LIST ("list"): Summary index. See :ref:`Ref-Indices-List` for summary indices.
+ KEYWORD_TABLE ("keyword_table"): Keyword table index. See
+ :ref:`Ref-Indices-Table`
+ for keyword table indices.
+ DICT ("dict"): Faiss Vector Store Index. See
+ :ref:`Ref-Indices-VectorStore`
+ for more information on the faiss vector store index.
+ SIMPLE_DICT ("simple_dict"): Simple Vector Store Index. See
+ :ref:`Ref-Indices-VectorStore`
+ for more information on the simple vector store index.
+ KG ("kg"): Knowledge Graph index.
+ See :ref:`Ref-Indices-Knowledge-Graph` for KG indices.
+ DOCUMENT_SUMMARY ("document_summary"): Document Summary Index.
+ See :ref:`Ref-Indices-Document-Summary` for Summary Indices.
+
+ """
+
+ # TODO: refactor so these are properties on the base class
+
+ NODE = "node"
+ TREE = "tree"
+ LIST = "list"
+ KEYWORD_TABLE = "keyword_table"
+
+ DICT = "dict"
+ # simple
+ SIMPLE_DICT = "simple_dict"
+ # for KG index
+ KG = "kg"
+ SIMPLE_KG = "simple_kg"
+ NEBULAGRAPH = "nebulagraph"
+ FALKORDB = "falkordb"
+
+ # EMPTY
+ EMPTY = "empty"
+ COMPOSITE = "composite"
+
+ DOCUMENT_SUMMARY = "document_summary"
diff --git a/pilot/graph_engine/kv_index.py b/pilot/graph_engine/kv_index.py
new file mode 100644
index 000000000..7b44b7d04
--- /dev/null
+++ b/pilot/graph_engine/kv_index.py
@@ -0,0 +1,74 @@
+from typing import List, Optional
+from llama_index.data_structs.data_structs import IndexStruct
+from llama_index.storage.index_store.utils import (
+ index_struct_to_json,
+ json_to_index_struct,
+)
+from llama_index.storage.kvstore.types import BaseKVStore
+
+DEFAULT_NAMESPACE = "index_store"
+
+
+class KVIndexStore:
+ """Key-Value Index store.
+
+ Args:
+ kvstore (BaseKVStore): key-value store
+ namespace (str): namespace for the index store
+
+ """
+
+ def __init__(self, kvstore: BaseKVStore, namespace: Optional[str] = None) -> None:
+ """Init a KVIndexStore."""
+ self._kvstore = kvstore
+ self._namespace = namespace or DEFAULT_NAMESPACE
+ self._collection = f"{self._namespace}/data"
+
+ def add_index_struct(self, index_struct: IndexStruct) -> None:
+ """Add an index struct.
+
+ Args:
+ index_struct (IndexStruct): index struct
+
+ """
+ key = index_struct.index_id
+ data = index_struct_to_json(index_struct)
+ self._kvstore.put(key, data, collection=self._collection)
+
+ def delete_index_struct(self, key: str) -> None:
+ """Delete an index struct.
+
+ Args:
+ key (str): index struct key
+
+ """
+ self._kvstore.delete(key, collection=self._collection)
+
+ def get_index_struct(
+ self, struct_id: Optional[str] = None
+ ) -> Optional[IndexStruct]:
+ """Get an index struct.
+
+ Args:
+ struct_id (Optional[str]): index struct id
+
+ """
+ if struct_id is None:
+ structs = self.index_structs()
+ assert len(structs) == 1
+ return structs[0]
+ else:
+ json = self._kvstore.get(struct_id, collection=self._collection)
+ if json is None:
+ return None
+ return json_to_index_struct(json)
+
+ def index_structs(self) -> List[IndexStruct]:
+ """Get all index structs.
+
+ Returns:
+ List[IndexStruct]: index structs
+
+ """
+ jsons = self._kvstore.get_all(collection=self._collection)
+ return [json_to_index_struct(json) for json in jsons.values()]
diff --git a/pilot/graph_engine/node.py b/pilot/graph_engine/node.py
new file mode 100644
index 000000000..6f6d45ae4
--- /dev/null
+++ b/pilot/graph_engine/node.py
@@ -0,0 +1,569 @@
+"""Base schema for data structures."""
+import json
+import textwrap
+import uuid
+from abc import abstractmethod
+from enum import Enum, auto
+from hashlib import sha256
+from typing import Any, Dict, List, Optional, Union
+
+from langchain.schema import Document
+from pydantic import BaseModel, Field, root_validator
+from typing_extensions import Self
+
+
+DEFAULT_TEXT_NODE_TMPL = "{metadata_str}\n\n{content}"
+DEFAULT_METADATA_TMPL = "{key}: {value}"
+# NOTE: for pretty printing
+TRUNCATE_LENGTH = 350
+WRAP_WIDTH = 70
+
+
+class BaseComponent(BaseModel):
+ """Base component object to caputure class names."""
+ """reference llama-index"""
+
+ @classmethod
+ @abstractmethod
+ def class_name(cls) -> str:
+ """Get class name."""
+
+ def to_dict(self, **kwargs: Any) -> Dict[str, Any]:
+ data = self.dict(**kwargs)
+ data["class_name"] = self.class_name()
+ return data
+
+ def to_json(self, **kwargs: Any) -> str:
+ data = self.to_dict(**kwargs)
+ return json.dumps(data)
+
+ # TODO: return type here not supported by current mypy version
+ @classmethod
+ def from_dict(cls, data: Dict[str, Any], **kwargs: Any) -> Self: # type: ignore
+ if isinstance(kwargs, dict):
+ data.update(kwargs)
+
+ data.pop("class_name", None)
+ return cls(**data)
+
+ @classmethod
+ def from_json(cls, data_str: str, **kwargs: Any) -> Self: # type: ignore
+ data = json.loads(data_str)
+ return cls.from_dict(data, **kwargs)
+
+
+class NodeRelationship(str, Enum):
+ """Node relationships used in `BaseNode` class.
+
+ Attributes:
+ SOURCE: The node is the source document.
+ PREVIOUS: The node is the previous node in the document.
+ NEXT: The node is the next node in the document.
+ PARENT: The node is the parent node in the document.
+ CHILD: The node is a child node in the document.
+
+ """
+
+ SOURCE = auto()
+ PREVIOUS = auto()
+ NEXT = auto()
+ PARENT = auto()
+ CHILD = auto()
+
+
+class ObjectType(str, Enum):
+ TEXT = auto()
+ IMAGE = auto()
+ INDEX = auto()
+ DOCUMENT = auto()
+
+
+class MetadataMode(str, Enum):
+ ALL = auto()
+ EMBED = auto()
+ LLM = auto()
+ NONE = auto()
+
+
+class RelatedNodeInfo(BaseComponent):
+ node_id: str
+ node_type: Optional[ObjectType] = None
+ metadata: Dict[str, Any] = Field(default_factory=dict)
+ hash: Optional[str] = None
+
+ @classmethod
+ def class_name(cls) -> str:
+ """Get class name."""
+ return "RelatedNodeInfo"
+
+
+RelatedNodeType = Union[RelatedNodeInfo, List[RelatedNodeInfo]]
+
+
+# Node classes for indexes
+class BaseNode(BaseComponent):
+ """Base node Object.
+
+ Generic abstract interface for retrievable nodes
+
+ """
+
+ class Config:
+ allow_population_by_field_name = True
+
+ id_: str = Field(
+ default_factory=lambda: str(uuid.uuid4()), description="Unique ID of the node."
+ )
+ embedding: Optional[List[float]] = Field(
+ default=None, description="Embedding of the node."
+ )
+
+ """"
+ metadata fields
+ - injected as part of the text shown to LLMs as context
+ - injected as part of the text for generating embeddings
+ - used by vector DBs for metadata filtering
+
+ """
+ metadata: Dict[str, Any] = Field(
+ default_factory=dict,
+ description="A flat dictionary of metadata fields",
+ alias="extra_info",
+ )
+ excluded_embed_metadata_keys: List[str] = Field(
+ default_factory=list,
+ description="Metadata keys that are exluded from text for the embed model.",
+ )
+ excluded_llm_metadata_keys: List[str] = Field(
+ default_factory=list,
+ description="Metadata keys that are exluded from text for the LLM.",
+ )
+ relationships: Dict[NodeRelationship, RelatedNodeType] = Field(
+ default_factory=dict,
+ description="A mapping of relationships to other node information.",
+ )
+ hash: str = Field(default="", description="Hash of the node content.")
+
+ @classmethod
+ @abstractmethod
+ def get_type(cls) -> str:
+ """Get Object type."""
+
+ @abstractmethod
+ def get_content(self, metadata_mode: MetadataMode = MetadataMode.ALL) -> str:
+ """Get object content."""
+
+ @abstractmethod
+ def get_metadata_str(self, mode: MetadataMode = MetadataMode.ALL) -> str:
+ """Metadata string."""
+
+ @abstractmethod
+ def set_content(self, value: Any) -> None:
+ """Set the content of the node."""
+
+ @property
+ def node_id(self) -> str:
+ return self.id_
+
+ @node_id.setter
+ def node_id(self, value: str) -> None:
+ self.id_ = value
+
+ @property
+ def source_node(self) -> Optional[RelatedNodeInfo]:
+ """Source object node.
+
+ Extracted from the relationships field.
+
+ """
+ if NodeRelationship.SOURCE not in self.relationships:
+ return None
+
+ relation = self.relationships[NodeRelationship.SOURCE]
+ if isinstance(relation, list):
+ raise ValueError("Source object must be a single RelatedNodeInfo object")
+ return relation
+
+ @property
+ def prev_node(self) -> Optional[RelatedNodeInfo]:
+ """Prev node."""
+ if NodeRelationship.PREVIOUS not in self.relationships:
+ return None
+
+ relation = self.relationships[NodeRelationship.PREVIOUS]
+ if not isinstance(relation, RelatedNodeInfo):
+ raise ValueError("Previous object must be a single RelatedNodeInfo object")
+ return relation
+
+ @property
+ def next_node(self) -> Optional[RelatedNodeInfo]:
+ """Next node."""
+ if NodeRelationship.NEXT not in self.relationships:
+ return None
+
+ relation = self.relationships[NodeRelationship.NEXT]
+ if not isinstance(relation, RelatedNodeInfo):
+ raise ValueError("Next object must be a single RelatedNodeInfo object")
+ return relation
+
+ @property
+ def parent_node(self) -> Optional[RelatedNodeInfo]:
+ """Parent node."""
+ if NodeRelationship.PARENT not in self.relationships:
+ return None
+
+ relation = self.relationships[NodeRelationship.PARENT]
+ if not isinstance(relation, RelatedNodeInfo):
+ raise ValueError("Parent object must be a single RelatedNodeInfo object")
+ return relation
+
+ @property
+ def child_nodes(self) -> Optional[List[RelatedNodeInfo]]:
+ """Child nodes."""
+ if NodeRelationship.CHILD not in self.relationships:
+ return None
+
+ relation = self.relationships[NodeRelationship.CHILD]
+ if not isinstance(relation, list):
+ raise ValueError("Child objects must be a list of RelatedNodeInfo objects.")
+ return relation
+
+ @property
+ def ref_doc_id(self) -> Optional[str]:
+ """Deprecated: Get ref doc id."""
+ source_node = self.source_node
+ if source_node is None:
+ return None
+ return source_node.node_id
+
+ @property
+ def extra_info(self) -> Dict[str, Any]:
+ """TODO: DEPRECATED: Extra info."""
+ return self.metadata
+
+ def __str__(self) -> str:
+ source_text_truncated = truncate_text(
+ self.get_content().strip(), TRUNCATE_LENGTH
+ )
+ source_text_wrapped = textwrap.fill(
+ f"Text: {source_text_truncated}\n", width=WRAP_WIDTH
+ )
+ return f"Node ID: {self.node_id}\n{source_text_wrapped}"
+
+ def truncate_text(text: str, max_length: int) -> str:
+ """Truncate text to a maximum length."""
+ if len(text) <= max_length:
+ return text
+ return text[: max_length - 3] + "..."
+
+ def get_embedding(self) -> List[float]:
+ """Get embedding.
+
+ Errors if embedding is None.
+
+ """
+ if self.embedding is None:
+ raise ValueError("embedding not set.")
+ return self.embedding
+
+ def as_related_node_info(self) -> RelatedNodeInfo:
+ """Get node as RelatedNodeInfo."""
+ return RelatedNodeInfo(
+ node_id=self.node_id, metadata=self.metadata, hash=self.hash
+ )
+
+
+class TextNode(BaseNode):
+ text: str = Field(default="", description="Text content of the node.")
+ start_char_idx: Optional[int] = Field(
+ default=None, description="Start char index of the node."
+ )
+ end_char_idx: Optional[int] = Field(
+ default=None, description="End char index of the node."
+ )
+ text_template: str = Field(
+ default=DEFAULT_TEXT_NODE_TMPL,
+ description=(
+ "Template for how text is formatted, with {content} and "
+ "{metadata_str} placeholders."
+ ),
+ )
+ metadata_template: str = Field(
+ default=DEFAULT_METADATA_TMPL,
+ description=(
+ "Template for how metadata is formatted, with {key} and "
+ "{value} placeholders."
+ ),
+ )
+ metadata_seperator: str = Field(
+ default="\n",
+ description="Seperator between metadata fields when converting to string.",
+ )
+
+ @classmethod
+ def class_name(cls) -> str:
+ """Get class name."""
+ return "TextNode"
+
+ @root_validator
+ def _check_hash(cls, values: dict) -> dict:
+ """Generate a hash to represent the node."""
+ text = values.get("text", "")
+ metadata = values.get("metadata", {})
+ doc_identity = str(text) + str(metadata)
+ values["hash"] = str(
+ sha256(doc_identity.encode("utf-8", "surrogatepass")).hexdigest()
+ )
+ return values
+
+ @classmethod
+ def get_type(cls) -> str:
+ """Get Object type."""
+ return ObjectType.TEXT
+
+ def get_content(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str:
+ """Get object content."""
+ metadata_str = self.get_metadata_str(mode=metadata_mode).strip()
+ if not metadata_str:
+ return self.text
+
+ return self.text_template.format(
+ content=self.text, metadata_str=metadata_str
+ ).strip()
+
+ def get_metadata_str(self, mode: MetadataMode = MetadataMode.ALL) -> str:
+ """metadata info string."""
+ if mode == MetadataMode.NONE:
+ return ""
+
+ usable_metadata_keys = set(self.metadata.keys())
+ if mode == MetadataMode.LLM:
+ for key in self.excluded_llm_metadata_keys:
+ if key in usable_metadata_keys:
+ usable_metadata_keys.remove(key)
+ elif mode == MetadataMode.EMBED:
+ for key in self.excluded_embed_metadata_keys:
+ if key in usable_metadata_keys:
+ usable_metadata_keys.remove(key)
+
+ return self.metadata_seperator.join(
+ [
+ self.metadata_template.format(key=key, value=str(value))
+ for key, value in self.metadata.items()
+ if key in usable_metadata_keys
+ ]
+ )
+
+ def set_content(self, value: str) -> None:
+ """Set the content of the node."""
+ self.text = value
+
+ def get_node_info(self) -> Dict[str, Any]:
+ """Get node info."""
+ return {"start": self.start_char_idx, "end": self.end_char_idx}
+
+ def get_text(self) -> str:
+ return self.get_content(metadata_mode=MetadataMode.NONE)
+
+ @property
+ def node_info(self) -> Dict[str, Any]:
+ """Deprecated: Get node info."""
+ return self.get_node_info()
+
+
+# TODO: legacy backport of old Node class
+Node = TextNode
+
+
+class ImageNode(TextNode):
+ """Node with image."""
+
+ # TODO: store reference instead of actual image
+ # base64 encoded image str
+ image: Optional[str] = None
+
+ @classmethod
+ def get_type(cls) -> str:
+ return ObjectType.IMAGE
+
+ @classmethod
+ def class_name(cls) -> str:
+ """Get class name."""
+ return "ImageNode"
+
+
+class IndexNode(TextNode):
+ """Node with reference to any object.
+
+ This can include other indices, query engines, retrievers.
+
+ This can also include other nodes (though this is overlapping with `relationships`
+ on the Node class).
+
+ """
+
+ index_id: str
+
+ @classmethod
+ def from_text_node(
+ cls,
+ node: TextNode,
+ index_id: str,
+ ) -> "IndexNode":
+ """Create index node from text node."""
+ # copy all attributes from text node, add index id
+ return cls(
+ **node.dict(),
+ index_id=index_id,
+ )
+
+ @classmethod
+ def get_type(cls) -> str:
+ return ObjectType.INDEX
+
+ @classmethod
+ def class_name(cls) -> str:
+ """Get class name."""
+ return "IndexNode"
+
+
+class NodeWithScore(BaseComponent):
+ node: BaseNode
+ score: Optional[float] = None
+
+ def __str__(self) -> str:
+ return f"{self.node}\nScore: {self.score: 0.3f}\n"
+
+ def get_score(self, raise_error: bool = False) -> float:
+ """Get score."""
+ if self.score is None:
+ if raise_error:
+ raise ValueError("Score not set.")
+ else:
+ return 0.0
+ else:
+ return self.score
+
+ @classmethod
+ def class_name(cls) -> str:
+ """Get class name."""
+ return "NodeWithScore"
+
+ ##### pass through methods to BaseNode #####
+ @property
+ def node_id(self) -> str:
+ return self.node.node_id
+
+ @property
+ def id_(self) -> str:
+ return self.node.id_
+
+ @property
+ def text(self) -> str:
+ if isinstance(self.node, TextNode):
+ return self.node.text
+ else:
+ raise ValueError("Node must be a TextNode to get text.")
+
+ @property
+ def metadata(self) -> Dict[str, Any]:
+ return self.node.metadata
+
+ @property
+ def embedding(self) -> Optional[List[float]]:
+ return self.node.embedding
+
+ def get_text(self) -> str:
+ if isinstance(self.node, TextNode):
+ return self.node.get_text()
+ else:
+ raise ValueError("Node must be a TextNode to get text.")
+
+ def get_content(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str:
+ return self.node.get_content(metadata_mode=metadata_mode)
+
+ def get_embedding(self) -> List[float]:
+ return self.node.get_embedding()
+
+
+# Document Classes for Readers
+
+
+class Document(TextNode):
+ """Generic interface for a data document.
+
+ This document connects to data sources.
+
+ """
+
+ # TODO: A lot of backwards compatibility logic here, clean up
+ id_: str = Field(
+ default_factory=lambda: str(uuid.uuid4()),
+ description="Unique ID of the node.",
+ alias="doc_id",
+ )
+
+ _compat_fields = {"doc_id": "id_", "extra_info": "metadata"}
+
+ @classmethod
+ def get_type(cls) -> str:
+ """Get Document type."""
+ return ObjectType.DOCUMENT
+
+ @property
+ def doc_id(self) -> str:
+ """Get document ID."""
+ return self.id_
+
+ def __str__(self) -> str:
+ source_text_truncated = truncate_text(
+ self.get_content().strip(), TRUNCATE_LENGTH
+ )
+ source_text_wrapped = textwrap.fill(
+ f"Text: {source_text_truncated}\n", width=WRAP_WIDTH
+ )
+ return f"Doc ID: {self.doc_id}\n{source_text_wrapped}"
+
+ def get_doc_id(self) -> str:
+ """TODO: Deprecated: Get document ID."""
+ return self.id_
+
+ def __setattr__(self, name: str, value: object) -> None:
+ if name in self._compat_fields:
+ name = self._compat_fields[name]
+ super().__setattr__(name, value)
+
+ def to_langchain_format(self) -> Document:
+ """Convert struct to LangChain document format."""
+ metadata = self.metadata or {}
+ return Document(page_content=self.text, metadata=metadata)
+
+ @classmethod
+ def from_langchain_format(cls, doc: Document) -> "Document":
+ """Convert struct from LangChain document format."""
+ return cls(text=doc.page_content, metadata=doc.metadata)
+
+ @classmethod
+ def example(cls) -> "Document":
+ document = Document(
+ text="",
+ metadata={"filename": "README.md", "category": "codebase"},
+ )
+ return document
+
+ @classmethod
+ def class_name(cls) -> str:
+ """Get class name."""
+ return "Document"
+
+
+class ImageDocument(Document):
+ """Data document containing an image."""
+
+ # base64 encoded image str
+ image: Optional[str] = None
+
+ @classmethod
+ def class_name(cls) -> str:
+ """Get class name."""
+ return "ImageDocument"
diff --git a/pilot/graph_engine/search.py b/pilot/graph_engine/search.py
new file mode 100644
index 000000000..8db837278
--- /dev/null
+++ b/pilot/graph_engine/search.py
@@ -0,0 +1,44 @@
+from abc import ABC, abstractmethod
+from enum import Enum
+
+
+class SearchMode(str, Enum):
+ """Query mode enum for Knowledge Graphs.
+
+ Can be passed as the enum struct, or as the underlying string.
+
+ Attributes:
+ KEYWORD ("keyword"): Default query mode, using keywords to find triplets.
+ EMBEDDING ("embedding"): Embedding mode, using embeddings to find
+ similar triplets.
+ HYBRID ("hybrid"): Hyrbid mode, combining both keywords and embeddings
+ to find relevant triplets.
+ """
+
+ KEYWORD = "keyword"
+ EMBEDDING = "embedding"
+ HYBRID = "hybrid"
+
+
+class BaseSearch(ABC):
+ """Base Search."""
+
+ def search(self, query: str):
+ """Retrieve nodes given query.
+
+ Args:
+ query (QueryType): Either a query string or
+ a QueryBundle object.
+
+ """
+ # if isinstance(query, str):
+ return self._search(query)
+
+ @abstractmethod
+ def _search(self, query: str):
+ """search nodes given query.
+
+ Implemented by the user.
+
+ """
+ pass
diff --git a/pilot/scene/chat_knowledge/extract_entity/__init__.py b/pilot/scene/chat_knowledge/extract_entity/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pilot/scene/chat_knowledge/extract_entity/chat.py b/pilot/scene/chat_knowledge/extract_entity/chat.py
new file mode 100644
index 000000000..bb52961b5
--- /dev/null
+++ b/pilot/scene/chat_knowledge/extract_entity/chat.py
@@ -0,0 +1,35 @@
+from typing import Dict
+
+from pilot.scene.base_chat import BaseChat
+from pilot.scene.base import ChatScene
+from pilot.configs.config import Config
+
+from pilot.scene.chat_knowledge.extract_entity.prompt import prompt
+
+CFG = Config()
+
+
+class ExtractEntity(BaseChat):
+ chat_scene: str = ChatScene.ExtractEntity.value()
+
+ """extracting entities by llm"""
+
+ def __init__(self, chat_param: Dict):
+ """ """
+ chat_param["chat_mode"] = ChatScene.ExtractEntity
+ super().__init__(
+ chat_param=chat_param,
+ )
+
+ self.user_input = chat_param["current_user_input"]
+ self.extract_mode = chat_param["select_param"]
+
+ def generate_input_values(self):
+ input_values = {
+ "text": self.user_input,
+ }
+ return input_values
+
+ @property
+ def chat_type(self) -> str:
+ return ChatScene.ExtractEntity.value
diff --git a/pilot/scene/chat_knowledge/extract_entity/out_parser.py b/pilot/scene/chat_knowledge/extract_entity/out_parser.py
new file mode 100644
index 000000000..4093e460f
--- /dev/null
+++ b/pilot/scene/chat_knowledge/extract_entity/out_parser.py
@@ -0,0 +1,39 @@
+import json
+import logging
+from typing import Set
+
+from pilot.out_parser.base import BaseOutputParser, T
+from pilot.configs.config import Config
+
+CFG = Config()
+
+
+logger = logging.getLogger(__name__)
+
+
+class ExtractEntityParser(BaseOutputParser):
+ def __init__(self, sep: str, is_stream_out: bool):
+ super().__init__(sep=sep, is_stream_out=is_stream_out)
+
+ def parse_prompt_response(self, response, max_length: int = 128) -> Set[str]:
+ lowercase = True
+ # clean_str = super().parse_prompt_response(response)
+ print("clean prompt response:", response)
+
+ results = []
+ response = response.strip() # Strip newlines from responses.
+
+ if response.startswith("KEYWORDS:"):
+ response = response[len("KEYWORDS:") :]
+
+ keywords = response.split(",")
+ for k in keywords:
+ rk = k
+ if lowercase:
+ rk = rk.lower()
+ results.append(rk.strip())
+
+ return set(results)
+
+ def parse_view_response(self, speak, data) -> str:
+ return data
diff --git a/pilot/scene/chat_knowledge/extract_entity/prompt.py b/pilot/scene/chat_knowledge/extract_entity/prompt.py
new file mode 100644
index 000000000..77349bd28
--- /dev/null
+++ b/pilot/scene/chat_knowledge/extract_entity/prompt.py
@@ -0,0 +1,52 @@
+import json
+
+from pilot.prompts.prompt_new import PromptTemplate
+from pilot.configs.config import Config
+from pilot.scene.base import ChatScene
+from pilot.common.schema import SeparatorStyle
+from pilot.scene.chat_knowledge.extract_entity.out_parser import ExtractEntityParser
+
+from pilot.scene.chat_knowledge.extract_triplet.out_parser import (
+ ExtractTripleParser,
+)
+
+
+CFG = Config()
+
+PROMPT_SCENE_DEFINE = """"""
+
+_DEFAULT_TEMPLATE = """
+"A question is provided below. Given the question, extract up to 10 "
+ "keywords from the text. Focus on extracting the keywords that we can use "
+ "to best lookup answers to the question. Avoid stopwords.\n"
+ "Example:"
+ "Text: Alice is Bob's mother."
+ "KEYWORDS:Alice,mother,Bob\n"
+ "---------------------\n"
+ "{text}\n"
+ "---------------------\n"
+ "Provide keywords in the following comma-separated format: 'KEYWORDS: '\n"
+"""
+PROMPT_RESPONSE = """"""
+
+
+RESPONSE_FORMAT = """"""
+
+
+PROMPT_SEP = SeparatorStyle.SINGLE.value
+
+PROMPT_NEED_NEED_STREAM_OUT = False
+
+prompt = PromptTemplate(
+ template_scene=ChatScene.ExtractEntity.value(),
+ input_variables=["text"],
+ response_format="",
+ template_define=PROMPT_SCENE_DEFINE,
+ template=_DEFAULT_TEMPLATE + PROMPT_RESPONSE,
+ stream_out=PROMPT_NEED_NEED_STREAM_OUT,
+ output_parser=ExtractEntityParser(
+ sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
+ ),
+)
+
+CFG.prompt_template_registry.register(prompt, is_default=True)
diff --git a/pilot/scene/chat_knowledge/extract_triplet/__init__.py b/pilot/scene/chat_knowledge/extract_triplet/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pilot/scene/chat_knowledge/extract_triplet/chat.py b/pilot/scene/chat_knowledge/extract_triplet/chat.py
new file mode 100644
index 000000000..11fe871ab
--- /dev/null
+++ b/pilot/scene/chat_knowledge/extract_triplet/chat.py
@@ -0,0 +1,35 @@
+from typing import Dict
+
+from pilot.scene.base_chat import BaseChat
+from pilot.scene.base import ChatScene
+from pilot.configs.config import Config
+
+from pilot.scene.chat_knowledge.extract_triplet.prompt import prompt
+
+CFG = Config()
+
+
+class ExtractTriplet(BaseChat):
+ chat_scene: str = ChatScene.ExtractTriplet.value()
+
+ """extracting triplets by llm"""
+
+ def __init__(self, chat_param: Dict):
+ """ """
+ chat_param["chat_mode"] = ChatScene.ExtractTriplet
+ super().__init__(
+ chat_param=chat_param,
+ )
+
+ self.user_input = chat_param["current_user_input"]
+ self.extract_mode = chat_param["select_param"]
+
+ def generate_input_values(self):
+ input_values = {
+ "text": self.user_input,
+ }
+ return input_values
+
+ @property
+ def chat_type(self) -> str:
+ return ChatScene.ExtractTriplet.value
diff --git a/pilot/scene/chat_knowledge/extract_triplet/out_parser.py b/pilot/scene/chat_knowledge/extract_triplet/out_parser.py
new file mode 100644
index 000000000..75606bd0f
--- /dev/null
+++ b/pilot/scene/chat_knowledge/extract_triplet/out_parser.py
@@ -0,0 +1,57 @@
+import json
+import logging
+import re
+from typing import List, Tuple
+
+from pilot.out_parser.base import BaseOutputParser, T
+from pilot.configs.config import Config
+
+CFG = Config()
+
+
+logger = logging.getLogger(__name__)
+
+
+class ExtractTripleParser(BaseOutputParser):
+ def __init__(self, sep: str, is_stream_out: bool):
+ super().__init__(sep=sep, is_stream_out=is_stream_out)
+
+ def parse_prompt_response(
+ self, response, max_length: int = 128
+ ) -> List[Tuple[str, str, str]]:
+ # clean_str = super().parse_prompt_response(response)
+ print("clean prompt response:", response)
+
+ if response.startswith("Triplets:"):
+ response = response[len("Triplets:") :]
+ pattern = r"\([^()]+\)"
+ response = re.findall(pattern, response)
+ # response = response.strip().split("\n")
+ print("parse prompt response:", response)
+ results = []
+ for text in response:
+ if not text or text[0] != "(" or text[-1] != ")":
+ # skip empty lines and non-triplets
+ continue
+ tokens = text[1:-1].split(",")
+ if len(tokens) != 3:
+ continue
+
+ if any(len(s.encode("utf-8")) > max_length for s in tokens):
+ # We count byte-length instead of len() for UTF-8 chars,
+ # will skip if any of the tokens are too long.
+ # This is normally due to a poorly formatted triplet
+ # extraction, in more serious KG building cases
+ # we'll need NLP models to better extract triplets.
+ continue
+
+ subject, predicate, obj = map(str.strip, tokens)
+ if not subject or not predicate or not obj:
+ # skip partial triplets
+ continue
+ results.append((subject.lower(), predicate.lower(), obj.lower()))
+ return results
+
+ def parse_view_response(self, speak, data) -> str:
+ ### tool out data to table view
+ return data
diff --git a/pilot/scene/chat_knowledge/extract_triplet/prompt.py b/pilot/scene/chat_knowledge/extract_triplet/prompt.py
new file mode 100644
index 000000000..dd391bce8
--- /dev/null
+++ b/pilot/scene/chat_knowledge/extract_triplet/prompt.py
@@ -0,0 +1,57 @@
+import json
+
+from pilot.prompts.prompt_new import PromptTemplate
+from pilot.configs.config import Config
+from pilot.scene.base import ChatScene
+from pilot.common.schema import SeparatorStyle
+
+from pilot.scene.chat_knowledge.extract_triplet.out_parser import (
+ ExtractTripleParser,
+)
+
+
+CFG = Config()
+
+PROMPT_SCENE_DEFINE = """"""
+
+_DEFAULT_TEMPLATE = """
+"Some text is provided below. Given the text, extract up to 10"
+ "knowledge triplets in the form of (subject, predicate, object). Avoid stopwords.\n"
+ "---------------------\n"
+ "Example:"
+ "Text: Alice is Bob's mother."
+ "Triplets:\n(Alice, is mother of, Bob)\n"
+ "Text: Philz is a coffee shop founded in Berkeley in 1982.\n"
+ "Triplets:\n"
+ "(Philz, is, coffee shop)\n"
+ "(Philz, founded in, Berkeley)\n"
+ "(Philz, founded in, 1982)\n"
+ "---------------------\n"
+ "Text: {text}\n"
+ "Triplets:\n"
+ ensure Respond in the following List(Tuple) format:
+ '(Stephen Curry, plays for, Golden State Warriors)\n(Stephen Curry, known for, shooting skills)\n(Stephen Curry, attended, Davidson College)\n(Stephen Curry, led, team to success)'
+"""
+PROMPT_RESPONSE = """"""
+
+
+RESPONSE_FORMAT = """"""
+
+
+PROMPT_SEP = SeparatorStyle.SINGLE.value
+
+PROMPT_NEED_NEED_STREAM_OUT = False
+
+prompt = PromptTemplate(
+ template_scene=ChatScene.ExtractTriplet.value(),
+ input_variables=["text"],
+ response_format="",
+ template_define=PROMPT_SCENE_DEFINE,
+ template=_DEFAULT_TEMPLATE + PROMPT_RESPONSE,
+ stream_out=PROMPT_NEED_NEED_STREAM_OUT,
+ output_parser=ExtractTripleParser(
+ sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
+ ),
+)
+
+CFG.prompt_template_registry.register(prompt, is_default=True)
diff --git a/pilot/scene/chat_knowledge/v1/chat.py b/pilot/scene/chat_knowledge/v1/chat.py
index ebecddd19..c381546f8 100644
--- a/pilot/scene/chat_knowledge/v1/chat.py
+++ b/pilot/scene/chat_knowledge/v1/chat.py
@@ -88,25 +88,7 @@ class ChatKnowledge(BaseChat):
if self.space_context:
self.prompt_template.template_define = self.space_context["prompt"]["scene"]
self.prompt_template.template = self.space_context["prompt"]["template"]
- # docs = self.rag_engine.search(query=self.current_user_input)
- # import httpx
- # with httpx.Client() as client:
- # request = client.build_request(
- # "post",
- # "http://127.0.0.1/api/knowledge/entities/extract",
- # json="application/json", # using json for data to ensure it sends as application/json
- # params={"text": self.current_user_input},
- # headers={},
- # )
- #
- # response = client.send(request)
- # if response.status_code != 200:
- # error_msg = f"request /api/knowledge/entities/extract failed, error: {response.text}"
- # raise Exception(error_msg)
- # docs = response.json()
- # import requests
- # docs = requests.post("http://127.0.0.1:5000/api/knowledge/entities/extract", headers={}, json={"text": self.current_user_input})
-
+ docs = self.rag_engine.search(query=self.current_user_input)
docs = self.knowledge_embedding_client.similar_search(
self.current_user_input, self.top_k
)
diff --git a/pilot/server/knowledge/api.py b/pilot/server/knowledge/api.py
index e0f31031e..8e5e52b58 100644
--- a/pilot/server/knowledge/api.py
+++ b/pilot/server/knowledge/api.py
@@ -205,12 +205,6 @@ def similar_query(space_name: str, query_request: KnowledgeQueryRequest):
async def entity_extract(request: EntityExtractRequest):
logger.info(f"Received params: {request}")
try:
- # from pilot.graph_engine.graph_factory import RAGGraphFactory
- # from pilot.component import ComponentType
- # rag_engine = CFG.SYSTEM_APP.get_component(
- # ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory
- # ).create()
- # return Result.succ(await rag_engine.search(request.text))
from pilot.scene.base import ChatScene
from pilot.common.chat_util import llm_chat_response_nostream
import uuid
@@ -222,11 +216,6 @@ async def entity_extract(request: EntityExtractRequest):
"model_name": request.model_name,
}
- # import nest_asyncio
- # nest_asyncio.apply()
- # loop = asyncio.get_event_loop()
- # loop.stop()
- # loop = utils.get_or_create_event_loop()
res = await llm_chat_response_nostream(
ChatScene.ExtractEntity.value(), **{"chat_param": chat_param}
)
From b63fa2dfe10a50c970fa402b819bfef98dec2138 Mon Sep 17 00:00:00 2001
From: aries_ckt <916701291@qq.com>
Date: Mon, 16 Oct 2023 14:09:04 +0800
Subject: [PATCH 05/57] feat:rag graph
---
pilot/common/chat_util.py | 2 +-
pilot/graph_engine/graph_engine.py | 17 ++------
pilot/graph_engine/graph_search.py | 43 +++++++++++--------
pilot/graph_engine/node.py | 1 +
pilot/graph_engine/search.py | 6 +--
pilot/scene/base_chat.py | 14 ++++--
.../chat_knowledge/extract_entity/chat.py | 2 +-
.../chat_knowledge/extract_triplet/chat.py | 2 +-
pilot/scene/chat_knowledge/v1/chat.py | 12 +++---
pilot/server/knowledge/service.py | 3 --
10 files changed, 51 insertions(+), 51 deletions(-)
diff --git a/pilot/common/chat_util.py b/pilot/common/chat_util.py
index 159db99d0..0de0b9bda 100644
--- a/pilot/common/chat_util.py
+++ b/pilot/common/chat_util.py
@@ -9,7 +9,7 @@ chat_factory = ChatFactory()
async def llm_chat_response_nostream(chat_scene: str, **chat_param):
- """ llm_chat_response_nostream """
+ """llm_chat_response_nostream"""
chat: BaseChat = chat_factory.get_implementation(chat_scene, **chat_param)
res = await chat.get_llm_response()
return res
diff --git a/pilot/graph_engine/graph_engine.py b/pilot/graph_engine/graph_engine.py
index c20142123..04c4f54d9 100644
--- a/pilot/graph_engine/graph_engine.py
+++ b/pilot/graph_engine/graph_engine.py
@@ -45,8 +45,7 @@ class RAGGraphEngine:
**kwargs: Any,
) -> None:
"""Initialize params."""
- # from llama_index.graph_stores import SimpleGraphStore
- # from llama_index.graph_stores.types import GraphStore
+ from llama_index.graph_stores import SimpleGraphStore
# need to set parameters before building index in base class.
self.knowledge_source = knowledge_source
@@ -55,8 +54,8 @@ class RAGGraphEngine:
self.text_splitter = text_splitter
self.index_struct = index_struct
self.include_embeddings = include_embeddings
- # self.graph_store = graph_store or SimpleGraphStore()
- self.graph_store = graph_store
+ self.graph_store = graph_store or SimpleGraphStore()
+ # self.graph_store = graph_store
self.max_triplets_per_chunk = max_triplets_per_chunk
self._max_object_length = max_object_length
self._extract_triplet_fn = extract_triplet_fn
@@ -103,14 +102,6 @@ class RAGGraphEngine:
)
)
return triplets
- # response = self._service_context.llm_predictor.predict(
- # self.kg_triple_extract_template,
- # text=text,
- # )
- # print(response, flush=True)
- # return self._parse_triplet_response(
- # response, max_length=self._max_object_length
- # )
def _build_index_from_docs(self, documents: List[Document]) -> KG:
"""Build the index from nodes."""
@@ -126,7 +117,6 @@ class RAGGraphEngine:
self.graph_store.upsert_triplet(*triplet)
index_struct.add_node([subj, obj], text_node)
-
return index_struct
def search(self, query):
@@ -134,4 +124,3 @@ class RAGGraphEngine:
graph_search = RAGGraphSearch(graph_engine=self)
return graph_search.search(query)
-
diff --git a/pilot/graph_engine/graph_search.py b/pilot/graph_engine/graph_search.py
index 9b06fd234..d1f6a4519 100644
--- a/pilot/graph_engine/graph_search.py
+++ b/pilot/graph_engine/graph_search.py
@@ -4,6 +4,8 @@ from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from typing import List, Optional, Dict, Any, Set, Callable
+from langchain.schema import Document
+
from pilot.graph_engine.node import BaseNode, TextNode, NodeWithScore
from pilot.graph_engine.search import BaseSearch, SearchMode
from pilot.utils import utils
@@ -67,14 +69,14 @@ class RAGGraphSearch(BaseSearch):
logger.warn(f"can not to find graph schema: {e}")
self._graph_schema = ""
- def _extract_subject_entities(self, query_str: str) -> Set[str]:
+ async def _extract_subject_entities(self, query_str: str) -> Set[str]:
"""extract subject entities."""
if self.extract_subject_entities_fn is not None:
- return self.extract_subject_entities_fn(query_str)
+ return await self.extract_subject_entities_fn(query_str)
else:
- return self._extract_entities_by_llm(query_str)
+ return await self._extract_entities_by_llm(query_str)
- def _extract_entities_by_llm(self, text: str) -> Set[str]:
+ async def _extract_entities_by_llm(self, text: str) -> Set[str]:
"""extract subject entities from text by llm"""
from pilot.scene.base import ChatScene
from pilot.common.chat_util import llm_chat_response_nostream
@@ -86,21 +88,23 @@ class RAGGraphSearch(BaseSearch):
"select_param": "entity",
"model_name": self.model_name,
}
- loop = utils.get_or_create_event_loop()
- entities = loop.run_until_complete(
- llm_chat_response_nostream(
- ChatScene.ExtractEntity.value(), **{"chat_param": chat_param}
- )
+ # loop = utils.get_or_create_event_loop()
+ # entities = loop.run_until_complete(
+ # llm_chat_response_nostream(
+ # ChatScene.ExtractEntity.value(), **{"chat_param": chat_param}
+ # )
+ # )
+ return await llm_chat_response_nostream(
+ ChatScene.ExtractEntity.value(), **{"chat_param": chat_param}
)
- return entities
- def _search(
+ async def _search(
self,
query_str: str,
- ) -> List[NodeWithScore]:
+ ) -> List[Document]:
"""Get nodes for response."""
node_visited = set()
- keywords = self._extract_subject_entities(query_str)
+ keywords = await self._extract_subject_entities(query_str)
print(f"extract entities: {keywords}\n")
rel_texts = []
cur_rel_map = {}
@@ -114,8 +118,8 @@ class RAGGraphSearch(BaseSearch):
if node_id in node_visited:
continue
- if self._include_text:
- chunk_indices_count[node_id] += 1
+ # if self._include_text:
+ # chunk_indices_count[node_id] += 1
node_visited.add(node_id)
@@ -179,8 +183,11 @@ class RAGGraphSearch(BaseSearch):
sorted_nodes_with_scores.append(
NodeWithScore(node=rel_text_node, score=DEFAULT_NODE_SCORE)
)
-
- return sorted_nodes_with_scores
+ docs = [
+ Document(page_content=node.text, metadata=node.metadata)
+ for node in sorted_nodes_with_scores
+ ]
+ return docs
def _get_metadata_for_response(
self, nodes: List[BaseNode]
@@ -190,4 +197,4 @@ class RAGGraphSearch(BaseSearch):
if node.metadata is None or "kg_rel_map" not in node.metadata:
continue
return node.metadata
- raise ValueError("kg_rel_map must be found in at least one Node.")
\ No newline at end of file
+ raise ValueError("kg_rel_map must be found in at least one Node.")
diff --git a/pilot/graph_engine/node.py b/pilot/graph_engine/node.py
index 6f6d45ae4..b23681010 100644
--- a/pilot/graph_engine/node.py
+++ b/pilot/graph_engine/node.py
@@ -21,6 +21,7 @@ WRAP_WIDTH = 70
class BaseComponent(BaseModel):
"""Base component object to caputure class names."""
+
"""reference llama-index"""
@classmethod
diff --git a/pilot/graph_engine/search.py b/pilot/graph_engine/search.py
index 8db837278..297620b00 100644
--- a/pilot/graph_engine/search.py
+++ b/pilot/graph_engine/search.py
@@ -23,7 +23,7 @@ class SearchMode(str, Enum):
class BaseSearch(ABC):
"""Base Search."""
- def search(self, query: str):
+ async def search(self, query: str):
"""Retrieve nodes given query.
Args:
@@ -32,10 +32,10 @@ class BaseSearch(ABC):
"""
# if isinstance(query, str):
- return self._search(query)
+ return await self._search(query)
@abstractmethod
- def _search(self, query: str):
+ async def _search(self, query: str):
"""search nodes given query.
Implemented by the user.
diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py
index c1880f48a..bea00bde3 100644
--- a/pilot/scene/base_chat.py
+++ b/pilot/scene/base_chat.py
@@ -105,8 +105,14 @@ class BaseChat(ABC):
speak_to_user = prompt_define_response
return speak_to_user
- def __call_base(self):
- input_values = self.generate_input_values()
+ async def __call_base(self):
+ import inspect
+
+ input_values = (
+ await self.generate_input_values()
+ if inspect.isawaitable(self.generate_input_values())
+ else self.generate_input_values()
+ )
### Chat sequence advance
self.current_message.chat_order = len(self.history_message) + 1
self.current_message.add_user_message(self.current_user_input)
@@ -146,7 +152,7 @@ class BaseChat(ABC):
async def stream_call(self):
# TODO Retry when server connection error
- payload = self.__call_base()
+ payload = await self.__call_base()
self.skip_echo_len = len(payload.get("prompt").replace("", " ")) + 11
logger.info(f"Request: \n{payload}")
@@ -234,7 +240,7 @@ class BaseChat(ABC):
return self.current_ai_response()
async def get_llm_response(self):
- payload = self.__call_base()
+ payload = await self.__call_base()
logger.info(f"Request: \n{payload}")
ai_response_text = ""
try:
diff --git a/pilot/scene/chat_knowledge/extract_entity/chat.py b/pilot/scene/chat_knowledge/extract_entity/chat.py
index bb52961b5..373bb4e5d 100644
--- a/pilot/scene/chat_knowledge/extract_entity/chat.py
+++ b/pilot/scene/chat_knowledge/extract_entity/chat.py
@@ -24,7 +24,7 @@ class ExtractEntity(BaseChat):
self.user_input = chat_param["current_user_input"]
self.extract_mode = chat_param["select_param"]
- def generate_input_values(self):
+ async def generate_input_values(self):
input_values = {
"text": self.user_input,
}
diff --git a/pilot/scene/chat_knowledge/extract_triplet/chat.py b/pilot/scene/chat_knowledge/extract_triplet/chat.py
index 11fe871ab..28152b92e 100644
--- a/pilot/scene/chat_knowledge/extract_triplet/chat.py
+++ b/pilot/scene/chat_knowledge/extract_triplet/chat.py
@@ -24,7 +24,7 @@ class ExtractTriplet(BaseChat):
self.user_input = chat_param["current_user_input"]
self.extract_mode = chat_param["select_param"]
- def generate_input_values(self):
+ async def generate_input_values(self):
input_values = {
"text": self.user_input,
}
diff --git a/pilot/scene/chat_knowledge/v1/chat.py b/pilot/scene/chat_knowledge/v1/chat.py
index c381546f8..ea7ca1922 100644
--- a/pilot/scene/chat_knowledge/v1/chat.py
+++ b/pilot/scene/chat_knowledge/v1/chat.py
@@ -64,7 +64,7 @@ class ChatKnowledge(BaseChat):
self.prompt_template.template_is_strict = False
async def stream_call(self):
- input_values = self.generate_input_values()
+ input_values = await self.generate_input_values()
# Source of knowledge file
relations = input_values.get("relations")
last_output = None
@@ -84,14 +84,14 @@ class ChatKnowledge(BaseChat):
)
yield last_output
- def generate_input_values(self):
+ async def generate_input_values(self):
if self.space_context:
self.prompt_template.template_define = self.space_context["prompt"]["scene"]
self.prompt_template.template = self.space_context["prompt"]["template"]
- docs = self.rag_engine.search(query=self.current_user_input)
- docs = self.knowledge_embedding_client.similar_search(
- self.current_user_input, self.top_k
- )
+ docs = await self.rag_engine.search(query=self.current_user_input)
+ # docs = self.knowledge_embedding_client.similar_search(
+ # self.current_user_input, self.top_k
+ # )
if not docs:
raise ValueError(
"you have no knowledge space, please add your knowledge space"
diff --git a/pilot/server/knowledge/service.py b/pilot/server/knowledge/service.py
index f4150fa73..95949d319 100644
--- a/pilot/server/knowledge/service.py
+++ b/pilot/server/knowledge/service.py
@@ -261,9 +261,6 @@ class KnowledgeService:
# docs = engine.search(
# "Comparing Curry and James in terms of their positions, playing styles, and achievements in the NBA"
# )
- embedding_factory = CFG.SYSTEM_APP.get_component(
- "embedding_factory", EmbeddingFactory
- )
# update document status
doc.status = SyncStatus.RUNNING.name
doc.chunk_size = len(chunk_docs)
From 68c9010e5c6e8ef464e139b0b479c76ca534e002 Mon Sep 17 00:00:00 2001
From: aries_ckt <916701291@qq.com>
Date: Mon, 16 Oct 2023 21:14:20 +0800
Subject: [PATCH 06/57] feat:rag graph
---
pilot/graph_engine/graph_engine.py | 49 ++++++++++++++++++++++------
pilot/scene/base_chat.py | 5 ++-
pilot/vector_store/weaviate_store.py | 8 ++---
3 files changed, 43 insertions(+), 19 deletions(-)
diff --git a/pilot/graph_engine/graph_engine.py b/pilot/graph_engine/graph_engine.py
index 04c4f54d9..e34baba79 100644
--- a/pilot/graph_engine/graph_engine.py
+++ b/pilot/graph_engine/graph_engine.py
@@ -106,21 +106,50 @@ class RAGGraphEngine:
def _build_index_from_docs(self, documents: List[Document]) -> KG:
"""Build the index from nodes."""
index_struct = self.index_struct_cls()
- for doc in documents:
- triplets = self._extract_triplets(doc.page_content)
- if len(triplets) == 0:
- continue
- text_node = TextNode(text=doc.page_content, metadata=doc.metadata)
- logger.info(f"extracted knowledge triplets: {triplets}")
- for triplet in triplets:
- subj, _, obj = triplet
- self.graph_store.upsert_triplet(*triplet)
- index_struct.add_node([subj, obj], text_node)
+ num_threads = 5
+ chunk_size = len(documents) if (len(documents) < num_threads) else len(documents) / num_threads
+ import concurrent
+ future_tasks = []
+ with concurrent.futures.ThreadPoolExecutor() as executor:
+ for i in range(num_threads):
+ start = i * chunk_size
+ end = start + chunk_size if i < num_threads - 1 else None
+ future_tasks.append(executor.submit(self._extract_triplets_task, documents[start:end][0], index_struct))
+
+ result = [future.result() for future in future_tasks]
return index_struct
+ # for doc in documents:
+ # triplets = self._extract_triplets(doc.page_content)
+ # if len(triplets) == 0:
+ # continue
+ # text_node = TextNode(text=doc.page_content, metadata=doc.metadata)
+ # logger.info(f"extracted knowledge triplets: {triplets}")
+ # for triplet in triplets:
+ # subj, _, obj = triplet
+ # self.graph_store.upsert_triplet(*triplet)
+ # index_struct.add_node([subj, obj], text_node)
+ #
+ # return index_struct
def search(self, query):
from pilot.graph_engine.graph_search import RAGGraphSearch
graph_search = RAGGraphSearch(graph_engine=self)
return graph_search.search(query)
+
+ def _extract_triplets_task(self, doc, index_struct):
+ import threading
+ thread_id = threading.get_ident()
+ print(f"current thread-{thread_id} begin extract triplets task")
+ triplets = self._extract_triplets(doc.page_content)
+ if len(triplets) == 0:
+ triplets = []
+ text_node = TextNode(text=doc.page_content, metadata=doc.metadata)
+ logger.info(f"extracted knowledge triplets: {triplets}")
+ print(f"current thread-{thread_id} end extract triplets tasks, triplets-{triplets}")
+ for triplet in triplets:
+ subj, _, obj = triplet
+ self.graph_store.upsert_triplet(*triplet)
+ self.graph_store.upsert_triplet(*triplet)
+ index_struct.add_node([subj, obj], text_node)
diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py
index bea00bde3..58a0becc9 100644
--- a/pilot/scene/base_chat.py
+++ b/pilot/scene/base_chat.py
@@ -107,10 +107,9 @@ class BaseChat(ABC):
async def __call_base(self):
import inspect
-
input_values = (
await self.generate_input_values()
- if inspect.isawaitable(self.generate_input_values())
+ if inspect.isawaitable(self.generate_input_values)
else self.generate_input_values()
)
### Chat sequence advance
@@ -181,7 +180,7 @@ class BaseChat(ABC):
span.end(metadata={"error": str(e)})
async def nostream_call(self):
- payload = self.__call_base()
+ payload = await self.__call_base()
logger.info(f"Request: \n{payload}")
ai_response_text = ""
span = root_tracer.start_span(
diff --git a/pilot/vector_store/weaviate_store.py b/pilot/vector_store/weaviate_store.py
index 795cf21f9..a8e126eb5 100644
--- a/pilot/vector_store/weaviate_store.py
+++ b/pilot/vector_store/weaviate_store.py
@@ -1,11 +1,7 @@
import os
-import json
import logging
-import weaviate
+#import weaviate
from langchain.schema import Document
-from langchain.vectorstores import Weaviate
-from weaviate.exceptions import WeaviateBaseError
-
from pilot.configs.config import Config
from pilot.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH
from pilot.vector_store.base import VectorStoreBase
@@ -72,7 +68,7 @@ class WeaviateStore(VectorStoreBase):
if self.vector_store_client.schema.get(self.vector_name):
return True
return False
- except WeaviateBaseError as e:
+ except Exception as e:
logger.error("vector_name_exists error", e.message)
return False
From f93af985eda442e9f960bcb9d9eeec34276427ea Mon Sep 17 00:00:00 2001
From: aries_ckt <916701291@qq.com>
Date: Tue, 17 Oct 2023 10:35:46 +0800
Subject: [PATCH 07/57] feat:rag graph
---
pilot/graph_engine/graph_engine.py | 3 ++-
pilot/scene/base_chat.py | 2 +-
pilot/scene/chat_knowledge/extract_entity/chat.py | 2 +-
pilot/scene/chat_knowledge/extract_triplet/chat.py | 2 +-
4 files changed, 5 insertions(+), 4 deletions(-)
diff --git a/pilot/graph_engine/graph_engine.py b/pilot/graph_engine/graph_engine.py
index e34baba79..80cbab066 100644
--- a/pilot/graph_engine/graph_engine.py
+++ b/pilot/graph_engine/graph_engine.py
@@ -107,7 +107,7 @@ class RAGGraphEngine:
"""Build the index from nodes."""
index_struct = self.index_struct_cls()
num_threads = 5
- chunk_size = len(documents) if (len(documents) < num_threads) else len(documents) / num_threads
+ chunk_size = len(documents) if (len(documents) < num_threads) else len(documents) // num_threads
import concurrent
future_tasks = []
@@ -132,6 +132,7 @@ class RAGGraphEngine:
#
# return index_struct
+
def search(self, query):
from pilot.graph_engine.graph_search import RAGGraphSearch
diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py
index 58a0becc9..a1d6d9f08 100644
--- a/pilot/scene/base_chat.py
+++ b/pilot/scene/base_chat.py
@@ -109,7 +109,7 @@ class BaseChat(ABC):
import inspect
input_values = (
await self.generate_input_values()
- if inspect.isawaitable(self.generate_input_values)
+ if inspect.isawaitable(self.generate_input_values())
else self.generate_input_values()
)
### Chat sequence advance
diff --git a/pilot/scene/chat_knowledge/extract_entity/chat.py b/pilot/scene/chat_knowledge/extract_entity/chat.py
index 373bb4e5d..bb52961b5 100644
--- a/pilot/scene/chat_knowledge/extract_entity/chat.py
+++ b/pilot/scene/chat_knowledge/extract_entity/chat.py
@@ -24,7 +24,7 @@ class ExtractEntity(BaseChat):
self.user_input = chat_param["current_user_input"]
self.extract_mode = chat_param["select_param"]
- async def generate_input_values(self):
+ def generate_input_values(self):
input_values = {
"text": self.user_input,
}
diff --git a/pilot/scene/chat_knowledge/extract_triplet/chat.py b/pilot/scene/chat_knowledge/extract_triplet/chat.py
index 28152b92e..11fe871ab 100644
--- a/pilot/scene/chat_knowledge/extract_triplet/chat.py
+++ b/pilot/scene/chat_knowledge/extract_triplet/chat.py
@@ -24,7 +24,7 @@ class ExtractTriplet(BaseChat):
self.user_input = chat_param["current_user_input"]
self.extract_mode = chat_param["select_param"]
- async def generate_input_values(self):
+ def generate_input_values(self):
input_values = {
"text": self.user_input,
}
From aff0553b7ecb5d2173b45995fdf08ca1409ca75f Mon Sep 17 00:00:00 2001
From: aries_ckt <916701291@qq.com>
Date: Thu, 19 Oct 2023 09:40:05 +0800
Subject: [PATCH 08/57] style:fmt
---
pilot/graph_engine/graph_engine.py | 21 +++++++++++++++++----
pilot/scene/base_chat.py | 1 +
pilot/server/knowledge/service.py | 3 ---
pilot/vector_store/weaviate_store.py | 3 ++-
4 files changed, 20 insertions(+), 8 deletions(-)
diff --git a/pilot/graph_engine/graph_engine.py b/pilot/graph_engine/graph_engine.py
index 80cbab066..a50ebabdc 100644
--- a/pilot/graph_engine/graph_engine.py
+++ b/pilot/graph_engine/graph_engine.py
@@ -107,15 +107,26 @@ class RAGGraphEngine:
"""Build the index from nodes."""
index_struct = self.index_struct_cls()
num_threads = 5
- chunk_size = len(documents) if (len(documents) < num_threads) else len(documents) // num_threads
+ chunk_size = (
+ len(documents)
+ if (len(documents) < num_threads)
+ else len(documents) // num_threads
+ )
import concurrent
+
future_tasks = []
with concurrent.futures.ThreadPoolExecutor() as executor:
for i in range(num_threads):
start = i * chunk_size
end = start + chunk_size if i < num_threads - 1 else None
- future_tasks.append(executor.submit(self._extract_triplets_task, documents[start:end][0], index_struct))
+ future_tasks.append(
+ executor.submit(
+ self._extract_triplets_task,
+ documents[start:end][0],
+ index_struct,
+ )
+ )
result = [future.result() for future in future_tasks]
return index_struct
@@ -132,7 +143,6 @@ class RAGGraphEngine:
#
# return index_struct
-
def search(self, query):
from pilot.graph_engine.graph_search import RAGGraphSearch
@@ -141,6 +151,7 @@ class RAGGraphEngine:
def _extract_triplets_task(self, doc, index_struct):
import threading
+
thread_id = threading.get_ident()
print(f"current thread-{thread_id} begin extract triplets task")
triplets = self._extract_triplets(doc.page_content)
@@ -148,7 +159,9 @@ class RAGGraphEngine:
triplets = []
text_node = TextNode(text=doc.page_content, metadata=doc.metadata)
logger.info(f"extracted knowledge triplets: {triplets}")
- print(f"current thread-{thread_id} end extract triplets tasks, triplets-{triplets}")
+ print(
+ f"current thread-{thread_id} end extract triplets tasks, triplets-{triplets}"
+ )
for triplet in triplets:
subj, _, obj = triplet
self.graph_store.upsert_triplet(*triplet)
diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py
index a1d6d9f08..10c89d620 100644
--- a/pilot/scene/base_chat.py
+++ b/pilot/scene/base_chat.py
@@ -107,6 +107,7 @@ class BaseChat(ABC):
async def __call_base(self):
import inspect
+
input_values = (
await self.generate_input_values()
if inspect.isawaitable(self.generate_input_values())
diff --git a/pilot/server/knowledge/service.py b/pilot/server/knowledge/service.py
index 95949d319..7bba99c0a 100644
--- a/pilot/server/knowledge/service.py
+++ b/pilot/server/knowledge/service.py
@@ -258,9 +258,6 @@ class KnowledgeService:
ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory
).create()
rag_engine.knowledge_graph(docs=chunk_docs)
- # docs = engine.search(
- # "Comparing Curry and James in terms of their positions, playing styles, and achievements in the NBA"
- # )
# update document status
doc.status = SyncStatus.RUNNING.name
doc.chunk_size = len(chunk_docs)
diff --git a/pilot/vector_store/weaviate_store.py b/pilot/vector_store/weaviate_store.py
index a8e126eb5..93816ea66 100644
--- a/pilot/vector_store/weaviate_store.py
+++ b/pilot/vector_store/weaviate_store.py
@@ -1,6 +1,7 @@
import os
import logging
-#import weaviate
+
+# import weaviate
from langchain.schema import Document
from pilot.configs.config import Config
from pilot.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH
From 39219a4fdce92e1c464e7ea9f56e7a4b7a2eabd1 Mon Sep 17 00:00:00 2001
From: aries_ckt <916701291@qq.com>
Date: Thu, 19 Oct 2023 12:02:27 +0800
Subject: [PATCH 09/57] feat:rag_graph
---
pilot/graph_engine/graph_engine.py | 16 +++++++++++-----
pilot/graph_engine/graph_search.py | 2 +-
2 files changed, 12 insertions(+), 6 deletions(-)
diff --git a/pilot/graph_engine/graph_engine.py b/pilot/graph_engine/graph_engine.py
index a50ebabdc..faf53aba1 100644
--- a/pilot/graph_engine/graph_engine.py
+++ b/pilot/graph_engine/graph_engine.py
@@ -129,6 +129,11 @@ class RAGGraphEngine:
)
result = [future.result() for future in future_tasks]
+ # for triplet in triplets:
+ # subj, _, obj = triplet
+ # self.graph_store.upsert_triplet(*triplet)
+ # self.graph_store.upsert_triplet(*triplet)
+ # index_struct.add_node([subj, obj], text_node)
return index_struct
# for doc in documents:
# triplets = self._extract_triplets(doc.page_content)
@@ -162,8 +167,9 @@ class RAGGraphEngine:
print(
f"current thread-{thread_id} end extract triplets tasks, triplets-{triplets}"
)
- for triplet in triplets:
- subj, _, obj = triplet
- self.graph_store.upsert_triplet(*triplet)
- self.graph_store.upsert_triplet(*triplet)
- index_struct.add_node([subj, obj], text_node)
+ return triplets
+ # for triplet in triplets:
+ # subj, _, obj = triplet
+ # self.graph_store.upsert_triplet(*triplet)
+ # self.graph_store.upsert_triplet(*triplet)
+ # index_struct.add_node([subj, obj], text_node)
diff --git a/pilot/graph_engine/graph_search.py b/pilot/graph_engine/graph_search.py
index d1f6a4519..fb883e48b 100644
--- a/pilot/graph_engine/graph_search.py
+++ b/pilot/graph_engine/graph_search.py
@@ -145,7 +145,7 @@ class RAGGraphSearch(BaseSearch):
if len(sorted_nodes_with_scores) == 0:
logger.info("> No nodes found by keywords, returning empty response.")
return [
- NodeWithScore(node=TextNode(text="No relationships found."), score=1.0)
+ Document(page_content="No relationships found.")
]
# add relationships as Node
From 724456dc3e82e9bc53d9cdf21bbd485f4a48e4eb Mon Sep 17 00:00:00 2001
From: aries_ckt <916701291@qq.com>
Date: Wed, 25 Oct 2023 21:18:37 +0800
Subject: [PATCH 10/57] feat:extract summary
---
pilot/graph_engine/graph_engine.py | 115 ++++++++++++++++++-----------
pilot/graph_engine/graph_search.py | 19 +++--
pilot/scene/base.py | 7 ++
pilot/scene/chat_factory.py | 1 +
pilot/server/knowledge/service.py | 60 ++++++++++++---
5 files changed, 141 insertions(+), 61 deletions(-)
diff --git a/pilot/graph_engine/graph_engine.py b/pilot/graph_engine/graph_engine.py
index faf53aba1..491a8625c 100644
--- a/pilot/graph_engine/graph_engine.py
+++ b/pilot/graph_engine/graph_engine.py
@@ -15,10 +15,12 @@ logger = logging.getLogger(__name__)
class RAGGraphEngine:
"""Knowledge RAG Graph Engine.
- Build a KG by extracting triplets, and leveraging the KG during query-time.
+ Build a RAG Graph Client can extract triplets and insert into graph store.
Args:
knowledge_type (Optional[str]): Default: KnowledgeType.DOCUMENT.value
extracting triplets.
+ knowledge_source (Optional[str]):
+ model_name (Optional[str]): llm model name
graph_store (Optional[GraphStore]): The graph store to use.refrence:llama-index
include_embeddings (bool): Whether to include embeddings in the index.
Defaults to False.
@@ -104,37 +106,64 @@ class RAGGraphEngine:
return triplets
def _build_index_from_docs(self, documents: List[Document]) -> KG:
- """Build the index from nodes."""
+ """Build the index from nodes.
+ Args:documents:List[Document]
+ """
index_struct = self.index_struct_cls()
- num_threads = 5
- chunk_size = (
- len(documents)
- if (len(documents) < num_threads)
- else len(documents) // num_threads
- )
-
- import concurrent
-
- future_tasks = []
- with concurrent.futures.ThreadPoolExecutor() as executor:
- for i in range(num_threads):
- start = i * chunk_size
- end = start + chunk_size if i < num_threads - 1 else None
- future_tasks.append(
- executor.submit(
- self._extract_triplets_task,
- documents[start:end][0],
- index_struct,
- )
- )
-
- result = [future.result() for future in future_tasks]
+ triplets = []
+ for doc in documents:
+ trips = self._extract_triplets_task([doc], index_struct)
+ triplets.extend(trips)
+ print(triplets)
+ text_node = TextNode(text=doc.page_content, metadata=doc.metadata)
+ for triplet in triplets:
+ subj, _, obj = triplet
+ self.graph_store.upsert_triplet(*triplet)
+ index_struct.add_node([subj, obj], text_node)
+ return index_struct
+ # num_threads = 5
+ # chunk_size = (
+ # len(documents)
+ # if (len(documents) < num_threads)
+ # else len(documents) // num_threads
+ # )
+ #
+ # import concurrent
+ # triples = []
+ # future_tasks = []
+ # with concurrent.futures.ThreadPoolExecutor() as executor:
+ # for i in range(num_threads):
+ # start = i * chunk_size
+ # end = start + chunk_size if i < num_threads - 1 else None
+ # # doc = documents[start:end]
+ # future_tasks.append(
+ # executor.submit(
+ # self._extract_triplets_task,
+ # documents[start:end],
+ # index_struct,
+ # )
+ # )
+ # # for doc in documents[start:end]:
+ # # future_tasks.append(
+ # # executor.submit(
+ # # self._extract_triplets_task,
+ # # doc,
+ # # index_struct,
+ # # )
+ # # )
+ #
+ # # result = [future.result() for future in future_tasks]
+ # completed_futures, _ = concurrent.futures.wait(future_tasks, return_when=concurrent.futures.ALL_COMPLETED)
+ # for future in completed_futures:
+ # # 获取已完成的future的结果并添加到results列表中
+ # result = future.result()
+ # triplets.extend(result)
+ # print(f"total triplets-{triples}")
# for triplet in triplets:
# subj, _, obj = triplet
# self.graph_store.upsert_triplet(*triplet)
- # self.graph_store.upsert_triplet(*triplet)
- # index_struct.add_node([subj, obj], text_node)
- return index_struct
+ # # index_struct.add_node([subj, obj], text_node)
+ # return index_struct
# for doc in documents:
# triplets = self._extract_triplets(doc.page_content)
# if len(triplets) == 0:
@@ -154,20 +183,22 @@ class RAGGraphEngine:
graph_search = RAGGraphSearch(graph_engine=self)
return graph_search.search(query)
- def _extract_triplets_task(self, doc, index_struct):
- import threading
-
- thread_id = threading.get_ident()
- print(f"current thread-{thread_id} begin extract triplets task")
- triplets = self._extract_triplets(doc.page_content)
- if len(triplets) == 0:
- triplets = []
- text_node = TextNode(text=doc.page_content, metadata=doc.metadata)
- logger.info(f"extracted knowledge triplets: {triplets}")
- print(
- f"current thread-{thread_id} end extract triplets tasks, triplets-{triplets}"
- )
- return triplets
+ def _extract_triplets_task(self, docs, index_struct):
+ triple_results = []
+ for doc in docs:
+ import threading
+ thread_id = threading.get_ident()
+ print(f"current thread-{thread_id} begin extract triplets task")
+ triplets = self._extract_triplets(doc.page_content)
+ if len(triplets) == 0:
+ triplets = []
+ text_node = TextNode(text=doc.page_content, metadata=doc.metadata)
+ logger.info(f"extracted knowledge triplets: {triplets}")
+ print(
+ f"current thread-{thread_id} end extract triplets tasks, triplets-{triplets}"
+ )
+ triple_results.extend(triplets)
+ return triple_results
# for triplet in triplets:
# subj, _, obj = triplet
# self.graph_store.upsert_triplet(*triplet)
diff --git a/pilot/graph_engine/graph_search.py b/pilot/graph_engine/graph_search.py
index fb883e48b..f3025be85 100644
--- a/pilot/graph_engine/graph_search.py
+++ b/pilot/graph_engine/graph_search.py
@@ -8,7 +8,6 @@ from langchain.schema import Document
from pilot.graph_engine.node import BaseNode, TextNode, NodeWithScore
from pilot.graph_engine.search import BaseSearch, SearchMode
-from pilot.utils import utils
logger = logging.getLogger(__name__)
DEFAULT_NODE_SCORE = 1000.0
@@ -113,15 +112,15 @@ class RAGGraphSearch(BaseSearch):
for keyword in keywords:
keyword = keyword.lower()
subjs = set((keyword,))
- node_ids = self._index_struct.search_node_by_keyword(keyword)
- for node_id in node_ids[:GLOBAL_EXPLORE_NODE_LIMIT]:
- if node_id in node_visited:
- continue
-
- # if self._include_text:
- # chunk_indices_count[node_id] += 1
-
- node_visited.add(node_id)
+ # node_ids = self._index_struct.search_node_by_keyword(keyword)
+ # for node_id in node_ids[:GLOBAL_EXPLORE_NODE_LIMIT]:
+ # if node_id in node_visited:
+ # continue
+ #
+ # # if self._include_text:
+ # # chunk_indices_count[node_id] += 1
+ #
+ # node_visited.add(node_id)
rel_map = self._graph_store.get_rel_map(
list(subjs), self.graph_store_query_depth
diff --git a/pilot/scene/base.py b/pilot/scene/base.py
index 6abc9c937..5c98003d9 100644
--- a/pilot/scene/base.py
+++ b/pilot/scene/base.py
@@ -89,6 +89,13 @@ class ChatScene(Enum):
["Extract Select"],
True,
)
+ ExtractSummary = Scene(
+ "extract_summary",
+ "Extract Summary",
+ "Extract Summary",
+ ["Extract Select"],
+ True,
+ )
ExtractEntity = Scene(
"extract_entity", "Extract Entity", "Extract Entity", ["Extract Select"], True
)
diff --git a/pilot/scene/chat_factory.py b/pilot/scene/chat_factory.py
index de11332f5..a57855a2b 100644
--- a/pilot/scene/chat_factory.py
+++ b/pilot/scene/chat_factory.py
@@ -15,6 +15,7 @@ class ChatFactory(metaclass=Singleton):
from pilot.scene.chat_knowledge.inner_db_summary.chat import InnerChatDBSummary
from pilot.scene.chat_knowledge.extract_triplet.chat import ExtractTriplet
from pilot.scene.chat_knowledge.extract_entity.chat import ExtractEntity
+ from pilot.scene.chat_knowledge.summary.chat import ExtractSummary
from pilot.scene.chat_data.chat_excel.excel_analyze.chat import ChatExcel
from pilot.scene.chat_agent.chat import ChatAgent
diff --git a/pilot/server/knowledge/service.py b/pilot/server/knowledge/service.py
index ed8c2846e..4c1c41994 100644
--- a/pilot/server/knowledge/service.py
+++ b/pilot/server/knowledge/service.py
@@ -280,12 +280,6 @@ class KnowledgeService:
embedding_factory=embedding_factory,
)
chunk_docs = client.read()
- from pilot.graph_engine.graph_factory import RAGGraphFactory
-
- rag_engine = CFG.SYSTEM_APP.get_component(
- ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory
- ).create()
- rag_engine.knowledge_graph(docs=chunk_docs)
# update document status
doc.status = SyncStatus.RUNNING.name
doc.chunk_size = len(chunk_docs)
@@ -294,8 +288,8 @@ class KnowledgeService:
executor = CFG.SYSTEM_APP.get_component(
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
).create()
- executor.submit(self.async_doc_embedding, client, chunk_docs, doc)
-
+ executor.submit(self.async_knowledge_graph, chunk_docs, doc)
+ # executor.submit(self.async_doc_embedding, client, chunk_docs, doc)
logger.info(f"begin save document chunks, doc:{doc.doc_name}")
# save chunk details
chunk_entities = [
@@ -397,13 +391,40 @@ class KnowledgeService:
res.total = document_chunk_dao.get_document_chunks_count(query)
res.page = request.page
return res
+ def async_knowledge_graph(self, chunk_docs, doc):
+ """async document extract triplets and save into graph db
+ Args:
+ - chunk_docs: List[Document]
+ - doc: KnowledgeDocumentEntity
+ """
+ for doc in chunk_docs:
+ text = doc.page_content
+ self._llm_extract_summary(text)
+ logger.info(
+ f"async_knowledge_graph, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to graph store"
+ )
+ # try:
+ # from pilot.graph_engine.graph_factory import RAGGraphFactory
+ #
+ # rag_engine = CFG.SYSTEM_APP.get_component(
+ # ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory
+ # ).create()
+ # rag_engine.knowledge_graph(chunk_docs)
+ # doc.status = SyncStatus.FINISHED.name
+ # doc.result = "document build graph success"
+ # except Exception as e:
+ # doc.status = SyncStatus.FAILED.name
+ # doc.result = "document build graph failed" + str(e)
+ # logger.error(f"document build graph failed:{doc.doc_name}, {str(e)}")
+ return knowledge_document_dao.update_knowledge_document(doc)
+
def async_doc_embedding(self, client, chunk_docs, doc):
"""async document embedding into vector db
Args:
- client: EmbeddingEngine Client
- chunk_docs: List[Document]
- - doc: doc
+ - doc: KnowledgeDocumentEntity
"""
logger.info(
f"async_doc_embedding, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to vector store-{CFG.VECTOR_STORE_TYPE}"
@@ -461,3 +482,24 @@ class KnowledgeService:
if space.context is not None:
return json.loads(spaces[0].context)
return None
+
+ def _llm_extract_summary(self, doc: str):
+ """Extract triplets from text by llm"""
+ from pilot.scene.base import ChatScene
+ from pilot.common.chat_util import llm_chat_response_nostream
+ import uuid
+
+ chat_param = {
+ "chat_session_id": uuid.uuid1(),
+ "current_user_input": doc,
+ "select_param": "summery",
+ "model_name": "proxyllm",
+ }
+ from pilot.utils import utils
+ loop = utils.get_or_create_event_loop()
+ triplets = loop.run_until_complete(
+ llm_chat_response_nostream(
+ ChatScene.ExtractSummary.value(), **{"chat_param": chat_param}
+ )
+ )
+ return triplets
From 95d3f5222b9e3b707f7baf62c0ca1e7f6969aa28 Mon Sep 17 00:00:00 2001
From: FangYin Cheng
Date: Mon, 30 Oct 2023 11:48:05 +0800
Subject: [PATCH 11/57] feat(model): Support AquilaChat2-34B
---
.env.template | 9 ++
docs/getting_started/faq/llm/llm_faq.md | 28 ++--
.../getting_started/faq/llm/llm_faq.po | 126 +++++++++++-------
pilot/configs/config.py | 2 +
pilot/model/cluster/controller/controller.py | 6 +-
pilot/model/cluster/worker/default_worker.py | 2 +-
.../model/cluster/worker/embedding_worker.py | 2 +-
pilot/model/cluster/worker/manager.py | 8 +-
pilot/model/loader.py | 2 +-
pilot/model/model_adapter.py | 52 ++++++--
pilot/model/proxy/llms/chatgpt.py | 14 +-
pilot/server/dbgpt_server.py | 6 +-
pilot/utils/parameter_utils.py | 21 ++-
pilot/utils/utils.py | 41 ++++++
14 files changed, 234 insertions(+), 85 deletions(-)
diff --git a/.env.template b/.env.template
index 16f4e3a6e..e03650033 100644
--- a/.env.template
+++ b/.env.template
@@ -23,6 +23,15 @@ WEB_SERVER_PORT=7860
#*******************************************************************#
# LLM_MODEL, see /pilot/configs/model_config.LLM_MODEL_CONFIG
LLM_MODEL=vicuna-13b-v1.5
+## LLM model path, by default, DB-GPT will read the model path from LLM_MODEL_CONFIG based on the LLM_MODEL.
+## Of course you can specify your model path according to LLM_MODEL_PATH
+## In DB-GPT, the priority from high to low to read model path:
+## 1. environment variable with key: {LLM_MODEL}_MODEL_PATH (Avoid multi-model conflicts)
+## 2. environment variable with key: MODEL_PATH
+## 3. environment variable with key: LLM_MODEL_PATH
+## 4. the config in /pilot/configs/model_config.LLM_MODEL_CONFIG
+# LLM_MODEL_PATH=/app/models/vicuna-13b-v1.5
+# LLM_PROMPT_TEMPLATE=vicuna_v1.1
MODEL_SERVER=http://127.0.0.1:8000
LIMIT_MODEL_CONCURRENCY=5
MAX_POSITION_EMBEDDINGS=4096
diff --git a/docs/getting_started/faq/llm/llm_faq.md b/docs/getting_started/faq/llm/llm_faq.md
index 7b4409d1f..53b8cf279 100644
--- a/docs/getting_started/faq/llm/llm_faq.md
+++ b/docs/getting_started/faq/llm/llm_faq.md
@@ -1,6 +1,6 @@
LLM USE FAQ
==================================
-##### Q1:how to use openai chatgpt service
+##### Q1: how to use openai chatgpt service
change your LLM_MODEL in `.env`
````shell
LLM_MODEL=proxyllm
@@ -15,7 +15,7 @@ PROXY_SERVER_URL=https://api.openai.com/v1/chat/completions
make sure your openapi API_KEY is available
-##### Q2 What difference between `python dbgpt_server --light` and `python dbgpt_server`
+##### Q2: What difference between `python dbgpt_server --light` and `python dbgpt_server`
```{note}
* `python dbgpt_server --light` dbgpt_server does not start the llm service. Users can deploy the llm service separately by using `python llmserver`, and dbgpt_server accesses the llm service through set the LLM_SERVER environment variable in .env. The purpose is to allow for the separate deployment of dbgpt's backend service and llm service.
@@ -35,7 +35,7 @@ python pilot/server/dbgpt_server.py --light
```
-##### Q3 How to use MultiGPUs
+##### Q3: How to use MultiGPUs
DB-GPT will use all available gpu by default. And you can modify the setting `CUDA_VISIBLE_DEVICES=0,1` in `.env` file
to use the specific gpu IDs.
@@ -52,7 +52,7 @@ CUDA_VISIBLE_DEVICES=3,4,5,6 python3 pilot/server/dbgpt_server.py
You can modify the setting `MAX_GPU_MEMORY=xxGib` in `.env` file to configure the maximum memory used by each GPU.
-##### Q4 Not Enough Memory
+##### Q4: Not Enough Memory
DB-GPT supported 8-bit quantization and 4-bit quantization.
@@ -60,9 +60,9 @@ You can modify the setting `QUANTIZE_8bit=True` or `QUANTIZE_4bit=True` in `.env
Llama-2-70b with 8-bit quantization can run with 80 GB of VRAM, and 4-bit quantization can run with 48 GB of VRAM.
-Note: you need to install the latest dependencies according to [requirements.txt](https://github.com/eosphoros-ai/DB-GPT/blob/main/requirements.txt).
+Note: you need to install the quantization dependencies with `pip install -e ".[quantization]"`
-##### Q5 How to Add LLM Service dynamic local mode
+##### Q5: How to Add LLM Service dynamic local mode
Now DB-GPT through multi-llm service switch, so how to add llm service dynamic,
@@ -75,7 +75,7 @@ eg: dbgpt model start --model_name chatglm2-6b --model_path /root/DB-GPT/models/
chatgpt
eg: dbgpt model start --model_name chatgpt_proxyllm --model_path chatgpt_proxyllm --proxy_api_key ${OPENAI_KEY} --proxy_server_url {OPENAI_URL}
```
-##### Q6 How to Add LLM Service dynamic in remote mode
+##### Q6: How to Add LLM Service dynamic in remote mode
If you deploy llm service in remote machine instance, and you want to add model service to dbgpt server to manage
use dbgpt start worker and set --controller_addr.
@@ -88,13 +88,13 @@ eg: dbgpt start worker --model_name vicuna-13b-v1.5 \
```
-##### Q7 dbgpt command not found
+##### Q7: dbgpt command not found
```commandline
pip install -e "pip install -e ".[default]"
```
-##### Q8 When starting the worker_manager on a cloud server and registering it with the controller, it is noticed that the worker's exposed IP is a private IP instead of a public IP, which leads to the inability to access the service.
+##### Q8: When starting the worker_manager on a cloud server and registering it with the controller, it is noticed that the worker's exposed IP is a private IP instead of a public IP, which leads to the inability to access the service.
```commandline
@@ -103,4 +103,14 @@ pip install -e "pip install -e ".[default]"
automatically determined
```
+##### Q9: How to customize model path and prompt template
+
+DB-GPT will read the model path from `pilot.configs.model_config.LLM_MODEL_CONFIG` based on the `LLM_MODEL`.
+Of course, you can use the environment variable `LLM_MODEL_PATH` to specify the model path and `LLM_PROMPT_TEMPLATE` to specify your model prompt template.
+
+```
+LLM_MODEL=vicuna-13b-v1.5
+LLM_MODEL_PATH=/app/models/vicuna-13b-v1.5
+# LLM_PROMPT_TEMPLATE=vicuna_v1.1
+```
diff --git a/docs/locales/zh_CN/LC_MESSAGES/getting_started/faq/llm/llm_faq.po b/docs/locales/zh_CN/LC_MESSAGES/getting_started/faq/llm/llm_faq.po
index c0791b7cb..590cccfba 100644
--- a/docs/locales/zh_CN/LC_MESSAGES/getting_started/faq/llm/llm_faq.po
+++ b/docs/locales/zh_CN/LC_MESSAGES/getting_started/faq/llm/llm_faq.po
@@ -8,7 +8,7 @@ msgid ""
msgstr ""
"Project-Id-Version: DB-GPT 👏👏 0.3.5\n"
"Report-Msgid-Bugs-To: \n"
-"POT-Creation-Date: 2023-10-20 22:29+0800\n"
+"POT-Creation-Date: 2023-10-30 11:37+0800\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME \n"
"Language: zh_CN\n"
@@ -19,34 +19,36 @@ msgstr ""
"Content-Transfer-Encoding: 8bit\n"
"Generated-By: Babel 2.12.1\n"
-#: ../../getting_started/faq/llm/llm_faq.md:1 54763acec7da4deb90669195c54ec3a1
+#: ../../getting_started/faq/llm/llm_faq.md:1 98e23f85313c45169ff2ba7f80193356
msgid "LLM USE FAQ"
msgstr "LLM模型使用FAQ"
-#: ../../getting_started/faq/llm/llm_faq.md:3 66f73fd2ee7b462e92d3f263792a5e33
-msgid "Q1:how to use openai chatgpt service"
+#: ../../getting_started/faq/llm/llm_faq.md:3 0d49acfb4af947cb969b249346b00d33
+#, fuzzy
+msgid "Q1: how to use openai chatgpt service"
msgstr "我怎么使用OPENAI服务"
-#: ../../getting_started/faq/llm/llm_faq.md:4 9d178d8462b74cb188bbacf2ac2ac12b
+#: ../../getting_started/faq/llm/llm_faq.md:4 7010fec33e264987a29de86c54da93e8
#, fuzzy
msgid "change your LLM_MODEL in `.env`"
msgstr "通过在.env文件设置LLM_MODEL"
-#: ../../getting_started/faq/llm/llm_faq.md:9 f7ca82f257be4ac09639a7f8af5e83eb
+#: ../../getting_started/faq/llm/llm_faq.md:9 0982d6d5d0b3434fb00698aaf675f3f3
msgid "set your OPENAPI KEY"
msgstr "set your OPENAPI KEY"
-#: ../../getting_started/faq/llm/llm_faq.md:16 d6255b20dce34a2690df7e2af3505d97
+#: ../../getting_started/faq/llm/llm_faq.md:16 63650494c1574de09c007e1d470dd53d
msgid "make sure your openapi API_KEY is available"
msgstr "确认openapi API_KEY是否可用"
-#: ../../getting_started/faq/llm/llm_faq.md:18 6f1c6dbdb31f4210a6d21f0f3a6ae589
+#: ../../getting_started/faq/llm/llm_faq.md:18 5721ec71e344499d96c55b7e531d7c08
+#, fuzzy
msgid ""
-"Q2 What difference between `python dbgpt_server --light` and `python "
+"Q2: What difference between `python dbgpt_server --light` and `python "
"dbgpt_server`"
-msgstr "Q2 `python dbgpt_server --light` 和 `python dbgpt_server`的区别是什么?"
+msgstr "Q2: `python dbgpt_server --light` 和 `python dbgpt_server`的区别是什么?"
-#: ../../getting_started/faq/llm/llm_faq.md:20 b839771ae9e34e998b0edf8d69deabdd
+#: ../../getting_started/faq/llm/llm_faq.md:20 76a650f195dd40b6a3a3564030cdc040
msgid ""
"`python dbgpt_server --light` dbgpt_server does not start the llm "
"service. Users can deploy the llm service separately by using `python "
@@ -58,75 +60,75 @@ msgstr ""
"用户可以通过`python "
"llmserver`单独部署模型服务,dbgpt_server通过LLM_SERVER环境变量来访问模型服务。目的是为了可以将dbgpt后台服务和大模型服务分离部署。"
-#: ../../getting_started/faq/llm/llm_faq.md:22 aba39cef6fe84799bcd03e8f36c41296
+#: ../../getting_started/faq/llm/llm_faq.md:22 8cd87e3504784d9e891e1fb96c79e143
msgid ""
"`python dbgpt_server` dbgpt_server service and the llm service are "
"deployed on the same instance. when dbgpt_server starts the service, it "
"also starts the llm service at the same time."
msgstr "`python dbgpt_server` 是将后台服务和模型服务部署在同一台实例上.dbgpt_server在启动服务的时候同时开启模型服务."
-#: ../../getting_started/faq/llm/llm_faq.md:27 c65270d479af49e28e99b35a7932adbd
+#: ../../getting_started/faq/llm/llm_faq.md:27 58a6eaf57e6d425685f67058b1a642d4
msgid ""
"If you want to access an external LLM service(deployed by DB-GPT), you "
"need to"
msgstr "如果模型服务部署(通过DB-GPT部署)在别的机器,想通过dbgpt服务访问模型服务"
-#: ../../getting_started/faq/llm/llm_faq.md:29 da153e6d18c543f28e0c4e85618e3d3d
+#: ../../getting_started/faq/llm/llm_faq.md:29 67ac8823ca2e49ba9c833368e2cfb53c
msgid ""
"1.set the variables LLM_MODEL=YOUR_MODEL_NAME, "
"MODEL_SERVER=YOUR_MODEL_SERVER(eg:http://localhost:5000) in the .env "
"file."
msgstr ""
-#: ../../getting_started/faq/llm/llm_faq.md:31 cd89b8a2075f4407b8036a74151a6377
+#: ../../getting_started/faq/llm/llm_faq.md:31 e5c066bcdf0649a1b33bbfc7fd3b1a66
msgid "2.execute dbgpt_server.py in light mode"
msgstr "2.execute dbgpt_server.py light 模式"
-#: ../../getting_started/faq/llm/llm_faq.md:33 8f4b9401ac4f4a25a7479bee9ef5e8c1
+#: ../../getting_started/faq/llm/llm_faq.md:33 402ff01d7ee94d97be4a0eb964e39b97
msgid "python pilot/server/dbgpt_server.py --light"
msgstr ""
-#: ../../getting_started/faq/llm/llm_faq.md:38 69e1064cd7554ce6b49da732f800eacc
+#: ../../getting_started/faq/llm/llm_faq.md:38 86190c689d8f4d9a9b58d904e0b5867b
#, fuzzy
-msgid "Q3 How to use MultiGPUs"
-msgstr "Q2 怎么使用 MultiGPUs"
+msgid "Q3: How to use MultiGPUs"
+msgstr "Q3: 怎么使用 MultiGPUs"
-#: ../../getting_started/faq/llm/llm_faq.md:40 6de3f105ce96430db5756f38bbd9ca12
+#: ../../getting_started/faq/llm/llm_faq.md:40 6b08cff88750440b98956203d8b8a084
msgid ""
"DB-GPT will use all available gpu by default. And you can modify the "
"setting `CUDA_VISIBLE_DEVICES=0,1` in `.env` file to use the specific gpu"
" IDs."
msgstr "DB-GPT默认加载可利用的gpu,你也可以通过修改 在`.env`文件 `CUDA_VISIBLE_DEVICES=0,1`来指定gpu IDs"
-#: ../../getting_started/faq/llm/llm_faq.md:43 87cb9bfb20af4b259d719df797c42a7d
+#: ../../getting_started/faq/llm/llm_faq.md:43 93b39089e5be4475b9e90e7813f5a7d9
msgid ""
"Optionally, you can also specify the gpu ID to use before the starting "
"command, as shown below:"
msgstr "你也可以指定gpu ID启动"
-#: ../../getting_started/faq/llm/llm_faq.md:53 bcfa35cda6304ee5ab9a775a2d4eda63
+#: ../../getting_started/faq/llm/llm_faq.md:53 62e3074c109d401fa4bf1ddbdc6c7be1
msgid ""
"You can modify the setting `MAX_GPU_MEMORY=xxGib` in `.env` file to "
"configure the maximum memory used by each GPU."
msgstr "同时你可以通过在.env文件设置`MAX_GPU_MEMORY=xxGib`修改每个GPU的最大使用内存"
-#: ../../getting_started/faq/llm/llm_faq.md:55 a05c5484927844c8bb4791f0a9ccc82e
+#: ../../getting_started/faq/llm/llm_faq.md:55 d235bd83545c476f8e12572658d1c723
#, fuzzy
-msgid "Q4 Not Enough Memory"
-msgstr "Q3 机器显存不够 "
+msgid "Q4: Not Enough Memory"
+msgstr "Q4: 机器显存不够 "
-#: ../../getting_started/faq/llm/llm_faq.md:57 fe17a023b6eb4a92b1b927e1b94e3784
+#: ../../getting_started/faq/llm/llm_faq.md:57 b3243ed9147f42bba987d7f9b778e66f
msgid "DB-GPT supported 8-bit quantization and 4-bit quantization."
msgstr "DB-GPT 支持 8-bit quantization 和 4-bit quantization."
-#: ../../getting_started/faq/llm/llm_faq.md:59 76c3684c10864b8e87e5c2255b6c0b7f
+#: ../../getting_started/faq/llm/llm_faq.md:59 1ddb9f94ab994bfebfee46d1c19888d4
msgid ""
"You can modify the setting `QUANTIZE_8bit=True` or `QUANTIZE_4bit=True` "
"in `.env` file to use quantization(8-bit quantization is enabled by "
"default)."
msgstr "你可以通过在.env文件设置`QUANTIZE_8bit=True` or `QUANTIZE_4bit=True`"
-#: ../../getting_started/faq/llm/llm_faq.md:61 c5d849a38f1a4f0687bbcffb6699dc39
+#: ../../getting_started/faq/llm/llm_faq.md:61 54b85daa3fb24b17b67a6da31d2be8b0
msgid ""
"Llama-2-70b with 8-bit quantization can run with 80 GB of VRAM, and 4-bit"
" quantization can run with 48 GB of VRAM."
@@ -134,49 +136,77 @@ msgstr ""
"Llama-2-70b with 8-bit quantization 可以运行在 80 GB VRAM机器, 4-bit "
"quantization可以运行在 48 GB VRAM"
-#: ../../getting_started/faq/llm/llm_faq.md:63 867329a5e3b0403083e96f72b8747fb2
+#: ../../getting_started/faq/llm/llm_faq.md:63 097d680aed184fee9eceebee55a47ac1
msgid ""
-"Note: you need to install the latest dependencies according to "
-"[requirements.txt](https://github.com/eosphoros-ai/DB-"
-"GPT/blob/main/requirements.txt)."
+"Note: you need to install the quantization dependencies with `pip install"
+" -e \".[quantization]\"`"
msgstr ""
-#: ../../getting_started/faq/llm/llm_faq.md:65 60ceee25e9fb4ddba40c5306bfb0a82f
+#: ../../getting_started/faq/llm/llm_faq.md:65 f3a51056043c49eb84471040f2b364aa
#, fuzzy
-msgid "Q5 How to Add LLM Service dynamic local mode"
-msgstr "Q5 怎样动态新增模型服务"
+msgid "Q5: How to Add LLM Service dynamic local mode"
+msgstr "Q5: 怎样动态新增模型服务"
-#: ../../getting_started/faq/llm/llm_faq.md:67 c99eb7f7ae844884a8f0da94238ea7e0
+#: ../../getting_started/faq/llm/llm_faq.md:67 43ee6b0f23814c94a4ddb2429801a5e1
msgid ""
"Now DB-GPT through multi-llm service switch, so how to add llm service "
"dynamic,"
msgstr "DB-GPT支持多个模型服务切换, 怎样添加一个模型服务呢"
-#: ../../getting_started/faq/llm/llm_faq.md:78 cd89b8a2075f4407b8036a74151a6377
+#: ../../getting_started/faq/llm/llm_faq.md:78 c217bbf0d2b6425fa7a1c691b7704a8d
#, fuzzy
-msgid "Q6 How to Add LLM Service dynamic in remote mode"
-msgstr "Q5 怎样动态新增模型服务"
+msgid "Q6: How to Add LLM Service dynamic in remote mode"
+msgstr "Q6: 怎样动态新增模型服务"
-#: ../../getting_started/faq/llm/llm_faq.md:79 8833ce89465848259b08ef0a4fa68d96
+#: ../../getting_started/faq/llm/llm_faq.md:79 195bdaa937a94c7aa0d8c6e1a5430d6e
msgid ""
"If you deploy llm service in remote machine instance, and you want to "
"add model service to dbgpt server to manage"
msgstr "如果你想在远程机器实例部署大模型服务并添加到本地dbgpt_server进行管理"
-#: ../../getting_started/faq/llm/llm_faq.md:81 992eb37e3cca48829636c15ba3ec2ee8
+#: ../../getting_started/faq/llm/llm_faq.md:81 c64098b838a94821963a1d16e56497ff
msgid "use dbgpt start worker and set --controller_addr."
msgstr "使用1`dbgpt start worker`命令并设置注册地址--controller_addr"
-#: ../../getting_started/faq/llm/llm_faq.md:91 0d06d7d6dd3d4780894ecd914c89b5a2
+#: ../../getting_started/faq/llm/llm_faq.md:91 cb12d5e9d9d24f14abc3ebea877a4b24
#, fuzzy
-msgid "Q7 dbgpt command not found"
-msgstr "Q6 dbgpt command not found"
+msgid "Q7: dbgpt command not found"
+msgstr "Q7: dbgpt command not found"
-#: ../../getting_started/faq/llm/llm_faq.md:97 5d9beed0d95a4503a43d0e025664273b
+#: ../../getting_started/faq/llm/llm_faq.md:97 f95cdccfa82d4b3eb2a23dd297131faa
+#, fuzzy
msgid ""
-"Q8 When starting the worker_manager on a cloud server and registering it "
-"with the controller, it is noticed that the worker's exposed IP is a "
+"Q8: When starting the worker_manager on a cloud server and registering it"
+" with the controller, it is noticed that the worker's exposed IP is a "
"private IP instead of a public IP, which leads to the inability to access"
" the service."
-msgstr "云服务器启动worker_manager注册到controller时,发现worker暴露的ip是私网ip, 没有以公网ip暴露,导致服务访问不到"
+msgstr ""
+"Q8: 云服务器启动worker_manager注册到controller时,发现worker暴露的ip是私网ip, "
+"没有以公网ip暴露,导致服务访问不到"
+
+#: ../../getting_started/faq/llm/llm_faq.md:106
+#: 739a2983f3484acf98e877dc12f4ccda
+msgid "Q9: How to customize model path and prompt template"
+msgstr "Q9: 如何自定义模型路径和 prompt 模板"
+
+#: ../../getting_started/faq/llm/llm_faq.md:108
+#: 8b82a33a311649c7850c30c00c987c72
+#, fuzzy
+msgid ""
+"DB-GPT will read the model path from "
+"`pilot.configs.model_config.LLM_MODEL_CONFIG` based on the `LLM_MODEL`. "
+"Of course, you can use the environment variable `LLM_MODEL_PATH` to "
+"specify the model path and `LLM_PROMPT_TEMPLATE` to specify your model "
+"prompt template."
+msgstr ""
+"DB-GPT 会根据 `LLM_MODEL` 从 `pilot.configs.model_config.LLM_MODEL_CONFIG` "
+"中读取模型路径。当然,你可以使用环境 `LLM_MODEL_PATH` 来指定模型路径,以及使用 `LLM_PROMPT_TEMPLATE` "
+"来指定模型的 prompt 模板。"
+
+#~ msgid ""
+#~ "Note: you need to install the "
+#~ "latest dependencies according to "
+#~ "[requirements.txt](https://github.com/eosphoros-ai/DB-"
+#~ "GPT/blob/main/requirements.txt)."
+#~ msgstr ""
diff --git a/pilot/configs/config.py b/pilot/configs/config.py
index 6213cedf2..b263b46c4 100644
--- a/pilot/configs/config.py
+++ b/pilot/configs/config.py
@@ -194,6 +194,8 @@ class Config(metaclass=Singleton):
### LLM Model Service Configuration
self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b-v1.5")
+ self.LLM_MODEL_PATH = os.getenv("LLM_MODEL_PATH")
+
### Proxy llm backend, this configuration is only valid when "LLM_MODEL=proxyllm"
### When we use the rest API provided by deployment frameworks like fastchat as a proxyllm, "PROXYLLM_BACKEND" is the model they actually deploy.
### We need to use "PROXYLLM_BACKEND" to load the prompt of the corresponding scene.
diff --git a/pilot/model/cluster/controller/controller.py b/pilot/model/cluster/controller/controller.py
index 35a91ee3c..173c8c019 100644
--- a/pilot/model/cluster/controller/controller.py
+++ b/pilot/model/cluster/controller/controller.py
@@ -13,7 +13,7 @@ from pilot.utils.api_utils import (
_api_remote as api_remote,
_sync_api_remote as sync_api_remote,
)
-from pilot.utils.utils import setup_logging
+from pilot.utils.utils import setup_logging, setup_http_service_logging
logger = logging.getLogger(__name__)
@@ -149,6 +149,7 @@ def initialize_controller(
else:
import uvicorn
+ setup_http_service_logging()
app = FastAPI()
app.include_router(router, prefix="/api", tags=["Model"])
uvicorn.run(app, host=host, port=port, log_level="info")
@@ -179,7 +180,8 @@ def run_model_controller():
parser = EnvArgumentParser()
env_prefix = "controller_"
controller_params: ModelControllerParameters = parser.parse_args_into_dataclass(
- ModelControllerParameters, env_prefix=env_prefix
+ ModelControllerParameters,
+ env_prefixes=[env_prefix],
)
setup_logging(
diff --git a/pilot/model/cluster/worker/default_worker.py b/pilot/model/cluster/worker/default_worker.py
index 04b47cbdb..5caa2ee7e 100644
--- a/pilot/model/cluster/worker/default_worker.py
+++ b/pilot/model/cluster/worker/default_worker.py
@@ -76,7 +76,7 @@ class DefaultModelWorker(ModelWorker):
model_type = self.llm_adapter.model_type()
model_params: ModelParameters = model_args.parse_args_into_dataclass(
param_cls,
- env_prefix=env_prefix,
+ env_prefixes=[env_prefix, "LLM_"],
command_args=command_args,
model_name=self.model_name,
model_path=self.model_path,
diff --git a/pilot/model/cluster/worker/embedding_worker.py b/pilot/model/cluster/worker/embedding_worker.py
index 62b799864..22c644034 100644
--- a/pilot/model/cluster/worker/embedding_worker.py
+++ b/pilot/model/cluster/worker/embedding_worker.py
@@ -106,7 +106,7 @@ def _parse_embedding_params(
env_prefix = EnvArgumentParser.get_env_prefix(model_name)
model_params: BaseEmbeddingModelParameters = model_args.parse_args_into_dataclass(
param_cls,
- env_prefix=env_prefix,
+ env_prefixes=[env_prefix],
command_args=command_args,
model_name=model_name,
model_path=model_path,
diff --git a/pilot/model/cluster/worker/manager.py b/pilot/model/cluster/worker/manager.py
index cc5ef97d6..a76fa6685 100644
--- a/pilot/model/cluster/worker/manager.py
+++ b/pilot/model/cluster/worker/manager.py
@@ -38,7 +38,7 @@ from pilot.utils.parameter_utils import (
_dict_to_command_args,
_get_dict_from_obj,
)
-from pilot.utils.utils import setup_logging
+from pilot.utils.utils import setup_logging, setup_http_service_logging
from pilot.utils.tracer import initialize_tracer, root_tracer, SpanType, SpanTypeRunName
from pilot.utils.system_utils import get_system_info
@@ -735,6 +735,8 @@ def _setup_fastapi(
):
if not app:
app = FastAPI()
+ setup_http_service_logging()
+
if worker_params.standalone:
from pilot.model.cluster.controller.controller import initialize_controller
from pilot.model.cluster.controller.controller import (
@@ -781,7 +783,7 @@ def _parse_worker_params(
env_prefix = EnvArgumentParser.get_env_prefix(model_name)
worker_params: ModelWorkerParameters = worker_args.parse_args_into_dataclass(
ModelWorkerParameters,
- env_prefix=env_prefix,
+ env_prefixes=[env_prefix],
model_name=model_name,
model_path=model_path,
**kwargs,
@@ -790,7 +792,7 @@ def _parse_worker_params(
# Read parameters agein with prefix of model name.
new_worker_params = worker_args.parse_args_into_dataclass(
ModelWorkerParameters,
- env_prefix=env_prefix,
+ env_prefixes=[env_prefix],
model_name=worker_params.model_name,
model_path=worker_params.model_path,
**kwargs,
diff --git a/pilot/model/loader.py b/pilot/model/loader.py
index 2f5f10c2d..b7cf57815 100644
--- a/pilot/model/loader.py
+++ b/pilot/model/loader.py
@@ -95,7 +95,7 @@ class ModelLoader:
env_prefix = env_prefix.replace("-", "_")
model_params = args_parser.parse_args_into_dataclass(
param_cls,
- env_prefix=env_prefix,
+ env_prefixes=[env_prefix],
device=self.device,
model_path=self.model_path,
model_name=self.model_name,
diff --git a/pilot/model/model_adapter.py b/pilot/model/model_adapter.py
index 3809729bc..1580e8863 100644
--- a/pilot/model/model_adapter.py
+++ b/pilot/model/model_adapter.py
@@ -445,17 +445,47 @@ class VLLMModelAdaperWrapper(LLMModelAdaper):
# Covering the configuration of fastcaht, we will regularly feedback the code here to fastchat.
# We also recommend that you modify it directly in the fastchat repository.
+
+# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L212
register_conv_template(
Conversation(
- name="internlm-chat",
- system_message="A chat between a curious <|User|> and an <|Bot|>. The <|Bot|> gives helpful, detailed, and polite answers to the <|User|>'s questions.\n\n",
- roles=("<|User|>", "<|Bot|>"),
- sep_style=SeparatorStyle.CHATINTERN,
- sep="",
- sep2="",
- stop_token_ids=[1, 103028],
- # TODO feedback stop_str to fastchat
- stop_str="",
- ),
- override=True,
+ name="aquila-legacy",
+ system_message="A chat between a curious human and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
+ roles=("### Human: ", "### Assistant: ", "System"),
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.NO_COLON_TWO,
+ sep="\n",
+ sep2="",
+ stop_str=["", "[UNK]"],
+ )
+)
+# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L227
+register_conv_template(
+ Conversation(
+ name="aquila",
+ system_message="A chat between a curious human and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
+ roles=("Human", "Assistant", "System"),
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.ADD_COLON_TWO,
+ sep="###",
+ sep2="",
+ stop_str=["", "[UNK]"],
+ )
+)
+# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L242
+register_conv_template(
+ Conversation(
+ name="aquila-v1",
+ roles=("<|startofpiece|>", "<|endofpiece|>", ""),
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.NO_COLON_TWO,
+ sep="",
+ sep2="",
+ stop_str=["", "<|endoftext|>"],
+ )
)
diff --git a/pilot/model/proxy/llms/chatgpt.py b/pilot/model/proxy/llms/chatgpt.py
index a2aff0b86..ab95e58f6 100644
--- a/pilot/model/proxy/llms/chatgpt.py
+++ b/pilot/model/proxy/llms/chatgpt.py
@@ -5,8 +5,6 @@ import os
from typing import List
import logging
-import openai
-
from pilot.model.proxy.llms.proxy_model import ProxyModel
from pilot.model.parameter import ProxyModelParameters
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
@@ -15,6 +13,14 @@ logger = logging.getLogger(__name__)
def _initialize_openai(params: ProxyModelParameters):
+ try:
+ import openai
+ except ImportError as exc:
+ raise ValueError(
+ "Could not import python package: openai "
+ "Please install openai by command `pip install openai` "
+ ) from exc
+
api_type = params.proxy_api_type or os.getenv("OPENAI_API_TYPE", "open_ai")
api_base = params.proxy_api_base or os.getenv(
@@ -106,6 +112,8 @@ def _build_request(model: ProxyModel, params):
def chatgpt_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
+ import openai
+
history, payloads = _build_request(model, params)
res = openai.ChatCompletion.create(messages=history, **payloads)
@@ -121,6 +129,8 @@ def chatgpt_generate_stream(
async def async_chatgpt_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
+ import openai
+
history, payloads = _build_request(model, params)
res = await openai.ChatCompletion.acreate(messages=history, **payloads)
diff --git a/pilot/server/dbgpt_server.py b/pilot/server/dbgpt_server.py
index c6e084a93..e94526b9a 100644
--- a/pilot/server/dbgpt_server.py
+++ b/pilot/server/dbgpt_server.py
@@ -2,6 +2,7 @@ import os
import argparse
import sys
from typing import List
+import logging
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH)
@@ -39,6 +40,7 @@ from pilot.utils.utils import (
setup_logging,
_get_logging_level,
logging_str_to_uvicorn_level,
+ setup_http_service_logging,
)
from pilot.utils.tracer import root_tracer, initialize_tracer, SpanType, SpanTypeRunName
from pilot.utils.parameter_utils import _get_dict_from_obj
@@ -127,6 +129,7 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
setup_logging(
"pilot", logging_level=param.log_level, logger_filename=param.log_file
)
+
# Before start
system_app.before_start()
@@ -141,7 +144,7 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
model_name = param.model_name or CFG.LLM_MODEL
- model_path = LLM_MODEL_CONFIG.get(model_name)
+ model_path = CFG.LLM_MODEL_PATH or LLM_MODEL_CONFIG.get(model_name)
if not param.light:
print("Model Unified Deployment Mode!")
if not param.remote_embedding:
@@ -180,6 +183,7 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
def run_uvicorn(param: WebWerverParameters):
import uvicorn
+ setup_http_service_logging()
uvicorn.run(
app,
host=param.host,
diff --git a/pilot/utils/parameter_utils.py b/pilot/utils/parameter_utils.py
index fbf9c5fb5..e76904203 100644
--- a/pilot/utils/parameter_utils.py
+++ b/pilot/utils/parameter_utils.py
@@ -190,6 +190,17 @@ def _genenv_ignoring_key_case(env_key: str, env_prefix: str = None, default_valu
)
+def _genenv_ignoring_key_case_with_prefixes(
+ env_key: str, env_prefixes: List[str] = None, default_value=None
+) -> str:
+ if env_prefixes:
+ for env_prefix in env_prefixes:
+ env_var_value = _genenv_ignoring_key_case(env_key, env_prefix)
+ if env_var_value:
+ return env_var_value
+ return _genenv_ignoring_key_case(env_key, default_value=default_value)
+
+
class EnvArgumentParser:
@staticmethod
def get_env_prefix(env_key: str) -> str:
@@ -201,18 +212,16 @@ class EnvArgumentParser:
def parse_args_into_dataclass(
self,
dataclass_type: Type,
- env_prefix: str = None,
+ env_prefixes: List[str] = None,
command_args: List[str] = None,
**kwargs,
) -> Any:
"""Parse parameters from environment variables and command lines and populate them into data class"""
parser = argparse.ArgumentParser()
for field in fields(dataclass_type):
- env_var_value = _genenv_ignoring_key_case(field.name, env_prefix)
- if not env_var_value:
- # Read without env prefix
- env_var_value = _genenv_ignoring_key_case(field.name)
-
+ env_var_value = _genenv_ignoring_key_case_with_prefixes(
+ field.name, env_prefixes
+ )
if env_var_value:
env_var_value = env_var_value.strip()
if field.type is int or field.type == Optional[int]:
diff --git a/pilot/utils/utils.py b/pilot/utils/utils.py
index b72745a33..e3add3a0a 100644
--- a/pilot/utils/utils.py
+++ b/pilot/utils/utils.py
@@ -3,6 +3,8 @@
import logging
import logging.handlers
+from typing import Any, List
+
import os
import sys
import asyncio
@@ -184,3 +186,42 @@ def logging_str_to_uvicorn_level(log_level_str):
"NOTSET": "info",
}
return level_str_mapping.get(log_level_str.upper(), "info")
+
+
+class EndpointFilter(logging.Filter):
+ """Disable access log on certain endpoint
+
+ source: https://github.com/encode/starlette/issues/864#issuecomment-1254987630
+ """
+
+ def __init__(
+ self,
+ path: str,
+ *args: Any,
+ **kwargs: Any,
+ ):
+ super().__init__(*args, **kwargs)
+ self._path = path
+
+ def filter(self, record: logging.LogRecord) -> bool:
+ return record.getMessage().find(self._path) == -1
+
+
+def setup_http_service_logging(exclude_paths: List[str] = None):
+ """Setup http service logging
+
+ Now just disable some logs
+
+ Args:
+ exclude_paths (List[str]): The paths to disable log
+ """
+ if not exclude_paths:
+ # Not show heartbeat log
+ exclude_paths = ["/api/controller/heartbeat"]
+ uvicorn_logger = logging.getLogger("uvicorn.access")
+ if uvicorn_logger:
+ for path in exclude_paths:
+ uvicorn_logger.addFilter(EndpointFilter(path=path))
+ httpx_logger = logging.getLogger("httpx")
+ if httpx_logger:
+ httpx_logger.setLevel(logging.WARNING)
From b3d3716de74903f6a725f5915bca1629c0881c07 Mon Sep 17 00:00:00 2001
From: Aditya Aryaman Das <128703909+alienishi@users.noreply.github.com>
Date: Mon, 30 Oct 2023 13:33:21 +0530
Subject: [PATCH 12/57] docs: corrected all grammatical errors in README.md
---
README.md | 18 +++++++++---------
1 file changed, 9 insertions(+), 9 deletions(-)
diff --git a/README.md b/README.md
index c819c8a51..4a6878297 100644
--- a/README.md
+++ b/README.md
@@ -107,11 +107,11 @@ Currently, we have released multiple key features, which are listed below to dem
- Multi-Agents&Plugins
- Supports custom plug-ins to perform tasks, natively supports the Auto-GPT plug-in model, and the Agents protocol adopts the Agent Protocol standard
+ It supports custom plug-ins to perform tasks, natively supports the Auto-GPT plug-in model, and the Agents protocol adopts the Agent Protocol standard.
- Fine-tuning text2SQL
- An automated fine-tuning lightweight framework built around large language models, Text2SQL data sets, LoRA/QLoRA/Pturning and other fine-tuning methods, making TextSQL fine-tuning as convenient as an assembly line. [DB-GPT-Hub](https://github.com/eosphoros-ai/DB-GPT-Hub)
+ An automated fine-tuning lightweight framework built around large language models, Text2SQL data sets, LoRA/QLoRA/Pturning, and other fine-tuning methods, making TextSQL fine-tuning as convenient as an assembly line. [DB-GPT-Hub](https://github.com/eosphoros-ai/DB-GPT-Hub)
- Multi LLMs Support, Supports multiple large language models, currently supporting
@@ -141,7 +141,7 @@ Currently, we have released multiple key features, which are listed below to dem
- [Llama2-Chinese-13b-Chat](https://huggingface.co/FlagAlpha/Llama2-Chinese-13b-Chat)
- [OpenLLaMa OpenInstruct](https://huggingface.co/VMware/open-llama-7b-open-instruct)
- etc.
+ Etc.
- Support API Proxy LLMs
- [x] [ChatGPT](https://api.openai.com/)
@@ -151,7 +151,7 @@ Currently, we have released multiple key features, which are listed below to dem
- Privacy and security
- The privacy and security of data are ensured through various technologies such as privatized large models and proxy desensitization.
+ The privacy and security of data are ensured through various technologies, such as privatized large models and proxy desensitization.
- Support Datasources
@@ -185,7 +185,7 @@ Is the architecture of the entire DB-GPT shown in the following figure:
The core capabilities mainly consist of the following parts:
1. Multi-Models: Support multi-LLMs, such as LLaMA/LLaMA2、CodeLLaMA、ChatGLM, QWen、Vicuna and proxy model ChatGPT、Baichuan、tongyi、wenxin etc
-2. Knowledge Based QA: You can perform high-quality intelligent Q&A based on local documents such as pdf, word, excel and other data.
+2. Knowledge-Based QA: You can perform high-quality intelligent Q&A based on local documents such as PDF, word, excel, and other data.
3. Embedding: Unified data vector storage and indexing, Embed data as vectors and store them in vector databases, providing content similarity search.
4. Multi-Datasources: Used to connect different modules and data sources to achieve data flow and interaction.
5. Multi-Agents: Provides Agent and plugin mechanisms, allowing users to customize and enhance the system's behavior.
@@ -199,7 +199,7 @@ The core capabilities mainly consist of the following parts:
### SubModule
- [DB-GPT-Hub](https://github.com/eosphoros-ai/DB-GPT-Hub) Text-to-SQL performance by applying Supervised Fine-Tuning (SFT) on large language models.
-- [DB-GPT-Plugins](https://github.com/eosphoros-ai/DB-GPT-Plugins) DB-GPT Plugins, Can run autogpt plugin directly
+- [DB-GPT-Plugins](https://github.com/eosphoros-ai/DB-GPT-Plugins) DB-GPT Plugins Can run autogpt plugin directly
- [DB-GPT-Web](https://github.com/eosphoros-ai/DB-GPT-Web) ChatUI for DB-GPT
## Image
@@ -213,7 +213,7 @@ The core capabilities mainly consist of the following parts:
## Contribution
-- Please run `black .` before submitting the code. contributing guidelines, [how to contribution](https://github.com/csunny/DB-GPT/blob/main/CONTRIBUTING.md)
+- Please run `black .` before submitting the code. Contributing guidelines, [how to contribution](https://github.com/csunny/DB-GPT/blob/main/CONTRIBUTING.md)
## RoadMap
@@ -224,7 +224,7 @@ The core capabilities mainly consist of the following parts:
### KBQA RAG optimization
- [x] Multi Documents
- [x] PDF
- - [x] Excel, csv
+ - [x] Excel, CSV
- [x] Word
- [x] Text
- [x] MarkDown
@@ -235,7 +235,7 @@ The core capabilities mainly consist of the following parts:
- [ ] Graph Database
- [ ] Neo4j Graph
- [ ] Nebula Graph
-- [x] Multi Vector Database
+- [x] Multi-Vector Database
- [x] Chroma
- [x] Milvus
- [x] Weaviate
From 7619a16c163abd1a9fb52cfad097815d3d8b7364 Mon Sep 17 00:00:00 2001
From: wangzaistone
Date: Mon, 30 Oct 2023 17:06:24 +0800
Subject: [PATCH 13/57] support DB-GPT-Hub sft codellama
---
.env.template | 3 ++-
pilot/configs/model_config.py | 2 ++
pilot/model/adapter.py | 13 +++++++++++++
pilot/model/conversation.py | 22 ++++++++++++++++++++++
pilot/model/model_adapter.py | 1 +
pilot/server/chat_adapter.py | 10 ++++++++++
6 files changed, 50 insertions(+), 1 deletion(-)
diff --git a/.env.template b/.env.template
index e03650033..272ee2922 100644
--- a/.env.template
+++ b/.env.template
@@ -22,7 +22,8 @@ WEB_SERVER_PORT=7860
#** LLM MODELS **#
#*******************************************************************#
# LLM_MODEL, see /pilot/configs/model_config.LLM_MODEL_CONFIG
-LLM_MODEL=vicuna-13b-v1.5
+# LLM_MODEL=vicuna-13b-v1.5
+LLM_MODEL=codellama-13b-sql-sft
## LLM model path, by default, DB-GPT will read the model path from LLM_MODEL_CONFIG based on the LLM_MODEL.
## Of course you can specify your model path according to LLM_MODEL_PATH
## In DB-GPT, the priority from high to low to read model path:
diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py
index e1575ea03..16deee50a 100644
--- a/pilot/configs/model_config.py
+++ b/pilot/configs/model_config.py
@@ -78,6 +78,8 @@ LLM_MODEL_CONFIG = {
"internlm-7b": os.path.join(MODEL_PATH, "internlm-chat-7b"),
"internlm-7b-8k": os.path.join(MODEL_PATH, "internlm-chat-7b-8k"),
"internlm-20b": os.path.join(MODEL_PATH, "internlm-chat-20b"),
+ "codellama-13b-sql-sft": os.path.join(MODEL_PATH, "codellama-13b-sql-sft"),
+
# For test now
"opt-125m": os.path.join(MODEL_PATH, "opt-125m"),
}
diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py
index 69b159a13..02fbe8aa9 100644
--- a/pilot/model/adapter.py
+++ b/pilot/model/adapter.py
@@ -319,6 +319,18 @@ class Llama2Adapter(BaseLLMAdaper):
model.config.pad_token_id = tokenizer.pad_token_id
return model, tokenizer
+class CodeLlamaAdapter(BaseLLMAdaper):
+ """The model adapter for codellama """
+
+ def match(self, model_path: str):
+ return "codelama" in model_path.lower()
+
+ def loader(self, model_path: str, from_pretrained_kwargs: dict):
+ model, tokenizer = super().loader(model_path, from_pretrained_kwargs)
+ model.config.eos_token_id = tokenizer.eos_token_id
+ model.config.pad_token_id = tokenizer.pad_token_id
+ return model, tokenizer
+
class BaichuanAdapter(BaseLLMAdaper):
"""The model adapter for Baichuan models (e.g., baichuan-inc/Baichuan-13B-Chat)"""
@@ -420,6 +432,7 @@ register_llm_model_adapters(FalconAdapater)
register_llm_model_adapters(GorillaAdapter)
register_llm_model_adapters(GPT4AllAdapter)
register_llm_model_adapters(Llama2Adapter)
+register_llm_model_adapters(CodeLlamaAdapter)
register_llm_model_adapters(BaichuanAdapter)
register_llm_model_adapters(WizardLMAdapter)
register_llm_model_adapters(LlamaCppAdapater)
diff --git a/pilot/model/conversation.py b/pilot/model/conversation.py
index b3674e946..98dfc720d 100644
--- a/pilot/model/conversation.py
+++ b/pilot/model/conversation.py
@@ -339,6 +339,28 @@ register_conv_template(
)
)
+
+# codellama template
+# reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212
+# reference2 : https://github.com/eosphoros-ai/DB-GPT-Hub/blob/main/README.zh.md
+register_conv_template(
+ Conversation(
+ name="codellama",
+ system="[INST] <>\nI want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request."
+ "If you don't know the answer to the request, please don't share false information.\n<>\n\n",
+ roles=("[INST]", "[/INST]"),
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.LLAMA2,
+ sep=" ",
+ sep2=" ",
+ stop_token_ids=[2],
+ system_formatter=lambda msg: f"[INST] <>\n{msg}\n<>\n\n",
+ )
+)
+
+
+
# Alpaca default template
register_conv_template(
Conversation(
diff --git a/pilot/model/model_adapter.py b/pilot/model/model_adapter.py
index 1580e8863..112fb468a 100644
--- a/pilot/model/model_adapter.py
+++ b/pilot/model/model_adapter.py
@@ -45,6 +45,7 @@ _OLD_MODELS = [
"llama-cpp",
"proxyllm",
"gptj-6b",
+ "codellama-13b-sql-sft"
]
diff --git a/pilot/server/chat_adapter.py b/pilot/server/chat_adapter.py
index cb486021b..509305247 100644
--- a/pilot/server/chat_adapter.py
+++ b/pilot/server/chat_adapter.py
@@ -213,6 +213,15 @@ class Llama2ChatAdapter(BaseChatAdpter):
def get_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("llama-2")
+
+
+class CodeLlamaChatAdapter(BaseChatAdpter):
+ """The model ChatAdapter for codellama ."""
+ def match(self, model_path: str):
+ return "codelama" in model_path.lower()
+
+ def get_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("codellama")
class BaichuanChatAdapter(BaseChatAdpter):
@@ -268,6 +277,7 @@ register_llm_model_chat_adapter(FalconChatAdapter)
register_llm_model_chat_adapter(GorillaChatAdapter)
register_llm_model_chat_adapter(GPT4AllChatAdapter)
register_llm_model_chat_adapter(Llama2ChatAdapter)
+register_llm_model_chat_adapter(CodeLlamaChatAdapter)
register_llm_model_chat_adapter(BaichuanChatAdapter)
register_llm_model_chat_adapter(WizardLMChatAdapter)
register_llm_model_chat_adapter(LlamaCppChatAdapter)
From 53b1fc40901cb59ebcdc93caaa852c942ea5f858 Mon Sep 17 00:00:00 2001
From: aries_ckt <916701291@qq.com>
Date: Mon, 30 Oct 2023 19:06:09 +0800
Subject: [PATCH 14/57] feat:document summary
---
pilot/scene/base.py | 7 ++
pilot/scene/base_chat.py | 2 -
pilot/scene/chat_factory.py | 1 +
pilot/server/knowledge/document_db.py | 3 +-
pilot/server/knowledge/request/response.py | 3 +-
pilot/server/knowledge/service.py | 89 ++++++++++++++++------
6 files changed, 79 insertions(+), 26 deletions(-)
diff --git a/pilot/scene/base.py b/pilot/scene/base.py
index 5c98003d9..e3478f7c3 100644
--- a/pilot/scene/base.py
+++ b/pilot/scene/base.py
@@ -96,6 +96,13 @@ class ChatScene(Enum):
["Extract Select"],
True,
)
+ ExtractRefineSummary = Scene(
+ "extract_refine_summary",
+ "Extract Summary",
+ "Extract Summary",
+ ["Extract Select"],
+ True,
+ )
ExtractEntity = Scene(
"extract_entity", "Extract Entity", "Extract Entity", ["Extract Select"], True
)
diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py
index 73a2c5ef6..e43cdc812 100644
--- a/pilot/scene/base_chat.py
+++ b/pilot/scene/base_chat.py
@@ -127,8 +127,6 @@ class BaseChat(ABC):
speak_to_user = prompt_define_response
return speak_to_user
- async def __call_base(self):
- input_values = await self.generate_input_values()
async def __call_base(self):
import inspect
diff --git a/pilot/scene/chat_factory.py b/pilot/scene/chat_factory.py
index 2e103f15d..10a588c04 100644
--- a/pilot/scene/chat_factory.py
+++ b/pilot/scene/chat_factory.py
@@ -17,6 +17,7 @@ class ChatFactory(metaclass=Singleton):
from pilot.scene.chat_knowledge.extract_triplet.chat import ExtractTriplet
from pilot.scene.chat_knowledge.extract_entity.chat import ExtractEntity
from pilot.scene.chat_knowledge.summary.chat import ExtractSummary
+ from pilot.scene.chat_knowledge.refine_summary.chat import ExtractRefineSummary
from pilot.scene.chat_data.chat_excel.excel_analyze.chat import ChatExcel
from pilot.scene.chat_agent.chat import ChatAgent
diff --git a/pilot/server/knowledge/document_db.py b/pilot/server/knowledge/document_db.py
index 3e6dfb0c4..bbe1426d7 100644
--- a/pilot/server/knowledge/document_db.py
+++ b/pilot/server/knowledge/document_db.py
@@ -30,11 +30,12 @@ class KnowledgeDocumentEntity(Base):
content = Column(Text)
result = Column(Text)
vector_ids = Column(Text)
+ summary = Column(Text)
gmt_created = Column(DateTime)
gmt_modified = Column(DateTime)
def __repr__(self):
- return f"KnowledgeDocumentEntity(id={self.id}, doc_name='{self.doc_name}', doc_type='{self.doc_type}', chunk_size='{self.chunk_size}', status='{self.status}', last_sync='{self.last_sync}', content='{self.content}', result='{self.result}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
+ return f"KnowledgeDocumentEntity(id={self.id}, doc_name='{self.doc_name}', doc_type='{self.doc_type}', chunk_size='{self.chunk_size}', status='{self.status}', last_sync='{self.last_sync}', content='{self.content}', result='{self.result}', summary='{self.summary}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
class KnowledgeDocumentDao(BaseDao):
diff --git a/pilot/server/knowledge/request/response.py b/pilot/server/knowledge/request/response.py
index fb7aa55e9..2e3e5f0ab 100644
--- a/pilot/server/knowledge/request/response.py
+++ b/pilot/server/knowledge/request/response.py
@@ -5,8 +5,9 @@ from pydantic import BaseModel
class ChunkQueryResponse(BaseModel):
"""data: data"""
-
data: List = None
+ """summary: document summary"""
+ summary: str = None
"""total: total size"""
total: int = None
"""page: current page"""
diff --git a/pilot/server/knowledge/service.py b/pilot/server/knowledge/service.py
index 4c1c41994..017fef3ec 100644
--- a/pilot/server/knowledge/service.py
+++ b/pilot/server/knowledge/service.py
@@ -288,8 +288,8 @@ class KnowledgeService:
executor = CFG.SYSTEM_APP.get_component(
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
).create()
- executor.submit(self.async_knowledge_graph, chunk_docs, doc)
- # executor.submit(self.async_doc_embedding, client, chunk_docs, doc)
+ executor.submit(self.async_document_summary, chunk_docs, doc)
+ executor.submit(self.async_doc_embedding, client, chunk_docs, doc)
logger.info(f"begin save document chunks, doc:{doc.doc_name}")
# save chunk details
chunk_entities = [
@@ -384,38 +384,59 @@ class KnowledgeService:
doc_name=request.doc_name,
doc_type=request.doc_type,
)
+ document_query = KnowledgeDocumentEntity(id=request.document_id)
+ documents = knowledge_document_dao.get_documents(document_query)
+
res = ChunkQueryResponse()
res.data = document_chunk_dao.get_document_chunks(
query, page=request.page, page_size=request.page_size
)
+ res.summary = documents[0].summary
res.total = document_chunk_dao.get_document_chunks_count(query)
res.page = request.page
return res
+
def async_knowledge_graph(self, chunk_docs, doc):
"""async document extract triplets and save into graph db
Args:
- chunk_docs: List[Document]
- doc: KnowledgeDocumentEntity
"""
- for doc in chunk_docs:
- text = doc.page_content
- self._llm_extract_summary(text)
logger.info(
f"async_knowledge_graph, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to graph store"
)
- # try:
- # from pilot.graph_engine.graph_factory import RAGGraphFactory
- #
- # rag_engine = CFG.SYSTEM_APP.get_component(
- # ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory
- # ).create()
- # rag_engine.knowledge_graph(chunk_docs)
- # doc.status = SyncStatus.FINISHED.name
- # doc.result = "document build graph success"
- # except Exception as e:
- # doc.status = SyncStatus.FAILED.name
- # doc.result = "document build graph failed" + str(e)
- # logger.error(f"document build graph failed:{doc.doc_name}, {str(e)}")
+ try:
+ from pilot.graph_engine.graph_factory import RAGGraphFactory
+
+ rag_engine = CFG.SYSTEM_APP.get_component(
+ ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory
+ ).create()
+ rag_engine.knowledge_graph(chunk_docs)
+ doc.status = SyncStatus.FINISHED.name
+ doc.result = "document build graph success"
+ except Exception as e:
+ doc.status = SyncStatus.FAILED.name
+ doc.result = "document build graph failed" + str(e)
+ logger.error(f"document build graph failed:{doc.doc_name}, {str(e)}")
+ return knowledge_document_dao.update_knowledge_document(doc)
+
+ def async_document_summary(self, chunk_docs, doc):
+ """async document extract summary
+ Args:
+ - chunk_docs: List[Document]
+ - doc: KnowledgeDocumentEntity
+ """
+ from llama_index import PromptHelper
+ from llama_index.prompts.default_prompt_selectors import DEFAULT_TREE_SUMMARIZE_PROMPT_SEL
+ texts = [doc.page_content for doc in chunk_docs]
+ prompt_helper = PromptHelper()
+ texts = prompt_helper.repack(prompt=DEFAULT_TREE_SUMMARIZE_PROMPT_SEL, text_chunks=texts)
+ summary = self._llm_extract_summary(chunk_docs[0])
+ outputs, summary = self._refine_extract_summary(texts[1:], summary)
+ logger.info(
+ f"async_document_summary, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to graph store"
+ )
+ doc.summary = summary
return knowledge_document_dao.update_knowledge_document(doc)
@@ -491,15 +512,39 @@ class KnowledgeService:
chat_param = {
"chat_session_id": uuid.uuid1(),
- "current_user_input": doc,
- "select_param": "summery",
+ "current_user_input": doc.page_content,
+ "select_param": "summary",
"model_name": "proxyllm",
}
from pilot.utils import utils
loop = utils.get_or_create_event_loop()
- triplets = loop.run_until_complete(
+ summary = loop.run_until_complete(
llm_chat_response_nostream(
ChatScene.ExtractSummary.value(), **{"chat_param": chat_param}
)
)
- return triplets
+ return summary
+ def _refine_extract_summary(self, docs, summary: str):
+ """Extract refine summary by llm"""
+ from pilot.scene.base import ChatScene
+ from pilot.common.chat_util import llm_chat_response_nostream
+ import uuid
+ outputs = []
+ for doc in docs:
+ chat_param = {
+ "chat_session_id": uuid.uuid1(),
+ "current_user_input": doc,
+ "select_param": summary,
+ "model_name": "proxyllm",
+ }
+ from pilot.utils import utils
+ loop = utils.get_or_create_event_loop()
+ summary = loop.run_until_complete(
+ llm_chat_response_nostream(
+ ChatScene.ExtractRefineSummary.value(), **{"chat_param": chat_param}
+ )
+ )
+ outputs.append(summary)
+ return outputs, summary
+
+
From d4d231afe742753f49af7b475eeb8d8b58b6e538 Mon Sep 17 00:00:00 2001
From: aries_ckt <916701291@qq.com>
Date: Mon, 30 Oct 2023 19:09:00 +0800
Subject: [PATCH 15/57] chore:discord expire
---
README.md | 8 ++++----
README.zh.md | 6 +++---
2 files changed, 7 insertions(+), 7 deletions(-)
diff --git a/README.md b/README.md
index c819c8a51..b73db828d 100644
--- a/README.md
+++ b/README.md
@@ -25,8 +25,8 @@
-
-
+
+
@@ -34,7 +34,7 @@
-[**简体中文**](README.zh.md) |[**Discord**](https://discord.gg/vqBrcV7Nd) |[**Documents**](https://db-gpt.readthedocs.io/en/latest/)|[**Wechat**](https://github.com/eosphoros-ai/DB-GPT/blob/main/README.zh.md#%E8%81%94%E7%B3%BB%E6%88%91%E4%BB%AC)|[**Community**](https://github.com/eosphoros-ai/community)
+[**简体中文**](README.zh.md) |[**Discord**](https://discord.gg/nASQyBjvY) |[**Documents**](https://db-gpt.readthedocs.io/en/latest/)|[**Wechat**](https://github.com/eosphoros-ai/DB-GPT/blob/main/README.zh.md#%E8%81%94%E7%B3%BB%E6%88%91%E4%BB%AC)|[**Community**](https://github.com/eosphoros-ai/community)
## What is DB-GPT?
@@ -331,7 +331,7 @@ The MIT License (MIT)
## Contact Information
We are working on building a community, if you have any ideas about building the community, feel free to contact us.
-[](https://discord.gg/vqBrcV7Nd)
+[](https://discord.gg/nASQyBjvY)
diff --git a/README.zh.md b/README.zh.md
index a6119663d..dca465541 100644
--- a/README.zh.md
+++ b/README.zh.md
@@ -22,15 +22,15 @@
-
-
+
+
-[**English**](README.md)|[**Discord**](https://discord.gg/vqBrcV7Nd)|[**文档**](https://db-gpt.readthedocs.io/projects/db-gpt-docs-zh-cn/zh_CN/latest/)|[**微信**](https://github.com/csunny/DB-GPT/blob/main/README.zh.md#%E8%81%94%E7%B3%BB%E6%88%91%E4%BB%AC)|[**社区**](https://github.com/eosphoros-ai/community)
+[**English**](README.md)|[**Discord**](https://discord.gg/nASQyBjvY)|[**文档**](https://db-gpt.readthedocs.io/projects/db-gpt-docs-zh-cn/zh_CN/latest/)|[**微信**](https://github.com/csunny/DB-GPT/blob/main/README.zh.md#%E8%81%94%E7%B3%BB%E6%88%91%E4%BB%AC)|[**社区**](https://github.com/eosphoros-ai/community)
## DB-GPT 是什么?
From 6841050d43c015c6355549b3834d7743b4275cdb Mon Sep 17 00:00:00 2001
From: hairyputtar <148847552+hairyputtar@users.noreply.github.com>
Date: Mon, 30 Oct 2023 22:11:32 +0530
Subject: [PATCH 16/57] fix typo
---
docs/getting_started/install/environment/environment.md | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/docs/getting_started/install/environment/environment.md b/docs/getting_started/install/environment/environment.md
index 4c8b23751..021bbe861 100644
--- a/docs/getting_started/install/environment/environment.md
+++ b/docs/getting_started/install/environment/environment.md
@@ -59,11 +59,11 @@ Embedding Chunk size, default 500
Embedding Chunk Overlap, default 100
* KNOWLEDGE_CHUNK_OVERLAP=100
-embeding recall top k,5
+embedding recall top k,5
* KNOWLEDGE_SEARCH_TOP_SIZE=5
-embeding recall max token ,2000
+embedding recall max token ,2000
* KNOWLEDGE_SEARCH_MAX_TOKEN=5
```
From e74e37f8187206466c9db028c5018f6e5954883e Mon Sep 17 00:00:00 2001
From: Abhishekgupta204 <116148980+Abhishekgupta204@users.noreply.github.com>
Date: Mon, 30 Oct 2023 23:16:52 +0530
Subject: [PATCH 17/57] Added CODE_OF_CONDUCT file
---
CODE_OF_CONDUCT | 126 ++++++++++++++++++++++++++++++++++++++++++++++++
1 file changed, 126 insertions(+)
create mode 100644 CODE_OF_CONDUCT
diff --git a/CODE_OF_CONDUCT b/CODE_OF_CONDUCT
new file mode 100644
index 000000000..b7efcc0b3
--- /dev/null
+++ b/CODE_OF_CONDUCT
@@ -0,0 +1,126 @@
+# Contributor Covenant Code of Conduct
+
+## Our Pledge
+
+We as members, contributors, and leaders pledge to make participation in our
+community a harassment-free experience for everyone, regardless of age, body
+size, visible or invisible disability, ethnicity, sex characteristics, gender
+identity and expression, level of experience, education, socio-economic status,
+nationality, personal appearance, race, caste, color, religion, or sexual
+identity and orientation.
+
+We pledge to act and interact in ways that contribute to an open, welcoming,
+diverse, inclusive, and healthy community.
+
+## Our Standards
+
+Examples of behavior that contributes to a positive environment for our
+community include:
+
+* Demonstrating empathy and kindness toward other people
+* Being respectful of differing opinions, viewpoints, and experiences
+* Giving and gracefully accepting constructive feedback
+* Accepting responsibility and apologizing to those affected by our mistakes,
+ and learning from the experience
+* Focusing on what is best not just for us as individuals, but for the overall
+ community
+
+Examples of unacceptable behavior include:
+
+* The use of sexualized language or imagery, and sexual attention or advances of
+ any kind
+* Trolling, insulting or derogatory comments, and personal or political attacks
+* Public or private harassment
+* Publishing others' private information, such as a physical or email address,
+ without their explicit permission
+* Other conduct which could reasonably be considered inappropriate in a
+ professional setting
+
+## Enforcement Responsibilities
+
+Community leaders are responsible for clarifying and enforcing our standards of
+acceptable behavior and will take appropriate and fair corrective action in
+response to any behavior that they deem inappropriate, threatening, offensive,
+or harmful.
+
+Community leaders have the right and responsibility to remove, edit, or reject
+comments, commits, code, wiki edits, issues, and other contributions that are
+not aligned to this Code of Conduct, and will communicate reasons for moderation
+decisions when appropriate.
+
+## Scope
+
+This Code of Conduct applies within all community spaces, and also applies when
+an individual is officially representing the community in public spaces.
+Examples of representing our community include using an official e-mail address,
+posting via an official social media account, or acting as an appointed
+representative at an online or offline event.
+
+## Enforcement
+
+Instances of abusive, harassing, or otherwise unacceptable behavior may be
+reported to the community leaders responsible for enforcement at
+[INSERT CONTACT METHOD].
+All complaints will be reviewed and investigated promptly and fairly.
+
+All community leaders are obligated to respect the privacy and security of the
+reporter of any incident.
+
+## Enforcement Guidelines
+
+Community leaders will follow these Community Impact Guidelines in determining
+the consequences for any action they deem in violation of this Code of Conduct:
+
+### 1. Correction
+
+*Community Impact*: Use of inappropriate language or other behavior deemed
+unprofessional or unwelcome in the community.
+
+*Consequence*: A private, written warning from community leaders, providing
+clarity around the nature of the violation and an explanation of why the
+behavior was inappropriate. A public apology may be requested.
+
+### 2. Warning
+
+*Community Impact*: A violation through a single incident or series of
+actions.
+
+*Consequence*: A warning with consequences for continued behavior. No
+interaction with the people involved, including unsolicited interaction with
+those enforcing the Code of Conduct, for a specified period of time. This
+includes avoiding interactions in community spaces as well as external channels
+like social media. Violating these terms may lead to a temporary or permanent
+ban.
+
+### 3. Temporary Ban
+
+*Community Impact*: A serious violation of community standards, including
+sustained inappropriate behavior.
+
+*Consequence*: A temporary ban from any sort of interaction or public
+communication with the community for a specified period of time. No public or
+private interaction with the people involved, including unsolicited interaction
+with those enforcing the Code of Conduct, is allowed during this period.
+Violating these terms may lead to a permanent ban.
+
+### 4. Permanent Ban
+
+*Community Impact*: Demonstrating a pattern of violation of community
+standards, including sustained inappropriate behavior, harassment of an
+individual, or aggression toward or disparagement of classes of individuals.
+
+*Consequence*: A permanent ban from any sort of public interaction within the
+community.
+
+## Attribution
+
+This Code of Conduct is adapted from the [Contributor Covenant][homepage],
+version 2.1, available at
+[https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1].
+
+Community Impact Guidelines were inspired by
+[Mozilla's code of conduct enforcement ladder][Mozilla CoC].
+
+For answers to common questions about this code of conduct, see the FAQ at
+[https://www.contributor-covenant.org/faq][FAQ]. Translations are available at
+[https://www.contributor-covenant.org/translations][translations].
From 29dac02955c02da2bbe50a0472d32db3a707c13f Mon Sep 17 00:00:00 2001
From: chinmay7016 <75988613+chinmay7016@users.noreply.github.com>
Date: Tue, 31 Oct 2023 00:27:51 +0530
Subject: [PATCH 18/57] Update and rename knownledge.md to knowledge.md
typo solved
---
docs/modules/{knownledge.md => knowledge.md} | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
rename docs/modules/{knownledge.md => knowledge.md} (97%)
diff --git a/docs/modules/knownledge.md b/docs/modules/knowledge.md
similarity index 97%
rename from docs/modules/knownledge.md
rename to docs/modules/knowledge.md
index 03fb45eec..7c8b2758f 100644
--- a/docs/modules/knownledge.md
+++ b/docs/modules/knowledge.md
@@ -1,4 +1,4 @@
-# Knownledge
+# Knowledge
As the knowledge base is currently the most significant user demand scenario, we natively support the construction and processing of knowledge bases. At the same time, we also provide multiple knowledge base management strategies in this project, such as:
1. Default built-in knowledge base
@@ -32,4 +32,4 @@ Optionally, you can run `dbgpt knowledge load --help` command to see more usage.
3.Add the knowledge repository in the interface by entering the name of your knowledge repository (if not specified, enter "default") so you can use it for Q&A based on your knowledge base.
-Note that the default vector model used is text2vec-large-chinese (which is a large model, so if your personal computer configuration is not enough, it is recommended to use text2vec-base-chinese). Therefore, ensure that you download the model and place it in the models directory.
\ No newline at end of file
+Note that the default vector model used is text2vec-large-chinese (which is a large model, so if your personal computer configuration is not enough, it is recommended to use text2vec-base-chinese). Therefore, ensure that you download the model and place it in the models directory.
From dca3ddb93113cd426e26253fdebffaec37f0da4b Mon Sep 17 00:00:00 2001
From: aries_ckt <916701291@qq.com>
Date: Tue, 31 Oct 2023 13:47:19 +0800
Subject: [PATCH 19/57] feat:add summary
---
pilot/common/chat_util.py | 35 ++++++++++++
.../chat_knowledge/refine_summary/__init__.py | 0
.../chat_knowledge/refine_summary/chat.py | 37 ++++++++++++
.../refine_summary/out_parser.py | 57 +++++++++++++++++++
.../chat_knowledge/refine_summary/prompt.py | 40 +++++++++++++
.../scene/chat_knowledge/summary/__init__.py | 0
pilot/scene/chat_knowledge/summary/chat.py | 35 ++++++++++++
.../chat_knowledge/summary/out_parser.py | 28 +++++++++
pilot/scene/chat_knowledge/summary/prompt.py | 47 +++++++++++++++
pilot/server/knowledge/service.py | 41 +++++++++++--
10 files changed, 315 insertions(+), 5 deletions(-)
create mode 100644 pilot/scene/chat_knowledge/refine_summary/__init__.py
create mode 100644 pilot/scene/chat_knowledge/refine_summary/chat.py
create mode 100644 pilot/scene/chat_knowledge/refine_summary/out_parser.py
create mode 100644 pilot/scene/chat_knowledge/refine_summary/prompt.py
create mode 100644 pilot/scene/chat_knowledge/summary/__init__.py
create mode 100644 pilot/scene/chat_knowledge/summary/chat.py
create mode 100644 pilot/scene/chat_knowledge/summary/out_parser.py
create mode 100644 pilot/scene/chat_knowledge/summary/prompt.py
diff --git a/pilot/common/chat_util.py b/pilot/common/chat_util.py
index 0de0b9bda..ae0ce73ed 100644
--- a/pilot/common/chat_util.py
+++ b/pilot/common/chat_util.py
@@ -1,4 +1,5 @@
import asyncio
+from typing import Coroutine, List, Any
from starlette.responses import StreamingResponse
@@ -18,3 +19,37 @@ async def llm_chat_response_nostream(chat_scene: str, **chat_param):
async def llm_chat_response(chat_scene: str, **chat_param):
chat: BaseChat = chat_factory.get_implementation(chat_scene, **chat_param)
return chat.stream_call()
+
+
+def run_async_tasks(
+ tasks: List[Coroutine],
+ show_progress: bool = False,
+ progress_bar_desc: str = "Running async tasks",
+) -> List[Any]:
+ """Run a list of async tasks."""
+
+ tasks_to_execute: List[Any] = tasks
+ if show_progress:
+ try:
+ import nest_asyncio
+ from tqdm.asyncio import tqdm
+
+ nest_asyncio.apply()
+ loop = asyncio.get_event_loop()
+
+ async def _tqdm_gather() -> List[Any]:
+ return await tqdm.gather(*tasks_to_execute, desc=progress_bar_desc)
+
+ tqdm_outputs: List[Any] = loop.run_until_complete(_tqdm_gather())
+ return tqdm_outputs
+ # run the operation w/o tqdm on hitting a fatal
+ # may occur in some environments where tqdm.asyncio
+ # is not supported
+ except Exception:
+ pass
+
+ async def _gather() -> List[Any]:
+ return await asyncio.gather(*tasks_to_execute)
+
+ outputs: List[Any] = asyncio.run(_gather())
+ return outputs
diff --git a/pilot/scene/chat_knowledge/refine_summary/__init__.py b/pilot/scene/chat_knowledge/refine_summary/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pilot/scene/chat_knowledge/refine_summary/chat.py b/pilot/scene/chat_knowledge/refine_summary/chat.py
new file mode 100644
index 000000000..b3a934dd5
--- /dev/null
+++ b/pilot/scene/chat_knowledge/refine_summary/chat.py
@@ -0,0 +1,37 @@
+from typing import Dict
+
+from pilot.scene.base_chat import BaseChat
+from pilot.scene.base import ChatScene
+from pilot.configs.config import Config
+
+from pilot.scene.chat_knowledge.refine_summary.prompt import prompt
+
+CFG = Config()
+
+
+class ExtractRefineSummary(BaseChat):
+ chat_scene: str = ChatScene.ExtractRefineSummary.value()
+
+ """get summary by llm"""
+
+ def __init__(self, chat_param: Dict):
+ """ """
+ chat_param["chat_mode"] = ChatScene.ExtractRefineSummary
+ super().__init__(
+ chat_param=chat_param,
+ )
+
+ self.user_input = chat_param["current_user_input"]
+ self.existing_answer = chat_param["select_param"]
+ # self.extract_mode = chat_param["select_param"]
+
+ def generate_input_values(self):
+ input_values = {
+ "context": self.user_input,
+ "existing_answer": self.existing_answer,
+ }
+ return input_values
+
+ @property
+ def chat_type(self) -> str:
+ return ChatScene.ExtractRefineSummary.value
diff --git a/pilot/scene/chat_knowledge/refine_summary/out_parser.py b/pilot/scene/chat_knowledge/refine_summary/out_parser.py
new file mode 100644
index 000000000..104419e88
--- /dev/null
+++ b/pilot/scene/chat_knowledge/refine_summary/out_parser.py
@@ -0,0 +1,57 @@
+import json
+import logging
+import re
+from typing import List, Tuple
+
+from pilot.out_parser.base import BaseOutputParser, T
+from pilot.configs.config import Config
+
+CFG = Config()
+
+
+logger = logging.getLogger(__name__)
+
+
+class ExtractRefineSummaryParser(BaseOutputParser):
+ def __init__(self, sep: str, is_stream_out: bool):
+ super().__init__(sep=sep, is_stream_out=is_stream_out)
+
+ def parse_prompt_response(
+ self, response, max_length: int = 128
+ ) -> List[Tuple[str, str, str]]:
+ # clean_str = super().parse_prompt_response(response)
+ print("clean prompt response:", response)
+
+ # if response.startswith("Triplets:"):
+ # response = response[len("Triplets:") :]
+ # pattern = r"\([^()]+\)"
+ # response = re.findall(pattern, response)
+ # # response = response.strip().split("\n")
+ # print("parse prompt response:", response)
+ # results = []
+ # for text in response:
+ # if not text or text[0] != "(" or text[-1] != ")":
+ # # skip empty lines and non-triplets
+ # continue
+ # tokens = text[1:-1].split(",")
+ # if len(tokens) != 3:
+ # continue
+ #
+ # if any(len(s.encode("utf-8")) > max_length for s in tokens):
+ # # We count byte-length instead of len() for UTF-8 chars,
+ # # will skip if any of the tokens are too long.
+ # # This is normally due to a poorly formatted triplet
+ # # extraction, in more serious KG building cases
+ # # we'll need NLP models to better extract triplets.
+ # continue
+ #
+ # subject, predicate, obj = map(str.strip, tokens)
+ # if not subject or not predicate or not obj:
+ # # skip partial triplets
+ # continue
+ # results.append((subject.lower(), predicate.lower(), obj.lower()))
+ return response
+
+ def parse_view_response(self, speak, data) -> str:
+ ### tool out data to table view
+ return data
diff --git a/pilot/scene/chat_knowledge/refine_summary/prompt.py b/pilot/scene/chat_knowledge/refine_summary/prompt.py
new file mode 100644
index 000000000..0161cee35
--- /dev/null
+++ b/pilot/scene/chat_knowledge/refine_summary/prompt.py
@@ -0,0 +1,40 @@
+from pilot.prompts.prompt_new import PromptTemplate
+from pilot.configs.config import Config
+from pilot.scene.base import ChatScene
+from pilot.common.schema import SeparatorStyle
+
+from pilot.scene.chat_knowledge.refine_summary.out_parser import ExtractRefineSummaryParser
+
+CFG = Config()
+
+
+PROMPT_SCENE_DEFINE = """Your job is to produce a final summary."""
+
+_DEFAULT_TEMPLATE = """
+We have provided an existing summary up to a certain point: {existing_answer}\nWe have the opportunity to refine the existing summary (only if needed) with some more context below.\n------------\n{context}\n------------\nGiven the new context, refine the original summary.\nIf the context isn't useful, return the original summary.
+
+please use original language.
+"""
+PROMPT_RESPONSE = """"""
+
+
+RESPONSE_FORMAT = """"""
+
+
+PROMPT_SEP = SeparatorStyle.SINGLE.value
+
+PROMPT_NEED_NEED_STREAM_OUT = False
+
+prompt = PromptTemplate(
+ template_scene=ChatScene.ExtractRefineSummary.value(),
+ input_variables=["existing_answer","context"],
+ response_format="",
+ template_define=PROMPT_SCENE_DEFINE,
+ template=_DEFAULT_TEMPLATE + PROMPT_RESPONSE,
+ stream_out=PROMPT_NEED_NEED_STREAM_OUT,
+ output_parser=ExtractRefineSummaryParser(
+ sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
+ ),
+)
+
+CFG.prompt_template_registry.register(prompt, is_default=True)
diff --git a/pilot/scene/chat_knowledge/summary/__init__.py b/pilot/scene/chat_knowledge/summary/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pilot/scene/chat_knowledge/summary/chat.py b/pilot/scene/chat_knowledge/summary/chat.py
new file mode 100644
index 000000000..f887bde82
--- /dev/null
+++ b/pilot/scene/chat_knowledge/summary/chat.py
@@ -0,0 +1,35 @@
+from typing import Dict
+
+from pilot.scene.base_chat import BaseChat
+from pilot.scene.base import ChatScene
+from pilot.configs.config import Config
+
+from pilot.scene.chat_knowledge.summary.prompt import prompt
+
+CFG = Config()
+
+
+class ExtractSummary(BaseChat):
+ chat_scene: str = ChatScene.ExtractSummary.value()
+
+ """get summary by llm"""
+
+ def __init__(self, chat_param: Dict):
+ """ """
+ chat_param["chat_mode"] = ChatScene.ExtractSummary
+ super().__init__(
+ chat_param=chat_param,
+ )
+
+ self.user_input = chat_param["current_user_input"]
+ # self.extract_mode = chat_param["select_param"]
+
+ def generate_input_values(self):
+ input_values = {
+ "context": self.user_input,
+ }
+ return input_values
+
+ @property
+ def chat_type(self) -> str:
+ return ChatScene.ExtractSummary.value
diff --git a/pilot/scene/chat_knowledge/summary/out_parser.py b/pilot/scene/chat_knowledge/summary/out_parser.py
new file mode 100644
index 000000000..5626d0d4a
--- /dev/null
+++ b/pilot/scene/chat_knowledge/summary/out_parser.py
@@ -0,0 +1,28 @@
+import json
+import logging
+import re
+from typing import List, Tuple
+
+from pilot.out_parser.base import BaseOutputParser, T
+from pilot.configs.config import Config
+
+CFG = Config()
+
+
+logger = logging.getLogger(__name__)
+
+
+class ExtractSummaryParser(BaseOutputParser):
+ def __init__(self, sep: str, is_stream_out: bool):
+ super().__init__(sep=sep, is_stream_out=is_stream_out)
+
+ def parse_prompt_response(
+ self, response, max_length: int = 128
+ ) -> List[Tuple[str, str, str]]:
+ # clean_str = super().parse_prompt_response(response)
+ print("clean prompt response:", response)
+ return response
+
+ def parse_view_response(self, speak, data) -> str:
+ ### tool out data to table view
+ return data
diff --git a/pilot/scene/chat_knowledge/summary/prompt.py b/pilot/scene/chat_knowledge/summary/prompt.py
new file mode 100644
index 000000000..cbf452c99
--- /dev/null
+++ b/pilot/scene/chat_knowledge/summary/prompt.py
@@ -0,0 +1,47 @@
+from pilot.prompts.prompt_new import PromptTemplate
+from pilot.configs.config import Config
+from pilot.scene.base import ChatScene
+from pilot.common.schema import SeparatorStyle
+
+from pilot.scene.chat_knowledge.summary.out_parser import ExtractSummaryParser
+
+CFG = Config()
+
+# PROMPT_SCENE_DEFINE = """You are an expert Q&A system that is trusted around the world.\nAlways answer the query using the provided context information, and not prior knowledge.\nSome rules to follow:\n1. Never directly reference the given context in your answer.\n2. Avoid statements like 'Based on the context, ...' or 'The context information ...' or anything along those lines."""
+
+PROMPT_SCENE_DEFINE = """Your job is to produce a final summary."""
+
+# _DEFAULT_TEMPLATE = """
+# Context information from multiple sources is below.\n---------------------\n
+# {context}
+# Given the information from multiple sources and not prior knowledge, answer the query.\nQuery: Describe what the provided text is about. Also describe some of the questions that this text can answer. \nAnswer: "
+# """
+
+_DEFAULT_TEMPLATE = """
+Write a concise summary of the following context:
+{context}
+please use original language.
+"""
+PROMPT_RESPONSE = """"""
+
+
+RESPONSE_FORMAT = """"""
+
+
+PROMPT_SEP = SeparatorStyle.SINGLE.value
+
+PROMPT_NEED_NEED_STREAM_OUT = False
+
+prompt = PromptTemplate(
+ template_scene=ChatScene.ExtractSummary.value(),
+ input_variables=["context"],
+ response_format="",
+ template_define=PROMPT_SCENE_DEFINE,
+ template=_DEFAULT_TEMPLATE + PROMPT_RESPONSE,
+ stream_out=PROMPT_NEED_NEED_STREAM_OUT,
+ output_parser=ExtractSummaryParser(
+ sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
+ ),
+)
+
+CFG.prompt_template_registry.register(prompt, is_default=True)
diff --git a/pilot/server/knowledge/service.py b/pilot/server/knowledge/service.py
index 017fef3ec..cde2b7bb7 100644
--- a/pilot/server/knowledge/service.py
+++ b/pilot/server/knowledge/service.py
@@ -288,7 +288,7 @@ class KnowledgeService:
executor = CFG.SYSTEM_APP.get_component(
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
).create()
- executor.submit(self.async_document_summary, chunk_docs, doc)
+ # executor.submit(self.async_document_summary, chunk_docs, doc)
executor.submit(self.async_doc_embedding, client, chunk_docs, doc)
logger.info(f"begin save document chunks, doc:{doc.doc_name}")
# save chunk details
@@ -431,7 +431,8 @@ class KnowledgeService:
texts = [doc.page_content for doc in chunk_docs]
prompt_helper = PromptHelper()
texts = prompt_helper.repack(prompt=DEFAULT_TREE_SUMMARIZE_PROMPT_SEL, text_chunks=texts)
- summary = self._llm_extract_summary(chunk_docs[0])
+ summary = self._llm_extract_summary(texts[0])
+ # summaries = self._mapreduce_extract_summary(texts)
outputs, summary = self._refine_extract_summary(texts[1:], summary)
logger.info(
f"async_document_summary, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to graph store"
@@ -452,6 +453,7 @@ class KnowledgeService:
)
try:
vector_ids = client.knowledge_embedding_batch(chunk_docs)
+ self.async_document_summary(chunk_docs, doc)
doc.status = SyncStatus.FINISHED.name
doc.result = "document embedding success"
if vector_ids is not None:
@@ -512,9 +514,9 @@ class KnowledgeService:
chat_param = {
"chat_session_id": uuid.uuid1(),
- "current_user_input": doc.page_content,
+ "current_user_input": doc,
"select_param": "summary",
- "model_name": "proxyllm",
+ "model_name": CFG.LLM_MODEL,
}
from pilot.utils import utils
loop = utils.get_or_create_event_loop()
@@ -535,7 +537,7 @@ class KnowledgeService:
"chat_session_id": uuid.uuid1(),
"current_user_input": doc,
"select_param": summary,
- "model_name": "proxyllm",
+ "model_name": CFG.LLM_MODEL,
}
from pilot.utils import utils
loop = utils.get_or_create_event_loop()
@@ -547,4 +549,33 @@ class KnowledgeService:
outputs.append(summary)
return outputs, summary
+ def _mapreduce_extract_summary(self, docs):
+ """Extract mapreduce summary by llm"""
+ from pilot.scene.base import ChatScene
+ from pilot.common.chat_util import llm_chat_response_nostream
+ import uuid
+ outputs = []
+ tasks = []
+ for doc in docs:
+ chat_param = {
+ "chat_session_id": uuid.uuid1(),
+ "current_user_input": doc,
+ "select_param": "summary",
+ "model_name": CFG.LLM_MODEL,
+ }
+ tasks.append(llm_chat_response_nostream(
+ ChatScene.ExtractSummary.value(), **{"chat_param": chat_param}
+ ))
+ from pilot.common.chat_util import run_async_tasks
+ summaries = run_async_tasks(tasks)
+ # from pilot.utils import utils
+ # loop = utils.get_or_create_event_loop()
+ # summary = loop.run_until_complete(
+ # llm_chat_response_nostream(
+ # ChatScene.ExtractRefineSummary.value(), **{"chat_param": chat_param}
+ # )
+ # )
+ # outputs.append(summary)
+ return summaries
+
From 16dd8e3ef550765144e53ed824be695b79ace0df Mon Sep 17 00:00:00 2001
From: aries_ckt <916701291@qq.com>
Date: Tue, 31 Oct 2023 13:48:15 +0800
Subject: [PATCH 20/57] feat:document summary
---
pilot/scene/base_chat.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py
index e43cdc812..5c33f4770 100644
--- a/pilot/scene/base_chat.py
+++ b/pilot/scene/base_chat.py
@@ -284,6 +284,7 @@ class BaseChat(ABC):
)
### model result deal
self.current_message.add_ai_message(ai_response_text)
+ prompt_define_response = None
prompt_define_response = (
self.prompt_template.output_parser.parse_prompt_response(
ai_response_text
From 523838fb796d6611b2e0075311d9e2a9741b15e5 Mon Sep 17 00:00:00 2001
From: aries_ckt <916701291@qq.com>
Date: Tue, 31 Oct 2023 15:09:11 +0800
Subject: [PATCH 21/57] feat:document summary
---
pilot/server/knowledge/service.py | 25 +++++++++++++++++--------
1 file changed, 17 insertions(+), 8 deletions(-)
diff --git a/pilot/server/knowledge/service.py b/pilot/server/knowledge/service.py
index cde2b7bb7..81e1dbdcc 100644
--- a/pilot/server/knowledge/service.py
+++ b/pilot/server/knowledge/service.py
@@ -429,14 +429,15 @@ class KnowledgeService:
from llama_index import PromptHelper
from llama_index.prompts.default_prompt_selectors import DEFAULT_TREE_SUMMARIZE_PROMPT_SEL
texts = [doc.page_content for doc in chunk_docs]
- prompt_helper = PromptHelper()
+ prompt_helper = PromptHelper(context_window=5000)
texts = prompt_helper.repack(prompt=DEFAULT_TREE_SUMMARIZE_PROMPT_SEL, text_chunks=texts)
+ logger.info(
+ f"async_document_summary, doc:{doc.doc_name}, chunk_size:{len(texts)}, begin generate summary"
+ )
summary = self._llm_extract_summary(texts[0])
# summaries = self._mapreduce_extract_summary(texts)
outputs, summary = self._refine_extract_summary(texts[1:], summary)
- logger.info(
- f"async_document_summary, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to graph store"
- )
+
doc.summary = summary
return knowledge_document_dao.update_knowledge_document(doc)
@@ -525,14 +526,18 @@ class KnowledgeService:
ChatScene.ExtractSummary.value(), **{"chat_param": chat_param}
)
)
+ logger.info(
+ f"initialize summary is :{summary}"
+ )
return summary
- def _refine_extract_summary(self, docs, summary: str):
+ def _refine_extract_summary(self, docs, summary: str, max_iteration:int = 5):
"""Extract refine summary by llm"""
from pilot.scene.base import ChatScene
from pilot.common.chat_util import llm_chat_response_nostream
import uuid
outputs = []
- for doc in docs:
+ max_iteration = max_iteration if len(docs) > max_iteration else len(docs)
+ for doc in docs[0:max_iteration]:
chat_param = {
"chat_session_id": uuid.uuid1(),
"current_user_input": doc,
@@ -547,6 +552,9 @@ class KnowledgeService:
)
)
outputs.append(summary)
+ logger.info(
+ f"iterator is {len(outputs)} current summary is :{summary}"
+ )
return outputs, summary
def _mapreduce_extract_summary(self, docs):
@@ -567,7 +575,8 @@ class KnowledgeService:
ChatScene.ExtractSummary.value(), **{"chat_param": chat_param}
))
from pilot.common.chat_util import run_async_tasks
- summaries = run_async_tasks(tasks)
+ summary_iters = run_async_tasks(tasks)
+ summary = self._llm_extract_summary(" ".join(summary_iters))
# from pilot.utils import utils
# loop = utils.get_or_create_event_loop()
# summary = loop.run_until_complete(
@@ -576,6 +585,6 @@ class KnowledgeService:
# )
# )
# outputs.append(summary)
- return summaries
+ return summary
From 7dcfa1921d10f02c9a8c57093dfd57f77fe92e58 Mon Sep 17 00:00:00 2001
From: aries_ckt <916701291@qq.com>
Date: Tue, 31 Oct 2023 15:53:40 +0800
Subject: [PATCH 22/57] feat:document summary
---
pilot/server/knowledge/service.py | 10 +++++++---
1 file changed, 7 insertions(+), 3 deletions(-)
diff --git a/pilot/server/knowledge/service.py b/pilot/server/knowledge/service.py
index 81e1dbdcc..570a549ed 100644
--- a/pilot/server/knowledge/service.py
+++ b/pilot/server/knowledge/service.py
@@ -437,7 +437,11 @@ class KnowledgeService:
summary = self._llm_extract_summary(texts[0])
# summaries = self._mapreduce_extract_summary(texts)
outputs, summary = self._refine_extract_summary(texts[1:], summary)
-
+ summaries = prompt_helper.repack(prompt=DEFAULT_TREE_SUMMARIZE_PROMPT_SEL, text_chunks=outputs)
+ summary = self._llm_extract_summary("|".join(summaries))
+ print(
+ f"final summary:{summary}"
+ )
doc.summary = summary
return knowledge_document_dao.update_knowledge_document(doc)
@@ -526,7 +530,7 @@ class KnowledgeService:
ChatScene.ExtractSummary.value(), **{"chat_param": chat_param}
)
)
- logger.info(
+ print(
f"initialize summary is :{summary}"
)
return summary
@@ -552,7 +556,7 @@ class KnowledgeService:
)
)
outputs.append(summary)
- logger.info(
+ print(
f"iterator is {len(outputs)} current summary is :{summary}"
)
return outputs, summary
From b3dbf31209ebfa6c2dde1d46d53a520b46f054bf Mon Sep 17 00:00:00 2001
From: aries_ckt <916701291@qq.com>
Date: Tue, 31 Oct 2023 16:10:06 +0800
Subject: [PATCH 23/57] feat:document summary
---
pilot/server/knowledge/service.py | 11 +++++++----
1 file changed, 7 insertions(+), 4 deletions(-)
diff --git a/pilot/server/knowledge/service.py b/pilot/server/knowledge/service.py
index 570a549ed..f906ff372 100644
--- a/pilot/server/knowledge/service.py
+++ b/pilot/server/knowledge/service.py
@@ -437,6 +437,9 @@ class KnowledgeService:
summary = self._llm_extract_summary(texts[0])
# summaries = self._mapreduce_extract_summary(texts)
outputs, summary = self._refine_extract_summary(texts[1:], summary)
+ print(
+ f"refine summary outputs:{outputs}"
+ )
summaries = prompt_helper.repack(prompt=DEFAULT_TREE_SUMMARIZE_PROMPT_SEL, text_chunks=outputs)
summary = self._llm_extract_summary("|".join(summaries))
print(
@@ -530,16 +533,16 @@ class KnowledgeService:
ChatScene.ExtractSummary.value(), **{"chat_param": chat_param}
)
)
- print(
- f"initialize summary is :{summary}"
- )
return summary
def _refine_extract_summary(self, docs, summary: str, max_iteration:int = 5):
"""Extract refine summary by llm"""
from pilot.scene.base import ChatScene
from pilot.common.chat_util import llm_chat_response_nostream
import uuid
- outputs = []
+ print(
+ f"initialize summary is :{summary}"
+ )
+ outputs = [summary]
max_iteration = max_iteration if len(docs) > max_iteration else len(docs)
for doc in docs[0:max_iteration]:
chat_param = {
From 40eed546ac346061989c429a319f1e72f3ac3df4 Mon Sep 17 00:00:00 2001
From: aries_ckt <916701291@qq.com>
Date: Tue, 31 Oct 2023 16:14:04 +0800
Subject: [PATCH 24/57] chore:discord expire
---
pilot/server/knowledge/service.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/pilot/server/knowledge/service.py b/pilot/server/knowledge/service.py
index f906ff372..cd18c539f 100644
--- a/pilot/server/knowledge/service.py
+++ b/pilot/server/knowledge/service.py
@@ -429,7 +429,7 @@ class KnowledgeService:
from llama_index import PromptHelper
from llama_index.prompts.default_prompt_selectors import DEFAULT_TREE_SUMMARIZE_PROMPT_SEL
texts = [doc.page_content for doc in chunk_docs]
- prompt_helper = PromptHelper(context_window=5000)
+ prompt_helper = PromptHelper(context_window=3900)
texts = prompt_helper.repack(prompt=DEFAULT_TREE_SUMMARIZE_PROMPT_SEL, text_chunks=texts)
logger.info(
f"async_document_summary, doc:{doc.doc_name}, chunk_size:{len(texts)}, begin generate summary"
From 04dcd905022b55bc5354828d6e5c283ccec5d6b1 Mon Sep 17 00:00:00 2001
From: aries_ckt <916701291@qq.com>
Date: Tue, 31 Oct 2023 16:14:26 +0800
Subject: [PATCH 25/57] feat:document summary
---
pilot/server/knowledge/service.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/pilot/server/knowledge/service.py b/pilot/server/knowledge/service.py
index cd18c539f..4db3d6c51 100644
--- a/pilot/server/knowledge/service.py
+++ b/pilot/server/knowledge/service.py
@@ -429,7 +429,7 @@ class KnowledgeService:
from llama_index import PromptHelper
from llama_index.prompts.default_prompt_selectors import DEFAULT_TREE_SUMMARIZE_PROMPT_SEL
texts = [doc.page_content for doc in chunk_docs]
- prompt_helper = PromptHelper(context_window=3900)
+ prompt_helper = PromptHelper()
texts = prompt_helper.repack(prompt=DEFAULT_TREE_SUMMARIZE_PROMPT_SEL, text_chunks=texts)
logger.info(
f"async_document_summary, doc:{doc.doc_name}, chunk_size:{len(texts)}, begin generate summary"
From 48539d7206a754644306cd1541337ad37b2d1c4c Mon Sep 17 00:00:00 2001
From: wangzaistone
Date: Tue, 31 Oct 2023 17:10:12 +0800
Subject: [PATCH 26/57] codellama bug fix
---
pilot/server/chat_adapter.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/pilot/server/chat_adapter.py b/pilot/server/chat_adapter.py
index 509305247..4b6dd0eed 100644
--- a/pilot/server/chat_adapter.py
+++ b/pilot/server/chat_adapter.py
@@ -218,7 +218,7 @@ class Llama2ChatAdapter(BaseChatAdpter):
class CodeLlamaChatAdapter(BaseChatAdpter):
"""The model ChatAdapter for codellama ."""
def match(self, model_path: str):
- return "codelama" in model_path.lower()
+ return "codellama" in model_path.lower()
def get_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("codellama")
From a670e5c00d31007652ef370a85d7f9996cd5a36b Mon Sep 17 00:00:00 2001
From: wangzaistone
Date: Tue, 31 Oct 2023 17:24:54 +0800
Subject: [PATCH 27/57] add other codellama models
---
pilot/configs/model_config.py | 7 +++++++
pilot/model/adapter.py | 2 +-
pilot/model/model_adapter.py | 5 ++++-
3 files changed, 12 insertions(+), 2 deletions(-)
diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py
index 16deee50a..803d0fae9 100644
--- a/pilot/configs/model_config.py
+++ b/pilot/configs/model_config.py
@@ -78,8 +78,15 @@ LLM_MODEL_CONFIG = {
"internlm-7b": os.path.join(MODEL_PATH, "internlm-chat-7b"),
"internlm-7b-8k": os.path.join(MODEL_PATH, "internlm-chat-7b-8k"),
"internlm-20b": os.path.join(MODEL_PATH, "internlm-chat-20b"),
+ "codellama-7b": os.path.join(MODEL_PATH, "CodeLlama-7b-Instruct-hf"),
+ "codellama-7b-sql-sft": os.path.join(MODEL_PATH, "codellama-7b-sql-sft"),
+ "codellama-13b": os.path.join(MODEL_PATH, "CodeLlama-13b-Instruct-hf"),
"codellama-13b-sql-sft": os.path.join(MODEL_PATH, "codellama-13b-sql-sft"),
+
+
+
+
# For test now
"opt-125m": os.path.join(MODEL_PATH, "opt-125m"),
}
diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py
index 02fbe8aa9..cb9885d2a 100644
--- a/pilot/model/adapter.py
+++ b/pilot/model/adapter.py
@@ -323,7 +323,7 @@ class CodeLlamaAdapter(BaseLLMAdaper):
"""The model adapter for codellama """
def match(self, model_path: str):
- return "codelama" in model_path.lower()
+ return "codellama" in model_path.lower()
def loader(self, model_path: str, from_pretrained_kwargs: dict):
model, tokenizer = super().loader(model_path, from_pretrained_kwargs)
diff --git a/pilot/model/model_adapter.py b/pilot/model/model_adapter.py
index 112fb468a..cadb1cebd 100644
--- a/pilot/model/model_adapter.py
+++ b/pilot/model/model_adapter.py
@@ -45,7 +45,10 @@ _OLD_MODELS = [
"llama-cpp",
"proxyllm",
"gptj-6b",
- "codellama-13b-sql-sft"
+ "codellama-13b-sql-sft",
+ "codellama-7b",
+ "codellama-7b-sql-sft",
+ "codellama-13b"
]
From 17e21a395bcff6228ad3dacc0c251af0ec38d4b8 Mon Sep 17 00:00:00 2001
From: wangzaistone
Date: Tue, 31 Oct 2023 17:26:59 +0800
Subject: [PATCH 28/57] keep as origin default param
---
.env.template | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/.env.template b/.env.template
index 272ee2922..e03650033 100644
--- a/.env.template
+++ b/.env.template
@@ -22,8 +22,7 @@ WEB_SERVER_PORT=7860
#** LLM MODELS **#
#*******************************************************************#
# LLM_MODEL, see /pilot/configs/model_config.LLM_MODEL_CONFIG
-# LLM_MODEL=vicuna-13b-v1.5
-LLM_MODEL=codellama-13b-sql-sft
+LLM_MODEL=vicuna-13b-v1.5
## LLM model path, by default, DB-GPT will read the model path from LLM_MODEL_CONFIG based on the LLM_MODEL.
## Of course you can specify your model path according to LLM_MODEL_PATH
## In DB-GPT, the priority from high to low to read model path:
From 3233e260b20e16fa39424dfe8a606b3df4d92b1a Mon Sep 17 00:00:00 2001
From: wangzaistone
Date: Tue, 31 Oct 2023 17:39:14 +0800
Subject: [PATCH 29/57] add conv judge
---
pilot/configs/model_config.py | 5 -----
pilot/model/adapter.py | 3 ++-
pilot/model/conversation.py | 1 -
pilot/model/model_adapter.py | 8 ++++++--
pilot/server/chat_adapter.py | 3 ++-
5 files changed, 10 insertions(+), 10 deletions(-)
diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py
index 803d0fae9..0e1fb3d40 100644
--- a/pilot/configs/model_config.py
+++ b/pilot/configs/model_config.py
@@ -82,11 +82,6 @@ LLM_MODEL_CONFIG = {
"codellama-7b-sql-sft": os.path.join(MODEL_PATH, "codellama-7b-sql-sft"),
"codellama-13b": os.path.join(MODEL_PATH, "CodeLlama-13b-Instruct-hf"),
"codellama-13b-sql-sft": os.path.join(MODEL_PATH, "codellama-13b-sql-sft"),
-
-
-
-
-
# For test now
"opt-125m": os.path.join(MODEL_PATH, "opt-125m"),
}
diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py
index cb9885d2a..5ce5b2173 100644
--- a/pilot/model/adapter.py
+++ b/pilot/model/adapter.py
@@ -319,8 +319,9 @@ class Llama2Adapter(BaseLLMAdaper):
model.config.pad_token_id = tokenizer.pad_token_id
return model, tokenizer
+
class CodeLlamaAdapter(BaseLLMAdaper):
- """The model adapter for codellama """
+ """The model adapter for codellama"""
def match(self, model_path: str):
return "codellama" in model_path.lower()
diff --git a/pilot/model/conversation.py b/pilot/model/conversation.py
index 98dfc720d..5d4309d9f 100644
--- a/pilot/model/conversation.py
+++ b/pilot/model/conversation.py
@@ -360,7 +360,6 @@ register_conv_template(
)
-
# Alpaca default template
register_conv_template(
Conversation(
diff --git a/pilot/model/model_adapter.py b/pilot/model/model_adapter.py
index cadb1cebd..e09b868e7 100644
--- a/pilot/model/model_adapter.py
+++ b/pilot/model/model_adapter.py
@@ -48,7 +48,7 @@ _OLD_MODELS = [
"codellama-13b-sql-sft",
"codellama-7b",
"codellama-7b-sql-sft",
- "codellama-13b"
+ "codellama-13b",
]
@@ -152,8 +152,12 @@ class LLMModelAdaper:
conv.append_message(conv.roles[1], content)
else:
raise ValueError(f"Unknown role: {role}")
+
if system_messages:
- conv.set_system_message("".join(system_messages))
+ if isinstance(conv, Conversation):
+ conv.set_system_message("".join(system_messages))
+ else:
+ conv.update_system_message("".join(system_messages))
# Add a blank message for the assistant.
conv.append_message(conv.roles[1], None)
diff --git a/pilot/server/chat_adapter.py b/pilot/server/chat_adapter.py
index 4b6dd0eed..64b72739b 100644
--- a/pilot/server/chat_adapter.py
+++ b/pilot/server/chat_adapter.py
@@ -213,10 +213,11 @@ class Llama2ChatAdapter(BaseChatAdpter):
def get_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("llama-2")
-
+
class CodeLlamaChatAdapter(BaseChatAdpter):
"""The model ChatAdapter for codellama ."""
+
def match(self, model_path: str):
return "codellama" in model_path.lower()
From de902448280a16fdd0b7756e904cc256280d3551 Mon Sep 17 00:00:00 2001
From: aries_ckt <916701291@qq.com>
Date: Tue, 31 Oct 2023 18:52:58 +0800
Subject: [PATCH 30/57] feat:document summary
---
pilot/model/cluster/worker/remote_worker.py | 2 +-
.../chat_knowledge/refine_summary/prompt.py | 16 +++--
pilot/scene/chat_knowledge/summary/prompt.py | 17 +++--
pilot/server/knowledge/service.py | 70 ++++++++++---------
4 files changed, 58 insertions(+), 47 deletions(-)
diff --git a/pilot/model/cluster/worker/remote_worker.py b/pilot/model/cluster/worker/remote_worker.py
index f974ba714..149f8b86a 100644
--- a/pilot/model/cluster/worker/remote_worker.py
+++ b/pilot/model/cluster/worker/remote_worker.py
@@ -13,7 +13,7 @@ class RemoteModelWorker(ModelWorker):
def __init__(self) -> None:
self.headers = {}
# TODO Configured by ModelParameters
- self.timeout = 180
+ self.timeout = 360
self.host = None
self.port = None
diff --git a/pilot/scene/chat_knowledge/refine_summary/prompt.py b/pilot/scene/chat_knowledge/refine_summary/prompt.py
index 0161cee35..69d4e46df 100644
--- a/pilot/scene/chat_knowledge/refine_summary/prompt.py
+++ b/pilot/scene/chat_knowledge/refine_summary/prompt.py
@@ -8,19 +8,21 @@ from pilot.scene.chat_knowledge.refine_summary.out_parser import ExtractRefineSu
CFG = Config()
-PROMPT_SCENE_DEFINE = """Your job is to produce a final summary."""
+PROMPT_SCENE_DEFINE = """"""
-_DEFAULT_TEMPLATE = """
-We have provided an existing summary up to a certain point: {existing_answer}\nWe have the opportunity to refine the existing summary (only if needed) with some more context below.\n------------\n{context}\n------------\nGiven the new context, refine the original summary.\nIf the context isn't useful, return the original summary.
+_DEFAULT_TEMPLATE_ZH = """根据提供的上下文信息,我们已经提供了一个到某一点的现有总结:{existing_answer}\n 我们有机会在下面提供的更多上下文信息的基础上进一步完善现有的总结(仅在需要的情况下)。请根据新的上下文信息,完善原来的总结。\n------------\n{context}\n------------\n如果上下文信息没有用处,请返回原来的总结。"""
+_DEFAULT_TEMPLATE_EN = """
+We have provided an existing summary up to a certain point: {existing_answer}\nWe have the opportunity to refine the existing summary (only if needed) with some more context below.\n------------\n{context}\n------------\nGiven the new context, refine the original summary. \nIf the context isn't useful, return the original summary.
please use original language.
"""
+
+_DEFAULT_TEMPLATE = (
+ _DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
+)
+
PROMPT_RESPONSE = """"""
-
-RESPONSE_FORMAT = """"""
-
-
PROMPT_SEP = SeparatorStyle.SINGLE.value
PROMPT_NEED_NEED_STREAM_OUT = False
diff --git a/pilot/scene/chat_knowledge/summary/prompt.py b/pilot/scene/chat_knowledge/summary/prompt.py
index cbf452c99..ec7c05c32 100644
--- a/pilot/scene/chat_knowledge/summary/prompt.py
+++ b/pilot/scene/chat_knowledge/summary/prompt.py
@@ -9,19 +9,22 @@ CFG = Config()
# PROMPT_SCENE_DEFINE = """You are an expert Q&A system that is trusted around the world.\nAlways answer the query using the provided context information, and not prior knowledge.\nSome rules to follow:\n1. Never directly reference the given context in your answer.\n2. Avoid statements like 'Based on the context, ...' or 'The context information ...' or anything along those lines."""
-PROMPT_SCENE_DEFINE = """Your job is to produce a final summary."""
+PROMPT_SCENE_DEFINE = """"""
-# _DEFAULT_TEMPLATE = """
-# Context information from multiple sources is below.\n---------------------\n
-# {context}
-# Given the information from multiple sources and not prior knowledge, answer the query.\nQuery: Describe what the provided text is about. Also describe some of the questions that this text can answer. \nAnswer: "
-# """
+_DEFAULT_TEMPLATE_ZH = """请根据提供的上下文信息的进行简洁地总结:
+{context}
+"""
-_DEFAULT_TEMPLATE = """
+_DEFAULT_TEMPLATE_EN = """
Write a concise summary of the following context:
{context}
please use original language.
"""
+
+_DEFAULT_TEMPLATE = (
+ _DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
+)
+
PROMPT_RESPONSE = """"""
diff --git a/pilot/server/knowledge/service.py b/pilot/server/knowledge/service.py
index 4db3d6c51..d7fb476d7 100644
--- a/pilot/server/knowledge/service.py
+++ b/pilot/server/knowledge/service.py
@@ -429,19 +429,22 @@ class KnowledgeService:
from llama_index import PromptHelper
from llama_index.prompts.default_prompt_selectors import DEFAULT_TREE_SUMMARIZE_PROMPT_SEL
texts = [doc.page_content for doc in chunk_docs]
- prompt_helper = PromptHelper()
+ prompt_helper = PromptHelper(context_window=2500)
+
texts = prompt_helper.repack(prompt=DEFAULT_TREE_SUMMARIZE_PROMPT_SEL, text_chunks=texts)
logger.info(
f"async_document_summary, doc:{doc.doc_name}, chunk_size:{len(texts)}, begin generate summary"
)
- summary = self._llm_extract_summary(texts[0])
- # summaries = self._mapreduce_extract_summary(texts)
- outputs, summary = self._refine_extract_summary(texts[1:], summary)
- print(
- f"refine summary outputs:{outputs}"
- )
- summaries = prompt_helper.repack(prompt=DEFAULT_TREE_SUMMARIZE_PROMPT_SEL, text_chunks=outputs)
- summary = self._llm_extract_summary("|".join(summaries))
+ # summary = self._llm_extract_summary(texts[0])
+ summary = self._mapreduce_extract_summary(texts)
+ # summaries = prompt_helper.repack(prompt=DEFAULT_TREE_SUMMARIZE_PROMPT_SEL, text_chunks=summaries)
+ # if (len(summaries)) > 1:
+ # outputs, summary = self._refine_extract_summary(summaries[1:], summaries[0])
+ # else:
+ # summary = self._llm_extract_summary("\n".join(summaries))
+ # print(
+ # f"refine summary outputs:{summaries}"
+ # )
print(
f"final summary:{summary}"
)
@@ -565,33 +568,36 @@ class KnowledgeService:
return outputs, summary
def _mapreduce_extract_summary(self, docs):
- """Extract mapreduce summary by llm"""
+ """Extract mapreduce summary by llm
+ map -> multi thread generate summary
+ reduce -> merge the summaries by map process
+ Args:
+ docs:List[str]
+ """
from pilot.scene.base import ChatScene
from pilot.common.chat_util import llm_chat_response_nostream
import uuid
- outputs = []
tasks = []
- for doc in docs:
- chat_param = {
- "chat_session_id": uuid.uuid1(),
- "current_user_input": doc,
- "select_param": "summary",
- "model_name": CFG.LLM_MODEL,
- }
- tasks.append(llm_chat_response_nostream(
+ if len(docs) == 1:
+ summary = self._llm_extract_summary(doc=docs[0])
+ return summary
+ else:
+ for doc in docs:
+ chat_param = {
+ "chat_session_id": uuid.uuid1(),
+ "current_user_input": doc,
+ "select_param": "summary",
+ "model_name": CFG.LLM_MODEL,
+ }
+ tasks.append(llm_chat_response_nostream(
ChatScene.ExtractSummary.value(), **{"chat_param": chat_param}
- ))
- from pilot.common.chat_util import run_async_tasks
- summary_iters = run_async_tasks(tasks)
- summary = self._llm_extract_summary(" ".join(summary_iters))
- # from pilot.utils import utils
- # loop = utils.get_or_create_event_loop()
- # summary = loop.run_until_complete(
- # llm_chat_response_nostream(
- # ChatScene.ExtractRefineSummary.value(), **{"chat_param": chat_param}
- # )
- # )
- # outputs.append(summary)
- return summary
+ ))
+ from pilot.common.chat_util import run_async_tasks
+ summary_iters = run_async_tasks(tasks)
+ from pilot.common.prompt_util import PromptHelper
+ from llama_index.prompts.default_prompt_selectors import DEFAULT_TREE_SUMMARIZE_PROMPT_SEL
+ prompt_helper = PromptHelper(context_window=2500)
+ summary_iters = prompt_helper.repack(prompt=DEFAULT_TREE_SUMMARIZE_PROMPT_SEL, text_chunks=summary_iters)
+ return self._mapreduce_extract_summary(summary_iters)
From 67f41559a8ce7c614d1088f1a361e74e405e29fb Mon Sep 17 00:00:00 2001
From: aries_ckt <916701291@qq.com>
Date: Tue, 31 Oct 2023 19:33:58 +0800
Subject: [PATCH 31/57] feat:mapreduce summary
---
pilot/server/knowledge/service.py | 2 ++
1 file changed, 2 insertions(+)
diff --git a/pilot/server/knowledge/service.py b/pilot/server/knowledge/service.py
index d7fb476d7..0b6e260dc 100644
--- a/pilot/server/knowledge/service.py
+++ b/pilot/server/knowledge/service.py
@@ -578,10 +578,12 @@ class KnowledgeService:
from pilot.common.chat_util import llm_chat_response_nostream
import uuid
tasks = []
+ max_iteration = 5
if len(docs) == 1:
summary = self._llm_extract_summary(doc=docs[0])
return summary
else:
+ max_iteration = max_iteration if len(docs) > max_iteration else len(docs)
for doc in docs:
chat_param = {
"chat_session_id": uuid.uuid1(),
From be1e1cb1603b3e3eb7478f944d0a140b5da19379 Mon Sep 17 00:00:00 2001
From: aries_ckt <916701291@qq.com>
Date: Tue, 31 Oct 2023 19:38:20 +0800
Subject: [PATCH 32/57] feat:document summary set max iteration
---
pilot/server/knowledge/service.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/pilot/server/knowledge/service.py b/pilot/server/knowledge/service.py
index 0b6e260dc..7e899ba78 100644
--- a/pilot/server/knowledge/service.py
+++ b/pilot/server/knowledge/service.py
@@ -584,7 +584,7 @@ class KnowledgeService:
return summary
else:
max_iteration = max_iteration if len(docs) > max_iteration else len(docs)
- for doc in docs:
+ for doc in docs[0:max_iteration]:
chat_param = {
"chat_session_id": uuid.uuid1(),
"current_user_input": doc,
From d319470f69e5f84cbaa41f8f31c4d2d74a84ea21 Mon Sep 17 00:00:00 2001
From: 0xrahul6 <113128186+0xrahul6@users.noreply.github.com>
Date: Tue, 31 Oct 2023 19:49:40 +0530
Subject: [PATCH 33/57] Update README.md
---
README.md | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/README.md b/README.md
index 7d7af852a..eaee7ec9d 100644
--- a/README.md
+++ b/README.md
@@ -177,7 +177,7 @@ Currently, we have released multiple key features, which are listed below to dem
| [StarRocks](https://github.com/StarRocks/starrocks) | No | TODO |
## Introduction
-Is the architecture of the entire DB-GPT shown in the following figure:
+The architecture of the entire DB-GPT is shown.
@@ -330,7 +330,7 @@ As of October 10, 2023, by fine-tuning an open-source model of 13 billion parame
The MIT License (MIT)
## Contact Information
-We are working on building a community, if you have any ideas about building the community, feel free to contact us.
+We are working on building a community, if you have any ideas for building the community, feel free to contact us.
[](https://discord.gg/nASQyBjvY)
From eb3ddccd0a04a1a74c7556068265e5840a1e6fa0 Mon Sep 17 00:00:00 2001
From: nobunagaaa <146952817+nobunagaaa@users.noreply.github.com>
Date: Tue, 31 Oct 2023 19:54:48 +0530
Subject: [PATCH 34/57] fixed typos
---
README.md | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/README.md b/README.md
index 7d7af852a..e74db5c76 100644
--- a/README.md
+++ b/README.md
@@ -43,8 +43,8 @@ DB-GPT is an experimental open-source project that uses localized GPT large mode
## Contents
-- [install](#install)
-- [demo](#demo)
+- [Install](#install)
+- [Demo](#demo)
- [introduction](#introduction)
- [features](#features)
- [contribution](#contribution)
@@ -213,7 +213,7 @@ The core capabilities mainly consist of the following parts:
## Contribution
-- Please run `black .` before submitting the code. Contributing guidelines, [how to contribution](https://github.com/csunny/DB-GPT/blob/main/CONTRIBUTING.md)
+- Please run `black .` before submitting the code. Contributing guidelines, [how to contribute](https://github.com/csunny/DB-GPT/blob/main/CONTRIBUTING.md)
## RoadMap
From 606d384a55e548621fc9d0f8c235267eaf601279 Mon Sep 17 00:00:00 2001
From: aries_ckt <916701291@qq.com>
Date: Wed, 1 Nov 2023 21:55:24 +0800
Subject: [PATCH 35/57] feat:add knowledge reference
---
pilot/graph_engine/graph_engine.py | 1 +
pilot/graph_engine/graph_search.py | 4 +-
pilot/scene/base_chat.py | 6 +-
.../chat_knowledge/refine_summary/chat.py | 4 +-
.../chat_knowledge/refine_summary/prompt.py | 12 +--
pilot/scene/chat_knowledge/summary/prompt.py | 7 +-
pilot/scene/chat_knowledge/v1/chat.py | 70 ++++++++++++----
pilot/server/knowledge/request/request.py | 2 +
pilot/server/knowledge/request/response.py | 1 +
pilot/server/knowledge/service.py | 81 +++++++++++--------
10 files changed, 122 insertions(+), 66 deletions(-)
diff --git a/pilot/graph_engine/graph_engine.py b/pilot/graph_engine/graph_engine.py
index 491a8625c..bea5f3123 100644
--- a/pilot/graph_engine/graph_engine.py
+++ b/pilot/graph_engine/graph_engine.py
@@ -187,6 +187,7 @@ class RAGGraphEngine:
triple_results = []
for doc in docs:
import threading
+
thread_id = threading.get_ident()
print(f"current thread-{thread_id} begin extract triplets task")
triplets = self._extract_triplets(doc.page_content)
diff --git a/pilot/graph_engine/graph_search.py b/pilot/graph_engine/graph_search.py
index f3025be85..9419a4979 100644
--- a/pilot/graph_engine/graph_search.py
+++ b/pilot/graph_engine/graph_search.py
@@ -143,9 +143,7 @@ class RAGGraphSearch(BaseSearch):
logger.info("> No relationships found, returning nodes found by keywords.")
if len(sorted_nodes_with_scores) == 0:
logger.info("> No nodes found by keywords, returning empty response.")
- return [
- Document(page_content="No relationships found.")
- ]
+ return [Document(page_content="No relationships found.")]
# add relationships as Node
# TODO: make initial text customizable
diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py
index 5c33f4770..34f294c31 100644
--- a/pilot/scene/base_chat.py
+++ b/pilot/scene/base_chat.py
@@ -141,7 +141,6 @@ class BaseChat(ABC):
self.current_message.start_date = datetime.datetime.now().strftime(
"%Y-%m-%d %H:%M:%S"
)
-
self.current_message.tokens = 0
if self.prompt_template.template:
current_prompt = self.prompt_template.format(**input_values)
@@ -152,7 +151,6 @@ class BaseChat(ABC):
# Not new server mode, we convert the message format(List[ModelMessage]) to list of dict
# fix the error of "Object of type ModelMessage is not JSON serializable" when passing the payload to request.post
llm_messages = list(map(lambda m: m.dict(), llm_messages))
-
payload = {
"model": self.llm_model,
"prompt": self.generate_llm_text(),
@@ -167,6 +165,9 @@ class BaseChat(ABC):
def stream_plugin_call(self, text):
return text
+ def knowledge_reference_call(self, text):
+ return text
+
async def check_iterator_end(iterator):
try:
await asyncio.anext(iterator)
@@ -196,6 +197,7 @@ class BaseChat(ABC):
view_msg = view_msg.replace("\n", "\\n")
yield view_msg
self.current_message.add_ai_message(msg)
+ view_msg = self.knowledge_reference_call(msg)
self.current_message.add_view_message(view_msg)
except Exception as e:
print(traceback.format_exc())
diff --git a/pilot/scene/chat_knowledge/refine_summary/chat.py b/pilot/scene/chat_knowledge/refine_summary/chat.py
index b3a934dd5..d0b1e9471 100644
--- a/pilot/scene/chat_knowledge/refine_summary/chat.py
+++ b/pilot/scene/chat_knowledge/refine_summary/chat.py
@@ -21,13 +21,13 @@ class ExtractRefineSummary(BaseChat):
chat_param=chat_param,
)
- self.user_input = chat_param["current_user_input"]
+ # self.user_input = chat_param["current_user_input"]
self.existing_answer = chat_param["select_param"]
# self.extract_mode = chat_param["select_param"]
def generate_input_values(self):
input_values = {
- "context": self.user_input,
+ # "context": self.user_input,
"existing_answer": self.existing_answer,
}
return input_values
diff --git a/pilot/scene/chat_knowledge/refine_summary/prompt.py b/pilot/scene/chat_knowledge/refine_summary/prompt.py
index 69d4e46df..cd5087f35 100644
--- a/pilot/scene/chat_knowledge/refine_summary/prompt.py
+++ b/pilot/scene/chat_knowledge/refine_summary/prompt.py
@@ -3,18 +3,20 @@ from pilot.configs.config import Config
from pilot.scene.base import ChatScene
from pilot.common.schema import SeparatorStyle
-from pilot.scene.chat_knowledge.refine_summary.out_parser import ExtractRefineSummaryParser
+from pilot.scene.chat_knowledge.refine_summary.out_parser import (
+ ExtractRefineSummaryParser,
+)
CFG = Config()
PROMPT_SCENE_DEFINE = """"""
-_DEFAULT_TEMPLATE_ZH = """根据提供的上下文信息,我们已经提供了一个到某一点的现有总结:{existing_answer}\n 我们有机会在下面提供的更多上下文信息的基础上进一步完善现有的总结(仅在需要的情况下)。请根据新的上下文信息,完善原来的总结。\n------------\n{context}\n------------\n如果上下文信息没有用处,请返回原来的总结。"""
+_DEFAULT_TEMPLATE_ZH = """根据提供的上下文信息,我们已经提供了一个到某一点的现有总结:{existing_answer}\n 请再完善一下原来的总结。\n回答的时候最好按照1.2.3.点进行总结"""
_DEFAULT_TEMPLATE_EN = """
-We have provided an existing summary up to a certain point: {existing_answer}\nWe have the opportunity to refine the existing summary (only if needed) with some more context below.\n------------\n{context}\n------------\nGiven the new context, refine the original summary. \nIf the context isn't useful, return the original summary.
-please use original language.
+We have provided an existing summary up to a certain point: {existing_answer}\nWe have the opportunity to refine the existing summary (only if needed) with some more context below.please refine the original summary.
+\nWhen answering, it is best to summarize according to points 1.2.3.
"""
_DEFAULT_TEMPLATE = (
@@ -29,7 +31,7 @@ PROMPT_NEED_NEED_STREAM_OUT = False
prompt = PromptTemplate(
template_scene=ChatScene.ExtractRefineSummary.value(),
- input_variables=["existing_answer","context"],
+ input_variables=["existing_answer"],
response_format="",
template_define=PROMPT_SCENE_DEFINE,
template=_DEFAULT_TEMPLATE + PROMPT_RESPONSE,
diff --git a/pilot/scene/chat_knowledge/summary/prompt.py b/pilot/scene/chat_knowledge/summary/prompt.py
index ec7c05c32..10a239586 100644
--- a/pilot/scene/chat_knowledge/summary/prompt.py
+++ b/pilot/scene/chat_knowledge/summary/prompt.py
@@ -11,14 +11,15 @@ CFG = Config()
PROMPT_SCENE_DEFINE = """"""
-_DEFAULT_TEMPLATE_ZH = """请根据提供的上下文信息的进行简洁地总结:
+_DEFAULT_TEMPLATE_ZH = """请根据提供的上下文信息的进行总结:
{context}
+回答的时候最好按照1.2.3.点进行总结
"""
_DEFAULT_TEMPLATE_EN = """
-Write a concise summary of the following context:
+Write a summary of the following context:
{context}
-please use original language.
+When answering, it is best to summarize according to points 1.2.3.
"""
_DEFAULT_TEMPLATE = (
diff --git a/pilot/scene/chat_knowledge/v1/chat.py b/pilot/scene/chat_knowledge/v1/chat.py
index 576a93872..e6c3d9056 100644
--- a/pilot/scene/chat_knowledge/v1/chat.py
+++ b/pilot/scene/chat_knowledge/v1/chat.py
@@ -1,5 +1,5 @@
import os
-from typing import Dict
+from typing import Dict, List
from pilot.component import ComponentType
from pilot.scene.base_chat import BaseChat
@@ -20,7 +20,6 @@ CFG = Config()
class ChatKnowledge(BaseChat):
chat_scene: str = ChatScene.ChatKnowledge.value()
-
"""KBQA Chat Module"""
def __init__(self, chat_param: Dict):
@@ -46,7 +45,6 @@ class ChatKnowledge(BaseChat):
if self.space_context is None
else int(self.space_context["embedding"]["topk"])
)
- # self.recall_score = CFG.KNOWLEDGE_SEARCH_TOP_SIZE if self.space_context is None else self.space_context["embedding"]["recall_score"]
self.max_token = (
CFG.KNOWLEDGE_SEARCH_MAX_TOKEN
if self.space_context is None
@@ -56,11 +54,11 @@ class ChatKnowledge(BaseChat):
"vector_store_name": self.knowledge_space,
"vector_store_type": CFG.VECTOR_STORE_TYPE,
}
- from pilot.graph_engine.graph_factory import RAGGraphFactory
-
- self.rag_engine = CFG.SYSTEM_APP.get_component(
- ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory
- ).create()
+ # from pilot.graph_engine.graph_factory import RAGGraphFactory
+ #
+ # self.rag_engine = CFG.SYSTEM_APP.get_component(
+ # ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory
+ # ).create()
embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory
)
@@ -90,25 +88,29 @@ class ChatKnowledge(BaseChat):
last_output.text = (
last_output.text + "\n\nrelations:\n\n" + ",".join(relations)
)
- yield last_output
+ reference = f"\n\n{self.parse_source_view(self.sources)}"
+ last_output = last_output + reference
+ yield last_output
+
+ def knowledge_reference_call(self, text):
+ """return reference"""
+ return text + f"\n\n{self.parse_source_view(self.sources)}"
async def generate_input_values(self) -> Dict:
if self.space_context:
self.prompt_template.template_define = self.space_context["prompt"]["scene"]
self.prompt_template.template = self.space_context["prompt"]["template"]
- # docs = self.knowledge_embedding_client.similar_search(
- # self.current_user_input, self.top_k
- # )
docs = await blocking_func_to_async(
self._executor,
self.knowledge_embedding_client.similar_search,
self.current_user_input,
self.top_k,
)
- docs = await self.rag_engine.search(query=self.current_user_input)
- # docs = self.knowledge_embedding_client.similar_search(
- # self.current_user_input, self.top_k
- # )
+ self.sources = self.merge_by_key(
+ list(map(lambda doc: doc.metadata, docs)), "source"
+ )
+
+ self.current_message.knowledge_source = self.sources
if not docs:
raise ValueError(
"you have no knowledge space, please add your knowledge space"
@@ -125,6 +127,42 @@ class ChatKnowledge(BaseChat):
}
return input_values
+ def parse_source_view(self, sources: List):
+ html_title = f"##### **References:**"
+ lines = ""
+ for item in sources:
+ source = item["source"] if "source" in item else ""
+ pages = ",".join(item["pages"]) if "pages" in item else ""
+ lines += f"{source}"
+ if len(pages) > 0:
+ lines += f", **pages**:{pages}\n\n"
+ else:
+ lines += "\n\n"
+ html = f"""{html_title}\n{lines}"""
+ return html
+
+ def merge_by_key(self, data, key):
+ result = {}
+ for item in data:
+ item_key = os.path.basename(item.get(key))
+ if item_key in result:
+ if "pages" in result[item_key] and "page" in item:
+ result[item_key]["pages"].append(str(item["page"]))
+ elif "page" in item:
+ result[item_key]["pages"] = [
+ result[item_key]["pages"],
+ str(item["page"]),
+ ]
+ else:
+ if "page" in item:
+ result[item_key] = {
+ "source": item_key,
+ "pages": [str(item["page"])],
+ }
+ else:
+ result[item_key] = {"source": item_key}
+ return list(result.values())
+
@property
def chat_type(self) -> str:
return ChatScene.ChatKnowledge.value()
diff --git a/pilot/server/knowledge/request/request.py b/pilot/server/knowledge/request/request.py
index c6b94ff0d..032b97ba1 100644
--- a/pilot/server/knowledge/request/request.py
+++ b/pilot/server/knowledge/request/request.py
@@ -59,6 +59,8 @@ class DocumentSyncRequest(BaseModel):
"""doc_ids: doc ids"""
doc_ids: List
+ model_name: Optional[str] = None
+
"""Preseparator, this separator is used for pre-splitting before the document is actually split by the text splitter.
Preseparator are not included in the vectorized text.
"""
diff --git a/pilot/server/knowledge/request/response.py b/pilot/server/knowledge/request/response.py
index 2e3e5f0ab..5c1c7efd1 100644
--- a/pilot/server/knowledge/request/response.py
+++ b/pilot/server/knowledge/request/response.py
@@ -5,6 +5,7 @@ from pydantic import BaseModel
class ChunkQueryResponse(BaseModel):
"""data: data"""
+
data: List = None
"""summary: document summary"""
summary: str = None
diff --git a/pilot/server/knowledge/service.py b/pilot/server/knowledge/service.py
index 7e899ba78..1dece9054 100644
--- a/pilot/server/knowledge/service.py
+++ b/pilot/server/knowledge/service.py
@@ -199,6 +199,7 @@ class KnowledgeService:
# import langchain is very very slow!!!
doc_ids = sync_request.doc_ids
+ self.model_name = sync_request.model_name or CFG.LLM_MODEL
for doc_id in doc_ids:
query = KnowledgeDocumentEntity(
id=doc_id,
@@ -427,11 +428,16 @@ class KnowledgeService:
- doc: KnowledgeDocumentEntity
"""
from llama_index import PromptHelper
- from llama_index.prompts.default_prompt_selectors import DEFAULT_TREE_SUMMARIZE_PROMPT_SEL
- texts = [doc.page_content for doc in chunk_docs]
- prompt_helper = PromptHelper(context_window=2500)
+ from llama_index.prompts.default_prompt_selectors import (
+ DEFAULT_TREE_SUMMARIZE_PROMPT_SEL,
+ )
- texts = prompt_helper.repack(prompt=DEFAULT_TREE_SUMMARIZE_PROMPT_SEL, text_chunks=texts)
+ texts = [doc.page_content for doc in chunk_docs]
+ prompt_helper = PromptHelper(context_window=2000)
+
+ texts = prompt_helper.repack(
+ prompt=DEFAULT_TREE_SUMMARIZE_PROMPT_SEL, text_chunks=texts
+ )
logger.info(
f"async_document_summary, doc:{doc.doc_name}, chunk_size:{len(texts)}, begin generate summary"
)
@@ -445,13 +451,10 @@ class KnowledgeService:
# print(
# f"refine summary outputs:{summaries}"
# )
- print(
- f"final summary:{summary}"
- )
+ print(f"final summary:{summary}")
doc.summary = summary
return knowledge_document_dao.update_knowledge_document(doc)
-
def async_doc_embedding(self, client, chunk_docs, doc):
"""async document embedding into vector db
Args:
@@ -460,11 +463,11 @@ class KnowledgeService:
- doc: KnowledgeDocumentEntity
"""
logger.info(
- f"async_doc_embedding, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to vector store-{CFG.VECTOR_STORE_TYPE}"
+ f"async doc sync, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to vector store-{CFG.VECTOR_STORE_TYPE}"
)
try:
- vector_ids = client.knowledge_embedding_batch(chunk_docs)
self.async_document_summary(chunk_docs, doc)
+ vector_ids = client.knowledge_embedding_batch(chunk_docs)
doc.status = SyncStatus.FINISHED.name
doc.result = "document embedding success"
if vector_ids is not None:
@@ -526,25 +529,27 @@ class KnowledgeService:
chat_param = {
"chat_session_id": uuid.uuid1(),
"current_user_input": doc,
- "select_param": "summary",
- "model_name": CFG.LLM_MODEL,
+ "select_param": doc,
+ "model_name": self.model_name,
}
- from pilot.utils import utils
- loop = utils.get_or_create_event_loop()
- summary = loop.run_until_complete(
- llm_chat_response_nostream(
- ChatScene.ExtractSummary.value(), **{"chat_param": chat_param}
- )
+ from pilot.common.chat_util import run_async_tasks
+
+ summary_iters = run_async_tasks(
+ [
+ llm_chat_response_nostream(
+ ChatScene.ExtractRefineSummary.value(), **{"chat_param": chat_param}
+ )
+ ]
)
- return summary
- def _refine_extract_summary(self, docs, summary: str, max_iteration:int = 5):
+ return summary_iters[0]
+
+ def _refine_extract_summary(self, docs, summary: str, max_iteration: int = 5):
"""Extract refine summary by llm"""
from pilot.scene.base import ChatScene
from pilot.common.chat_util import llm_chat_response_nostream
import uuid
- print(
- f"initialize summary is :{summary}"
- )
+
+ print(f"initialize summary is :{summary}")
outputs = [summary]
max_iteration = max_iteration if len(docs) > max_iteration else len(docs)
for doc in docs[0:max_iteration]:
@@ -552,9 +557,10 @@ class KnowledgeService:
"chat_session_id": uuid.uuid1(),
"current_user_input": doc,
"select_param": summary,
- "model_name": CFG.LLM_MODEL,
+ "model_name": self.model_name,
}
from pilot.utils import utils
+
loop = utils.get_or_create_event_loop()
summary = loop.run_until_complete(
llm_chat_response_nostream(
@@ -562,9 +568,7 @@ class KnowledgeService:
)
)
outputs.append(summary)
- print(
- f"iterator is {len(outputs)} current summary is :{summary}"
- )
+ print(f"iterator is {len(outputs)} current summary is :{summary}")
return outputs, summary
def _mapreduce_extract_summary(self, docs):
@@ -577,6 +581,7 @@ class KnowledgeService:
from pilot.scene.base import ChatScene
from pilot.common.chat_util import llm_chat_response_nostream
import uuid
+
tasks = []
max_iteration = 5
if len(docs) == 1:
@@ -589,17 +594,23 @@ class KnowledgeService:
"chat_session_id": uuid.uuid1(),
"current_user_input": doc,
"select_param": "summary",
- "model_name": CFG.LLM_MODEL,
+ "model_name": self.model_name,
}
- tasks.append(llm_chat_response_nostream(
- ChatScene.ExtractSummary.value(), **{"chat_param": chat_param}
- ))
+ tasks.append(
+ llm_chat_response_nostream(
+ ChatScene.ExtractSummary.value(), **{"chat_param": chat_param}
+ )
+ )
from pilot.common.chat_util import run_async_tasks
+
summary_iters = run_async_tasks(tasks)
from pilot.common.prompt_util import PromptHelper
- from llama_index.prompts.default_prompt_selectors import DEFAULT_TREE_SUMMARIZE_PROMPT_SEL
+ from llama_index.prompts.default_prompt_selectors import (
+ DEFAULT_TREE_SUMMARIZE_PROMPT_SEL,
+ )
+
prompt_helper = PromptHelper(context_window=2500)
- summary_iters = prompt_helper.repack(prompt=DEFAULT_TREE_SUMMARIZE_PROMPT_SEL, text_chunks=summary_iters)
+ summary_iters = prompt_helper.repack(
+ prompt=DEFAULT_TREE_SUMMARIZE_PROMPT_SEL, text_chunks=summary_iters
+ )
return self._mapreduce_extract_summary(summary_iters)
-
-
From 41ad69cb3901376e6374edfbd0ff3ce85c6271da Mon Sep 17 00:00:00 2001
From: csunny
Date: Thu, 2 Nov 2023 14:25:50 +0800
Subject: [PATCH 36/57] chores: update wechat
---
assets/wechat.jpg | Bin 234518 -> 206792 bytes
1 file changed, 0 insertions(+), 0 deletions(-)
diff --git a/assets/wechat.jpg b/assets/wechat.jpg
index c7de6223b996de33e357f4ecf70edc60fd2a59af..ec465785c48264626c9fb04b51dde1cd09fd7769 100644
GIT binary patch
literal 206792
zcmeFYcU)7;*EhODLNB2=k$_YwqVy&KX#y%m5dk$Ky^C}~qM#sMP(e`%U_nHPQk51u
zA}SzA5R_^G1WAw*Amt98k4X`2>drhx!Dck;?9l^OSvYj45QFm8mJl&cW8g>VWw#Mv%>JXU+sL
z@j_5wP}o@q%Y9PEoL!_?zk$*4fqN&U?&cnP#>CF)uFa>qm!IW^hxt;ZfAf7%@E*0VK4f^0A0eW|khi?!BA*BFq6c`xn
z0q9eJRs?m<0D7x`-P8ZpkN3aPZf>Xl>eJ25`!D(rUO-FmVl&^+GXZYVzkd9m{0IyP
z2lD!L$%9WWpR@ZNz_%cf+pvGI%@)lC==q@2hin1O4`_-n;O-BaiNhy+zayYQeI^O_
zP*WgFNEpz&J>1M}04)J%i=co*Tl2kDe%i~|!V1v5fR6GFvpNE3Q9xhu3^hOeXPttx
z!47}Ei>b!ztf|SL6m7RUi!5$1DbKcInp;O*{izs3)Ef%ka^nfxgec!X!@v8}Oz`b;$*o@VBN
z2E4;(JcAs!_}%JyIZEkQ;Os
z@`gMi)jvyqw{ZIPBnW&ShR#6?P!OmS`d7Pszn*$R5ulv*SLq*ZRUxlmPa}S{@Q1#F
zHv~d9PyqNH4rou%=J(r6(tp)70sJ2Lvu*hA
zH%cLYw`VbDF=N@!V!^T#A&xLY7$6RSPh)VaBXkjZzxnuAede#s{mfI$@0ll&=pIV50h`pe9&@)09Ar7e`G{JjJ5NPmbEl_IuCx=^7
z`8^~5&eLBUK>wPJKkKpYW@lm-VK-)1{Ci6tB_6Hc{Qi>lzf1Cuv7Y`*-oHoi&-ef9
zi7(^<#%c4{Z2UO__!W31ybIm}?}j(PYauB(0sb8R8s4}?|M`B4U$g7|*V`O^jmIC%
zi|=1`e)Ic##<$i-=YYA@vkyPT?pdigrJu%
z+{4dC{Lvp`>x&s!2_(=oF-Q`UffNBh>X0^M02xE(kTqlr9f6#olVH92fOUQriiBdJ
zi%=qz3Ow))C=bFx_n=2m8B_t)LJd$0)Cu)Keb9So6s(>v(07OgtwNhH7>pIh4HJNg
z!B8+cm@-TqrVBHInZpji9AM6{Q!sB>5G))P3rm2d!m?p`utHb~>>2C@tOeE$>xX@S
zO~PhjORxAk%gxc_uX`T_zJI8>S;nCz!mMLYU4mU1G{$y2(_;^n|H~sfDSB
z={?f~(|4v7I0WZ}3&W-1s&HMnDclzB0{4QSg~!2D;5qOjcsaZd%<}+z0zMC4N3b9S
z5K;(Lu!8p^91-q_5JVgz4UvaY{h(x*_%0%IgvSs
z`5|*Pa|h6+FU+efEG)t-3M@K67oA!BSYlYxSnjgmS(;h;S*BT5SXo)cSe04#vf8q`
zv4*iGvfg5S!b)I$$NGhp!p6ylVpC_cU~^#$VvA?XVSCKh#P*Ku8`~y3FS{JOKD#Zu
z2YWPoCi{K%I`&@nFYKEfd>jfKdpVA9_;Fn1xXDq@(ZTVN<0mI4=MGLiPJ2#Y&I_D3
zIiGUA2M?C#Hct)DJDwj%cBCBA1bGs94w-|jMD`-*dD(g8c};oUc;k3)^S7C{9;
zD?vZORKYUAw}K=gej!bvqe4+aw}l#oCWPU_^1@ca0m7NW6~aTp8zS39Ohi0Il10iz
z-ifSY^%PAJ#f!cd-4K%!GZ*t0%Mz;<`y`GKR}!}uj}*@r?+~BgCb-RDoBOtu
zZ57)_B$yw{4%76qYoW^pm_U*(5oO5MQ%m3Vu)gq;^0n}
zow_^yci!FEr^KYBrR1Y@Td7x>Nm*OjSNX1TzY4R8o=T8PkqS|jOVva*LbX(Na+mO~
z1H0mP)$aPCCadP6mZjFEM>qJ7D+y-J^T>_gL?V-&40|NnJ_ZUH!KDpa!>wg$7pR
zxdvHNS<_Q9UvpSXK4`}DAS
z4SMVP+WKMo75a+?ss{cBj}2xF6%0KM?;B3<-LcngZ{gkvBPpX(Mg>L_Xlb+?x(GdG
zEMx3p{LuKD$xahLlP4xW_U+zxc3<_rHB)`l^QO&aa5GD@6tlPH{N|44dFEplJ1o2{
zN-ap1T9z@E&HI`6+w9NUPqf-@1
z)KSqf+_Cj2^62rSB}dns%$%~FCXT5cJAbUlS=`yz`MC?L%TbqyE-S9)t~Xr29M?LY
zczo!D!ilI8-6zFQ2ApifATe&3%2UjzoK8JHMRl`vD{@;oZFTz2X_C9Cd#?L;4`Yw(
z9wV38)(7pA<1_DT>YL|F_S^4=^IP+`^)L2k1ULoY1K9#k
z1-=O45AqLc50(s$3hqCnd?xYCM2LRKjS$jVo3jr?VWB5NpN9zoApbgCG5m7)WQ0+~
z?FdTb(a6duzNnz6*U?JR$3h;|rZY0UGTvrtXXamJxf*nJC<~qSD4Q?)T=w*}gV$=V
zOI=U7zI+37qdP}42bar{8=gCH^T5sLw`6Z!y+zIQ$$Ni$-);OIi95-6*6w=TeV1>X
zUyj?3OT}#z_!bNoS{Bw6$rt6^W4RY`@9X`e_q!hGJ}7x8{xJ0+tvIN7;*rCn_7a_v
zlE)H{GfSCD!%JtM9DmYRW>)sRT(!IqFN{xp3Ox;fI`{0&ke
z-7?*ZX&r4l);8E~-~P74rsHL&MQ2-=Nf+Ve-j@yC`rR*H>AtG%(e9~vt@XP4jnGuJX+UX@7i$M__i5IWu@Mv?WEPwZRj5vfegkOx3g})u)tQV&Ey3l^QB@4
z;&z50fj)pAT>pswe${~e5wijs_UFA-{5SfK7;)A49*8*v_mMs2QvUE;QIX*7!w@9%)-jX&cO*PRP#bi
zFgTnE0cU33N*ut>gXa)}kC|V3j|q!_y&LO}5JB~eS%qve`<}H3IrNicHBO(6XXg+W
z5fu}clUGpOsidi;t)r`_Z)#?4VY%PR`tT9QqfP*xxqEnedHeYKg@%PkL`FrQOSp77
zG3iQjO7^wuH*#`s-pVVwcmKh|;zuQqD=Mq1YiggrsB3L&@96A$+5PI>z~K8K;_!!&
z$*JkjU%q~unVlmq{ajvIU8AgTZ1Dwy;D5077iWLtix2R{gh0R%tXq7+m?FUq=R+_{
z?_uFLv1fG)5!j)Akxg)4R^hW2b{P!^lF;e1{T#xwnv-(mE!KW<_Mc-c{{I$de=+ti
zzD7Z4!Sw5bGcm!L;cz%J3p2P_*Z}fkVPWH7`*m^rd2#)^c(yL&Ul#*(0t0;@5C~TA
zpO=%JllQ-TF+PDrlObaS;)27#VuJHQXoyBFN>POV8{Ke5-+x1%`~MsA{O+GTK)hiK8g?qLhRD&il@>jNHrtC1^y9oaG5ILc5yEH^hGHc`RYYE1
zW9mKeegRQewy{URW%RYtG56W(`zLMScZ=Z?#GjbZ@X#28PbXdt9F*=c$bN10!>Rx|
zuRxb{-$Y{j?GhLe^GGfO+R@K|kb7uJgqcSs3@CoVh5=>yQx=OZpr0Wr(V6pm8Ibut
zBvr%@Lrt}$8g{KpGN7VDRyt?38r>W}-IM~++tHI>(NEWCLTqT#i4yFRDvI6EZdw;hVe+DCs7>ZG8EN6a#9Bc));OR$<#-W8D@vlz6cd^$rZ^vkTo|
z28`15|GQBN)dDHph--J!=n59;;0#=P>iEg+(W7o%F|TR}c{T*&K?&|K5+&zc6K6s3r;RV`9%^-tf=g
zcgfF{!R?rX+Lvc4e%0%oin#EcDKh~(QI4*zxjU7v{L1OLwF8yt{
z(es-=hLp;mirKPGAT7@gVUKD$eM2sEZ+b;tTaBaLp=yv(%`T_8L%ZU~40i6)ewD*A
zHNGJCfvr95F~^0dURl15I2i`i_5mG;*$@xFk`8+_pw1ci#!wMK7&~zrQ@hrIA5oZM
zDIwvfmR#o|7;ZssFFYpkS;DIcORFMG6ceggFBlm`U7N(W>IKH`
zeGqkMa(SI==;@eMxTfW|>$eaDKE8sG1u@>TdVAVavA1!9-gm?eBlTWMkBeDyOBxKxsW_)z}%|XF!(@
z;uz32{Kw!A-m7nhn$+UDH^#eABEIZ2{G@2PFj@Lk)~D`(+qI861tLTGA_|cG9Lnp)
zLc&R{fu6jbwy$n~-uLF>$2-IVDi2szm2m@sWSh1j
zH2IKPqUEC2;()}?HczD*AiNIjE+IZdbmLuCzT>z}Ts(Q<<9yj~xcbZsx@Wx5
zgweSBG2OP*z1LRX_fpff(UhH)3}`xhv#f|dQtMJYFF#&AIJVtva7V>0oUOC%^`OZY
z+5xY+vc$3?FUvkXsgdskUuH_7+x);zl%Z{Zp#PLAs{QAK_$z^@Dzis?kv&RP`|B^vI97HY@_}z?
z7ZJp;6ZaWVZC3~OLr>>9%Qn;{UzXe0y%l30=R|nBLSIN+mD&H<;oipyoHe!W4;d4d
zHHWuk+zVvf$GTEx9v2kXf~#>#!TR7XjLDLV+qWBO6JwEIyP~z4hO;+cCG_&}azA@^
zl1;+<>Yk1`gQ#gb$CpiKzmux(gJY#jhG+COuMs{3G_?etal0a+bj#Bs!q6l8-15`~
z^y62`>&_G`kaj4oG^|%_R;NjOP$ky1@c!JcdGpKFl?Mf4p9mkzaz5s*WYTG=CnmJ9
z4k|O#DjH5~YR7ITjWv>%IyH$6KKopce`LO)(sfg@lXw8iRJMvyMX4>;JGc
z-&d6B?4_I0u@1ER4L0SmI+R)o@x%x2LBR!ss#6)GYtbX8qwd)i=bONds$iR9=7M%G
zjBczG4J=zPLcGZAed<9n{qFrau7zzbogNL(lW#bZHML8WJ8qn^yjglR>P;*+809|N
z^9B==&Y`3UT&u`hPr?~%)zV*L=031-1otSyk5OovOmJ{A>U)N@Df2~V)lR45s|D}V}$sWjmXs#w8us=aXkdyz*yJGZ4YsW
z&dLx59)zyOJYw%rRx^$%