mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-14 14:34:28 +00:00
feat:knowledge rag graph
This commit is contained in:
parent
fa6a9040d5
commit
2f82f98e31
20
pilot/common/chat_util.py
Normal file
20
pilot/common/chat_util.py
Normal file
@ -0,0 +1,20 @@
|
||||
import asyncio
|
||||
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from pilot.scene.base_chat import BaseChat
|
||||
from pilot.scene.chat_factory import ChatFactory
|
||||
|
||||
chat_factory = ChatFactory()
|
||||
|
||||
|
||||
async def llm_chat_response_nostream(chat_scene: str, **chat_param):
|
||||
""" llm_chat_response_nostream """
|
||||
chat: BaseChat = chat_factory.get_implementation(chat_scene, **chat_param)
|
||||
res = await chat.get_llm_response()
|
||||
return res
|
||||
|
||||
|
||||
async def llm_chat_response(chat_scene: str, **chat_param):
|
||||
chat: BaseChat = chat_factory.get_implementation(chat_scene, **chat_param)
|
||||
return chat.stream_call()
|
0
pilot/graph_engine/__init__.py
Normal file
0
pilot/graph_engine/__init__.py
Normal file
137
pilot/graph_engine/graph_engine.py
Normal file
137
pilot/graph_engine/graph_engine.py
Normal file
@ -0,0 +1,137 @@
|
||||
import logging
|
||||
from typing import Any, Optional, Callable, Tuple, List
|
||||
|
||||
from langchain.schema import Document
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from pilot.embedding_engine import KnowledgeType
|
||||
from pilot.embedding_engine.knowledge_type import get_knowledge_embedding
|
||||
from pilot.graph_engine.index_struct import KG
|
||||
from pilot.graph_engine.node import TextNode
|
||||
from pilot.utils import utils
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RAGGraphEngine:
|
||||
"""Knowledge RAG Graph Engine.
|
||||
Build a KG by extracting triplets, and leveraging the KG during query-time.
|
||||
Args:
|
||||
knowledge_type (Optional[str]): Default: KnowledgeType.DOCUMENT.value
|
||||
extracting triplets.
|
||||
graph_store (Optional[GraphStore]): The graph store to use.refrence:llama-index
|
||||
include_embeddings (bool): Whether to include embeddings in the index.
|
||||
Defaults to False.
|
||||
max_object_length (int): The maximum length of the object in a triplet.
|
||||
Defaults to 128.
|
||||
extract_triplet_fn (Optional[Callable]): The function to use for
|
||||
extracting triplets. Defaults to None.
|
||||
"""
|
||||
|
||||
index_struct_cls = KG
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
knowledge_type: Optional[str] = KnowledgeType.DOCUMENT.value,
|
||||
knowledge_source: Optional[str] = None,
|
||||
text_splitter=None,
|
||||
graph_store=None,
|
||||
index_struct: Optional[KG] = None,
|
||||
model_name: Optional[str] = None,
|
||||
max_triplets_per_chunk: int = 10,
|
||||
include_embeddings: bool = False,
|
||||
max_object_length: int = 128,
|
||||
extract_triplet_fn: Optional[Callable] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize params."""
|
||||
# from llama_index.graph_stores import SimpleGraphStore
|
||||
# from llama_index.graph_stores.types import GraphStore
|
||||
|
||||
# need to set parameters before building index in base class.
|
||||
self.knowledge_source = knowledge_source
|
||||
self.knowledge_type = knowledge_type
|
||||
self.model_name = model_name
|
||||
self.text_splitter = text_splitter
|
||||
self.index_struct = index_struct
|
||||
self.include_embeddings = include_embeddings
|
||||
# self.graph_store = graph_store or SimpleGraphStore()
|
||||
self.graph_store = graph_store
|
||||
self.max_triplets_per_chunk = max_triplets_per_chunk
|
||||
self._max_object_length = max_object_length
|
||||
self._extract_triplet_fn = extract_triplet_fn
|
||||
|
||||
def knowledge_graph(self, docs=None):
|
||||
"""knowledge docs into graph store"""
|
||||
if not docs:
|
||||
if self.text_splitter:
|
||||
self.text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=2000, chunk_overlap=100
|
||||
)
|
||||
knowledge_source = get_knowledge_embedding(
|
||||
knowledge_type=self.knowledge_type,
|
||||
knowledge_source=self.knowledge_source,
|
||||
text_splitter=self.text_splitter,
|
||||
)
|
||||
docs = knowledge_source.read()
|
||||
if self.index_struct is None:
|
||||
self.index_struct = self._build_index_from_docs(docs)
|
||||
|
||||
def _extract_triplets(self, text: str) -> List[Tuple[str, str, str]]:
|
||||
"""Extract triplets from text by function or llm"""
|
||||
if self._extract_triplet_fn is not None:
|
||||
return self._extract_triplet_fn(text)
|
||||
else:
|
||||
return self._llm_extract_triplets(text)
|
||||
|
||||
def _llm_extract_triplets(self, text: str) -> List[Tuple[str, str, str]]:
|
||||
"""Extract triplets from text by llm"""
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.common.chat_util import llm_chat_response_nostream
|
||||
import uuid
|
||||
|
||||
chat_param = {
|
||||
"chat_session_id": uuid.uuid1(),
|
||||
"current_user_input": text,
|
||||
"select_param": "triplet",
|
||||
"model_name": self.model_name,
|
||||
}
|
||||
loop = utils.get_or_create_event_loop()
|
||||
triplets = loop.run_until_complete(
|
||||
llm_chat_response_nostream(
|
||||
ChatScene.ExtractTriplet.value(), **{"chat_param": chat_param}
|
||||
)
|
||||
)
|
||||
return triplets
|
||||
# response = self._service_context.llm_predictor.predict(
|
||||
# self.kg_triple_extract_template,
|
||||
# text=text,
|
||||
# )
|
||||
# print(response, flush=True)
|
||||
# return self._parse_triplet_response(
|
||||
# response, max_length=self._max_object_length
|
||||
# )
|
||||
|
||||
def _build_index_from_docs(self, documents: List[Document]) -> KG:
|
||||
"""Build the index from nodes."""
|
||||
index_struct = self.index_struct_cls()
|
||||
for doc in documents:
|
||||
triplets = self._extract_triplets(doc.page_content)
|
||||
if len(triplets) == 0:
|
||||
continue
|
||||
text_node = TextNode(text=doc.page_content, metadata=doc.metadata)
|
||||
logger.info(f"extracted knowledge triplets: {triplets}")
|
||||
for triplet in triplets:
|
||||
subj, _, obj = triplet
|
||||
self.graph_store.upsert_triplet(*triplet)
|
||||
index_struct.add_node([subj, obj], text_node)
|
||||
|
||||
|
||||
return index_struct
|
||||
|
||||
def search(self, query):
|
||||
from pilot.graph_engine.graph_search import RAGGraphSearch
|
||||
|
||||
graph_search = RAGGraphSearch(graph_engine=self)
|
||||
return graph_search.search(query)
|
||||
|
34
pilot/graph_engine/graph_factory.py
Normal file
34
pilot/graph_engine/graph_factory.py
Normal file
@ -0,0 +1,34 @@
|
||||
from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Type
|
||||
|
||||
from pilot.component import BaseComponent, ComponentType
|
||||
|
||||
|
||||
class RAGGraphFactory(BaseComponent, ABC):
|
||||
name = ComponentType.RAG_GRAPH_DEFAULT.value
|
||||
|
||||
@abstractmethod
|
||||
def create(self, model_name: str = None, embedding_cls: Type = None):
|
||||
"""Create RAG Graph Engine"""
|
||||
|
||||
|
||||
class DefaultRAGGraphFactory(RAGGraphFactory):
|
||||
def __init__(
|
||||
self, system_app=None, default_model_name: str = None, **kwargs: Any
|
||||
) -> None:
|
||||
super().__init__(system_app=system_app)
|
||||
self._default_model_name = default_model_name
|
||||
self.kwargs = kwargs
|
||||
from pilot.graph_engine.graph_engine import RAGGraphEngine
|
||||
|
||||
self.rag_engine = RAGGraphEngine(model_name="proxyllm")
|
||||
|
||||
def init_app(self, system_app):
|
||||
pass
|
||||
|
||||
def create(self, model_name: str = None, rag_cls: Type = None):
|
||||
if not model_name:
|
||||
model_name = self._default_model_name
|
||||
|
||||
return self.rag_engine
|
193
pilot/graph_engine/graph_search.py
Normal file
193
pilot/graph_engine/graph_search.py
Normal file
@ -0,0 +1,193 @@
|
||||
import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import List, Optional, Dict, Any, Set, Callable
|
||||
|
||||
from pilot.graph_engine.node import BaseNode, TextNode, NodeWithScore
|
||||
from pilot.graph_engine.search import BaseSearch, SearchMode
|
||||
from pilot.utils import utils
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_NODE_SCORE = 1000.0
|
||||
GLOBAL_EXPLORE_NODE_LIMIT = 3
|
||||
REL_TEXT_LIMIT = 30
|
||||
|
||||
|
||||
class RAGGraphSearch(BaseSearch):
|
||||
"""RAG Graph Search.
|
||||
|
||||
args:
|
||||
graph_engine RAGGraphEngine.
|
||||
model_name (str): model name
|
||||
(see :ref:`Prompt-Templates`).
|
||||
text_qa_template (Optional[BasePromptTemplate]): A Question Answering Prompt
|
||||
(see :ref:`Prompt-Templates`).
|
||||
max_keywords_per_query (int): Maximum number of keywords to extract from query.
|
||||
num_chunks_per_query (int): Maximum number of text chunks to query.
|
||||
search_mode (Optional[SearchMode]): Specifies whether to use keyowrds, default SearchMode.KEYWORD
|
||||
embeddings, or both to find relevant triplets. Should be one of "keyword",
|
||||
"embedding", or "hybrid".
|
||||
graph_store_query_depth (int): The depth of the graph store query.
|
||||
extract_subject_entities_fn (Optional[Callback]): extract_subject_entities callback.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph_engine,
|
||||
model_name: str = None,
|
||||
max_keywords_per_query: int = 10,
|
||||
num_chunks_per_query: int = 10,
|
||||
search_mode: Optional[SearchMode] = SearchMode.KEYWORD,
|
||||
graph_store_query_depth: int = 2,
|
||||
extract_subject_entities_fn: Optional[Callable] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize params."""
|
||||
from pilot.graph_engine.graph_engine import RAGGraphEngine
|
||||
|
||||
self.graph_engine: RAGGraphEngine = graph_engine
|
||||
self.model_name = model_name or self.graph_engine.model_name
|
||||
self._index_struct = self.graph_engine.index_struct
|
||||
self.max_keywords_per_query = max_keywords_per_query
|
||||
self.num_chunks_per_query = num_chunks_per_query
|
||||
self._search_mode = search_mode
|
||||
|
||||
self._graph_store = self.graph_engine.graph_store
|
||||
self.graph_store_query_depth = graph_store_query_depth
|
||||
self._verbose = kwargs.get("verbose", False)
|
||||
refresh_schema = kwargs.get("refresh_schema", False)
|
||||
self.extract_subject_entities_fn = extract_subject_entities_fn
|
||||
self.executor = ThreadPoolExecutor(max_workers=os.cpu_count() * 5)
|
||||
try:
|
||||
self._graph_schema = self._graph_store.get_schema(refresh=refresh_schema)
|
||||
except NotImplementedError:
|
||||
self._graph_schema = ""
|
||||
except Exception as e:
|
||||
logger.warn(f"can not to find graph schema: {e}")
|
||||
self._graph_schema = ""
|
||||
|
||||
def _extract_subject_entities(self, query_str: str) -> Set[str]:
|
||||
"""extract subject entities."""
|
||||
if self.extract_subject_entities_fn is not None:
|
||||
return self.extract_subject_entities_fn(query_str)
|
||||
else:
|
||||
return self._extract_entities_by_llm(query_str)
|
||||
|
||||
def _extract_entities_by_llm(self, text: str) -> Set[str]:
|
||||
"""extract subject entities from text by llm"""
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.common.chat_util import llm_chat_response_nostream
|
||||
import uuid
|
||||
|
||||
chat_param = {
|
||||
"chat_session_id": uuid.uuid1(),
|
||||
"current_user_input": text,
|
||||
"select_param": "entity",
|
||||
"model_name": self.model_name,
|
||||
}
|
||||
loop = utils.get_or_create_event_loop()
|
||||
entities = loop.run_until_complete(
|
||||
llm_chat_response_nostream(
|
||||
ChatScene.ExtractEntity.value(), **{"chat_param": chat_param}
|
||||
)
|
||||
)
|
||||
return entities
|
||||
|
||||
def _search(
|
||||
self,
|
||||
query_str: str,
|
||||
) -> List[NodeWithScore]:
|
||||
"""Get nodes for response."""
|
||||
node_visited = set()
|
||||
keywords = self._extract_subject_entities(query_str)
|
||||
print(f"extract entities: {keywords}\n")
|
||||
rel_texts = []
|
||||
cur_rel_map = {}
|
||||
chunk_indices_count: Dict[str, int] = defaultdict(int)
|
||||
if self._search_mode != SearchMode.EMBEDDING:
|
||||
for keyword in keywords:
|
||||
keyword = keyword.lower()
|
||||
subjs = set((keyword,))
|
||||
node_ids = self._index_struct.search_node_by_keyword(keyword)
|
||||
for node_id in node_ids[:GLOBAL_EXPLORE_NODE_LIMIT]:
|
||||
if node_id in node_visited:
|
||||
continue
|
||||
|
||||
if self._include_text:
|
||||
chunk_indices_count[node_id] += 1
|
||||
|
||||
node_visited.add(node_id)
|
||||
|
||||
rel_map = self._graph_store.get_rel_map(
|
||||
list(subjs), self.graph_store_query_depth
|
||||
)
|
||||
logger.debug(f"rel_map: {rel_map}")
|
||||
|
||||
if not rel_map:
|
||||
continue
|
||||
rel_texts.extend(
|
||||
[
|
||||
str(rel_obj)
|
||||
for rel_objs in rel_map.values()
|
||||
for rel_obj in rel_objs
|
||||
]
|
||||
)
|
||||
cur_rel_map.update(rel_map)
|
||||
|
||||
sorted_nodes_with_scores = []
|
||||
if not rel_texts:
|
||||
logger.info("> No relationships found, returning nodes found by keywords.")
|
||||
if len(sorted_nodes_with_scores) == 0:
|
||||
logger.info("> No nodes found by keywords, returning empty response.")
|
||||
return [
|
||||
NodeWithScore(node=TextNode(text="No relationships found."), score=1.0)
|
||||
]
|
||||
|
||||
# add relationships as Node
|
||||
# TODO: make initial text customizable
|
||||
rel_initial_text = (
|
||||
f"The following are knowledge sequence in max depth"
|
||||
f" {self.graph_store_query_depth} "
|
||||
f"in the form of directed graph like:\n"
|
||||
f"`subject -[predicate]->, object, <-[predicate_next_hop]-,"
|
||||
f" object_next_hop ...`"
|
||||
)
|
||||
rel_info = [rel_initial_text] + rel_texts
|
||||
rel_node_info = {
|
||||
"kg_rel_texts": rel_texts,
|
||||
"kg_rel_map": cur_rel_map,
|
||||
}
|
||||
if self._graph_schema != "":
|
||||
rel_node_info["kg_schema"] = {"schema": self._graph_schema}
|
||||
rel_info_text = "\n".join(
|
||||
[
|
||||
str(item)
|
||||
for sublist in rel_info
|
||||
for item in (sublist if isinstance(sublist, list) else [sublist])
|
||||
]
|
||||
)
|
||||
if self._verbose:
|
||||
print(f"KG context:\n{rel_info_text}\n", color="blue")
|
||||
rel_text_node = TextNode(
|
||||
text=rel_info_text,
|
||||
metadata=rel_node_info,
|
||||
excluded_embed_metadata_keys=["kg_rel_map", "kg_rel_texts"],
|
||||
excluded_llm_metadata_keys=["kg_rel_map", "kg_rel_texts"],
|
||||
)
|
||||
# this node is constructed from rel_texts, give high confidence to avoid cutoff
|
||||
sorted_nodes_with_scores.append(
|
||||
NodeWithScore(node=rel_text_node, score=DEFAULT_NODE_SCORE)
|
||||
)
|
||||
|
||||
return sorted_nodes_with_scores
|
||||
|
||||
def _get_metadata_for_response(
|
||||
self, nodes: List[BaseNode]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Get metadata for response."""
|
||||
for node in nodes:
|
||||
if node.metadata is None or "kg_rel_map" not in node.metadata:
|
||||
continue
|
||||
return node.metadata
|
||||
raise ValueError("kg_rel_map must be found in at least one Node.")
|
259
pilot/graph_engine/index_struct.py
Normal file
259
pilot/graph_engine/index_struct.py
Normal file
@ -0,0 +1,259 @@
|
||||
"""Data structures.
|
||||
|
||||
Nodes are decoupled from the indices.
|
||||
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Sequence, Set
|
||||
|
||||
from dataclasses_json import DataClassJsonMixin
|
||||
|
||||
|
||||
from pilot.graph_engine.index_type import IndexStructType
|
||||
from pilot.graph_engine.node import TextNode, BaseNode
|
||||
|
||||
# TODO: legacy backport of old Node class
|
||||
Node = TextNode
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexStruct(DataClassJsonMixin):
|
||||
"""A base data struct for a LlamaIndex."""
|
||||
|
||||
index_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
summary: Optional[str] = None
|
||||
|
||||
def get_summary(self) -> str:
|
||||
"""Get text summary."""
|
||||
if self.summary is None:
|
||||
raise ValueError("summary field of the index_struct not set.")
|
||||
return self.summary
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_type(cls):
|
||||
"""Get index struct type."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexGraph(IndexStruct):
|
||||
"""A graph representing the tree-structured index."""
|
||||
|
||||
# mapping from index in tree to Node doc id.
|
||||
all_nodes: Dict[int, str] = field(default_factory=dict)
|
||||
root_nodes: Dict[int, str] = field(default_factory=dict)
|
||||
node_id_to_children_ids: Dict[str, List[str]] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def node_id_to_index(self) -> Dict[str, int]:
|
||||
"""Map from node id to index."""
|
||||
return {node_id: index for index, node_id in self.all_nodes.items()}
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
"""Get the size of the graph."""
|
||||
return len(self.all_nodes)
|
||||
|
||||
def get_index(self, node: BaseNode) -> int:
|
||||
"""Get index of node."""
|
||||
return self.node_id_to_index[node.node_id]
|
||||
|
||||
def insert(
|
||||
self,
|
||||
node: BaseNode,
|
||||
index: Optional[int] = None,
|
||||
children_nodes: Optional[Sequence[BaseNode]] = None,
|
||||
) -> None:
|
||||
"""Insert node."""
|
||||
index = index or self.size
|
||||
node_id = node.node_id
|
||||
|
||||
self.all_nodes[index] = node_id
|
||||
|
||||
if children_nodes is None:
|
||||
children_nodes = []
|
||||
children_ids = [n.node_id for n in children_nodes]
|
||||
self.node_id_to_children_ids[node_id] = children_ids
|
||||
|
||||
def get_children(self, parent_node: Optional[BaseNode]) -> Dict[int, str]:
|
||||
"""Get children nodes."""
|
||||
if parent_node is None:
|
||||
return self.root_nodes
|
||||
else:
|
||||
parent_id = parent_node.node_id
|
||||
children_ids = self.node_id_to_children_ids[parent_id]
|
||||
return {
|
||||
self.node_id_to_index[child_id]: child_id for child_id in children_ids
|
||||
}
|
||||
|
||||
def insert_under_parent(
|
||||
self,
|
||||
node: BaseNode,
|
||||
parent_node: Optional[BaseNode],
|
||||
new_index: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Insert under parent node."""
|
||||
new_index = new_index or self.size
|
||||
if parent_node is None:
|
||||
self.root_nodes[new_index] = node.node_id
|
||||
self.node_id_to_children_ids[node.node_id] = []
|
||||
else:
|
||||
if parent_node.node_id not in self.node_id_to_children_ids:
|
||||
self.node_id_to_children_ids[parent_node.node_id] = []
|
||||
self.node_id_to_children_ids[parent_node.node_id].append(node.node_id)
|
||||
|
||||
self.all_nodes[new_index] = node.node_id
|
||||
|
||||
@classmethod
|
||||
def get_type(cls) -> IndexStructType:
|
||||
"""Get type."""
|
||||
return IndexStructType.TREE
|
||||
|
||||
|
||||
@dataclass
|
||||
class KeywordTable(IndexStruct):
|
||||
"""A table of keywords mapping keywords to text chunks."""
|
||||
|
||||
table: Dict[str, Set[str]] = field(default_factory=dict)
|
||||
|
||||
def add_node(self, keywords: List[str], node: BaseNode) -> None:
|
||||
"""Add text to table."""
|
||||
for keyword in keywords:
|
||||
if keyword not in self.table:
|
||||
self.table[keyword] = set()
|
||||
self.table[keyword].add(node.node_id)
|
||||
|
||||
@property
|
||||
def node_ids(self) -> Set[str]:
|
||||
"""Get all node ids."""
|
||||
return set.union(*self.table.values())
|
||||
|
||||
@property
|
||||
def keywords(self) -> Set[str]:
|
||||
"""Get all keywords in the table."""
|
||||
return set(self.table.keys())
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
"""Get the size of the table."""
|
||||
return len(self.table)
|
||||
|
||||
@classmethod
|
||||
def get_type(cls) -> IndexStructType:
|
||||
"""Get type."""
|
||||
return IndexStructType.KEYWORD_TABLE
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexList(IndexStruct):
|
||||
"""A list of documents."""
|
||||
|
||||
nodes: List[str] = field(default_factory=list)
|
||||
|
||||
def add_node(self, node: BaseNode) -> None:
|
||||
"""Add text to table, return current position in list."""
|
||||
# don't worry about child indices for now, nodes are all in order
|
||||
self.nodes.append(node.node_id)
|
||||
|
||||
@classmethod
|
||||
def get_type(cls) -> IndexStructType:
|
||||
"""Get type."""
|
||||
return IndexStructType.LIST
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexDict(IndexStruct):
|
||||
"""A simple dictionary of documents."""
|
||||
|
||||
# TODO: slightly deprecated, should likely be a list or set now
|
||||
# mapping from vector store id to node doc_id
|
||||
nodes_dict: Dict[str, str] = field(default_factory=dict)
|
||||
|
||||
# TODO: deprecated, not used
|
||||
# mapping from node doc_id to vector store id
|
||||
doc_id_dict: Dict[str, List[str]] = field(default_factory=dict)
|
||||
|
||||
# TODO: deprecated, not used
|
||||
# this should be empty for all other indices
|
||||
embeddings_dict: Dict[str, List[float]] = field(default_factory=dict)
|
||||
|
||||
def add_node(
|
||||
self,
|
||||
node: BaseNode,
|
||||
text_id: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Add text to table, return current position in list."""
|
||||
# # don't worry about child indices for now, nodes are all in order
|
||||
# self.nodes_dict[int_id] = node
|
||||
vector_id = text_id if text_id is not None else node.node_id
|
||||
self.nodes_dict[vector_id] = node.node_id
|
||||
|
||||
return vector_id
|
||||
|
||||
def delete(self, doc_id: str) -> None:
|
||||
"""Delete a Node."""
|
||||
del self.nodes_dict[doc_id]
|
||||
|
||||
@classmethod
|
||||
def get_type(cls) -> IndexStructType:
|
||||
"""Get type."""
|
||||
return IndexStructType.VECTOR_STORE
|
||||
|
||||
|
||||
@dataclass
|
||||
class KG(IndexStruct):
|
||||
"""A table of keywords mapping keywords to text chunks."""
|
||||
|
||||
# Unidirectional
|
||||
|
||||
# table of keywords to node ids
|
||||
table: Dict[str, Set[str]] = field(default_factory=dict)
|
||||
|
||||
# TODO: legacy attribute, remove in future releases
|
||||
rel_map: Dict[str, List[List[str]]] = field(default_factory=dict)
|
||||
|
||||
# TBD, should support vector store, now we just persist the embedding memory
|
||||
# maybe chainable abstractions for *_stores could be designed
|
||||
embedding_dict: Dict[str, List[float]] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def node_ids(self) -> Set[str]:
|
||||
"""Get all node ids."""
|
||||
return set.union(*self.table.values())
|
||||
|
||||
def add_to_embedding_dict(self, triplet_str: str, embedding: List[float]) -> None:
|
||||
"""Add embedding to dict."""
|
||||
self.embedding_dict[triplet_str] = embedding
|
||||
|
||||
def add_node(self, keywords: List[str], node: BaseNode) -> None:
|
||||
"""Add text to table."""
|
||||
node_id = node.node_id
|
||||
for keyword in keywords:
|
||||
keyword = keyword.lower()
|
||||
if keyword not in self.table:
|
||||
self.table[keyword] = set()
|
||||
self.table[keyword].add(node_id)
|
||||
|
||||
def search_node_by_keyword(self, keyword: str) -> List[str]:
|
||||
"""Search for nodes by keyword."""
|
||||
if keyword not in self.table:
|
||||
return []
|
||||
return list(self.table[keyword])
|
||||
|
||||
@classmethod
|
||||
def get_type(cls) -> IndexStructType:
|
||||
"""Get type."""
|
||||
return IndexStructType.KG
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmptyIndexStruct(IndexStruct):
|
||||
"""Empty index."""
|
||||
|
||||
@classmethod
|
||||
def get_type(cls) -> IndexStructType:
|
||||
"""Get type."""
|
||||
return IndexStructType.EMPTY
|
48
pilot/graph_engine/index_type.py
Normal file
48
pilot/graph_engine/index_type.py
Normal file
@ -0,0 +1,48 @@
|
||||
"""IndexStructType class."""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class IndexStructType(str, Enum):
|
||||
"""Index struct type. Identifier for a "type" of index.
|
||||
|
||||
Attributes:
|
||||
TREE ("tree"): Tree index. See :ref:`Ref-Indices-Tree` for tree indices.
|
||||
LIST ("list"): Summary index. See :ref:`Ref-Indices-List` for summary indices.
|
||||
KEYWORD_TABLE ("keyword_table"): Keyword table index. See
|
||||
:ref:`Ref-Indices-Table`
|
||||
for keyword table indices.
|
||||
DICT ("dict"): Faiss Vector Store Index. See
|
||||
:ref:`Ref-Indices-VectorStore`
|
||||
for more information on the faiss vector store index.
|
||||
SIMPLE_DICT ("simple_dict"): Simple Vector Store Index. See
|
||||
:ref:`Ref-Indices-VectorStore`
|
||||
for more information on the simple vector store index.
|
||||
KG ("kg"): Knowledge Graph index.
|
||||
See :ref:`Ref-Indices-Knowledge-Graph` for KG indices.
|
||||
DOCUMENT_SUMMARY ("document_summary"): Document Summary Index.
|
||||
See :ref:`Ref-Indices-Document-Summary` for Summary Indices.
|
||||
|
||||
"""
|
||||
|
||||
# TODO: refactor so these are properties on the base class
|
||||
|
||||
NODE = "node"
|
||||
TREE = "tree"
|
||||
LIST = "list"
|
||||
KEYWORD_TABLE = "keyword_table"
|
||||
|
||||
DICT = "dict"
|
||||
# simple
|
||||
SIMPLE_DICT = "simple_dict"
|
||||
# for KG index
|
||||
KG = "kg"
|
||||
SIMPLE_KG = "simple_kg"
|
||||
NEBULAGRAPH = "nebulagraph"
|
||||
FALKORDB = "falkordb"
|
||||
|
||||
# EMPTY
|
||||
EMPTY = "empty"
|
||||
COMPOSITE = "composite"
|
||||
|
||||
DOCUMENT_SUMMARY = "document_summary"
|
74
pilot/graph_engine/kv_index.py
Normal file
74
pilot/graph_engine/kv_index.py
Normal file
@ -0,0 +1,74 @@
|
||||
from typing import List, Optional
|
||||
from llama_index.data_structs.data_structs import IndexStruct
|
||||
from llama_index.storage.index_store.utils import (
|
||||
index_struct_to_json,
|
||||
json_to_index_struct,
|
||||
)
|
||||
from llama_index.storage.kvstore.types import BaseKVStore
|
||||
|
||||
DEFAULT_NAMESPACE = "index_store"
|
||||
|
||||
|
||||
class KVIndexStore:
|
||||
"""Key-Value Index store.
|
||||
|
||||
Args:
|
||||
kvstore (BaseKVStore): key-value store
|
||||
namespace (str): namespace for the index store
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, kvstore: BaseKVStore, namespace: Optional[str] = None) -> None:
|
||||
"""Init a KVIndexStore."""
|
||||
self._kvstore = kvstore
|
||||
self._namespace = namespace or DEFAULT_NAMESPACE
|
||||
self._collection = f"{self._namespace}/data"
|
||||
|
||||
def add_index_struct(self, index_struct: IndexStruct) -> None:
|
||||
"""Add an index struct.
|
||||
|
||||
Args:
|
||||
index_struct (IndexStruct): index struct
|
||||
|
||||
"""
|
||||
key = index_struct.index_id
|
||||
data = index_struct_to_json(index_struct)
|
||||
self._kvstore.put(key, data, collection=self._collection)
|
||||
|
||||
def delete_index_struct(self, key: str) -> None:
|
||||
"""Delete an index struct.
|
||||
|
||||
Args:
|
||||
key (str): index struct key
|
||||
|
||||
"""
|
||||
self._kvstore.delete(key, collection=self._collection)
|
||||
|
||||
def get_index_struct(
|
||||
self, struct_id: Optional[str] = None
|
||||
) -> Optional[IndexStruct]:
|
||||
"""Get an index struct.
|
||||
|
||||
Args:
|
||||
struct_id (Optional[str]): index struct id
|
||||
|
||||
"""
|
||||
if struct_id is None:
|
||||
structs = self.index_structs()
|
||||
assert len(structs) == 1
|
||||
return structs[0]
|
||||
else:
|
||||
json = self._kvstore.get(struct_id, collection=self._collection)
|
||||
if json is None:
|
||||
return None
|
||||
return json_to_index_struct(json)
|
||||
|
||||
def index_structs(self) -> List[IndexStruct]:
|
||||
"""Get all index structs.
|
||||
|
||||
Returns:
|
||||
List[IndexStruct]: index structs
|
||||
|
||||
"""
|
||||
jsons = self._kvstore.get_all(collection=self._collection)
|
||||
return [json_to_index_struct(json) for json in jsons.values()]
|
569
pilot/graph_engine/node.py
Normal file
569
pilot/graph_engine/node.py
Normal file
@ -0,0 +1,569 @@
|
||||
"""Base schema for data structures."""
|
||||
import json
|
||||
import textwrap
|
||||
import uuid
|
||||
from abc import abstractmethod
|
||||
from enum import Enum, auto
|
||||
from hashlib import sha256
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langchain.schema import Document
|
||||
from pydantic import BaseModel, Field, root_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
DEFAULT_TEXT_NODE_TMPL = "{metadata_str}\n\n{content}"
|
||||
DEFAULT_METADATA_TMPL = "{key}: {value}"
|
||||
# NOTE: for pretty printing
|
||||
TRUNCATE_LENGTH = 350
|
||||
WRAP_WIDTH = 70
|
||||
|
||||
|
||||
class BaseComponent(BaseModel):
|
||||
"""Base component object to caputure class names."""
|
||||
"""reference llama-index"""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def class_name(cls) -> str:
|
||||
"""Get class name."""
|
||||
|
||||
def to_dict(self, **kwargs: Any) -> Dict[str, Any]:
|
||||
data = self.dict(**kwargs)
|
||||
data["class_name"] = self.class_name()
|
||||
return data
|
||||
|
||||
def to_json(self, **kwargs: Any) -> str:
|
||||
data = self.to_dict(**kwargs)
|
||||
return json.dumps(data)
|
||||
|
||||
# TODO: return type here not supported by current mypy version
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any], **kwargs: Any) -> Self: # type: ignore
|
||||
if isinstance(kwargs, dict):
|
||||
data.update(kwargs)
|
||||
|
||||
data.pop("class_name", None)
|
||||
return cls(**data)
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, data_str: str, **kwargs: Any) -> Self: # type: ignore
|
||||
data = json.loads(data_str)
|
||||
return cls.from_dict(data, **kwargs)
|
||||
|
||||
|
||||
class NodeRelationship(str, Enum):
|
||||
"""Node relationships used in `BaseNode` class.
|
||||
|
||||
Attributes:
|
||||
SOURCE: The node is the source document.
|
||||
PREVIOUS: The node is the previous node in the document.
|
||||
NEXT: The node is the next node in the document.
|
||||
PARENT: The node is the parent node in the document.
|
||||
CHILD: The node is a child node in the document.
|
||||
|
||||
"""
|
||||
|
||||
SOURCE = auto()
|
||||
PREVIOUS = auto()
|
||||
NEXT = auto()
|
||||
PARENT = auto()
|
||||
CHILD = auto()
|
||||
|
||||
|
||||
class ObjectType(str, Enum):
|
||||
TEXT = auto()
|
||||
IMAGE = auto()
|
||||
INDEX = auto()
|
||||
DOCUMENT = auto()
|
||||
|
||||
|
||||
class MetadataMode(str, Enum):
|
||||
ALL = auto()
|
||||
EMBED = auto()
|
||||
LLM = auto()
|
||||
NONE = auto()
|
||||
|
||||
|
||||
class RelatedNodeInfo(BaseComponent):
|
||||
node_id: str
|
||||
node_type: Optional[ObjectType] = None
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
hash: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
"""Get class name."""
|
||||
return "RelatedNodeInfo"
|
||||
|
||||
|
||||
RelatedNodeType = Union[RelatedNodeInfo, List[RelatedNodeInfo]]
|
||||
|
||||
|
||||
# Node classes for indexes
|
||||
class BaseNode(BaseComponent):
|
||||
"""Base node Object.
|
||||
|
||||
Generic abstract interface for retrievable nodes
|
||||
|
||||
"""
|
||||
|
||||
class Config:
|
||||
allow_population_by_field_name = True
|
||||
|
||||
id_: str = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()), description="Unique ID of the node."
|
||||
)
|
||||
embedding: Optional[List[float]] = Field(
|
||||
default=None, description="Embedding of the node."
|
||||
)
|
||||
|
||||
""""
|
||||
metadata fields
|
||||
- injected as part of the text shown to LLMs as context
|
||||
- injected as part of the text for generating embeddings
|
||||
- used by vector DBs for metadata filtering
|
||||
|
||||
"""
|
||||
metadata: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="A flat dictionary of metadata fields",
|
||||
alias="extra_info",
|
||||
)
|
||||
excluded_embed_metadata_keys: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="Metadata keys that are exluded from text for the embed model.",
|
||||
)
|
||||
excluded_llm_metadata_keys: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="Metadata keys that are exluded from text for the LLM.",
|
||||
)
|
||||
relationships: Dict[NodeRelationship, RelatedNodeType] = Field(
|
||||
default_factory=dict,
|
||||
description="A mapping of relationships to other node information.",
|
||||
)
|
||||
hash: str = Field(default="", description="Hash of the node content.")
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_type(cls) -> str:
|
||||
"""Get Object type."""
|
||||
|
||||
@abstractmethod
|
||||
def get_content(self, metadata_mode: MetadataMode = MetadataMode.ALL) -> str:
|
||||
"""Get object content."""
|
||||
|
||||
@abstractmethod
|
||||
def get_metadata_str(self, mode: MetadataMode = MetadataMode.ALL) -> str:
|
||||
"""Metadata string."""
|
||||
|
||||
@abstractmethod
|
||||
def set_content(self, value: Any) -> None:
|
||||
"""Set the content of the node."""
|
||||
|
||||
@property
|
||||
def node_id(self) -> str:
|
||||
return self.id_
|
||||
|
||||
@node_id.setter
|
||||
def node_id(self, value: str) -> None:
|
||||
self.id_ = value
|
||||
|
||||
@property
|
||||
def source_node(self) -> Optional[RelatedNodeInfo]:
|
||||
"""Source object node.
|
||||
|
||||
Extracted from the relationships field.
|
||||
|
||||
"""
|
||||
if NodeRelationship.SOURCE not in self.relationships:
|
||||
return None
|
||||
|
||||
relation = self.relationships[NodeRelationship.SOURCE]
|
||||
if isinstance(relation, list):
|
||||
raise ValueError("Source object must be a single RelatedNodeInfo object")
|
||||
return relation
|
||||
|
||||
@property
|
||||
def prev_node(self) -> Optional[RelatedNodeInfo]:
|
||||
"""Prev node."""
|
||||
if NodeRelationship.PREVIOUS not in self.relationships:
|
||||
return None
|
||||
|
||||
relation = self.relationships[NodeRelationship.PREVIOUS]
|
||||
if not isinstance(relation, RelatedNodeInfo):
|
||||
raise ValueError("Previous object must be a single RelatedNodeInfo object")
|
||||
return relation
|
||||
|
||||
@property
|
||||
def next_node(self) -> Optional[RelatedNodeInfo]:
|
||||
"""Next node."""
|
||||
if NodeRelationship.NEXT not in self.relationships:
|
||||
return None
|
||||
|
||||
relation = self.relationships[NodeRelationship.NEXT]
|
||||
if not isinstance(relation, RelatedNodeInfo):
|
||||
raise ValueError("Next object must be a single RelatedNodeInfo object")
|
||||
return relation
|
||||
|
||||
@property
|
||||
def parent_node(self) -> Optional[RelatedNodeInfo]:
|
||||
"""Parent node."""
|
||||
if NodeRelationship.PARENT not in self.relationships:
|
||||
return None
|
||||
|
||||
relation = self.relationships[NodeRelationship.PARENT]
|
||||
if not isinstance(relation, RelatedNodeInfo):
|
||||
raise ValueError("Parent object must be a single RelatedNodeInfo object")
|
||||
return relation
|
||||
|
||||
@property
|
||||
def child_nodes(self) -> Optional[List[RelatedNodeInfo]]:
|
||||
"""Child nodes."""
|
||||
if NodeRelationship.CHILD not in self.relationships:
|
||||
return None
|
||||
|
||||
relation = self.relationships[NodeRelationship.CHILD]
|
||||
if not isinstance(relation, list):
|
||||
raise ValueError("Child objects must be a list of RelatedNodeInfo objects.")
|
||||
return relation
|
||||
|
||||
@property
|
||||
def ref_doc_id(self) -> Optional[str]:
|
||||
"""Deprecated: Get ref doc id."""
|
||||
source_node = self.source_node
|
||||
if source_node is None:
|
||||
return None
|
||||
return source_node.node_id
|
||||
|
||||
@property
|
||||
def extra_info(self) -> Dict[str, Any]:
|
||||
"""TODO: DEPRECATED: Extra info."""
|
||||
return self.metadata
|
||||
|
||||
def __str__(self) -> str:
|
||||
source_text_truncated = truncate_text(
|
||||
self.get_content().strip(), TRUNCATE_LENGTH
|
||||
)
|
||||
source_text_wrapped = textwrap.fill(
|
||||
f"Text: {source_text_truncated}\n", width=WRAP_WIDTH
|
||||
)
|
||||
return f"Node ID: {self.node_id}\n{source_text_wrapped}"
|
||||
|
||||
def truncate_text(text: str, max_length: int) -> str:
|
||||
"""Truncate text to a maximum length."""
|
||||
if len(text) <= max_length:
|
||||
return text
|
||||
return text[: max_length - 3] + "..."
|
||||
|
||||
def get_embedding(self) -> List[float]:
|
||||
"""Get embedding.
|
||||
|
||||
Errors if embedding is None.
|
||||
|
||||
"""
|
||||
if self.embedding is None:
|
||||
raise ValueError("embedding not set.")
|
||||
return self.embedding
|
||||
|
||||
def as_related_node_info(self) -> RelatedNodeInfo:
|
||||
"""Get node as RelatedNodeInfo."""
|
||||
return RelatedNodeInfo(
|
||||
node_id=self.node_id, metadata=self.metadata, hash=self.hash
|
||||
)
|
||||
|
||||
|
||||
class TextNode(BaseNode):
|
||||
text: str = Field(default="", description="Text content of the node.")
|
||||
start_char_idx: Optional[int] = Field(
|
||||
default=None, description="Start char index of the node."
|
||||
)
|
||||
end_char_idx: Optional[int] = Field(
|
||||
default=None, description="End char index of the node."
|
||||
)
|
||||
text_template: str = Field(
|
||||
default=DEFAULT_TEXT_NODE_TMPL,
|
||||
description=(
|
||||
"Template for how text is formatted, with {content} and "
|
||||
"{metadata_str} placeholders."
|
||||
),
|
||||
)
|
||||
metadata_template: str = Field(
|
||||
default=DEFAULT_METADATA_TMPL,
|
||||
description=(
|
||||
"Template for how metadata is formatted, with {key} and "
|
||||
"{value} placeholders."
|
||||
),
|
||||
)
|
||||
metadata_seperator: str = Field(
|
||||
default="\n",
|
||||
description="Seperator between metadata fields when converting to string.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
"""Get class name."""
|
||||
return "TextNode"
|
||||
|
||||
@root_validator
|
||||
def _check_hash(cls, values: dict) -> dict:
|
||||
"""Generate a hash to represent the node."""
|
||||
text = values.get("text", "")
|
||||
metadata = values.get("metadata", {})
|
||||
doc_identity = str(text) + str(metadata)
|
||||
values["hash"] = str(
|
||||
sha256(doc_identity.encode("utf-8", "surrogatepass")).hexdigest()
|
||||
)
|
||||
return values
|
||||
|
||||
@classmethod
|
||||
def get_type(cls) -> str:
|
||||
"""Get Object type."""
|
||||
return ObjectType.TEXT
|
||||
|
||||
def get_content(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str:
|
||||
"""Get object content."""
|
||||
metadata_str = self.get_metadata_str(mode=metadata_mode).strip()
|
||||
if not metadata_str:
|
||||
return self.text
|
||||
|
||||
return self.text_template.format(
|
||||
content=self.text, metadata_str=metadata_str
|
||||
).strip()
|
||||
|
||||
def get_metadata_str(self, mode: MetadataMode = MetadataMode.ALL) -> str:
|
||||
"""metadata info string."""
|
||||
if mode == MetadataMode.NONE:
|
||||
return ""
|
||||
|
||||
usable_metadata_keys = set(self.metadata.keys())
|
||||
if mode == MetadataMode.LLM:
|
||||
for key in self.excluded_llm_metadata_keys:
|
||||
if key in usable_metadata_keys:
|
||||
usable_metadata_keys.remove(key)
|
||||
elif mode == MetadataMode.EMBED:
|
||||
for key in self.excluded_embed_metadata_keys:
|
||||
if key in usable_metadata_keys:
|
||||
usable_metadata_keys.remove(key)
|
||||
|
||||
return self.metadata_seperator.join(
|
||||
[
|
||||
self.metadata_template.format(key=key, value=str(value))
|
||||
for key, value in self.metadata.items()
|
||||
if key in usable_metadata_keys
|
||||
]
|
||||
)
|
||||
|
||||
def set_content(self, value: str) -> None:
|
||||
"""Set the content of the node."""
|
||||
self.text = value
|
||||
|
||||
def get_node_info(self) -> Dict[str, Any]:
|
||||
"""Get node info."""
|
||||
return {"start": self.start_char_idx, "end": self.end_char_idx}
|
||||
|
||||
def get_text(self) -> str:
|
||||
return self.get_content(metadata_mode=MetadataMode.NONE)
|
||||
|
||||
@property
|
||||
def node_info(self) -> Dict[str, Any]:
|
||||
"""Deprecated: Get node info."""
|
||||
return self.get_node_info()
|
||||
|
||||
|
||||
# TODO: legacy backport of old Node class
|
||||
Node = TextNode
|
||||
|
||||
|
||||
class ImageNode(TextNode):
|
||||
"""Node with image."""
|
||||
|
||||
# TODO: store reference instead of actual image
|
||||
# base64 encoded image str
|
||||
image: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def get_type(cls) -> str:
|
||||
return ObjectType.IMAGE
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
"""Get class name."""
|
||||
return "ImageNode"
|
||||
|
||||
|
||||
class IndexNode(TextNode):
|
||||
"""Node with reference to any object.
|
||||
|
||||
This can include other indices, query engines, retrievers.
|
||||
|
||||
This can also include other nodes (though this is overlapping with `relationships`
|
||||
on the Node class).
|
||||
|
||||
"""
|
||||
|
||||
index_id: str
|
||||
|
||||
@classmethod
|
||||
def from_text_node(
|
||||
cls,
|
||||
node: TextNode,
|
||||
index_id: str,
|
||||
) -> "IndexNode":
|
||||
"""Create index node from text node."""
|
||||
# copy all attributes from text node, add index id
|
||||
return cls(
|
||||
**node.dict(),
|
||||
index_id=index_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_type(cls) -> str:
|
||||
return ObjectType.INDEX
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
"""Get class name."""
|
||||
return "IndexNode"
|
||||
|
||||
|
||||
class NodeWithScore(BaseComponent):
|
||||
node: BaseNode
|
||||
score: Optional[float] = None
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.node}\nScore: {self.score: 0.3f}\n"
|
||||
|
||||
def get_score(self, raise_error: bool = False) -> float:
|
||||
"""Get score."""
|
||||
if self.score is None:
|
||||
if raise_error:
|
||||
raise ValueError("Score not set.")
|
||||
else:
|
||||
return 0.0
|
||||
else:
|
||||
return self.score
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
"""Get class name."""
|
||||
return "NodeWithScore"
|
||||
|
||||
##### pass through methods to BaseNode #####
|
||||
@property
|
||||
def node_id(self) -> str:
|
||||
return self.node.node_id
|
||||
|
||||
@property
|
||||
def id_(self) -> str:
|
||||
return self.node.id_
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
if isinstance(self.node, TextNode):
|
||||
return self.node.text
|
||||
else:
|
||||
raise ValueError("Node must be a TextNode to get text.")
|
||||
|
||||
@property
|
||||
def metadata(self) -> Dict[str, Any]:
|
||||
return self.node.metadata
|
||||
|
||||
@property
|
||||
def embedding(self) -> Optional[List[float]]:
|
||||
return self.node.embedding
|
||||
|
||||
def get_text(self) -> str:
|
||||
if isinstance(self.node, TextNode):
|
||||
return self.node.get_text()
|
||||
else:
|
||||
raise ValueError("Node must be a TextNode to get text.")
|
||||
|
||||
def get_content(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str:
|
||||
return self.node.get_content(metadata_mode=metadata_mode)
|
||||
|
||||
def get_embedding(self) -> List[float]:
|
||||
return self.node.get_embedding()
|
||||
|
||||
|
||||
# Document Classes for Readers
|
||||
|
||||
|
||||
class Document(TextNode):
|
||||
"""Generic interface for a data document.
|
||||
|
||||
This document connects to data sources.
|
||||
|
||||
"""
|
||||
|
||||
# TODO: A lot of backwards compatibility logic here, clean up
|
||||
id_: str = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
description="Unique ID of the node.",
|
||||
alias="doc_id",
|
||||
)
|
||||
|
||||
_compat_fields = {"doc_id": "id_", "extra_info": "metadata"}
|
||||
|
||||
@classmethod
|
||||
def get_type(cls) -> str:
|
||||
"""Get Document type."""
|
||||
return ObjectType.DOCUMENT
|
||||
|
||||
@property
|
||||
def doc_id(self) -> str:
|
||||
"""Get document ID."""
|
||||
return self.id_
|
||||
|
||||
def __str__(self) -> str:
|
||||
source_text_truncated = truncate_text(
|
||||
self.get_content().strip(), TRUNCATE_LENGTH
|
||||
)
|
||||
source_text_wrapped = textwrap.fill(
|
||||
f"Text: {source_text_truncated}\n", width=WRAP_WIDTH
|
||||
)
|
||||
return f"Doc ID: {self.doc_id}\n{source_text_wrapped}"
|
||||
|
||||
def get_doc_id(self) -> str:
|
||||
"""TODO: Deprecated: Get document ID."""
|
||||
return self.id_
|
||||
|
||||
def __setattr__(self, name: str, value: object) -> None:
|
||||
if name in self._compat_fields:
|
||||
name = self._compat_fields[name]
|
||||
super().__setattr__(name, value)
|
||||
|
||||
def to_langchain_format(self) -> Document:
|
||||
"""Convert struct to LangChain document format."""
|
||||
metadata = self.metadata or {}
|
||||
return Document(page_content=self.text, metadata=metadata)
|
||||
|
||||
@classmethod
|
||||
def from_langchain_format(cls, doc: Document) -> "Document":
|
||||
"""Convert struct from LangChain document format."""
|
||||
return cls(text=doc.page_content, metadata=doc.metadata)
|
||||
|
||||
@classmethod
|
||||
def example(cls) -> "Document":
|
||||
document = Document(
|
||||
text="",
|
||||
metadata={"filename": "README.md", "category": "codebase"},
|
||||
)
|
||||
return document
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
"""Get class name."""
|
||||
return "Document"
|
||||
|
||||
|
||||
class ImageDocument(Document):
|
||||
"""Data document containing an image."""
|
||||
|
||||
# base64 encoded image str
|
||||
image: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
"""Get class name."""
|
||||
return "ImageDocument"
|
44
pilot/graph_engine/search.py
Normal file
44
pilot/graph_engine/search.py
Normal file
@ -0,0 +1,44 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class SearchMode(str, Enum):
|
||||
"""Query mode enum for Knowledge Graphs.
|
||||
|
||||
Can be passed as the enum struct, or as the underlying string.
|
||||
|
||||
Attributes:
|
||||
KEYWORD ("keyword"): Default query mode, using keywords to find triplets.
|
||||
EMBEDDING ("embedding"): Embedding mode, using embeddings to find
|
||||
similar triplets.
|
||||
HYBRID ("hybrid"): Hyrbid mode, combining both keywords and embeddings
|
||||
to find relevant triplets.
|
||||
"""
|
||||
|
||||
KEYWORD = "keyword"
|
||||
EMBEDDING = "embedding"
|
||||
HYBRID = "hybrid"
|
||||
|
||||
|
||||
class BaseSearch(ABC):
|
||||
"""Base Search."""
|
||||
|
||||
def search(self, query: str):
|
||||
"""Retrieve nodes given query.
|
||||
|
||||
Args:
|
||||
query (QueryType): Either a query string or
|
||||
a QueryBundle object.
|
||||
|
||||
"""
|
||||
# if isinstance(query, str):
|
||||
return self._search(query)
|
||||
|
||||
@abstractmethod
|
||||
def _search(self, query: str):
|
||||
"""search nodes given query.
|
||||
|
||||
Implemented by the user.
|
||||
|
||||
"""
|
||||
pass
|
35
pilot/scene/chat_knowledge/extract_entity/chat.py
Normal file
35
pilot/scene/chat_knowledge/extract_entity/chat.py
Normal file
@ -0,0 +1,35 @@
|
||||
from typing import Dict
|
||||
|
||||
from pilot.scene.base_chat import BaseChat
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.configs.config import Config
|
||||
|
||||
from pilot.scene.chat_knowledge.extract_entity.prompt import prompt
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class ExtractEntity(BaseChat):
|
||||
chat_scene: str = ChatScene.ExtractEntity.value()
|
||||
|
||||
"""extracting entities by llm"""
|
||||
|
||||
def __init__(self, chat_param: Dict):
|
||||
""" """
|
||||
chat_param["chat_mode"] = ChatScene.ExtractEntity
|
||||
super().__init__(
|
||||
chat_param=chat_param,
|
||||
)
|
||||
|
||||
self.user_input = chat_param["current_user_input"]
|
||||
self.extract_mode = chat_param["select_param"]
|
||||
|
||||
def generate_input_values(self):
|
||||
input_values = {
|
||||
"text": self.user_input,
|
||||
}
|
||||
return input_values
|
||||
|
||||
@property
|
||||
def chat_type(self) -> str:
|
||||
return ChatScene.ExtractEntity.value
|
39
pilot/scene/chat_knowledge/extract_entity/out_parser.py
Normal file
39
pilot/scene/chat_knowledge/extract_entity/out_parser.py
Normal file
@ -0,0 +1,39 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Set
|
||||
|
||||
from pilot.out_parser.base import BaseOutputParser, T
|
||||
from pilot.configs.config import Config
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExtractEntityParser(BaseOutputParser):
|
||||
def __init__(self, sep: str, is_stream_out: bool):
|
||||
super().__init__(sep=sep, is_stream_out=is_stream_out)
|
||||
|
||||
def parse_prompt_response(self, response, max_length: int = 128) -> Set[str]:
|
||||
lowercase = True
|
||||
# clean_str = super().parse_prompt_response(response)
|
||||
print("clean prompt response:", response)
|
||||
|
||||
results = []
|
||||
response = response.strip() # Strip newlines from responses.
|
||||
|
||||
if response.startswith("KEYWORDS:"):
|
||||
response = response[len("KEYWORDS:") :]
|
||||
|
||||
keywords = response.split(",")
|
||||
for k in keywords:
|
||||
rk = k
|
||||
if lowercase:
|
||||
rk = rk.lower()
|
||||
results.append(rk.strip())
|
||||
|
||||
return set(results)
|
||||
|
||||
def parse_view_response(self, speak, data) -> str:
|
||||
return data
|
52
pilot/scene/chat_knowledge/extract_entity/prompt.py
Normal file
52
pilot/scene/chat_knowledge/extract_entity/prompt.py
Normal file
@ -0,0 +1,52 @@
|
||||
import json
|
||||
|
||||
from pilot.prompts.prompt_new import PromptTemplate
|
||||
from pilot.configs.config import Config
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.common.schema import SeparatorStyle
|
||||
from pilot.scene.chat_knowledge.extract_entity.out_parser import ExtractEntityParser
|
||||
|
||||
from pilot.scene.chat_knowledge.extract_triplet.out_parser import (
|
||||
ExtractTripleParser,
|
||||
)
|
||||
|
||||
|
||||
CFG = Config()
|
||||
|
||||
PROMPT_SCENE_DEFINE = """"""
|
||||
|
||||
_DEFAULT_TEMPLATE = """
|
||||
"A question is provided below. Given the question, extract up to 10 "
|
||||
"keywords from the text. Focus on extracting the keywords that we can use "
|
||||
"to best lookup answers to the question. Avoid stopwords.\n"
|
||||
"Example:"
|
||||
"Text: Alice is Bob's mother."
|
||||
"KEYWORDS:Alice,mother,Bob\n"
|
||||
"---------------------\n"
|
||||
"{text}\n"
|
||||
"---------------------\n"
|
||||
"Provide keywords in the following comma-separated format: 'KEYWORDS: <keywords>'\n"
|
||||
"""
|
||||
PROMPT_RESPONSE = """"""
|
||||
|
||||
|
||||
RESPONSE_FORMAT = """"""
|
||||
|
||||
|
||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||
|
||||
PROMPT_NEED_NEED_STREAM_OUT = False
|
||||
|
||||
prompt = PromptTemplate(
|
||||
template_scene=ChatScene.ExtractEntity.value(),
|
||||
input_variables=["text"],
|
||||
response_format="",
|
||||
template_define=PROMPT_SCENE_DEFINE,
|
||||
template=_DEFAULT_TEMPLATE + PROMPT_RESPONSE,
|
||||
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
|
||||
output_parser=ExtractEntityParser(
|
||||
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
|
||||
),
|
||||
)
|
||||
|
||||
CFG.prompt_template_registry.register(prompt, is_default=True)
|
35
pilot/scene/chat_knowledge/extract_triplet/chat.py
Normal file
35
pilot/scene/chat_knowledge/extract_triplet/chat.py
Normal file
@ -0,0 +1,35 @@
|
||||
from typing import Dict
|
||||
|
||||
from pilot.scene.base_chat import BaseChat
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.configs.config import Config
|
||||
|
||||
from pilot.scene.chat_knowledge.extract_triplet.prompt import prompt
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class ExtractTriplet(BaseChat):
|
||||
chat_scene: str = ChatScene.ExtractTriplet.value()
|
||||
|
||||
"""extracting triplets by llm"""
|
||||
|
||||
def __init__(self, chat_param: Dict):
|
||||
""" """
|
||||
chat_param["chat_mode"] = ChatScene.ExtractTriplet
|
||||
super().__init__(
|
||||
chat_param=chat_param,
|
||||
)
|
||||
|
||||
self.user_input = chat_param["current_user_input"]
|
||||
self.extract_mode = chat_param["select_param"]
|
||||
|
||||
def generate_input_values(self):
|
||||
input_values = {
|
||||
"text": self.user_input,
|
||||
}
|
||||
return input_values
|
||||
|
||||
@property
|
||||
def chat_type(self) -> str:
|
||||
return ChatScene.ExtractTriplet.value
|
57
pilot/scene/chat_knowledge/extract_triplet/out_parser.py
Normal file
57
pilot/scene/chat_knowledge/extract_triplet/out_parser.py
Normal file
@ -0,0 +1,57 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Tuple
|
||||
|
||||
from pilot.out_parser.base import BaseOutputParser, T
|
||||
from pilot.configs.config import Config
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExtractTripleParser(BaseOutputParser):
|
||||
def __init__(self, sep: str, is_stream_out: bool):
|
||||
super().__init__(sep=sep, is_stream_out=is_stream_out)
|
||||
|
||||
def parse_prompt_response(
|
||||
self, response, max_length: int = 128
|
||||
) -> List[Tuple[str, str, str]]:
|
||||
# clean_str = super().parse_prompt_response(response)
|
||||
print("clean prompt response:", response)
|
||||
|
||||
if response.startswith("Triplets:"):
|
||||
response = response[len("Triplets:") :]
|
||||
pattern = r"\([^()]+\)"
|
||||
response = re.findall(pattern, response)
|
||||
# response = response.strip().split("\n")
|
||||
print("parse prompt response:", response)
|
||||
results = []
|
||||
for text in response:
|
||||
if not text or text[0] != "(" or text[-1] != ")":
|
||||
# skip empty lines and non-triplets
|
||||
continue
|
||||
tokens = text[1:-1].split(",")
|
||||
if len(tokens) != 3:
|
||||
continue
|
||||
|
||||
if any(len(s.encode("utf-8")) > max_length for s in tokens):
|
||||
# We count byte-length instead of len() for UTF-8 chars,
|
||||
# will skip if any of the tokens are too long.
|
||||
# This is normally due to a poorly formatted triplet
|
||||
# extraction, in more serious KG building cases
|
||||
# we'll need NLP models to better extract triplets.
|
||||
continue
|
||||
|
||||
subject, predicate, obj = map(str.strip, tokens)
|
||||
if not subject or not predicate or not obj:
|
||||
# skip partial triplets
|
||||
continue
|
||||
results.append((subject.lower(), predicate.lower(), obj.lower()))
|
||||
return results
|
||||
|
||||
def parse_view_response(self, speak, data) -> str:
|
||||
### tool out data to table view
|
||||
return data
|
57
pilot/scene/chat_knowledge/extract_triplet/prompt.py
Normal file
57
pilot/scene/chat_knowledge/extract_triplet/prompt.py
Normal file
@ -0,0 +1,57 @@
|
||||
import json
|
||||
|
||||
from pilot.prompts.prompt_new import PromptTemplate
|
||||
from pilot.configs.config import Config
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.common.schema import SeparatorStyle
|
||||
|
||||
from pilot.scene.chat_knowledge.extract_triplet.out_parser import (
|
||||
ExtractTripleParser,
|
||||
)
|
||||
|
||||
|
||||
CFG = Config()
|
||||
|
||||
PROMPT_SCENE_DEFINE = """"""
|
||||
|
||||
_DEFAULT_TEMPLATE = """
|
||||
"Some text is provided below. Given the text, extract up to 10"
|
||||
"knowledge triplets in the form of (subject, predicate, object). Avoid stopwords.\n"
|
||||
"---------------------\n"
|
||||
"Example:"
|
||||
"Text: Alice is Bob's mother."
|
||||
"Triplets:\n(Alice, is mother of, Bob)\n"
|
||||
"Text: Philz is a coffee shop founded in Berkeley in 1982.\n"
|
||||
"Triplets:\n"
|
||||
"(Philz, is, coffee shop)\n"
|
||||
"(Philz, founded in, Berkeley)\n"
|
||||
"(Philz, founded in, 1982)\n"
|
||||
"---------------------\n"
|
||||
"Text: {text}\n"
|
||||
"Triplets:\n"
|
||||
ensure Respond in the following List(Tuple) format:
|
||||
'(Stephen Curry, plays for, Golden State Warriors)\n(Stephen Curry, known for, shooting skills)\n(Stephen Curry, attended, Davidson College)\n(Stephen Curry, led, team to success)'
|
||||
"""
|
||||
PROMPT_RESPONSE = """"""
|
||||
|
||||
|
||||
RESPONSE_FORMAT = """"""
|
||||
|
||||
|
||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||
|
||||
PROMPT_NEED_NEED_STREAM_OUT = False
|
||||
|
||||
prompt = PromptTemplate(
|
||||
template_scene=ChatScene.ExtractTriplet.value(),
|
||||
input_variables=["text"],
|
||||
response_format="",
|
||||
template_define=PROMPT_SCENE_DEFINE,
|
||||
template=_DEFAULT_TEMPLATE + PROMPT_RESPONSE,
|
||||
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
|
||||
output_parser=ExtractTripleParser(
|
||||
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
|
||||
),
|
||||
)
|
||||
|
||||
CFG.prompt_template_registry.register(prompt, is_default=True)
|
@ -88,25 +88,7 @@ class ChatKnowledge(BaseChat):
|
||||
if self.space_context:
|
||||
self.prompt_template.template_define = self.space_context["prompt"]["scene"]
|
||||
self.prompt_template.template = self.space_context["prompt"]["template"]
|
||||
# docs = self.rag_engine.search(query=self.current_user_input)
|
||||
# import httpx
|
||||
# with httpx.Client() as client:
|
||||
# request = client.build_request(
|
||||
# "post",
|
||||
# "http://127.0.0.1/api/knowledge/entities/extract",
|
||||
# json="application/json", # using json for data to ensure it sends as application/json
|
||||
# params={"text": self.current_user_input},
|
||||
# headers={},
|
||||
# )
|
||||
#
|
||||
# response = client.send(request)
|
||||
# if response.status_code != 200:
|
||||
# error_msg = f"request /api/knowledge/entities/extract failed, error: {response.text}"
|
||||
# raise Exception(error_msg)
|
||||
# docs = response.json()
|
||||
# import requests
|
||||
# docs = requests.post("http://127.0.0.1:5000/api/knowledge/entities/extract", headers={}, json={"text": self.current_user_input})
|
||||
|
||||
docs = self.rag_engine.search(query=self.current_user_input)
|
||||
docs = self.knowledge_embedding_client.similar_search(
|
||||
self.current_user_input, self.top_k
|
||||
)
|
||||
|
@ -205,12 +205,6 @@ def similar_query(space_name: str, query_request: KnowledgeQueryRequest):
|
||||
async def entity_extract(request: EntityExtractRequest):
|
||||
logger.info(f"Received params: {request}")
|
||||
try:
|
||||
# from pilot.graph_engine.graph_factory import RAGGraphFactory
|
||||
# from pilot.component import ComponentType
|
||||
# rag_engine = CFG.SYSTEM_APP.get_component(
|
||||
# ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory
|
||||
# ).create()
|
||||
# return Result.succ(await rag_engine.search(request.text))
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.common.chat_util import llm_chat_response_nostream
|
||||
import uuid
|
||||
@ -222,11 +216,6 @@ async def entity_extract(request: EntityExtractRequest):
|
||||
"model_name": request.model_name,
|
||||
}
|
||||
|
||||
# import nest_asyncio
|
||||
# nest_asyncio.apply()
|
||||
# loop = asyncio.get_event_loop()
|
||||
# loop.stop()
|
||||
# loop = utils.get_or_create_event_loop()
|
||||
res = await llm_chat_response_nostream(
|
||||
ChatScene.ExtractEntity.value(), **{"chat_param": chat_param}
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user