feat:knowledge rag graph

This commit is contained in:
aries_ckt 2023-10-13 17:13:51 +08:00
parent fa6a9040d5
commit 2f82f98e31
20 changed files with 1654 additions and 30 deletions

20
pilot/common/chat_util.py Normal file
View File

@ -0,0 +1,20 @@
import asyncio
from starlette.responses import StreamingResponse
from pilot.scene.base_chat import BaseChat
from pilot.scene.chat_factory import ChatFactory
chat_factory = ChatFactory()
async def llm_chat_response_nostream(chat_scene: str, **chat_param):
""" llm_chat_response_nostream """
chat: BaseChat = chat_factory.get_implementation(chat_scene, **chat_param)
res = await chat.get_llm_response()
return res
async def llm_chat_response(chat_scene: str, **chat_param):
chat: BaseChat = chat_factory.get_implementation(chat_scene, **chat_param)
return chat.stream_call()

View File

View File

@ -0,0 +1,137 @@
import logging
from typing import Any, Optional, Callable, Tuple, List
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from pilot.embedding_engine import KnowledgeType
from pilot.embedding_engine.knowledge_type import get_knowledge_embedding
from pilot.graph_engine.index_struct import KG
from pilot.graph_engine.node import TextNode
from pilot.utils import utils
logger = logging.getLogger(__name__)
class RAGGraphEngine:
"""Knowledge RAG Graph Engine.
Build a KG by extracting triplets, and leveraging the KG during query-time.
Args:
knowledge_type (Optional[str]): Default: KnowledgeType.DOCUMENT.value
extracting triplets.
graph_store (Optional[GraphStore]): The graph store to use.refrence:llama-index
include_embeddings (bool): Whether to include embeddings in the index.
Defaults to False.
max_object_length (int): The maximum length of the object in a triplet.
Defaults to 128.
extract_triplet_fn (Optional[Callable]): The function to use for
extracting triplets. Defaults to None.
"""
index_struct_cls = KG
def __init__(
self,
knowledge_type: Optional[str] = KnowledgeType.DOCUMENT.value,
knowledge_source: Optional[str] = None,
text_splitter=None,
graph_store=None,
index_struct: Optional[KG] = None,
model_name: Optional[str] = None,
max_triplets_per_chunk: int = 10,
include_embeddings: bool = False,
max_object_length: int = 128,
extract_triplet_fn: Optional[Callable] = None,
**kwargs: Any,
) -> None:
"""Initialize params."""
# from llama_index.graph_stores import SimpleGraphStore
# from llama_index.graph_stores.types import GraphStore
# need to set parameters before building index in base class.
self.knowledge_source = knowledge_source
self.knowledge_type = knowledge_type
self.model_name = model_name
self.text_splitter = text_splitter
self.index_struct = index_struct
self.include_embeddings = include_embeddings
# self.graph_store = graph_store or SimpleGraphStore()
self.graph_store = graph_store
self.max_triplets_per_chunk = max_triplets_per_chunk
self._max_object_length = max_object_length
self._extract_triplet_fn = extract_triplet_fn
def knowledge_graph(self, docs=None):
"""knowledge docs into graph store"""
if not docs:
if self.text_splitter:
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=2000, chunk_overlap=100
)
knowledge_source = get_knowledge_embedding(
knowledge_type=self.knowledge_type,
knowledge_source=self.knowledge_source,
text_splitter=self.text_splitter,
)
docs = knowledge_source.read()
if self.index_struct is None:
self.index_struct = self._build_index_from_docs(docs)
def _extract_triplets(self, text: str) -> List[Tuple[str, str, str]]:
"""Extract triplets from text by function or llm"""
if self._extract_triplet_fn is not None:
return self._extract_triplet_fn(text)
else:
return self._llm_extract_triplets(text)
def _llm_extract_triplets(self, text: str) -> List[Tuple[str, str, str]]:
"""Extract triplets from text by llm"""
from pilot.scene.base import ChatScene
from pilot.common.chat_util import llm_chat_response_nostream
import uuid
chat_param = {
"chat_session_id": uuid.uuid1(),
"current_user_input": text,
"select_param": "triplet",
"model_name": self.model_name,
}
loop = utils.get_or_create_event_loop()
triplets = loop.run_until_complete(
llm_chat_response_nostream(
ChatScene.ExtractTriplet.value(), **{"chat_param": chat_param}
)
)
return triplets
# response = self._service_context.llm_predictor.predict(
# self.kg_triple_extract_template,
# text=text,
# )
# print(response, flush=True)
# return self._parse_triplet_response(
# response, max_length=self._max_object_length
# )
def _build_index_from_docs(self, documents: List[Document]) -> KG:
"""Build the index from nodes."""
index_struct = self.index_struct_cls()
for doc in documents:
triplets = self._extract_triplets(doc.page_content)
if len(triplets) == 0:
continue
text_node = TextNode(text=doc.page_content, metadata=doc.metadata)
logger.info(f"extracted knowledge triplets: {triplets}")
for triplet in triplets:
subj, _, obj = triplet
self.graph_store.upsert_triplet(*triplet)
index_struct.add_node([subj, obj], text_node)
return index_struct
def search(self, query):
from pilot.graph_engine.graph_search import RAGGraphSearch
graph_search = RAGGraphSearch(graph_engine=self)
return graph_search.search(query)

View File

@ -0,0 +1,34 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Type
from pilot.component import BaseComponent, ComponentType
class RAGGraphFactory(BaseComponent, ABC):
name = ComponentType.RAG_GRAPH_DEFAULT.value
@abstractmethod
def create(self, model_name: str = None, embedding_cls: Type = None):
"""Create RAG Graph Engine"""
class DefaultRAGGraphFactory(RAGGraphFactory):
def __init__(
self, system_app=None, default_model_name: str = None, **kwargs: Any
) -> None:
super().__init__(system_app=system_app)
self._default_model_name = default_model_name
self.kwargs = kwargs
from pilot.graph_engine.graph_engine import RAGGraphEngine
self.rag_engine = RAGGraphEngine(model_name="proxyllm")
def init_app(self, system_app):
pass
def create(self, model_name: str = None, rag_cls: Type = None):
if not model_name:
model_name = self._default_model_name
return self.rag_engine

