mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-14 14:34:28 +00:00
feat:knowledge rag graph
This commit is contained in:
parent
fa6a9040d5
commit
2f82f98e31
20
pilot/common/chat_util.py
Normal file
20
pilot/common/chat_util.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
import asyncio
|
||||||
|
|
||||||
|
from starlette.responses import StreamingResponse
|
||||||
|
|
||||||
|
from pilot.scene.base_chat import BaseChat
|
||||||
|
from pilot.scene.chat_factory import ChatFactory
|
||||||
|
|
||||||
|
chat_factory = ChatFactory()
|
||||||
|
|
||||||
|
|
||||||
|
async def llm_chat_response_nostream(chat_scene: str, **chat_param):
|
||||||
|
""" llm_chat_response_nostream """
|
||||||
|
chat: BaseChat = chat_factory.get_implementation(chat_scene, **chat_param)
|
||||||
|
res = await chat.get_llm_response()
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
async def llm_chat_response(chat_scene: str, **chat_param):
|
||||||
|
chat: BaseChat = chat_factory.get_implementation(chat_scene, **chat_param)
|
||||||
|
return chat.stream_call()
|
0
pilot/graph_engine/__init__.py
Normal file
0
pilot/graph_engine/__init__.py
Normal file
137
pilot/graph_engine/graph_engine.py
Normal file
137
pilot/graph_engine/graph_engine.py
Normal file
@ -0,0 +1,137 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Any, Optional, Callable, Tuple, List
|
||||||
|
|
||||||
|
from langchain.schema import Document
|
||||||
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||||
|
|
||||||
|
from pilot.embedding_engine import KnowledgeType
|
||||||
|
from pilot.embedding_engine.knowledge_type import get_knowledge_embedding
|
||||||
|
from pilot.graph_engine.index_struct import KG
|
||||||
|
from pilot.graph_engine.node import TextNode
|
||||||
|
from pilot.utils import utils
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class RAGGraphEngine:
|
||||||
|
"""Knowledge RAG Graph Engine.
|
||||||
|
Build a KG by extracting triplets, and leveraging the KG during query-time.
|
||||||
|
Args:
|
||||||
|
knowledge_type (Optional[str]): Default: KnowledgeType.DOCUMENT.value
|
||||||
|
extracting triplets.
|
||||||
|
graph_store (Optional[GraphStore]): The graph store to use.refrence:llama-index
|
||||||
|
include_embeddings (bool): Whether to include embeddings in the index.
|
||||||
|
Defaults to False.
|
||||||
|
max_object_length (int): The maximum length of the object in a triplet.
|
||||||
|
Defaults to 128.
|
||||||
|
extract_triplet_fn (Optional[Callable]): The function to use for
|
||||||
|
extracting triplets. Defaults to None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
index_struct_cls = KG
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
knowledge_type: Optional[str] = KnowledgeType.DOCUMENT.value,
|
||||||
|
knowledge_source: Optional[str] = None,
|
||||||
|
text_splitter=None,
|
||||||
|
graph_store=None,
|
||||||
|
index_struct: Optional[KG] = None,
|
||||||
|
model_name: Optional[str] = None,
|
||||||
|
max_triplets_per_chunk: int = 10,
|
||||||
|
include_embeddings: bool = False,
|
||||||
|
max_object_length: int = 128,
|
||||||
|
extract_triplet_fn: Optional[Callable] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize params."""
|
||||||
|
# from llama_index.graph_stores import SimpleGraphStore
|
||||||
|
# from llama_index.graph_stores.types import GraphStore
|
||||||
|
|
||||||
|
# need to set parameters before building index in base class.
|
||||||
|
self.knowledge_source = knowledge_source
|
||||||
|
self.knowledge_type = knowledge_type
|
||||||
|
self.model_name = model_name
|
||||||
|
self.text_splitter = text_splitter
|
||||||
|
self.index_struct = index_struct
|
||||||
|
self.include_embeddings = include_embeddings
|
||||||
|
# self.graph_store = graph_store or SimpleGraphStore()
|
||||||
|
self.graph_store = graph_store
|
||||||
|
self.max_triplets_per_chunk = max_triplets_per_chunk
|
||||||
|
self._max_object_length = max_object_length
|
||||||
|
self._extract_triplet_fn = extract_triplet_fn
|
||||||
|
|
||||||
|
def knowledge_graph(self, docs=None):
|
||||||
|
"""knowledge docs into graph store"""
|
||||||
|
if not docs:
|
||||||
|
if self.text_splitter:
|
||||||
|
self.text_splitter = RecursiveCharacterTextSplitter(
|
||||||
|
chunk_size=2000, chunk_overlap=100
|
||||||
|
)
|
||||||
|
knowledge_source = get_knowledge_embedding(
|
||||||
|
knowledge_type=self.knowledge_type,
|
||||||
|
knowledge_source=self.knowledge_source,
|
||||||
|
text_splitter=self.text_splitter,
|
||||||
|
)
|
||||||
|
docs = knowledge_source.read()
|
||||||
|
if self.index_struct is None:
|
||||||
|
self.index_struct = self._build_index_from_docs(docs)
|
||||||
|
|
||||||
|
def _extract_triplets(self, text: str) -> List[Tuple[str, str, str]]:
|
||||||
|
"""Extract triplets from text by function or llm"""
|
||||||
|
if self._extract_triplet_fn is not None:
|
||||||
|
return self._extract_triplet_fn(text)
|
||||||
|
else:
|
||||||
|
return self._llm_extract_triplets(text)
|
||||||
|
|
||||||
|
def _llm_extract_triplets(self, text: str) -> List[Tuple[str, str, str]]:
|
||||||
|
"""Extract triplets from text by llm"""
|
||||||
|
from pilot.scene.base import ChatScene
|
||||||
|
from pilot.common.chat_util import llm_chat_response_nostream
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
chat_param = {
|
||||||
|
"chat_session_id": uuid.uuid1(),
|
||||||
|
"current_user_input": text,
|
||||||
|
"select_param": "triplet",
|
||||||
|
"model_name": self.model_name,
|
||||||
|
}
|
||||||
|
loop = utils.get_or_create_event_loop()
|
||||||
|
triplets = loop.run_until_complete(
|
||||||
|
llm_chat_response_nostream(
|
||||||
|
ChatScene.ExtractTriplet.value(), **{"chat_param": chat_param}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return triplets
|
||||||
|
# response = self._service_context.llm_predictor.predict(
|
||||||
|
# self.kg_triple_extract_template,
|
||||||
|
# text=text,
|
||||||
|
# )
|
||||||
|
# print(response, flush=True)
|
||||||
|
# return self._parse_triplet_response(
|
||||||
|
# response, max_length=self._max_object_length
|
||||||
|
# )
|
||||||
|
|
||||||
|
def _build_index_from_docs(self, documents: List[Document]) -> KG:
|
||||||
|
"""Build the index from nodes."""
|
||||||
|
index_struct = self.index_struct_cls()
|
||||||
|
for doc in documents:
|
||||||
|
triplets = self._extract_triplets(doc.page_content)
|
||||||
|
if len(triplets) == 0:
|
||||||
|
continue
|
||||||
|
text_node = TextNode(text=doc.page_content, metadata=doc.metadata)
|
||||||
|
logger.info(f"extracted knowledge triplets: {triplets}")
|
||||||
|
for triplet in triplets:
|
||||||
|
subj, _, obj = triplet
|
||||||
|
self.graph_store.upsert_triplet(*triplet)
|
||||||
|
index_struct.add_node([subj, obj], text_node)
|
||||||
|
|
||||||
|
|
||||||
|
return index_struct
|
||||||
|
|
||||||
|
def search(self, query):
|
||||||
|
from pilot.graph_engine.graph_search import RAGGraphSearch
|
||||||
|
|
||||||
|
graph_search = RAGGraphSearch(graph_engine=self)
|
||||||
|
return graph_search.search(query)
|
||||||
|
|
34
pilot/graph_engine/graph_factory.py
Normal file
34
pilot/graph_engine/graph_factory.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Type
|
||||||
|
|
||||||
|
from pilot.component import BaseComponent, ComponentType
|
||||||
|
|
||||||
|
|
||||||
|
class RAGGraphFactory(BaseComponent, ABC):
|
||||||
|
name = ComponentType.RAG_GRAPH_DEFAULT.value
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create(self, model_name: str = None, embedding_cls: Type = None):
|
||||||
|
"""Create RAG Graph Engine"""
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultRAGGraphFactory(RAGGraphFactory):
|
||||||
|
def __init__(
|
||||||
|
self, system_app=None, default_model_name: str = None, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
super().__init__(system_app=system_app)
|
||||||
|
self._default_model_name = default_model_name
|
||||||
|
self.kwargs = kwargs
|
||||||
|
from pilot.graph_engine.graph_engine import RAGGraphEngine
|
||||||
|
|
||||||
|
self.rag_engine = RAGGraphEngine(model_name="proxyllm")
|
||||||
|
|
||||||
|
def init_app(self, system_app):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def create(self, model_name: str = None, rag_cls: Type = None):
|
||||||
|
if not model_name:
|
||||||
|
model_name = self._default_model_name
|
||||||
|
|
||||||
|
return self.rag_engine
|
193
pilot/graph_engine/graph_search.py
Normal file
193
pilot/graph_engine/graph_search.py
Normal file
@ -0,0 +1,193 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from collections import defaultdict
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from typing import List, Optional, Dict, Any, Set, Callable
|
||||||
|
|
||||||
|
from pilot.graph_engine.node import BaseNode, TextNode, NodeWithScore
|
||||||
|
from pilot.graph_engine.search import BaseSearch, SearchMode
|
||||||
|
from pilot.utils import utils
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
DEFAULT_NODE_SCORE = 1000.0
|
||||||
|
GLOBAL_EXPLORE_NODE_LIMIT = 3
|
||||||
|
REL_TEXT_LIMIT = 30
|
||||||
|
|
||||||
|
|
||||||
|
class RAGGraphSearch(BaseSearch):
|
||||||
|
"""RAG Graph Search.
|
||||||
|
|
||||||
|
args:
|
||||||
|
graph_engine RAGGraphEngine.
|
||||||
|
model_name (str): model name
|
||||||
|
(see :ref:`Prompt-Templates`).
|
||||||
|
text_qa_template (Optional[BasePromptTemplate]): A Question Answering Prompt
|
||||||
|
(see :ref:`Prompt-Templates`).
|
||||||
|
max_keywords_per_query (int): Maximum number of keywords to extract from query.
|
||||||
|
num_chunks_per_query (int): Maximum number of text chunks to query.
|
||||||
|
search_mode (Optional[SearchMode]): Specifies whether to use keyowrds, default SearchMode.KEYWORD
|
||||||
|
embeddings, or both to find relevant triplets. Should be one of "keyword",
|
||||||
|
"embedding", or "hybrid".
|
||||||
|
graph_store_query_depth (int): The depth of the graph store query.
|
||||||
|
extract_subject_entities_fn (Optional[Callback]): extract_subject_entities callback.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
graph_engine,
|
||||||
|
model_name: str = None,
|
||||||
|
max_keywords_per_query: int = 10,
|
||||||
|
num_chunks_per_query: int = 10,
|
||||||
|
search_mode: Optional[SearchMode] = SearchMode.KEYWORD,
|
||||||
|
graph_store_query_depth: int = 2,
|
||||||
|
extract_subject_entities_fn: Optional[Callable] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize params."""
|
||||||
|
from pilot.graph_engine.graph_engine import RAGGraphEngine
|
||||||
|
|
||||||
|
self.graph_engine: RAGGraphEngine = graph_engine
|
||||||
|
self.model_name = model_name or self.graph_engine.model_name
|
||||||
|
self._index_struct = self.graph_engine.index_struct
|
||||||
|
self.max_keywords_per_query = max_keywords_per_query
|
||||||
|
self.num_chunks_per_query = num_chunks_per_query
|
||||||
|
self._search_mode = search_mode
|
||||||
|
|
||||||
|
self._graph_store = self.graph_engine.graph_store
|
||||||
|
self.graph_store_query_depth = graph_store_query_depth
|
||||||
|
self._verbose = kwargs.get("verbose", False)
|
||||||
|
refresh_schema = kwargs.get("refresh_schema", False)
|
||||||
|
self.extract_subject_entities_fn = extract_subject_entities_fn
|
||||||
|
self.executor = ThreadPoolExecutor(max_workers=os.cpu_count() * 5)
|
||||||
|
try:
|
||||||
|
self._graph_schema = self._graph_store.get_schema(refresh=refresh_schema)
|
||||||
|
except NotImplementedError:
|
||||||
|
self._graph_schema = ""
|
||||||
|
except Exception as e:
|
||||||
|
logger.warn(f"can not to find graph schema: {e}")
|
||||||
|
self._graph_schema = ""
|
||||||
|
|
||||||
|
def _extract_subject_entities(self, query_str: str) -> Set[str]:
|
||||||
|
"""extract subject entities."""
|
||||||
|
if self.extract_subject_entities_fn is not None:
|
||||||
|
return self.extract_subject_entities_fn(query_str)
|
||||||
|
else:
|
||||||
|
return self._extract_entities_by_llm(query_str)
|
||||||
|
|
||||||
|
def _extract_entities_by_llm(self, text: str) -> Set[str]:
|
||||||
|
"""extract subject entities from text by llm"""
|
||||||
|
from pilot.scene.base import ChatScene
|
||||||
|
from pilot.common.chat_util import llm_chat_response_nostream
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
chat_param = {
|
||||||
|
"chat_session_id": uuid.uuid1(),
|
||||||
|
"current_user_input": text,
|
||||||
|
"select_param": "entity",
|
||||||
|
"model_name": self.model_name,
|
||||||
|
}
|
||||||
|
loop = utils.get_or_create_event_loop()
|
||||||
|
entities = loop.run_until_complete(
|
||||||
|
llm_chat_response_nostream(
|
||||||
|
ChatScene.ExtractEntity.value(), **{"chat_param": chat_param}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return entities
|
||||||
|
|
||||||
|
def _search(
|
||||||
|
self,
|
||||||
|
query_str: str,
|
||||||
|
) -> List[NodeWithScore]:
|
||||||
|
"""Get nodes for response."""
|
||||||
|
node_visited = set()
|
||||||
|
keywords = self._extract_subject_entities(query_str)
|
||||||
|
print(f"extract entities: {keywords}\n")
|
||||||
|
rel_texts = []
|
||||||
|
cur_rel_map = {}
|
||||||
|
chunk_indices_count: Dict[str, int] = defaultdict(int)
|
||||||
|
if self._search_mode != SearchMode.EMBEDDING:
|
||||||
|
for keyword in keywords:
|
||||||
|
keyword = keyword.lower()
|
||||||
|
subjs = set((keyword,))
|
||||||
|
node_ids = self._index_struct.search_node_by_keyword(keyword)
|
||||||
|
for node_id in node_ids[:GLOBAL_EXPLORE_NODE_LIMIT]:
|
||||||
|
if node_id in node_visited:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if self._include_text:
|
||||||
|
chunk_indices_count[node_id] += 1
|
||||||
|
|
||||||
|
node_visited.add(node_id)
|
||||||
|
|
||||||
|
rel_map = self._graph_store.get_rel_map(
|
||||||
|
list(subjs), self.graph_store_query_depth
|
||||||
|
)
|
||||||
|
logger.debug(f"rel_map: {rel_map}")
|
||||||
|
|
||||||
|
if not rel_map:
|
||||||
|
continue
|
||||||
|
rel_texts.extend(
|
||||||
|
[
|
||||||
|
str(rel_obj)
|
||||||
|
for rel_objs in rel_map.values()
|
||||||
|
for rel_obj in rel_objs
|
||||||
|
]
|
||||||
|
)
|
||||||
|
cur_rel_map.update(rel_map)
|
||||||
|
|
||||||
|
sorted_nodes_with_scores = []
|
||||||
|
if not rel_texts:
|
||||||
|
logger.info("> No relationships found, returning nodes found by keywords.")
|
||||||
|
if len(sorted_nodes_with_scores) == 0:
|
||||||
|
logger.info("> No nodes found by keywords, returning empty response.")
|
||||||
|
return [
|
||||||
|
NodeWithScore(node=TextNode(text="No relationships found."), score=1.0)
|
||||||
|
]
|
||||||
|
|
||||||
|
# add relationships as Node
|
||||||
|
# TODO: make initial text customizable
|
||||||
|
rel_initial_text = (
|
||||||
|
f"The following are knowledge sequence in max depth"
|
||||||
|
f" {self.graph_store_query_depth} "
|
||||||
|
f"in the form of directed graph like:\n"
|
||||||
|
f"`subject -[predicate]->, object, <-[predicate_next_hop]-,"
|
||||||
|
f" object_next_hop ...`"
|
||||||
|
)
|
||||||
|
rel_info = [rel_initial_text] + rel_texts
|
||||||
|
rel_node_info = {
|
||||||
|
"kg_rel_texts": rel_texts,
|
||||||
|
"kg_rel_map": cur_rel_map,
|
||||||
|
}
|
||||||
|
if self._graph_schema != "":
|
||||||
|
rel_node_info["kg_schema"] = {"schema": self._graph_schema}
|
||||||
|
rel_info_text = "\n".join(
|
||||||
|
[
|
||||||
|
str(item)
|
||||||
|
for sublist in rel_info
|
||||||
|
for item in (sublist if isinstance(sublist, list) else [sublist])
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if self._verbose:
|
||||||
|
print(f"KG context:\n{rel_info_text}\n", color="blue")
|
||||||
|
rel_text_node = TextNode(
|
||||||
|
text=rel_info_text,
|
||||||
|
metadata=rel_node_info,
|
||||||
|
excluded_embed_metadata_keys=["kg_rel_map", "kg_rel_texts"],
|
||||||
|
excluded_llm_metadata_keys=["kg_rel_map", "kg_rel_texts"],
|
||||||
|
)
|
||||||
|
# this node is constructed from rel_texts, give high confidence to avoid cutoff
|
||||||
|
sorted_nodes_with_scores.append(
|
||||||
|
NodeWithScore(node=rel_text_node, score=DEFAULT_NODE_SCORE)
|
||||||
|
)
|
||||||
|
|
||||||
|
return sorted_nodes_with_scores
|
||||||
|
|
||||||
|
def _get_metadata_for_response(
|
||||||
|
self, nodes: List[BaseNode]
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Get metadata for response."""
|
||||||
|
for node in nodes:
|
||||||
|
if node.metadata is None or "kg_rel_map" not in node.metadata:
|
||||||
|
continue
|
||||||
|
return node.metadata
|
||||||
|
raise ValueError("kg_rel_map must be found in at least one Node.")
|
259
pilot/graph_engine/index_struct.py
Normal file
259
pilot/graph_engine/index_struct.py
Normal file
@ -0,0 +1,259 @@
|
|||||||
|
"""Data structures.
|
||||||
|
|
||||||
|
Nodes are decoupled from the indices.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from abc import abstractmethod
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Dict, List, Optional, Sequence, Set
|
||||||
|
|
||||||
|
from dataclasses_json import DataClassJsonMixin
|
||||||
|
|
||||||
|
|
||||||
|
from pilot.graph_engine.index_type import IndexStructType
|
||||||
|
from pilot.graph_engine.node import TextNode, BaseNode
|
||||||
|
|
||||||
|
# TODO: legacy backport of old Node class
|
||||||
|
Node = TextNode
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class IndexStruct(DataClassJsonMixin):
|
||||||
|
"""A base data struct for a LlamaIndex."""
|
||||||
|
|
||||||
|
index_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||||
|
summary: Optional[str] = None
|
||||||
|
|
||||||
|
def get_summary(self) -> str:
|
||||||
|
"""Get text summary."""
|
||||||
|
if self.summary is None:
|
||||||
|
raise ValueError("summary field of the index_struct not set.")
|
||||||
|
return self.summary
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
|
def get_type(cls):
|
||||||
|
"""Get index struct type."""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class IndexGraph(IndexStruct):
|
||||||
|
"""A graph representing the tree-structured index."""
|
||||||
|
|
||||||
|
# mapping from index in tree to Node doc id.
|
||||||
|
all_nodes: Dict[int, str] = field(default_factory=dict)
|
||||||
|
root_nodes: Dict[int, str] = field(default_factory=dict)
|
||||||
|
node_id_to_children_ids: Dict[str, List[str]] = field(default_factory=dict)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def node_id_to_index(self) -> Dict[str, int]:
|
||||||
|
"""Map from node id to index."""
|
||||||
|
return {node_id: index for index, node_id in self.all_nodes.items()}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def size(self) -> int:
|
||||||
|
"""Get the size of the graph."""
|
||||||
|
return len(self.all_nodes)
|
||||||
|
|
||||||
|
def get_index(self, node: BaseNode) -> int:
|
||||||
|
"""Get index of node."""
|
||||||
|
return self.node_id_to_index[node.node_id]
|
||||||
|
|
||||||
|
def insert(
|
||||||
|
self,
|
||||||
|
node: BaseNode,
|
||||||
|
index: Optional[int] = None,
|
||||||
|
children_nodes: Optional[Sequence[BaseNode]] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Insert node."""
|
||||||
|
index = index or self.size
|
||||||
|
node_id = node.node_id
|
||||||
|
|
||||||
|
self.all_nodes[index] = node_id
|
||||||
|
|
||||||
|
if children_nodes is None:
|
||||||
|
children_nodes = []
|
||||||
|
children_ids = [n.node_id for n in children_nodes]
|
||||||
|
self.node_id_to_children_ids[node_id] = children_ids
|
||||||
|
|
||||||
|
def get_children(self, parent_node: Optional[BaseNode]) -> Dict[int, str]:
|
||||||
|
"""Get children nodes."""
|
||||||
|
if parent_node is None:
|
||||||
|
return self.root_nodes
|
||||||
|
else:
|
||||||
|
parent_id = parent_node.node_id
|
||||||
|
children_ids = self.node_id_to_children_ids[parent_id]
|
||||||
|
return {
|
||||||
|
self.node_id_to_index[child_id]: child_id for child_id in children_ids
|
||||||
|
}
|
||||||
|
|
||||||
|
def insert_under_parent(
|
||||||
|
self,
|
||||||
|
node: BaseNode,
|
||||||
|
parent_node: Optional[BaseNode],
|
||||||
|
new_index: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Insert under parent node."""
|
||||||
|
new_index = new_index or self.size
|
||||||
|
if parent_node is None:
|
||||||
|
self.root_nodes[new_index] = node.node_id
|
||||||
|
self.node_id_to_children_ids[node.node_id] = []
|
||||||
|
else:
|
||||||
|
if parent_node.node_id not in self.node_id_to_children_ids:
|
||||||
|
self.node_id_to_children_ids[parent_node.node_id] = []
|
||||||
|
self.node_id_to_children_ids[parent_node.node_id].append(node.node_id)
|
||||||
|
|
||||||
|
self.all_nodes[new_index] = node.node_id
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_type(cls) -> IndexStructType:
|
||||||
|
"""Get type."""
|
||||||
|
return IndexStructType.TREE
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class KeywordTable(IndexStruct):
|
||||||
|
"""A table of keywords mapping keywords to text chunks."""
|
||||||
|
|
||||||
|
table: Dict[str, Set[str]] = field(default_factory=dict)
|
||||||
|
|
||||||
|
def add_node(self, keywords: List[str], node: BaseNode) -> None:
|
||||||
|
"""Add text to table."""
|
||||||
|
for keyword in keywords:
|
||||||
|
if keyword not in self.table:
|
||||||
|
self.table[keyword] = set()
|
||||||
|
self.table[keyword].add(node.node_id)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def node_ids(self) -> Set[str]:
|
||||||
|
"""Get all node ids."""
|
||||||
|
return set.union(*self.table.values())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def keywords(self) -> Set[str]:
|
||||||
|
"""Get all keywords in the table."""
|
||||||
|
return set(self.table.keys())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def size(self) -> int:
|
||||||
|
"""Get the size of the table."""
|
||||||
|
return len(self.table)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_type(cls) -> IndexStructType:
|
||||||
|
"""Get type."""
|
||||||
|
return IndexStructType.KEYWORD_TABLE
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class IndexList(IndexStruct):
|
||||||
|
"""A list of documents."""
|
||||||
|
|
||||||
|
nodes: List[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
def add_node(self, node: BaseNode) -> None:
|
||||||
|
"""Add text to table, return current position in list."""
|
||||||
|
# don't worry about child indices for now, nodes are all in order
|
||||||
|
self.nodes.append(node.node_id)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_type(cls) -> IndexStructType:
|
||||||
|
"""Get type."""
|
||||||
|
return IndexStructType.LIST
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class IndexDict(IndexStruct):
|
||||||
|
"""A simple dictionary of documents."""
|
||||||
|
|
||||||
|
# TODO: slightly deprecated, should likely be a list or set now
|
||||||
|
# mapping from vector store id to node doc_id
|
||||||
|
nodes_dict: Dict[str, str] = field(default_factory=dict)
|
||||||
|
|
||||||
|
# TODO: deprecated, not used
|
||||||
|
# mapping from node doc_id to vector store id
|
||||||
|
doc_id_dict: Dict[str, List[str]] = field(default_factory=dict)
|
||||||
|
|
||||||
|
# TODO: deprecated, not used
|
||||||
|
# this should be empty for all other indices
|
||||||
|
embeddings_dict: Dict[str, List[float]] = field(default_factory=dict)
|
||||||
|
|
||||||
|
def add_node(
|
||||||
|
self,
|
||||||
|
node: BaseNode,
|
||||||
|
text_id: Optional[str] = None,
|
||||||
|
) -> str:
|
||||||
|
"""Add text to table, return current position in list."""
|
||||||
|
# # don't worry about child indices for now, nodes are all in order
|
||||||
|
# self.nodes_dict[int_id] = node
|
||||||
|
vector_id = text_id if text_id is not None else node.node_id
|
||||||
|
self.nodes_dict[vector_id] = node.node_id
|
||||||
|
|
||||||
|
return vector_id
|
||||||
|
|
||||||
|
def delete(self, doc_id: str) -> None:
|
||||||
|
"""Delete a Node."""
|
||||||
|
del self.nodes_dict[doc_id]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_type(cls) -> IndexStructType:
|
||||||
|
"""Get type."""
|
||||||
|
return IndexStructType.VECTOR_STORE
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class KG(IndexStruct):
|
||||||
|
"""A table of keywords mapping keywords to text chunks."""
|
||||||
|
|
||||||
|
# Unidirectional
|
||||||
|
|
||||||
|
# table of keywords to node ids
|
||||||
|
table: Dict[str, Set[str]] = field(default_factory=dict)
|
||||||
|
|
||||||
|
# TODO: legacy attribute, remove in future releases
|
||||||
|
rel_map: Dict[str, List[List[str]]] = field(default_factory=dict)
|
||||||
|
|
||||||
|
# TBD, should support vector store, now we just persist the embedding memory
|
||||||
|
# maybe chainable abstractions for *_stores could be designed
|
||||||
|
embedding_dict: Dict[str, List[float]] = field(default_factory=dict)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def node_ids(self) -> Set[str]:
|
||||||
|
"""Get all node ids."""
|
||||||
|
return set.union(*self.table.values())
|
||||||
|
|
||||||
|
def add_to_embedding_dict(self, triplet_str: str, embedding: List[float]) -> None:
|
||||||
|
"""Add embedding to dict."""
|
||||||
|
self.embedding_dict[triplet_str] = embedding
|
||||||
|
|
||||||
|
def add_node(self, keywords: List[str], node: BaseNode) -> None:
|
||||||
|
"""Add text to table."""
|
||||||
|
node_id = node.node_id
|
||||||
|
for keyword in keywords:
|
||||||
|
keyword = keyword.lower()
|
||||||
|
if keyword not in self.table:
|
||||||
|
self.table[keyword] = set()
|
||||||
|
self.table[keyword].add(node_id)
|
||||||
|
|
||||||
|
def search_node_by_keyword(self, keyword: str) -> List[str]:
|
||||||
|
"""Search for nodes by keyword."""
|
||||||
|
if keyword not in self.table:
|
||||||
|
return []
|
||||||
|
return list(self.table[keyword])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_type(cls) -> IndexStructType:
|
||||||
|
"""Get type."""
|
||||||
|
return IndexStructType.KG
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EmptyIndexStruct(IndexStruct):
|
||||||
|
"""Empty index."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_type(cls) -> IndexStructType:
|
||||||
|
"""Get type."""
|
||||||
|
return IndexStructType.EMPTY
|
48
pilot/graph_engine/index_type.py
Normal file
48
pilot/graph_engine/index_type.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
"""IndexStructType class."""
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class IndexStructType(str, Enum):
|
||||||
|
"""Index struct type. Identifier for a "type" of index.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
TREE ("tree"): Tree index. See :ref:`Ref-Indices-Tree` for tree indices.
|
||||||
|
LIST ("list"): Summary index. See :ref:`Ref-Indices-List` for summary indices.
|
||||||
|
KEYWORD_TABLE ("keyword_table"): Keyword table index. See
|
||||||
|
:ref:`Ref-Indices-Table`
|
||||||
|
for keyword table indices.
|
||||||
|
DICT ("dict"): Faiss Vector Store Index. See
|
||||||
|
:ref:`Ref-Indices-VectorStore`
|
||||||
|
for more information on the faiss vector store index.
|
||||||
|
SIMPLE_DICT ("simple_dict"): Simple Vector Store Index. See
|
||||||
|
:ref:`Ref-Indices-VectorStore`
|
||||||
|
for more information on the simple vector store index.
|
||||||
|
KG ("kg"): Knowledge Graph index.
|
||||||
|
See :ref:`Ref-Indices-Knowledge-Graph` for KG indices.
|
||||||
|
DOCUMENT_SUMMARY ("document_summary"): Document Summary Index.
|
||||||
|
See :ref:`Ref-Indices-Document-Summary` for Summary Indices.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
# TODO: refactor so these are properties on the base class
|
||||||
|
|
||||||
|
NODE = "node"
|
||||||
|
TREE = "tree"
|
||||||
|
LIST = "list"
|
||||||
|
KEYWORD_TABLE = "keyword_table"
|
||||||
|
|
||||||
|
DICT = "dict"
|
||||||
|
# simple
|
||||||
|
SIMPLE_DICT = "simple_dict"
|
||||||
|
# for KG index
|
||||||
|
KG = "kg"
|
||||||
|
SIMPLE_KG = "simple_kg"
|
||||||
|
NEBULAGRAPH = "nebulagraph"
|
||||||
|
FALKORDB = "falkordb"
|
||||||
|
|
||||||
|
# EMPTY
|
||||||
|
EMPTY = "empty"
|
||||||
|
COMPOSITE = "composite"
|
||||||
|
|
||||||
|
DOCUMENT_SUMMARY = "document_summary"
|
74
pilot/graph_engine/kv_index.py
Normal file
74
pilot/graph_engine/kv_index.py
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
from llama_index.data_structs.data_structs import IndexStruct
|
||||||
|
from llama_index.storage.index_store.utils import (
|
||||||
|
index_struct_to_json,
|
||||||
|
json_to_index_struct,
|
||||||
|
)
|
||||||
|
from llama_index.storage.kvstore.types import BaseKVStore
|
||||||
|
|
||||||
|
DEFAULT_NAMESPACE = "index_store"
|
||||||
|
|
||||||
|
|
||||||
|
class KVIndexStore:
|
||||||
|
"""Key-Value Index store.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
kvstore (BaseKVStore): key-value store
|
||||||
|
namespace (str): namespace for the index store
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, kvstore: BaseKVStore, namespace: Optional[str] = None) -> None:
|
||||||
|
"""Init a KVIndexStore."""
|
||||||
|
self._kvstore = kvstore
|
||||||
|
self._namespace = namespace or DEFAULT_NAMESPACE
|
||||||
|
self._collection = f"{self._namespace}/data"
|
||||||
|
|
||||||
|
def add_index_struct(self, index_struct: IndexStruct) -> None:
|
||||||
|
"""Add an index struct.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index_struct (IndexStruct): index struct
|
||||||
|
|
||||||
|
"""
|
||||||
|
key = index_struct.index_id
|
||||||
|
data = index_struct_to_json(index_struct)
|
||||||
|
self._kvstore.put(key, data, collection=self._collection)
|
||||||
|
|
||||||
|
def delete_index_struct(self, key: str) -> None:
|
||||||
|
"""Delete an index struct.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key (str): index struct key
|
||||||
|
|
||||||
|
"""
|
||||||
|
self._kvstore.delete(key, collection=self._collection)
|
||||||
|
|
||||||
|
def get_index_struct(
|
||||||
|
self, struct_id: Optional[str] = None
|
||||||
|
) -> Optional[IndexStruct]:
|
||||||
|
"""Get an index struct.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
struct_id (Optional[str]): index struct id
|
||||||
|
|
||||||
|
"""
|
||||||
|
if struct_id is None:
|
||||||
|
structs = self.index_structs()
|
||||||
|
assert len(structs) == 1
|
||||||
|
return structs[0]
|
||||||
|
else:
|
||||||
|
json = self._kvstore.get(struct_id, collection=self._collection)
|
||||||
|
if json is None:
|
||||||
|
return None
|
||||||
|
return json_to_index_struct(json)
|
||||||
|
|
||||||
|
def index_structs(self) -> List[IndexStruct]:
|
||||||
|
"""Get all index structs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[IndexStruct]: index structs
|
||||||
|
|
||||||
|
"""
|
||||||
|
jsons = self._kvstore.get_all(collection=self._collection)
|
||||||
|
return [json_to_index_struct(json) for json in jsons.values()]
|
569
pilot/graph_engine/node.py
Normal file
569
pilot/graph_engine/node.py
Normal file
@ -0,0 +1,569 @@
|
|||||||
|
"""Base schema for data structures."""
|
||||||
|
import json
|
||||||
|
import textwrap
|
||||||
|
import uuid
|
||||||
|
from abc import abstractmethod
|
||||||
|
from enum import Enum, auto
|
||||||
|
from hashlib import sha256
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from langchain.schema import Document
|
||||||
|
from pydantic import BaseModel, Field, root_validator
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_TEXT_NODE_TMPL = "{metadata_str}\n\n{content}"
|
||||||
|
DEFAULT_METADATA_TMPL = "{key}: {value}"
|
||||||
|
# NOTE: for pretty printing
|
||||||
|
TRUNCATE_LENGTH = 350
|
||||||
|
WRAP_WIDTH = 70
|
||||||
|
|
||||||
|
|
||||||
|
class BaseComponent(BaseModel):
|
||||||
|
"""Base component object to caputure class names."""
|
||||||
|
"""reference llama-index"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
|
def class_name(cls) -> str:
|
||||||
|
"""Get class name."""
|
||||||
|
|
||||||
|
def to_dict(self, **kwargs: Any) -> Dict[str, Any]:
|
||||||
|
data = self.dict(**kwargs)
|
||||||
|
data["class_name"] = self.class_name()
|
||||||
|
return data
|
||||||
|
|
||||||
|
def to_json(self, **kwargs: Any) -> str:
|
||||||
|
data = self.to_dict(**kwargs)
|
||||||
|
return json.dumps(data)
|
||||||
|
|
||||||
|
# TODO: return type here not supported by current mypy version
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any], **kwargs: Any) -> Self: # type: ignore
|
||||||
|
if isinstance(kwargs, dict):
|
||||||
|
data.update(kwargs)
|
||||||
|
|
||||||
|
data.pop("class_name", None)
|
||||||
|
return cls(**data)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_json(cls, data_str: str, **kwargs: Any) -> Self: # type: ignore
|
||||||
|
data = json.loads(data_str)
|
||||||
|
return cls.from_dict(data, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class NodeRelationship(str, Enum):
|
||||||
|
"""Node relationships used in `BaseNode` class.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
SOURCE: The node is the source document.
|
||||||
|
PREVIOUS: The node is the previous node in the document.
|
||||||
|
NEXT: The node is the next node in the document.
|
||||||
|
PARENT: The node is the parent node in the document.
|
||||||
|
CHILD: The node is a child node in the document.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
SOURCE = auto()
|
||||||
|
PREVIOUS = auto()
|
||||||
|
NEXT = auto()
|
||||||
|
PARENT = auto()
|
||||||
|
CHILD = auto()
|
||||||
|
|
||||||
|
|
||||||
|
class ObjectType(str, Enum):
|
||||||
|
TEXT = auto()
|
||||||
|
IMAGE = auto()
|
||||||
|
INDEX = auto()
|
||||||
|
DOCUMENT = auto()
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataMode(str, Enum):
|
||||||
|
ALL = auto()
|
||||||
|
EMBED = auto()
|
||||||
|
LLM = auto()
|
||||||
|
NONE = auto()
|
||||||
|
|
||||||
|
|
||||||
|
class RelatedNodeInfo(BaseComponent):
|
||||||
|
node_id: str
|
||||||
|
node_type: Optional[ObjectType] = None
|
||||||
|
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
hash: Optional[str] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def class_name(cls) -> str:
|
||||||
|
"""Get class name."""
|
||||||
|
return "RelatedNodeInfo"
|
||||||
|
|
||||||
|
|
||||||
|
RelatedNodeType = Union[RelatedNodeInfo, List[RelatedNodeInfo]]
|
||||||
|
|
||||||
|
|
||||||
|
# Node classes for indexes
|
||||||
|
class BaseNode(BaseComponent):
|
||||||
|
"""Base node Object.
|
||||||
|
|
||||||
|
Generic abstract interface for retrievable nodes
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
allow_population_by_field_name = True
|
||||||
|
|
||||||
|
id_: str = Field(
|
||||||
|
default_factory=lambda: str(uuid.uuid4()), description="Unique ID of the node."
|
||||||
|
)
|
||||||
|
embedding: Optional[List[float]] = Field(
|
||||||
|
default=None, description="Embedding of the node."
|
||||||
|
)
|
||||||
|
|
||||||
|
""""
|
||||||
|
metadata fields
|
||||||
|
- injected as part of the text shown to LLMs as context
|
||||||
|
- injected as part of the text for generating embeddings
|
||||||
|
- used by vector DBs for metadata filtering
|
||||||
|
|
||||||
|
"""
|
||||||
|
metadata: Dict[str, Any] = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="A flat dictionary of metadata fields",
|
||||||
|
alias="extra_info",
|
||||||
|
)
|
||||||
|
excluded_embed_metadata_keys: List[str] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="Metadata keys that are exluded from text for the embed model.",
|
||||||
|
)
|
||||||
|
excluded_llm_metadata_keys: List[str] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="Metadata keys that are exluded from text for the LLM.",
|
||||||
|
)
|
||||||
|
relationships: Dict[NodeRelationship, RelatedNodeType] = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="A mapping of relationships to other node information.",
|
||||||
|
)
|
||||||
|
hash: str = Field(default="", description="Hash of the node content.")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
|
def get_type(cls) -> str:
|
||||||
|
"""Get Object type."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_content(self, metadata_mode: MetadataMode = MetadataMode.ALL) -> str:
|
||||||
|
"""Get object content."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_metadata_str(self, mode: MetadataMode = MetadataMode.ALL) -> str:
|
||||||
|
"""Metadata string."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def set_content(self, value: Any) -> None:
|
||||||
|
"""Set the content of the node."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def node_id(self) -> str:
|
||||||
|
return self.id_
|
||||||
|
|
||||||
|
@node_id.setter
|
||||||
|
def node_id(self, value: str) -> None:
|
||||||
|
self.id_ = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def source_node(self) -> Optional[RelatedNodeInfo]:
|
||||||
|
"""Source object node.
|
||||||
|
|
||||||
|
Extracted from the relationships field.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if NodeRelationship.SOURCE not in self.relationships:
|
||||||
|
return None
|
||||||
|
|
||||||
|
relation = self.relationships[NodeRelationship.SOURCE]
|
||||||
|
if isinstance(relation, list):
|
||||||
|
raise ValueError("Source object must be a single RelatedNodeInfo object")
|
||||||
|
return relation
|
||||||
|
|
||||||
|
@property
|
||||||
|
def prev_node(self) -> Optional[RelatedNodeInfo]:
|
||||||
|
"""Prev node."""
|
||||||
|
if NodeRelationship.PREVIOUS not in self.relationships:
|
||||||
|
return None
|
||||||
|
|
||||||
|
relation = self.relationships[NodeRelationship.PREVIOUS]
|
||||||
|
if not isinstance(relation, RelatedNodeInfo):
|
||||||
|
raise ValueError("Previous object must be a single RelatedNodeInfo object")
|
||||||
|
return relation
|
||||||
|
|
||||||
|
@property
|
||||||
|
def next_node(self) -> Optional[RelatedNodeInfo]:
|
||||||
|
"""Next node."""
|
||||||
|
if NodeRelationship.NEXT not in self.relationships:
|
||||||
|
return None
|
||||||
|
|
||||||
|
relation = self.relationships[NodeRelationship.NEXT]
|
||||||
|
if not isinstance(relation, RelatedNodeInfo):
|
||||||
|
raise ValueError("Next object must be a single RelatedNodeInfo object")
|
||||||
|
return relation
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parent_node(self) -> Optional[RelatedNodeInfo]:
|
||||||
|
"""Parent node."""
|
||||||
|
if NodeRelationship.PARENT not in self.relationships:
|
||||||
|
return None
|
||||||
|
|
||||||
|
relation = self.relationships[NodeRelationship.PARENT]
|
||||||
|
if not isinstance(relation, RelatedNodeInfo):
|
||||||
|
raise ValueError("Parent object must be a single RelatedNodeInfo object")
|
||||||
|
return relation
|
||||||
|
|
||||||
|
@property
|
||||||
|
def child_nodes(self) -> Optional[List[RelatedNodeInfo]]:
|
||||||
|
"""Child nodes."""
|
||||||
|
if NodeRelationship.CHILD not in self.relationships:
|
||||||
|
return None
|
||||||
|
|
||||||
|
relation = self.relationships[NodeRelationship.CHILD]
|
||||||
|
if not isinstance(relation, list):
|
||||||
|
raise ValueError("Child objects must be a list of RelatedNodeInfo objects.")
|
||||||
|
return relation
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ref_doc_id(self) -> Optional[str]:
|
||||||
|
"""Deprecated: Get ref doc id."""
|
||||||
|
source_node = self.source_node
|
||||||
|
if source_node is None:
|
||||||
|
return None
|
||||||
|
return source_node.node_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def extra_info(self) -> Dict[str, Any]:
|
||||||
|
"""TODO: DEPRECATED: Extra info."""
|
||||||
|
return self.metadata
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
source_text_truncated = truncate_text(
|
||||||
|
self.get_content().strip(), TRUNCATE_LENGTH
|
||||||
|
)
|
||||||
|
source_text_wrapped = textwrap.fill(
|
||||||
|
f"Text: {source_text_truncated}\n", width=WRAP_WIDTH
|
||||||
|
)
|
||||||
|
return f"Node ID: {self.node_id}\n{source_text_wrapped}"
|
||||||
|
|
||||||
|
def truncate_text(text: str, max_length: int) -> str:
|
||||||
|
"""Truncate text to a maximum length."""
|
||||||
|
if len(text) <= max_length:
|
||||||
|
return text
|
||||||
|
return text[: max_length - 3] + "..."
|
||||||
|
|
||||||
|
def get_embedding(self) -> List[float]:
|
||||||
|
"""Get embedding.
|
||||||
|
|
||||||
|
Errors if embedding is None.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if self.embedding is None:
|
||||||
|
raise ValueError("embedding not set.")
|
||||||
|
return self.embedding
|
||||||
|
|
||||||
|
def as_related_node_info(self) -> RelatedNodeInfo:
|
||||||
|
"""Get node as RelatedNodeInfo."""
|
||||||
|
return RelatedNodeInfo(
|
||||||
|
node_id=self.node_id, metadata=self.metadata, hash=self.hash
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TextNode(BaseNode):
|
||||||
|
text: str = Field(default="", description="Text content of the node.")
|
||||||
|
start_char_idx: Optional[int] = Field(
|
||||||
|
default=None, description="Start char index of the node."
|
||||||
|
)
|
||||||
|
end_char_idx: Optional[int] = Field(
|
||||||
|
default=None, description="End char index of the node."
|
||||||
|
)
|
||||||
|
text_template: str = Field(
|
||||||
|
default=DEFAULT_TEXT_NODE_TMPL,
|
||||||
|
description=(
|
||||||
|
"Template for how text is formatted, with {content} and "
|
||||||
|
"{metadata_str} placeholders."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
metadata_template: str = Field(
|
||||||
|
default=DEFAULT_METADATA_TMPL,
|
||||||
|
description=(
|
||||||
|
"Template for how metadata is formatted, with {key} and "
|
||||||
|
"{value} placeholders."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
metadata_seperator: str = Field(
|
||||||
|
default="\n",
|
||||||
|
description="Seperator between metadata fields when converting to string.",
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def class_name(cls) -> str:
|
||||||
|
"""Get class name."""
|
||||||
|
return "TextNode"
|
||||||
|
|
||||||
|
@root_validator
|
||||||
|
def _check_hash(cls, values: dict) -> dict:
|
||||||
|
"""Generate a hash to represent the node."""
|
||||||
|
text = values.get("text", "")
|
||||||
|
metadata = values.get("metadata", {})
|
||||||
|
doc_identity = str(text) + str(metadata)
|
||||||
|
values["hash"] = str(
|
||||||
|
sha256(doc_identity.encode("utf-8", "surrogatepass")).hexdigest()
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_type(cls) -> str:
|
||||||
|
"""Get Object type."""
|
||||||
|
return ObjectType.TEXT
|
||||||
|
|
||||||
|
def get_content(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str:
|
||||||
|
"""Get object content."""
|
||||||
|
metadata_str = self.get_metadata_str(mode=metadata_mode).strip()
|
||||||
|
if not metadata_str:
|
||||||
|
return self.text
|
||||||
|
|
||||||
|
return self.text_template.format(
|
||||||
|
content=self.text, metadata_str=metadata_str
|
||||||
|
).strip()
|
||||||
|
|
||||||
|
def get_metadata_str(self, mode: MetadataMode = MetadataMode.ALL) -> str:
|
||||||
|
"""metadata info string."""
|
||||||
|
if mode == MetadataMode.NONE:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
usable_metadata_keys = set(self.metadata.keys())
|
||||||
|
if mode == MetadataMode.LLM:
|
||||||
|
for key in self.excluded_llm_metadata_keys:
|
||||||
|
if key in usable_metadata_keys:
|
||||||
|
usable_metadata_keys.remove(key)
|
||||||
|
elif mode == MetadataMode.EMBED:
|
||||||
|
for key in self.excluded_embed_metadata_keys:
|
||||||
|
if key in usable_metadata_keys:
|
||||||
|
usable_metadata_keys.remove(key)
|
||||||
|
|
||||||
|
return self.metadata_seperator.join(
|
||||||
|
[
|
||||||
|
self.metadata_template.format(key=key, value=str(value))
|
||||||
|
for key, value in self.metadata.items()
|
||||||
|
if key in usable_metadata_keys
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_content(self, value: str) -> None:
|
||||||
|
"""Set the content of the node."""
|
||||||
|
self.text = value
|
||||||
|
|
||||||
|
def get_node_info(self) -> Dict[str, Any]:
|
||||||
|
"""Get node info."""
|
||||||
|
return {"start": self.start_char_idx, "end": self.end_char_idx}
|
||||||
|
|
||||||
|
def get_text(self) -> str:
|
||||||
|
return self.get_content(metadata_mode=MetadataMode.NONE)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def node_info(self) -> Dict[str, Any]:
|
||||||
|
"""Deprecated: Get node info."""
|
||||||
|
return self.get_node_info()
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: legacy backport of old Node class
|
||||||
|
Node = TextNode
|
||||||
|
|
||||||
|
|
||||||
|
class ImageNode(TextNode):
|
||||||
|
"""Node with image."""
|
||||||
|
|
||||||
|
# TODO: store reference instead of actual image
|
||||||
|
# base64 encoded image str
|
||||||
|
image: Optional[str] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_type(cls) -> str:
|
||||||
|
return ObjectType.IMAGE
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def class_name(cls) -> str:
|
||||||
|
"""Get class name."""
|
||||||
|
return "ImageNode"
|
||||||
|
|
||||||
|
|
||||||
|
class IndexNode(TextNode):
|
||||||
|
"""Node with reference to any object.
|
||||||
|
|
||||||
|
This can include other indices, query engines, retrievers.
|
||||||
|
|
||||||
|
This can also include other nodes (though this is overlapping with `relationships`
|
||||||
|
on the Node class).
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
index_id: str
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_text_node(
|
||||||
|
cls,
|
||||||
|
node: TextNode,
|
||||||
|
index_id: str,
|
||||||
|
) -> "IndexNode":
|
||||||
|
"""Create index node from text node."""
|
||||||
|
# copy all attributes from text node, add index id
|
||||||
|
return cls(
|
||||||
|
**node.dict(),
|
||||||
|
index_id=index_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_type(cls) -> str:
|
||||||
|
return ObjectType.INDEX
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def class_name(cls) -> str:
|
||||||
|
"""Get class name."""
|
||||||
|
return "IndexNode"
|
||||||
|
|
||||||
|
|
||||||
|
class NodeWithScore(BaseComponent):
|
||||||
|
node: BaseNode
|
||||||
|
score: Optional[float] = None
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"{self.node}\nScore: {self.score: 0.3f}\n"
|
||||||
|
|
||||||
|
def get_score(self, raise_error: bool = False) -> float:
|
||||||
|
"""Get score."""
|
||||||
|
if self.score is None:
|
||||||
|
if raise_error:
|
||||||
|
raise ValueError("Score not set.")
|
||||||
|
else:
|
||||||
|
return 0.0
|
||||||
|
else:
|
||||||
|
return self.score
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def class_name(cls) -> str:
|
||||||
|
"""Get class name."""
|
||||||
|
return "NodeWithScore"
|
||||||
|
|
||||||
|
##### pass through methods to BaseNode #####
|
||||||
|
@property
|
||||||
|
def node_id(self) -> str:
|
||||||
|
return self.node.node_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def id_(self) -> str:
|
||||||
|
return self.node.id_
|
||||||
|
|
||||||
|
@property
|
||||||
|
def text(self) -> str:
|
||||||
|
if isinstance(self.node, TextNode):
|
||||||
|
return self.node.text
|
||||||
|
else:
|
||||||
|
raise ValueError("Node must be a TextNode to get text.")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def metadata(self) -> Dict[str, Any]:
|
||||||
|
return self.node.metadata
|
||||||
|
|
||||||
|
@property
|
||||||
|
def embedding(self) -> Optional[List[float]]:
|
||||||
|
return self.node.embedding
|
||||||
|
|
||||||
|
def get_text(self) -> str:
|
||||||
|
if isinstance(self.node, TextNode):
|
||||||
|
return self.node.get_text()
|
||||||
|
else:
|
||||||
|
raise ValueError("Node must be a TextNode to get text.")
|
||||||
|
|
||||||
|
def get_content(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str:
|
||||||
|
return self.node.get_content(metadata_mode=metadata_mode)
|
||||||
|
|
||||||
|
def get_embedding(self) -> List[float]:
|
||||||
|
return self.node.get_embedding()
|
||||||
|
|
||||||
|
|
||||||
|
# Document Classes for Readers
|
||||||
|
|
||||||
|
|
||||||
|
class Document(TextNode):
|
||||||
|
"""Generic interface for a data document.
|
||||||
|
|
||||||
|
This document connects to data sources.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
# TODO: A lot of backwards compatibility logic here, clean up
|
||||||
|
id_: str = Field(
|
||||||
|
default_factory=lambda: str(uuid.uuid4()),
|
||||||
|
description="Unique ID of the node.",
|
||||||
|
alias="doc_id",
|
||||||
|
)
|
||||||
|
|
||||||
|
_compat_fields = {"doc_id": "id_", "extra_info": "metadata"}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_type(cls) -> str:
|
||||||
|
"""Get Document type."""
|
||||||
|
return ObjectType.DOCUMENT
|
||||||
|
|
||||||
|
@property
|
||||||
|
def doc_id(self) -> str:
|
||||||
|
"""Get document ID."""
|
||||||
|
return self.id_
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
source_text_truncated = truncate_text(
|
||||||
|
self.get_content().strip(), TRUNCATE_LENGTH
|
||||||
|
)
|
||||||
|
source_text_wrapped = textwrap.fill(
|
||||||
|
f"Text: {source_text_truncated}\n", width=WRAP_WIDTH
|
||||||
|
)
|
||||||
|
return f"Doc ID: {self.doc_id}\n{source_text_wrapped}"
|
||||||
|
|
||||||
|
def get_doc_id(self) -> str:
|
||||||
|
"""TODO: Deprecated: Get document ID."""
|
||||||
|
return self.id_
|
||||||
|
|
||||||
|
def __setattr__(self, name: str, value: object) -> None:
|
||||||
|
if name in self._compat_fields:
|
||||||
|
name = self._compat_fields[name]
|
||||||
|
super().__setattr__(name, value)
|
||||||
|
|
||||||
|
def to_langchain_format(self) -> Document:
|
||||||
|
"""Convert struct to LangChain document format."""
|
||||||
|
metadata = self.metadata or {}
|
||||||
|
return Document(page_content=self.text, metadata=metadata)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_langchain_format(cls, doc: Document) -> "Document":
|
||||||
|
"""Convert struct from LangChain document format."""
|
||||||
|
return cls(text=doc.page_content, metadata=doc.metadata)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def example(cls) -> "Document":
|
||||||
|
document = Document(
|
||||||
|
text="",
|
||||||
|
metadata={"filename": "README.md", "category": "codebase"},
|
||||||
|
)
|
||||||
|
return document
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def class_name(cls) -> str:
|
||||||
|
"""Get class name."""
|
||||||
|
return "Document"
|
||||||
|
|
||||||
|
|
||||||
|
class ImageDocument(Document):
|
||||||
|
"""Data document containing an image."""
|
||||||
|
|
||||||
|
# base64 encoded image str
|
||||||
|
image: Optional[str] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def class_name(cls) -> str:
|
||||||
|
"""Get class name."""
|
||||||
|
return "ImageDocument"
|
44
pilot/graph_engine/search.py
Normal file
44
pilot/graph_engine/search.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class SearchMode(str, Enum):
|
||||||
|
"""Query mode enum for Knowledge Graphs.
|
||||||
|
|
||||||
|
Can be passed as the enum struct, or as the underlying string.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
KEYWORD ("keyword"): Default query mode, using keywords to find triplets.
|
||||||
|
EMBEDDING ("embedding"): Embedding mode, using embeddings to find
|
||||||
|
similar triplets.
|
||||||
|
HYBRID ("hybrid"): Hyrbid mode, combining both keywords and embeddings
|
||||||
|
to find relevant triplets.
|
||||||
|
"""
|
||||||
|
|
||||||
|
KEYWORD = "keyword"
|
||||||
|
EMBEDDING = "embedding"
|
||||||
|
HYBRID = "hybrid"
|
||||||
|
|
||||||
|
|
||||||
|
class BaseSearch(ABC):
|
||||||
|
"""Base Search."""
|
||||||
|
|
||||||
|
def search(self, query: str):
|
||||||
|
"""Retrieve nodes given query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (QueryType): Either a query string or
|
||||||
|
a QueryBundle object.
|
||||||
|
|
||||||
|
"""
|
||||||
|
# if isinstance(query, str):
|
||||||
|
return self._search(query)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _search(self, query: str):
|
||||||
|
"""search nodes given query.
|
||||||
|
|
||||||
|
Implemented by the user.
|
||||||
|
|
||||||
|
"""
|
||||||
|
pass
|
35
pilot/scene/chat_knowledge/extract_entity/chat.py
Normal file
35
pilot/scene/chat_knowledge/extract_entity/chat.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from pilot.scene.base_chat import BaseChat
|
||||||
|
from pilot.scene.base import ChatScene
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
|
||||||
|
from pilot.scene.chat_knowledge.extract_entity.prompt import prompt
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
|
class ExtractEntity(BaseChat):
|
||||||
|
chat_scene: str = ChatScene.ExtractEntity.value()
|
||||||
|
|
||||||
|
"""extracting entities by llm"""
|
||||||
|
|
||||||
|
def __init__(self, chat_param: Dict):
|
||||||
|
""" """
|
||||||
|
chat_param["chat_mode"] = ChatScene.ExtractEntity
|
||||||
|
super().__init__(
|
||||||
|
chat_param=chat_param,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.user_input = chat_param["current_user_input"]
|
||||||
|
self.extract_mode = chat_param["select_param"]
|
||||||
|
|
||||||
|
def generate_input_values(self):
|
||||||
|
input_values = {
|
||||||
|
"text": self.user_input,
|
||||||
|
}
|
||||||
|
return input_values
|
||||||
|
|
||||||
|
@property
|
||||||
|
def chat_type(self) -> str:
|
||||||
|
return ChatScene.ExtractEntity.value
|
39
pilot/scene/chat_knowledge/extract_entity/out_parser.py
Normal file
39
pilot/scene/chat_knowledge/extract_entity/out_parser.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Set
|
||||||
|
|
||||||
|
from pilot.out_parser.base import BaseOutputParser, T
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ExtractEntityParser(BaseOutputParser):
|
||||||
|
def __init__(self, sep: str, is_stream_out: bool):
|
||||||
|
super().__init__(sep=sep, is_stream_out=is_stream_out)
|
||||||
|
|
||||||
|
def parse_prompt_response(self, response, max_length: int = 128) -> Set[str]:
|
||||||
|
lowercase = True
|
||||||
|
# clean_str = super().parse_prompt_response(response)
|
||||||
|
print("clean prompt response:", response)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
response = response.strip() # Strip newlines from responses.
|
||||||
|
|
||||||
|
if response.startswith("KEYWORDS:"):
|
||||||
|
response = response[len("KEYWORDS:") :]
|
||||||
|
|
||||||
|
keywords = response.split(",")
|
||||||
|
for k in keywords:
|
||||||
|
rk = k
|
||||||
|
if lowercase:
|
||||||
|
rk = rk.lower()
|
||||||
|
results.append(rk.strip())
|
||||||
|
|
||||||
|
return set(results)
|
||||||
|
|
||||||
|
def parse_view_response(self, speak, data) -> str:
|
||||||
|
return data
|
52
pilot/scene/chat_knowledge/extract_entity/prompt.py
Normal file
52
pilot/scene/chat_knowledge/extract_entity/prompt.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
from pilot.prompts.prompt_new import PromptTemplate
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
from pilot.scene.base import ChatScene
|
||||||
|
from pilot.common.schema import SeparatorStyle
|
||||||
|
from pilot.scene.chat_knowledge.extract_entity.out_parser import ExtractEntityParser
|
||||||
|
|
||||||
|
from pilot.scene.chat_knowledge.extract_triplet.out_parser import (
|
||||||
|
ExtractTripleParser,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
|
|
||||||
|
PROMPT_SCENE_DEFINE = """"""
|
||||||
|
|
||||||
|
_DEFAULT_TEMPLATE = """
|
||||||
|
"A question is provided below. Given the question, extract up to 10 "
|
||||||
|
"keywords from the text. Focus on extracting the keywords that we can use "
|
||||||
|
"to best lookup answers to the question. Avoid stopwords.\n"
|
||||||
|
"Example:"
|
||||||
|
"Text: Alice is Bob's mother."
|
||||||
|
"KEYWORDS:Alice,mother,Bob\n"
|
||||||
|
"---------------------\n"
|
||||||
|
"{text}\n"
|
||||||
|
"---------------------\n"
|
||||||
|
"Provide keywords in the following comma-separated format: 'KEYWORDS: <keywords>'\n"
|
||||||
|
"""
|
||||||
|
PROMPT_RESPONSE = """"""
|
||||||
|
|
||||||
|
|
||||||
|
RESPONSE_FORMAT = """"""
|
||||||
|
|
||||||
|
|
||||||
|
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||||
|
|
||||||
|
PROMPT_NEED_NEED_STREAM_OUT = False
|
||||||
|
|
||||||
|
prompt = PromptTemplate(
|
||||||
|
template_scene=ChatScene.ExtractEntity.value(),
|
||||||
|
input_variables=["text"],
|
||||||
|
response_format="",
|
||||||
|
template_define=PROMPT_SCENE_DEFINE,
|
||||||
|
template=_DEFAULT_TEMPLATE + PROMPT_RESPONSE,
|
||||||
|
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
|
||||||
|
output_parser=ExtractEntityParser(
|
||||||
|
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
CFG.prompt_template_registry.register(prompt, is_default=True)
|
35
pilot/scene/chat_knowledge/extract_triplet/chat.py
Normal file
35
pilot/scene/chat_knowledge/extract_triplet/chat.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from pilot.scene.base_chat import BaseChat
|
||||||
|
from pilot.scene.base import ChatScene
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
|
||||||
|
from pilot.scene.chat_knowledge.extract_triplet.prompt import prompt
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
|
class ExtractTriplet(BaseChat):
|
||||||
|
chat_scene: str = ChatScene.ExtractTriplet.value()
|
||||||
|
|
||||||
|
"""extracting triplets by llm"""
|
||||||
|
|
||||||
|
def __init__(self, chat_param: Dict):
|
||||||
|
""" """
|
||||||
|
chat_param["chat_mode"] = ChatScene.ExtractTriplet
|
||||||
|
super().__init__(
|
||||||
|
chat_param=chat_param,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.user_input = chat_param["current_user_input"]
|
||||||
|
self.extract_mode = chat_param["select_param"]
|
||||||
|
|
||||||
|
def generate_input_values(self):
|
||||||
|
input_values = {
|
||||||
|
"text": self.user_input,
|
||||||
|
}
|
||||||
|
return input_values
|
||||||
|
|
||||||
|
@property
|
||||||
|
def chat_type(self) -> str:
|
||||||
|
return ChatScene.ExtractTriplet.value
|
57
pilot/scene/chat_knowledge/extract_triplet/out_parser.py
Normal file
57
pilot/scene/chat_knowledge/extract_triplet/out_parser.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
from pilot.out_parser.base import BaseOutputParser, T
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ExtractTripleParser(BaseOutputParser):
|
||||||
|
def __init__(self, sep: str, is_stream_out: bool):
|
||||||
|
super().__init__(sep=sep, is_stream_out=is_stream_out)
|
||||||
|
|
||||||
|
def parse_prompt_response(
|
||||||
|
self, response, max_length: int = 128
|
||||||
|
) -> List[Tuple[str, str, str]]:
|
||||||
|
# clean_str = super().parse_prompt_response(response)
|
||||||
|
print("clean prompt response:", response)
|
||||||
|
|
||||||
|
if response.startswith("Triplets:"):
|
||||||
|
response = response[len("Triplets:") :]
|
||||||
|
pattern = r"\([^()]+\)"
|
||||||
|
response = re.findall(pattern, response)
|
||||||
|
# response = response.strip().split("\n")
|
||||||
|
print("parse prompt response:", response)
|
||||||
|
results = []
|
||||||
|
for text in response:
|
||||||
|
if not text or text[0] != "(" or text[-1] != ")":
|
||||||
|
# skip empty lines and non-triplets
|
||||||
|
continue
|
||||||
|
tokens = text[1:-1].split(",")
|
||||||
|
if len(tokens) != 3:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if any(len(s.encode("utf-8")) > max_length for s in tokens):
|
||||||
|
# We count byte-length instead of len() for UTF-8 chars,
|
||||||
|
# will skip if any of the tokens are too long.
|
||||||
|
# This is normally due to a poorly formatted triplet
|
||||||
|
# extraction, in more serious KG building cases
|
||||||
|
# we'll need NLP models to better extract triplets.
|
||||||
|
continue
|
||||||
|
|
||||||
|
subject, predicate, obj = map(str.strip, tokens)
|
||||||
|
if not subject or not predicate or not obj:
|
||||||
|
# skip partial triplets
|
||||||
|
continue
|
||||||
|
results.append((subject.lower(), predicate.lower(), obj.lower()))
|
||||||
|
return results
|
||||||
|
|
||||||
|
def parse_view_response(self, speak, data) -> str:
|
||||||
|
### tool out data to table view
|
||||||
|
return data
|
57
pilot/scene/chat_knowledge/extract_triplet/prompt.py
Normal file
57
pilot/scene/chat_knowledge/extract_triplet/prompt.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
from pilot.prompts.prompt_new import PromptTemplate
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
from pilot.scene.base import ChatScene
|
||||||
|
from pilot.common.schema import SeparatorStyle
|
||||||
|
|
||||||
|
from pilot.scene.chat_knowledge.extract_triplet.out_parser import (
|
||||||
|
ExtractTripleParser,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
|
|
||||||
|
PROMPT_SCENE_DEFINE = """"""
|
||||||
|
|
||||||
|
_DEFAULT_TEMPLATE = """
|
||||||
|
"Some text is provided below. Given the text, extract up to 10"
|
||||||
|
"knowledge triplets in the form of (subject, predicate, object). Avoid stopwords.\n"
|
||||||
|
"---------------------\n"
|
||||||
|
"Example:"
|
||||||
|
"Text: Alice is Bob's mother."
|
||||||
|
"Triplets:\n(Alice, is mother of, Bob)\n"
|
||||||
|
"Text: Philz is a coffee shop founded in Berkeley in 1982.\n"
|
||||||
|
"Triplets:\n"
|
||||||
|
"(Philz, is, coffee shop)\n"
|
||||||
|
"(Philz, founded in, Berkeley)\n"
|
||||||
|
"(Philz, founded in, 1982)\n"
|
||||||
|
"---------------------\n"
|
||||||
|
"Text: {text}\n"
|
||||||
|
"Triplets:\n"
|
||||||
|
ensure Respond in the following List(Tuple) format:
|
||||||
|
'(Stephen Curry, plays for, Golden State Warriors)\n(Stephen Curry, known for, shooting skills)\n(Stephen Curry, attended, Davidson College)\n(Stephen Curry, led, team to success)'
|
||||||
|
"""
|
||||||
|
PROMPT_RESPONSE = """"""
|
||||||
|
|
||||||
|
|
||||||
|
RESPONSE_FORMAT = """"""
|
||||||
|
|
||||||
|
|
||||||
|
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||||
|
|
||||||
|
PROMPT_NEED_NEED_STREAM_OUT = False
|
||||||
|
|
||||||
|
prompt = PromptTemplate(
|
||||||
|
template_scene=ChatScene.ExtractTriplet.value(),
|
||||||
|
input_variables=["text"],
|
||||||
|
response_format="",
|
||||||
|
template_define=PROMPT_SCENE_DEFINE,
|
||||||
|
template=_DEFAULT_TEMPLATE + PROMPT_RESPONSE,
|
||||||
|
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
|
||||||
|
output_parser=ExtractTripleParser(
|
||||||
|
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
CFG.prompt_template_registry.register(prompt, is_default=True)
|
@ -88,25 +88,7 @@ class ChatKnowledge(BaseChat):
|
|||||||
if self.space_context:
|
if self.space_context:
|
||||||
self.prompt_template.template_define = self.space_context["prompt"]["scene"]
|
self.prompt_template.template_define = self.space_context["prompt"]["scene"]
|
||||||
self.prompt_template.template = self.space_context["prompt"]["template"]
|
self.prompt_template.template = self.space_context["prompt"]["template"]
|
||||||
# docs = self.rag_engine.search(query=self.current_user_input)
|
docs = self.rag_engine.search(query=self.current_user_input)
|
||||||
# import httpx
|
|
||||||
# with httpx.Client() as client:
|
|
||||||
# request = client.build_request(
|
|
||||||
# "post",
|
|
||||||
# "http://127.0.0.1/api/knowledge/entities/extract",
|
|
||||||
# json="application/json", # using json for data to ensure it sends as application/json
|
|
||||||
# params={"text": self.current_user_input},
|
|
||||||
# headers={},
|
|
||||||
# )
|
|
||||||
#
|
|
||||||
# response = client.send(request)
|
|
||||||
# if response.status_code != 200:
|
|
||||||
# error_msg = f"request /api/knowledge/entities/extract failed, error: {response.text}"
|
|
||||||
# raise Exception(error_msg)
|
|
||||||
# docs = response.json()
|
|
||||||
# import requests
|
|
||||||
# docs = requests.post("http://127.0.0.1:5000/api/knowledge/entities/extract", headers={}, json={"text": self.current_user_input})
|
|
||||||
|
|
||||||
docs = self.knowledge_embedding_client.similar_search(
|
docs = self.knowledge_embedding_client.similar_search(
|
||||||
self.current_user_input, self.top_k
|
self.current_user_input, self.top_k
|
||||||
)
|
)
|
||||||
|
@ -205,12 +205,6 @@ def similar_query(space_name: str, query_request: KnowledgeQueryRequest):
|
|||||||
async def entity_extract(request: EntityExtractRequest):
|
async def entity_extract(request: EntityExtractRequest):
|
||||||
logger.info(f"Received params: {request}")
|
logger.info(f"Received params: {request}")
|
||||||
try:
|
try:
|
||||||
# from pilot.graph_engine.graph_factory import RAGGraphFactory
|
|
||||||
# from pilot.component import ComponentType
|
|
||||||
# rag_engine = CFG.SYSTEM_APP.get_component(
|
|
||||||
# ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory
|
|
||||||
# ).create()
|
|
||||||
# return Result.succ(await rag_engine.search(request.text))
|
|
||||||
from pilot.scene.base import ChatScene
|
from pilot.scene.base import ChatScene
|
||||||
from pilot.common.chat_util import llm_chat_response_nostream
|
from pilot.common.chat_util import llm_chat_response_nostream
|
||||||
import uuid
|
import uuid
|
||||||
@ -222,11 +216,6 @@ async def entity_extract(request: EntityExtractRequest):
|
|||||||
"model_name": request.model_name,
|
"model_name": request.model_name,
|
||||||
}
|
}
|
||||||
|
|
||||||
# import nest_asyncio
|
|
||||||
# nest_asyncio.apply()
|
|
||||||
# loop = asyncio.get_event_loop()
|
|
||||||
# loop.stop()
|
|
||||||
# loop = utils.get_or_create_event_loop()
|
|
||||||
res = await llm_chat_response_nostream(
|
res = await llm_chat_response_nostream(
|
||||||
ChatScene.ExtractEntity.value(), **{"chat_param": chat_param}
|
ChatScene.ExtractEntity.value(), **{"chat_param": chat_param}
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user