View File

@ -0,0 +1,193 @@
import logging
import os
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from typing import List, Optional, Dict, Any, Set, Callable
from pilot.graph_engine.node import BaseNode, TextNode, NodeWithScore
from pilot.graph_engine.search import BaseSearch, SearchMode
from pilot.utils import utils
logger = logging.getLogger(__name__)
DEFAULT_NODE_SCORE = 1000.0
GLOBAL_EXPLORE_NODE_LIMIT = 3
REL_TEXT_LIMIT = 30
class RAGGraphSearch(BaseSearch):
"""RAG Graph Search.
args:
graph_engine RAGGraphEngine.
model_name (str): model name
(see :ref:`Prompt-Templates`).
text_qa_template (Optional[BasePromptTemplate]): A Question Answering Prompt
(see :ref:`Prompt-Templates`).
max_keywords_per_query (int): Maximum number of keywords to extract from query.
num_chunks_per_query (int): Maximum number of text chunks to query.
search_mode (Optional[SearchMode]): Specifies whether to use keyowrds, default SearchMode.KEYWORD
embeddings, or both to find relevant triplets. Should be one of "keyword",
"embedding", or "hybrid".
graph_store_query_depth (int): The depth of the graph store query.
extract_subject_entities_fn (Optional[Callback]): extract_subject_entities callback.
"""
def __init__(
self,
graph_engine,
model_name: str = None,
max_keywords_per_query: int = 10,
num_chunks_per_query: int = 10,
search_mode: Optional[SearchMode] = SearchMode.KEYWORD,
graph_store_query_depth: int = 2,
extract_subject_entities_fn: Optional[Callable] = None,
**kwargs: Any,
) -> None:
"""Initialize params."""
from pilot.graph_engine.graph_engine import RAGGraphEngine
self.graph_engine: RAGGraphEngine = graph_engine
self.model_name = model_name or self.graph_engine.model_name
self._index_struct = self.graph_engine.index_struct
self.max_keywords_per_query = max_keywords_per_query
self.num_chunks_per_query = num_chunks_per_query
self._search_mode = search_mode
self._graph_store = self.graph_engine.graph_store
self.graph_store_query_depth = graph_store_query_depth
self._verbose = kwargs.get("verbose", False)
refresh_schema = kwargs.get("refresh_schema", False)
self.extract_subject_entities_fn = extract_subject_entities_fn
self.executor = ThreadPoolExecutor(max_workers=os.cpu_count() * 5)
try:
self._graph_schema = self._graph_store.get_schema(refresh=refresh_schema)
except NotImplementedError:
self._graph_schema = ""
except Exception as e:
logger.warn(f"can not to find graph schema: {e}")
self._graph_schema = ""
def _extract_subject_entities(self, query_str: str) -> Set[str]:
"""extract subject entities."""
if self.extract_subject_entities_fn is not None:
return self.extract_subject_entities_fn(query_str)
else:
return self._extract_entities_by_llm(query_str)
def _extract_entities_by_llm(self, text: str) -> Set[str]:
"""extract subject entities from text by llm"""
from pilot.scene.base import ChatScene
from pilot.common.chat_util import llm_chat_response_nostream
import uuid
chat_param = {
"chat_session_id": uuid.uuid1(),
"current_user_input": text,
"select_param": "entity",
"model_name": self.model_name,
}
loop = utils.get_or_create_event_loop()
entities = loop.run_until_complete(
llm_chat_response_nostream(
ChatScene.ExtractEntity.value(), **{"chat_param": chat_param}
)
)
return entities
def _search(
self,
query_str: str,
) -> List[NodeWithScore]:
"""Get nodes for response."""
node_visited = set()
keywords = self._extract_subject_entities(query_str)
print(f"extract entities: {keywords}\n")
rel_texts = []
cur_rel_map = {}
chunk_indices_count: Dict[str, int] = defaultdict(int)
if self._search_mode != SearchMode.EMBEDDING:
for keyword in keywords:
keyword = keyword.lower()
subjs = set((keyword,))
node_ids = self._index_struct.search_node_by_keyword(keyword)
for node_id in node_ids[:GLOBAL_EXPLORE_NODE_LIMIT]:
if node_id in node_visited:
continue
if self._include_text:
chunk_indices_count[node_id] += 1
node_visited.add(node_id)
rel_map = self._graph_store.get_rel_map(
list(subjs), self.graph_store_query_depth
)
logger.debug(f"rel_map: {rel_map}")
if not rel_map:
continue
rel_texts.extend(
[
str(rel_obj)
for rel_objs in rel_map.values()
for rel_obj in rel_objs
]
)
cur_rel_map.update(rel_map)
sorted_nodes_with_scores = []
if not rel_texts:
logger.info("> No relationships found, returning nodes found by keywords.")
if len(sorted_nodes_with_scores) == 0:
logger.info("> No nodes found by keywords, returning empty response.")
return [
NodeWithScore(node=TextNode(text="No relationships found."), score=1.0)
]
# add relationships as Node
# TODO: make initial text customizable
rel_initial_text = (
f"The following are knowledge sequence in max depth"
f" {self.graph_store_query_depth} "
f"in the form of directed graph like:\n"
f"`subject -[predicate]->, object, <-[predicate_next_hop]-,"
f" object_next_hop ...`"
)
rel_info = [rel_initial_text] + rel_texts
rel_node_info = {
"kg_rel_texts": rel_texts,
"kg_rel_map": cur_rel_map,
}
if self._graph_schema != "":
rel_node_info["kg_schema"] = {"schema": self._graph_schema}
rel_info_text = "\n".join(
[
str(item)
for sublist in rel_info
for item in (sublist if isinstance(sublist, list) else [sublist])
]
)
if self._verbose:
print(f"KG context:\n{rel_info_text}\n", color="blue")
rel_text_node = TextNode(
text=rel_info_text,
metadata=rel_node_info,
excluded_embed_metadata_keys=["kg_rel_map", "kg_rel_texts"],
excluded_llm_metadata_keys=["kg_rel_map", "kg_rel_texts"],
)
# this node is constructed from rel_texts, give high confidence to avoid cutoff
sorted_nodes_with_scores.append(
NodeWithScore(node=rel_text_node, score=DEFAULT_NODE_SCORE)
)
return sorted_nodes_with_scores
def _get_metadata_for_response(
self, nodes: List[BaseNode]
) -> Optional[Dict[str, Any]]:
"""Get metadata for response."""
for node in nodes:
if node.metadata is None or "kg_rel_map" not in node.metadata:
continue
return node.metadata
raise ValueError("kg_rel_map must be found in at least one Node.")

View File

@ -0,0 +1,259 @@
"""Data structures.
Nodes are decoupled from the indices.
"""
import uuid
from abc import abstractmethod
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Sequence, Set
from dataclasses_json import DataClassJsonMixin
from pilot.graph_engine.index_type import IndexStructType
from pilot.graph_engine.node import TextNode, BaseNode
# TODO: legacy backport of old Node class
Node = TextNode
@dataclass
class IndexStruct(DataClassJsonMixin):
"""A base data struct for a LlamaIndex."""
index_id: str = field(default_factory=lambda: str(uuid.uuid4()))
summary: Optional[str] = None
def get_summary(self) -> str:
"""Get text summary."""
if self.summary is None:
raise ValueError("summary field of the index_struct not set.")
return self.summary
@classmethod
@abstractmethod
def get_type(cls):
"""Get index struct type."""
@dataclass
class IndexGraph(IndexStruct):
"""A graph representing the tree-structured index."""
# mapping from index in tree to Node doc id.
all_nodes: Dict[int, str] = field(default_factory=dict)
root_nodes: Dict[int, str] = field(default_factory=dict)
node_id_to_children_ids: Dict[str, List[str]] = field(default_factory=dict)
@property
def node_id_to_index(self) -> Dict[str, int]:
"""Map from node id to index."""
return {node_id: index for index, node_id in self.all_nodes.items()}
@property
def size(self) -> int:
"""Get the size of the graph."""
return len(self.all_nodes)
def get_index(self, node: BaseNode) -> int:
"""Get index of node."""
return self.node_id_to_index[node.node_id]
def insert(
self,
node: BaseNode,
index: Optional[int] = None,
children_nodes: Optional[Sequence[BaseNode]] = None,
) -> None:
"""Insert node."""
index = index or self.size
node_id = node.node_id
self.all_nodes[index] = node_id
if children_nodes is None:
children_nodes = []
children_ids = [n.node_id for n in children_nodes]
self.node_id_to_children_ids[node_id] = children_ids
def get_children(self, parent_node: Optional[BaseNode]) -> Dict[int, str]:
"""Get children nodes."""
if parent_node is None:
return self.root_nodes
else:
parent_id = parent_node.node_id
children_ids = self.node_id_to_children_ids[parent_id]
return {
self.node_id_to_index[child_id]: child_id for child_id in children_ids
}
def insert_under_parent(
self,
node: BaseNode,
parent_node: Optional[BaseNode],
new_index: Optional[int] = None,
) -> None:
"""Insert under parent node."""
new_index = new_index or self.size
if parent_node is None:
self.root_nodes[new_index] = node.node_id
self.node_id_to_children_ids[node.node_id] = []
else:
if parent_node.node_id not in self.node_id_to_children_ids:
self.node_id_to_children_ids[parent_node.node_id] = []
self.node_id_to_children_ids[parent_node.node_id].append(node.node_id)
self.all_nodes[new_index] = node.node_id
@classmethod
def get_type(cls) -> IndexStructType:
"""Get type."""
return IndexStructType.TREE
@dataclass
class KeywordTable(IndexStruct):
"""A table of keywords mapping keywords to text chunks."""
table: Dict[str, Set[str]] = field(default_factory=dict)
def add_node(self, keywords: List[str], node: BaseNode) -> None:
"""Add text to table."""
for keyword in keywords:
if keyword not in self.table:
self.table[keyword] = set()
self.table[keyword].add(node.node_id)
@property
def node_ids(self) -> Set[str]:
"""Get all node ids."""
return set.union(*self.table.values())
@property
def keywords(self) -> Set[str]:
"""Get all keywords in the table."""
return set(self.table.keys())
@property
def size(self) -> int:
"""Get the size of the table."""
return len(self.table)
@classmethod
def get_type(cls) -> IndexStructType:
"""Get type."""
return IndexStructType.KEYWORD_TABLE
@dataclass
class IndexList(IndexStruct):
"""A list of documents."""
nodes: List[str] = field(default_factory=list)
def add_node(self, node: BaseNode) -> None:
"""Add text to table, return current position in list."""
# don't worry about child indices for now, nodes are all in order
self.nodes.append(node.node_id)
@classmethod
def get_type(cls) -> IndexStructType:
"""Get type."""
return IndexStructType.LIST
@dataclass
class IndexDict(IndexStruct):
"""A simple dictionary of documents."""
# TODO: slightly deprecated, should likely be a list or set now
# mapping from vector store id to node doc_id
nodes_dict: Dict[str, str] = field(default_factory=dict)
# TODO: deprecated, not used
# mapping from node doc_id to vector store id
doc_id_dict: Dict[str, List[str]] = field(default_factory=dict)
# TODO: deprecated, not used
# this should be empty for all other indices
embeddings_dict: Dict[str, List[float]] = field(default_factory=dict)
def add_node(
self,
node: BaseNode,
text_id: Optional[str] = None,
) -> str:
"""Add text to table, return current position in list."""
# # don't worry about child indices for now, nodes are all in order
# self.nodes_dict[int_id] = node
vector_id = text_id if text_id is not None else node.node_id
self.nodes_dict[vector_id] = node.node_id
return vector_id
def delete(self, doc_id: str) -> None:
"""Delete a Node."""
del self.nodes_dict[doc_id]
@classmethod
def get_type(cls) -> IndexStructType:
"""Get type."""
return IndexStructType.VECTOR_STORE
@dataclass
class KG(IndexStruct):
"""A table of keywords mapping keywords to text chunks."""
# Unidirectional
# table of keywords to node ids
table: Dict[str, Set[str]] = field(default_factory=dict)
# TODO: legacy attribute, remove in future releases
rel_map: Dict[str, List[List[str]]] = field(default_factory=dict)
# TBD, should support vector store, now we just persist the embedding memory
# maybe chainable abstractions for *_stores could be designed
embedding_dict: Dict[str, List[float]] = field(default_factory=dict)
@property
def node_ids(self) -> Set[str]:
"""Get all node ids."""
return set.union(*self.table.values())
def add_to_embedding_dict(self, triplet_str: str, embedding: List[float]) -> None:
"""Add embedding to dict."""
self.embedding_dict[triplet_str] = embedding
def add_node(self, keywords: List[str], node: BaseNode) -> None:
"""Add text to table."""
node_id = node.node_id
for keyword in keywords:
keyword = keyword.lower()
if keyword not in self.table:
self.table[keyword] = set()
self.table[keyword].add(node_id)
def search_node_by_keyword(self, keyword: str) -> List[str]:
"""Search for nodes by keyword."""
if keyword not in self.table:
return []
return list(self.table[keyword])
@classmethod
def get_type(cls) -> IndexStructType:
"""Get type."""
return IndexStructType.KG
@dataclass
class EmptyIndexStruct(IndexStruct):
"""Empty index."""
@classmethod
def get_type(cls) -> IndexStructType:
"""Get type."""
return IndexStructType.EMPTY

View File

@ -0,0 +1,48 @@
"""IndexStructType class."""
from enum import Enum
class IndexStructType(str, Enum):
"""Index struct type. Identifier for a "type" of index.
Attributes:
TREE ("tree"): Tree index. See :ref:`Ref-Indices-Tree` for tree indices.
LIST ("list"): Summary index. See :ref:`Ref-Indices-List` for summary indices.
KEYWORD_TABLE ("keyword_table"): Keyword table index. See
:ref:`Ref-Indices-Table`
for keyword table indices.
DICT ("dict"): Faiss Vector Store Index. See
:ref:`Ref-Indices-VectorStore`
for more information on the faiss vector store index.
SIMPLE_DICT ("simple_dict"): Simple Vector Store Index. See
:ref:`Ref-Indices-VectorStore`
for more information on the simple vector store index.
KG ("kg"): Knowledge Graph index.
See :ref:`Ref-Indices-Knowledge-Graph` for KG indices.
DOCUMENT_SUMMARY ("document_summary"): Document Summary Index.
See :ref:`Ref-Indices-Document-Summary` for Summary Indices.
"""
# TODO: refactor so these are properties on the base class
NODE = "node"
TREE = "tree"
LIST = "list"
KEYWORD_TABLE = "keyword_table"
DICT = "dict"
# simple
SIMPLE_DICT = "simple_dict"
# for KG index
KG = "kg"
SIMPLE_KG = "simple_kg"
NEBULAGRAPH = "nebulagraph"
FALKORDB = "falkordb"
# EMPTY
EMPTY = "empty"
COMPOSITE = "composite"
DOCUMENT_SUMMARY = "document_summary"

View File

@ -0,0 +1,74 @@
from typing import List, Optional
from llama_index.data_structs.data_structs import IndexStruct
from llama_index.storage.index_store.utils import (
index_struct_to_json,
json_to_index_struct,
)
from llama_index.storage.kvstore.types import BaseKVStore
DEFAULT_NAMESPACE = "index_store"
class KVIndexStore:
"""Key-Value Index store.
Args:
kvstore (BaseKVStore): key-value store
namespace (str): namespace for the index store
"""
def __init__(self, kvstore: BaseKVStore, namespace: Optional[str] = None) -> None:
"""Init a KVIndexStore."""
self._kvstore = kvstore
self._namespace = namespace or DEFAULT_NAMESPACE
self._collection = f"{self._namespace}/data"
def add_index_struct(self, index_struct: IndexStruct) -> None:
"""Add an index struct.
Args:
index_struct (IndexStruct): index struct
"""
key = index_struct.index_id
data = index_struct_to_json(index_struct)
self._kvstore.put(key, data, collection=self._collection)
def delete_index_struct(self, key: str) -> None:
"""Delete an index struct.
Args:
key (str): index struct key
"""
self._kvstore.delete(key, collection=self._collection)
def get_index_struct(
self, struct_id: Optional[str] = None
) -> Optional[IndexStruct]:
"""Get an index struct.
Args:
struct_id (Optional[str]): index struct id
"""
if struct_id is None:
structs = self.index_structs()
assert len(structs) == 1
return structs[0]
else:
json = self._kvstore.get(struct_id, collection=self._collection)
if json is None:
return None
return json_to_index_struct(json)
def index_structs(self) -> List[IndexStruct]:
"""Get all index structs.
Returns:
List[IndexStruct]: index structs
"""
jsons = self._kvstore.get_all(collection=self._collection)
return [json_to_index_struct(json) for json in jsons.values()]

569
pilot/graph_engine/node.py Normal file
View File

@ -0,0 +1,569 @@
"""Base schema for data structures."""
import json
import textwrap
import uuid
from abc import abstractmethod
from enum import Enum, auto
from hashlib import sha256
from typing import Any, Dict, List, Optional, Union
from langchain.schema import Document
from pydantic import BaseModel, Field, root_validator
from typing_extensions import Self
DEFAULT_TEXT_NODE_TMPL = "{metadata_str}\n\n{content}"
DEFAULT_METADATA_TMPL = "{key}: {value}"
# NOTE: for pretty printing
TRUNCATE_LENGTH = 350
WRAP_WIDTH = 70
class BaseComponent(BaseModel):
"""Base component object to caputure class names."""
"""reference llama-index"""
@classmethod
@abstractmethod
def class_name(cls) -> str:
"""Get class name."""
def to_dict(self, **kwargs: Any) -> Dict[str, Any]:
data = self.dict(**kwargs)
data["class_name"] = self.class_name()
return data
def to_json(self, **kwargs: Any) -> str:
data = self.to_dict(**kwargs)
return json.dumps(data)
# TODO: return type here not supported by current mypy version
@classmethod
def from_dict(cls, data: Dict[str, Any], **kwargs: Any) -> Self: # type: ignore
if isinstance(kwargs, dict):
data.update(kwargs)
data.pop("class_name", None)
return cls(**data)
@classmethod
def from_json(cls, data_str: str, **kwargs: Any) -> Self: # type: ignore
data = json.loads(data_str)
return cls.from_dict(data, **kwargs)
class NodeRelationship(str, Enum):
"""Node relationships used in `BaseNode` class.
Attributes:
SOURCE: The node is the source document.
PREVIOUS: The node is the previous node in the document.
NEXT: The node is the next node in the document.
PARENT: The node is the parent node in the document.
CHILD: The node is a child node in the document.
"""
SOURCE = auto()
PREVIOUS = auto()
NEXT = auto()
PARENT = auto()
CHILD = auto()
class ObjectType(str, Enum):
TEXT = auto()
IMAGE = auto()
INDEX = auto()
DOCUMENT = auto()
class MetadataMode(str, Enum):
ALL = auto()
EMBED = auto()
LLM = auto()
NONE = auto()
class RelatedNodeInfo(BaseComponent):
node_id: str
node_type: Optional[ObjectType] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
hash: Optional[str] = None
@classmethod
def class_name(cls) -> str:
"""Get class name."""
return "RelatedNodeInfo"
RelatedNodeType = Union[RelatedNodeInfo, List[RelatedNodeInfo]]
# Node classes for indexes
class BaseNode(BaseComponent):
"""Base node Object.
Generic abstract interface for retrievable nodes
"""
class Config:
allow_population_by_field_name = True
id_: str = Field(
default_factory=lambda: str(uuid.uuid4()), description="Unique ID of the node."
)
embedding: Optional[List[float]] = Field(
default=None, description="Embedding of the node."
)
""""
metadata fields
- injected as part of the text shown to LLMs as context
- injected as part of the text for generating embeddings
- used by vector DBs for metadata filtering
"""
metadata: Dict[str, Any] = Field(
default_factory=dict,
description="A flat dictionary of metadata fields",
alias="extra_info",
)
excluded_embed_metadata_keys: List[str] = Field(
default_factory=list,
description="Metadata keys that are exluded from text for the embed model.",
)
excluded_llm_metadata_keys: List[str] = Field(
default_factory=list,
description="Metadata keys that are exluded from text for the LLM.",
)
relationships: Dict[NodeRelationship, RelatedNodeType] = Field(
default_factory=dict,
description="A mapping of relationships to other node information.",
)
hash: str = Field(default="", description="Hash of the node content.")
@classmethod
@abstractmethod
def get_type(cls) -> str:
"""Get Object type."""
@abstractmethod
def get_content(self, metadata_mode: MetadataMode = MetadataMode.ALL) -> str:
"""Get object content."""
@abstractmethod
def get_metadata_str(self, mode: MetadataMode = MetadataMode.ALL) -> str:
"""Metadata string."""
@abstractmethod
def set_content(self, value: Any) -> None:
"""Set the content of the node."""
@property
def node_id(self) -> str:
return self.id_
@node_id.setter
def node_id(self, value: str) -> None:
self.id_ = value
@property
def source_node(self) -> Optional[RelatedNodeInfo]:
"""Source object node.
Extracted from the relationships field.
"""
if NodeRelationship.SOURCE not in self.relationships:
return None
relation = self.relationships[NodeRelationship.SOURCE]
if isinstance(relation, list):
raise ValueError("Source object must be a single RelatedNodeInfo object")
return relation
@property
def prev_node(self) -> Optional[RelatedNodeInfo]:
"""Prev node."""
if NodeRelationship.PREVIOUS not in self.relationships:
return None
relation = self.relationships[NodeRelationship.PREVIOUS]
if not isinstance(relation, RelatedNodeInfo):
raise ValueError("Previous object must be a single RelatedNodeInfo object")
return relation
@property
def next_node(self) -> Optional[RelatedNodeInfo]:
"""Next node."""
if NodeRelationship.NEXT not in self.relationships:
return None
relation = self.relationships[NodeRelationship.NEXT]
if not isinstance(relation, RelatedNodeInfo):
raise ValueError("Next object must be a single RelatedNodeInfo object")
return relation
@property
def parent_node(self) -> Optional[RelatedNodeInfo]:
"""Parent node."""
if NodeRelationship.PARENT not in self.relationships:
return None
relation = self.relationships[NodeRelationship.PARENT]
if not isinstance(relation, RelatedNodeInfo):
raise ValueError("Parent object must be a single RelatedNodeInfo object")
return relation
@property
def child_nodes(self) -> Optional[List[RelatedNodeInfo]]:
"""Child nodes."""
if NodeRelationship.CHILD not in self.relationships:
return None
relation = self.relationships[NodeRelationship.CHILD]
if not isinstance(relation, list):
raise ValueError("Child objects must be a list of RelatedNodeInfo objects.")
return relation
@property
def ref_doc_id(self) -> Optional[str]:
"""Deprecated: Get ref doc id."""
source_node = self.source_node
if source_node is None:
return None
return source_node.node_id
@property
def extra_info(self) -> Dict[str, Any]:
"""TODO: DEPRECATED: Extra info."""
return self.metadata
def __str__(self) -> str:
source_text_truncated = truncate_text(
self.get_content().strip(), TRUNCATE_LENGTH
)
source_text_wrapped = textwrap.fill(
f"Text: {source_text_truncated}\n", width=WRAP_WIDTH
)
return f"Node ID: {self.node_id}\n{source_text_wrapped}"
def truncate_text(text: str, max_length: int) -> str:
"""Truncate text to a maximum length."""
if len(text) <= max_length:
return text
return text[: max_length - 3] + "..."
def get_embedding(self) -> List[float]:
"""Get embedding.
Errors if embedding is None.
"""
if self.embedding is None:
raise ValueError("embedding not set.")
return self.embedding
def as_related_node_info(self) -> RelatedNodeInfo:
"""Get node as RelatedNodeInfo."""
return RelatedNodeInfo(
node_id=self.node_id, metadata=self.metadata, hash=self.hash
)
class TextNode(BaseNode):
text: str = Field(default="", description="Text content of the node.")
start_char_idx: Optional[int] = Field(
default=None, description="Start char index of the node."
)
end_char_idx: Optional[int] = Field(
default=None, description="End char index of the node."
)
text_template: str = Field(
default=DEFAULT_TEXT_NODE_TMPL,
description=(
"Template for how text is formatted, with {content} and "
"{metadata_str} placeholders."
),
)
metadata_template: str = Field(
default=DEFAULT_METADATA_TMPL,
description=(
"Template for how metadata is formatted, with {key} and "
"{value} placeholders."
),
)
metadata_seperator: str = Field(
default="\n",
description="Seperator between metadata fields when converting to string.",
)
@classmethod
def class_name(cls) -> str:
"""Get class name."""
return "TextNode"
@root_validator
def _check_hash(cls, values: dict) -> dict:
"""Generate a hash to represent the node."""
text = values.get("text", "")
metadata = values.get("metadata", {})
doc_identity = str(text) + str(metadata)
values["hash"] = str(
sha256(doc_identity.encode("utf-8", "surrogatepass")).hexdigest()
)
return values
@classmethod
def get_type(cls) -> str:
"""Get Object type."""
return ObjectType.TEXT
def get_content(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str:
"""Get object content."""
metadata_str = self.get_metadata_str(mode=metadata_mode).strip()
if not metadata_str:
return self.text
return self.text_template.format(
content=self.text, metadata_str=metadata_str
).strip()
def get_metadata_str(self, mode: MetadataMode = MetadataMode.ALL) -> str:
"""metadata info string."""
if mode == MetadataMode.NONE:
return ""
usable_metadata_keys = set(self.metadata.keys())
if mode == MetadataMode.LLM:
for key in self.excluded_llm_metadata_keys:
if key in usable_metadata_keys:
usable_metadata_keys.remove(key)
elif mode == MetadataMode.EMBED:
for key in self.excluded_embed_metadata_keys:
if key in usable_metadata_keys:
usable_metadata_keys.remove(key)
return self.metadata_seperator.join(
[
self.metadata_template.format(key=key, value=str(value))
for key, value in self.metadata.items()
if key in usable_metadata_keys
]
)
def set_content(self, value: str) -> None:
"""Set the content of the node."""
self.text = value
def get_node_info(self) -> Dict[str, Any]:
"""Get node info."""
return {"start": self.start_char_idx, "end": self.end_char_idx}
def get_text(self) -> str:
return self.get_content(metadata_mode=MetadataMode.NONE)
@property
def node_info(self) -> Dict[str, Any]:
"""Deprecated: Get node info."""
return self.get_node_info()
# TODO: legacy backport of old Node class
Node = TextNode
class ImageNode(TextNode):
"""Node with image."""
# TODO: store reference instead of actual image
# base64 encoded image str
image: Optional[str] = None
@classmethod
def get_type(cls) -> str:
return ObjectType.IMAGE
@classmethod
def class_name(cls) -> str:
"""Get class name."""
return "ImageNode"
class IndexNode(TextNode):
"""Node with reference to any object.
This can include other indices, query engines, retrievers.
This can also include other nodes (though this is overlapping with `relationships`
on the Node class).
"""
index_id: str
@classmethod
def from_text_node(
cls,
node: TextNode,
index_id: str,
) -> "IndexNode":
"""Create index node from text node."""
# copy all attributes from text node, add index id
return cls(
**node.dict(),
index_id=index_id,
)
@classmethod
def get_type(cls) -> str:
return ObjectType.INDEX
@classmethod
def class_name(cls) -> str:
"""Get class name."""
return "IndexNode"
class NodeWithScore(BaseComponent):
node: BaseNode
score: Optional[float] = None
def __str__(self) -> str:
return f"{self.node}\nScore: {self.score: 0.3f}\n"
def get_score(self, raise_error: bool = False) -> float:
"""Get score."""
if self.score is None:
if raise_error:
raise ValueError("Score not set.")
else:
return 0.0
else:
return self.score
@classmethod
def class_name(cls) -> str:
"""Get class name."""
return "NodeWithScore"
##### pass through methods to BaseNode #####
@property
def node_id(self) -> str:
return self.node.node_id
@property
def id_(self) -> str:
return self.node.id_
@property
def text(self) -> str:
if isinstance(self.node, TextNode):
return self.node.text
else:
raise ValueError("Node must be a TextNode to get text.")
@property
def metadata(self) -> Dict[str, Any]:
return self.node.metadata
@property
def embedding(self) -> Optional[List[float]]:
return self.node.embedding
def get_text(self) -> str:
if isinstance(self.node, TextNode):
return self.node.get_text()
else:
raise ValueError("Node must be a TextNode to get text.")
def get_content(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str:
return self.node.get_content(metadata_mode=metadata_mode)
def get_embedding(self) -> List[float]:
return self.node.get_embedding()
# Document Classes for Readers
class Document(TextNode):
"""Generic interface for a data document.
This document connects to data sources.
"""
# TODO: A lot of backwards compatibility logic here, clean up
id_: str = Field(
default_factory=lambda: str(uuid.uuid4()),
description="Unique ID of the node.",
alias="doc_id",
)
_compat_fields = {"doc_id": "id_", "extra_info": "metadata"}
@classmethod
def get_type(cls) -> str:
"""Get Document type."""
return ObjectType.DOCUMENT
@property
def doc_id(self) -> str:
"""Get document ID."""
return self.id_
def __str__(self) -> str:
source_text_truncated = truncate_text(
self.get_content().strip(), TRUNCATE_LENGTH
)
source_text_wrapped = textwrap.fill(
f"Text: {source_text_truncated}\n", width=WRAP_WIDTH
)
return f"Doc ID: {self.doc_id}\n{source_text_wrapped}"
def get_doc_id(self) -> str:
"""TODO: Deprecated: Get document ID."""
return self.id_
def __setattr__(self, name: str, value: object) -> None:
if name in self._compat_fields:
name = self._compat_fields[name]
super().__setattr__(name, value)
def to_langchain_format(self) -> Document:
"""Convert struct to LangChain document format."""
metadata = self.metadata or {}
return Document(page_content=self.text, metadata=metadata)
@classmethod
def from_langchain_format(cls, doc: Document) -> "Document":
"""Convert struct from LangChain document format."""
return cls(text=doc.page_content, metadata=doc.metadata)
@classmethod
def example(cls) -> "Document":
document = Document(
text="",
metadata={"filename": "README.md", "category": "codebase"},
)
return document
@classmethod
def class_name(cls) -> str:
"""Get class name."""
return "Document"
class ImageDocument(Document):
"""Data document containing an image."""
# base64 encoded image str
image: Optional[str] = None
@classmethod
def class_name(cls) -> str:
"""Get class name."""
return "ImageDocument"

View File

@ -0,0 +1,44 @@
from abc import ABC, abstractmethod
from enum import Enum
class SearchMode(str, Enum):
"""Query mode enum for Knowledge Graphs.
Can be passed as the enum struct, or as the underlying string.
Attributes:
KEYWORD ("keyword"): Default query mode, using keywords to find triplets.
EMBEDDING ("embedding"): Embedding mode, using embeddings to find
similar triplets.
HYBRID ("hybrid"): Hyrbid mode, combining both keywords and embeddings
to find relevant triplets.
"""
KEYWORD = "keyword"
EMBEDDING = "embedding"
HYBRID = "hybrid"
class BaseSearch(ABC):
"""Base Search."""
def search(self, query: str):
"""Retrieve nodes given query.
Args:
query (QueryType): Either a query string or
a QueryBundle object.
"""
# if isinstance(query, str):
return self._search(query)
@abstractmethod
def _search(self, query: str):
"""search nodes given query.
Implemented by the user.
"""
pass

View File

@ -0,0 +1,35 @@
from typing import Dict
from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene
from pilot.configs.config import Config
from pilot.scene.chat_knowledge.extract_entity.prompt import prompt
CFG = Config()
class ExtractEntity(BaseChat):
chat_scene: str = ChatScene.ExtractEntity.value()
"""extracting entities by llm"""
def __init__(self, chat_param: Dict):
""" """
chat_param["chat_mode"] = ChatScene.ExtractEntity
super().__init__(
chat_param=chat_param,
)
self.user_input = chat_param["current_user_input"]
self.extract_mode = chat_param["select_param"]
def generate_input_values(self):
input_values = {
"text": self.user_input,
}
return input_values
@property
def chat_type(self) -> str:
return ChatScene.ExtractEntity.value

View File

@ -0,0 +1,39 @@
import json
import logging
from typing import Set
from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.config import Config
CFG = Config()
logger = logging.getLogger(__name__)
class ExtractEntityParser(BaseOutputParser):
def __init__(self, sep: str, is_stream_out: bool):
super().__init__(sep=sep, is_stream_out=is_stream_out)
def parse_prompt_response(self, response, max_length: int = 128) -> Set[str]:
lowercase = True
# clean_str = super().parse_prompt_response(response)
print("clean prompt response:", response)
results = []
response = response.strip() # Strip newlines from responses.
if response.startswith("KEYWORDS:"):
response = response[len("KEYWORDS:") :]
keywords = response.split(",")
for k in keywords:
rk = k
if lowercase:
rk = rk.lower()
results.append(rk.strip())
return set(results)
def parse_view_response(self, speak, data) -> str:
return data

View File

@ -0,0 +1,52 @@
import json
from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config
from pilot.scene.base import ChatScene
from pilot.common.schema import SeparatorStyle
from pilot.scene.chat_knowledge.extract_entity.out_parser import ExtractEntityParser
from pilot.scene.chat_knowledge.extract_triplet.out_parser import (
ExtractTripleParser,
)
CFG = Config()
PROMPT_SCENE_DEFINE = """"""
_DEFAULT_TEMPLATE = """
"A question is provided below. Given the question, extract up to 10 "
"keywords from the text. Focus on extracting the keywords that we can use "
"to best lookup answers to the question. Avoid stopwords.\n"
"Example:"
"Text: Alice is Bob's mother."
"KEYWORDS:Alice,mother,Bob\n"
"---------------------\n"
"{text}\n"
"---------------------\n"
"Provide keywords in the following comma-separated format: 'KEYWORDS: <keywords>'\n"
"""
PROMPT_RESPONSE = """"""
RESPONSE_FORMAT = """"""
PROMPT_SEP = SeparatorStyle.SINGLE.value
PROMPT_NEED_NEED_STREAM_OUT = False
prompt = PromptTemplate(
template_scene=ChatScene.ExtractEntity.value(),
input_variables=["text"],
response_format="",
template_define=PROMPT_SCENE_DEFINE,
template=_DEFAULT_TEMPLATE + PROMPT_RESPONSE,
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
output_parser=ExtractEntityParser(
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
),
)
CFG.prompt_template_registry.register(prompt, is_default=True)

View File

@ -0,0 +1,35 @@
from typing import Dict
from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene
from pilot.configs.config import Config
from pilot.scene.chat_knowledge.extract_triplet.prompt import prompt
CFG = Config()
class ExtractTriplet(BaseChat):
chat_scene: str = ChatScene.ExtractTriplet.value()
"""extracting triplets by llm"""
def __init__(self, chat_param: Dict):
""" """
chat_param["chat_mode"] = ChatScene.ExtractTriplet
super().__init__(
chat_param=chat_param,
)
self.user_input = chat_param["current_user_input"]
self.extract_mode = chat_param["select_param"]
def generate_input_values(self):
input_values = {
"text": self.user_input,
}
return input_values
@property
def chat_type(self) -> str:
return ChatScene.ExtractTriplet.value

View File

@ -0,0 +1,57 @@
import json
import logging
import re
from typing import List, Tuple
from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.config import Config
CFG = Config()
logger = logging.getLogger(__name__)
class ExtractTripleParser(BaseOutputParser):
def __init__(self, sep: str, is_stream_out: bool):
super().__init__(sep=sep, is_stream_out=is_stream_out)
def parse_prompt_response(
self, response, max_length: int = 128
) -> List[Tuple[str, str, str]]:
# clean_str = super().parse_prompt_response(response)
print("clean prompt response:", response)
if response.startswith("Triplets:"):
response = response[len("Triplets:") :]
pattern = r"\([^()]+\)"
response = re.findall(pattern, response)
# response = response.strip().split("\n")
print("parse prompt response:", response)
results = []
for text in response:
if not text or text[0] != "(" or text[-1] != ")":
# skip empty lines and non-triplets
continue
tokens = text[1:-1].split(",")
if len(tokens) != 3:
continue
if any(len(s.encode("utf-8")) > max_length for s in tokens):
# We count byte-length instead of len() for UTF-8 chars,
# will skip if any of the tokens are too long.
# This is normally due to a poorly formatted triplet
# extraction, in more serious KG building cases
# we'll need NLP models to better extract triplets.
continue
subject, predicate, obj = map(str.strip, tokens)
if not subject or not predicate or not obj:
# skip partial triplets
continue
results.append((subject.lower(), predicate.lower(), obj.lower()))
return results
def parse_view_response(self, speak, data) -> str:
### tool out data to table view
return data

View File

@ -0,0 +1,57 @@
import json
from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config
from pilot.scene.base import ChatScene
from pilot.common.schema import SeparatorStyle
from pilot.scene.chat_knowledge.extract_triplet.out_parser import (
ExtractTripleParser,
)
CFG = Config()
PROMPT_SCENE_DEFINE = """"""
_DEFAULT_TEMPLATE = """
"Some text is provided below. Given the text, extract up to 10"
"knowledge triplets in the form of (subject, predicate, object). Avoid stopwords.\n"
"---------------------\n"
"Example:"
"Text: Alice is Bob's mother."
"Triplets:\n(Alice, is mother of, Bob)\n"
"Text: Philz is a coffee shop founded in Berkeley in 1982.\n"
"Triplets:\n"
"(Philz, is, coffee shop)\n"
"(Philz, founded in, Berkeley)\n"
"(Philz, founded in, 1982)\n"
"---------------------\n"
"Text: {text}\n"
"Triplets:\n"
ensure Respond in the following List(Tuple) format:
'(Stephen Curry, plays for, Golden State Warriors)\n(Stephen Curry, known for, shooting skills)\n(Stephen Curry, attended, Davidson College)\n(Stephen Curry, led, team to success)'
"""
PROMPT_RESPONSE = """"""
RESPONSE_FORMAT = """"""
PROMPT_SEP = SeparatorStyle.SINGLE.value
PROMPT_NEED_NEED_STREAM_OUT = False
prompt = PromptTemplate(
template_scene=ChatScene.ExtractTriplet.value(),
input_variables=["text"],
response_format="",
template_define=PROMPT_SCENE_DEFINE,
template=_DEFAULT_TEMPLATE + PROMPT_RESPONSE,
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
output_parser=ExtractTripleParser(
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
),
)
CFG.prompt_template_registry.register(prompt, is_default=True)

View File

@ -88,25 +88,7 @@ class ChatKnowledge(BaseChat):
if self.space_context: if self.space_context:
self.prompt_template.template_define = self.space_context["prompt"]["scene"] self.prompt_template.template_define = self.space_context["prompt"]["scene"]
self.prompt_template.template = self.space_context["prompt"]["template"] self.prompt_template.template = self.space_context["prompt"]["template"]
# docs = self.rag_engine.search(query=self.current_user_input) docs = self.rag_engine.search(query=self.current_user_input)
# import httpx
# with httpx.Client() as client:
# request = client.build_request(
# "post",
# "http://127.0.0.1/api/knowledge/entities/extract",
# json="application/json", # using json for data to ensure it sends as application/json
# params={"text": self.current_user_input},
# headers={},
# )
#
# response = client.send(request)
# if response.status_code != 200:
# error_msg = f"request /api/knowledge/entities/extract failed, error: {response.text}"
# raise Exception(error_msg)
# docs = response.json()
# import requests
# docs = requests.post("http://127.0.0.1:5000/api/knowledge/entities/extract", headers={}, json={"text": self.current_user_input})
docs = self.knowledge_embedding_client.similar_search( docs = self.knowledge_embedding_client.similar_search(
self.current_user_input, self.top_k self.current_user_input, self.top_k
) )

View File

@ -205,12 +205,6 @@ def similar_query(space_name: str, query_request: KnowledgeQueryRequest):
async def entity_extract(request: EntityExtractRequest): async def entity_extract(request: EntityExtractRequest):
logger.info(f"Received params: {request}") logger.info(f"Received params: {request}")
try: try:
# from pilot.graph_engine.graph_factory import RAGGraphFactory
# from pilot.component import ComponentType
# rag_engine = CFG.SYSTEM_APP.get_component(
# ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory
# ).create()
# return Result.succ(await rag_engine.search(request.text))
from pilot.scene.base import ChatScene from pilot.scene.base import ChatScene
from pilot.common.chat_util import llm_chat_response_nostream from pilot.common.chat_util import llm_chat_response_nostream
import uuid import uuid
@ -222,11 +216,6 @@ async def entity_extract(request: EntityExtractRequest):
"model_name": request.model_name, "model_name": request.model_name,
} }
# import nest_asyncio
# nest_asyncio.apply()
# loop = asyncio.get_event_loop()
# loop.stop()
# loop = utils.get_or_create_event_loop()
res = await llm_chat_response_nostream( res = await llm_chat_response_nostream(
ChatScene.ExtractEntity.value(), **{"chat_param": chat_param} ChatScene.ExtractEntity.value(), **{"chat_param": chat_param}
) )