From 2f82f98e315d81129987c28196fd66e62a75f56b Mon Sep 17 00:00:00 2001 From: aries_ckt <916701291@qq.com> Date: Fri, 13 Oct 2023 17:13:51 +0800 Subject: [PATCH] feat:knowledge rag graph --- pilot/common/chat_util.py | 20 + pilot/graph_engine/__init__.py | 0 pilot/graph_engine/graph_engine.py | 137 +++++ pilot/graph_engine/graph_factory.py | 34 ++ pilot/graph_engine/graph_search.py | 193 ++++++ pilot/graph_engine/index_struct.py | 259 ++++++++ pilot/graph_engine/index_type.py | 48 ++ pilot/graph_engine/kv_index.py | 74 +++ pilot/graph_engine/node.py | 569 ++++++++++++++++++ pilot/graph_engine/search.py | 44 ++ .../chat_knowledge/extract_entity/__init__.py | 0 .../chat_knowledge/extract_entity/chat.py | 35 ++ .../extract_entity/out_parser.py | 39 ++ .../chat_knowledge/extract_entity/prompt.py | 52 ++ .../extract_triplet/__init__.py | 0 .../chat_knowledge/extract_triplet/chat.py | 35 ++ .../extract_triplet/out_parser.py | 57 ++ .../chat_knowledge/extract_triplet/prompt.py | 57 ++ pilot/scene/chat_knowledge/v1/chat.py | 20 +- pilot/server/knowledge/api.py | 11 - 20 files changed, 1654 insertions(+), 30 deletions(-) create mode 100644 pilot/common/chat_util.py create mode 100644 pilot/graph_engine/__init__.py create mode 100644 pilot/graph_engine/graph_engine.py create mode 100644 pilot/graph_engine/graph_factory.py create mode 100644 pilot/graph_engine/graph_search.py create mode 100644 pilot/graph_engine/index_struct.py create mode 100644 pilot/graph_engine/index_type.py create mode 100644 pilot/graph_engine/kv_index.py create mode 100644 pilot/graph_engine/node.py create mode 100644 pilot/graph_engine/search.py create mode 100644 pilot/scene/chat_knowledge/extract_entity/__init__.py create mode 100644 pilot/scene/chat_knowledge/extract_entity/chat.py create mode 100644 pilot/scene/chat_knowledge/extract_entity/out_parser.py create mode 100644 pilot/scene/chat_knowledge/extract_entity/prompt.py create mode 100644 pilot/scene/chat_knowledge/extract_triplet/__init__.py create mode 100644 pilot/scene/chat_knowledge/extract_triplet/chat.py create mode 100644 pilot/scene/chat_knowledge/extract_triplet/out_parser.py create mode 100644 pilot/scene/chat_knowledge/extract_triplet/prompt.py diff --git a/pilot/common/chat_util.py b/pilot/common/chat_util.py new file mode 100644 index 000000000..159db99d0 --- /dev/null +++ b/pilot/common/chat_util.py @@ -0,0 +1,20 @@ +import asyncio + +from starlette.responses import StreamingResponse + +from pilot.scene.base_chat import BaseChat +from pilot.scene.chat_factory import ChatFactory + +chat_factory = ChatFactory() + + +async def llm_chat_response_nostream(chat_scene: str, **chat_param): + """ llm_chat_response_nostream """ + chat: BaseChat = chat_factory.get_implementation(chat_scene, **chat_param) + res = await chat.get_llm_response() + return res + + +async def llm_chat_response(chat_scene: str, **chat_param): + chat: BaseChat = chat_factory.get_implementation(chat_scene, **chat_param) + return chat.stream_call() diff --git a/pilot/graph_engine/__init__.py b/pilot/graph_engine/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/graph_engine/graph_engine.py b/pilot/graph_engine/graph_engine.py new file mode 100644 index 000000000..c20142123 --- /dev/null +++ b/pilot/graph_engine/graph_engine.py @@ -0,0 +1,137 @@ +import logging +from typing import Any, Optional, Callable, Tuple, List + +from langchain.schema import Document +from langchain.text_splitter import RecursiveCharacterTextSplitter + +from pilot.embedding_engine import KnowledgeType +from pilot.embedding_engine.knowledge_type import get_knowledge_embedding +from pilot.graph_engine.index_struct import KG +from pilot.graph_engine.node import TextNode +from pilot.utils import utils + +logger = logging.getLogger(__name__) + + +class RAGGraphEngine: + """Knowledge RAG Graph Engine. + Build a KG by extracting triplets, and leveraging the KG during query-time. + Args: + knowledge_type (Optional[str]): Default: KnowledgeType.DOCUMENT.value + extracting triplets. + graph_store (Optional[GraphStore]): The graph store to use.refrence:llama-index + include_embeddings (bool): Whether to include embeddings in the index. + Defaults to False. + max_object_length (int): The maximum length of the object in a triplet. + Defaults to 128. + extract_triplet_fn (Optional[Callable]): The function to use for + extracting triplets. Defaults to None. + """ + + index_struct_cls = KG + + def __init__( + self, + knowledge_type: Optional[str] = KnowledgeType.DOCUMENT.value, + knowledge_source: Optional[str] = None, + text_splitter=None, + graph_store=None, + index_struct: Optional[KG] = None, + model_name: Optional[str] = None, + max_triplets_per_chunk: int = 10, + include_embeddings: bool = False, + max_object_length: int = 128, + extract_triplet_fn: Optional[Callable] = None, + **kwargs: Any, + ) -> None: + """Initialize params.""" + # from llama_index.graph_stores import SimpleGraphStore + # from llama_index.graph_stores.types import GraphStore + + # need to set parameters before building index in base class. + self.knowledge_source = knowledge_source + self.knowledge_type = knowledge_type + self.model_name = model_name + self.text_splitter = text_splitter + self.index_struct = index_struct + self.include_embeddings = include_embeddings + # self.graph_store = graph_store or SimpleGraphStore() + self.graph_store = graph_store + self.max_triplets_per_chunk = max_triplets_per_chunk + self._max_object_length = max_object_length + self._extract_triplet_fn = extract_triplet_fn + + def knowledge_graph(self, docs=None): + """knowledge docs into graph store""" + if not docs: + if self.text_splitter: + self.text_splitter = RecursiveCharacterTextSplitter( + chunk_size=2000, chunk_overlap=100 + ) + knowledge_source = get_knowledge_embedding( + knowledge_type=self.knowledge_type, + knowledge_source=self.knowledge_source, + text_splitter=self.text_splitter, + ) + docs = knowledge_source.read() + if self.index_struct is None: + self.index_struct = self._build_index_from_docs(docs) + + def _extract_triplets(self, text: str) -> List[Tuple[str, str, str]]: + """Extract triplets from text by function or llm""" + if self._extract_triplet_fn is not None: + return self._extract_triplet_fn(text) + else: + return self._llm_extract_triplets(text) + + def _llm_extract_triplets(self, text: str) -> List[Tuple[str, str, str]]: + """Extract triplets from text by llm""" + from pilot.scene.base import ChatScene + from pilot.common.chat_util import llm_chat_response_nostream + import uuid + + chat_param = { + "chat_session_id": uuid.uuid1(), + "current_user_input": text, + "select_param": "triplet", + "model_name": self.model_name, + } + loop = utils.get_or_create_event_loop() + triplets = loop.run_until_complete( + llm_chat_response_nostream( + ChatScene.ExtractTriplet.value(), **{"chat_param": chat_param} + ) + ) + return triplets + # response = self._service_context.llm_predictor.predict( + # self.kg_triple_extract_template, + # text=text, + # ) + # print(response, flush=True) + # return self._parse_triplet_response( + # response, max_length=self._max_object_length + # ) + + def _build_index_from_docs(self, documents: List[Document]) -> KG: + """Build the index from nodes.""" + index_struct = self.index_struct_cls() + for doc in documents: + triplets = self._extract_triplets(doc.page_content) + if len(triplets) == 0: + continue + text_node = TextNode(text=doc.page_content, metadata=doc.metadata) + logger.info(f"extracted knowledge triplets: {triplets}") + for triplet in triplets: + subj, _, obj = triplet + self.graph_store.upsert_triplet(*triplet) + index_struct.add_node([subj, obj], text_node) + + + return index_struct + + def search(self, query): + from pilot.graph_engine.graph_search import RAGGraphSearch + + graph_search = RAGGraphSearch(graph_engine=self) + return graph_search.search(query) + diff --git a/pilot/graph_engine/graph_factory.py b/pilot/graph_engine/graph_factory.py new file mode 100644 index 000000000..3a8b99c17 --- /dev/null +++ b/pilot/graph_engine/graph_factory.py @@ -0,0 +1,34 @@ +from __future__ import annotations +from abc import ABC, abstractmethod +from typing import Any, Type + +from pilot.component import BaseComponent, ComponentType + + +class RAGGraphFactory(BaseComponent, ABC): + name = ComponentType.RAG_GRAPH_DEFAULT.value + + @abstractmethod + def create(self, model_name: str = None, embedding_cls: Type = None): + """Create RAG Graph Engine""" + + +class DefaultRAGGraphFactory(RAGGraphFactory): + def __init__( + self, system_app=None, default_model_name: str = None, **kwargs: Any + ) -> None: + super().__init__(system_app=system_app) + self._default_model_name = default_model_name + self.kwargs = kwargs + from pilot.graph_engine.graph_engine import RAGGraphEngine + + self.rag_engine = RAGGraphEngine(model_name="proxyllm") + + def init_app(self, system_app): + pass + + def create(self, model_name: str = None, rag_cls: Type = None): + if not model_name: + model_name = self._default_model_name + + return self.rag_engine diff --git a/pilot/graph_engine/graph_search.py b/pilot/graph_engine/graph_search.py new file mode 100644 index 000000000..9b06fd234 --- /dev/null +++ b/pilot/graph_engine/graph_search.py @@ -0,0 +1,193 @@ +import logging +import os +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor +from typing import List, Optional, Dict, Any, Set, Callable + +from pilot.graph_engine.node import BaseNode, TextNode, NodeWithScore +from pilot.graph_engine.search import BaseSearch, SearchMode +from pilot.utils import utils + +logger = logging.getLogger(__name__) +DEFAULT_NODE_SCORE = 1000.0 +GLOBAL_EXPLORE_NODE_LIMIT = 3 +REL_TEXT_LIMIT = 30 + + +class RAGGraphSearch(BaseSearch): + """RAG Graph Search. + + args: + graph_engine RAGGraphEngine. + model_name (str): model name + (see :ref:`Prompt-Templates`). + text_qa_template (Optional[BasePromptTemplate]): A Question Answering Prompt + (see :ref:`Prompt-Templates`). + max_keywords_per_query (int): Maximum number of keywords to extract from query. + num_chunks_per_query (int): Maximum number of text chunks to query. + search_mode (Optional[SearchMode]): Specifies whether to use keyowrds, default SearchMode.KEYWORD + embeddings, or both to find relevant triplets. Should be one of "keyword", + "embedding", or "hybrid". + graph_store_query_depth (int): The depth of the graph store query. + extract_subject_entities_fn (Optional[Callback]): extract_subject_entities callback. + """ + + def __init__( + self, + graph_engine, + model_name: str = None, + max_keywords_per_query: int = 10, + num_chunks_per_query: int = 10, + search_mode: Optional[SearchMode] = SearchMode.KEYWORD, + graph_store_query_depth: int = 2, + extract_subject_entities_fn: Optional[Callable] = None, + **kwargs: Any, + ) -> None: + """Initialize params.""" + from pilot.graph_engine.graph_engine import RAGGraphEngine + + self.graph_engine: RAGGraphEngine = graph_engine + self.model_name = model_name or self.graph_engine.model_name + self._index_struct = self.graph_engine.index_struct + self.max_keywords_per_query = max_keywords_per_query + self.num_chunks_per_query = num_chunks_per_query + self._search_mode = search_mode + + self._graph_store = self.graph_engine.graph_store + self.graph_store_query_depth = graph_store_query_depth + self._verbose = kwargs.get("verbose", False) + refresh_schema = kwargs.get("refresh_schema", False) + self.extract_subject_entities_fn = extract_subject_entities_fn + self.executor = ThreadPoolExecutor(max_workers=os.cpu_count() * 5) + try: + self._graph_schema = self._graph_store.get_schema(refresh=refresh_schema) + except NotImplementedError: + self._graph_schema = "" + except Exception as e: + logger.warn(f"can not to find graph schema: {e}") + self._graph_schema = "" + + def _extract_subject_entities(self, query_str: str) -> Set[str]: + """extract subject entities.""" + if self.extract_subject_entities_fn is not None: + return self.extract_subject_entities_fn(query_str) + else: + return self._extract_entities_by_llm(query_str) + + def _extract_entities_by_llm(self, text: str) -> Set[str]: + """extract subject entities from text by llm""" + from pilot.scene.base import ChatScene + from pilot.common.chat_util import llm_chat_response_nostream + import uuid + + chat_param = { + "chat_session_id": uuid.uuid1(), + "current_user_input": text, + "select_param": "entity", + "model_name": self.model_name, + } + loop = utils.get_or_create_event_loop() + entities = loop.run_until_complete( + llm_chat_response_nostream( + ChatScene.ExtractEntity.value(), **{"chat_param": chat_param} + ) + ) + return entities + + def _search( + self, + query_str: str, + ) -> List[NodeWithScore]: + """Get nodes for response.""" + node_visited = set() + keywords = self._extract_subject_entities(query_str) + print(f"extract entities: {keywords}\n") + rel_texts = [] + cur_rel_map = {} + chunk_indices_count: Dict[str, int] = defaultdict(int) + if self._search_mode != SearchMode.EMBEDDING: + for keyword in keywords: + keyword = keyword.lower() + subjs = set((keyword,)) + node_ids = self._index_struct.search_node_by_keyword(keyword) + for node_id in node_ids[:GLOBAL_EXPLORE_NODE_LIMIT]: + if node_id in node_visited: + continue + + if self._include_text: + chunk_indices_count[node_id] += 1 + + node_visited.add(node_id) + + rel_map = self._graph_store.get_rel_map( + list(subjs), self.graph_store_query_depth + ) + logger.debug(f"rel_map: {rel_map}") + + if not rel_map: + continue + rel_texts.extend( + [ + str(rel_obj) + for rel_objs in rel_map.values() + for rel_obj in rel_objs + ] + ) + cur_rel_map.update(rel_map) + + sorted_nodes_with_scores = [] + if not rel_texts: + logger.info("> No relationships found, returning nodes found by keywords.") + if len(sorted_nodes_with_scores) == 0: + logger.info("> No nodes found by keywords, returning empty response.") + return [ + NodeWithScore(node=TextNode(text="No relationships found."), score=1.0) + ] + + # add relationships as Node + # TODO: make initial text customizable + rel_initial_text = ( + f"The following are knowledge sequence in max depth" + f" {self.graph_store_query_depth} " + f"in the form of directed graph like:\n" + f"`subject -[predicate]->, object, <-[predicate_next_hop]-," + f" object_next_hop ...`" + ) + rel_info = [rel_initial_text] + rel_texts + rel_node_info = { + "kg_rel_texts": rel_texts, + "kg_rel_map": cur_rel_map, + } + if self._graph_schema != "": + rel_node_info["kg_schema"] = {"schema": self._graph_schema} + rel_info_text = "\n".join( + [ + str(item) + for sublist in rel_info + for item in (sublist if isinstance(sublist, list) else [sublist]) + ] + ) + if self._verbose: + print(f"KG context:\n{rel_info_text}\n", color="blue") + rel_text_node = TextNode( + text=rel_info_text, + metadata=rel_node_info, + excluded_embed_metadata_keys=["kg_rel_map", "kg_rel_texts"], + excluded_llm_metadata_keys=["kg_rel_map", "kg_rel_texts"], + ) + # this node is constructed from rel_texts, give high confidence to avoid cutoff + sorted_nodes_with_scores.append( + NodeWithScore(node=rel_text_node, score=DEFAULT_NODE_SCORE) + ) + + return sorted_nodes_with_scores + + def _get_metadata_for_response( + self, nodes: List[BaseNode] + ) -> Optional[Dict[str, Any]]: + """Get metadata for response.""" + for node in nodes: + if node.metadata is None or "kg_rel_map" not in node.metadata: + continue + return node.metadata + raise ValueError("kg_rel_map must be found in at least one Node.") \ No newline at end of file diff --git a/pilot/graph_engine/index_struct.py b/pilot/graph_engine/index_struct.py new file mode 100644 index 000000000..edc47a7ac --- /dev/null +++ b/pilot/graph_engine/index_struct.py @@ -0,0 +1,259 @@ +"""Data structures. + +Nodes are decoupled from the indices. + +""" + +import uuid +from abc import abstractmethod +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Sequence, Set + +from dataclasses_json import DataClassJsonMixin + + +from pilot.graph_engine.index_type import IndexStructType +from pilot.graph_engine.node import TextNode, BaseNode + +# TODO: legacy backport of old Node class +Node = TextNode + + +@dataclass +class IndexStruct(DataClassJsonMixin): + """A base data struct for a LlamaIndex.""" + + index_id: str = field(default_factory=lambda: str(uuid.uuid4())) + summary: Optional[str] = None + + def get_summary(self) -> str: + """Get text summary.""" + if self.summary is None: + raise ValueError("summary field of the index_struct not set.") + return self.summary + + @classmethod + @abstractmethod + def get_type(cls): + """Get index struct type.""" + + +@dataclass +class IndexGraph(IndexStruct): + """A graph representing the tree-structured index.""" + + # mapping from index in tree to Node doc id. + all_nodes: Dict[int, str] = field(default_factory=dict) + root_nodes: Dict[int, str] = field(default_factory=dict) + node_id_to_children_ids: Dict[str, List[str]] = field(default_factory=dict) + + @property + def node_id_to_index(self) -> Dict[str, int]: + """Map from node id to index.""" + return {node_id: index for index, node_id in self.all_nodes.items()} + + @property + def size(self) -> int: + """Get the size of the graph.""" + return len(self.all_nodes) + + def get_index(self, node: BaseNode) -> int: + """Get index of node.""" + return self.node_id_to_index[node.node_id] + + def insert( + self, + node: BaseNode, + index: Optional[int] = None, + children_nodes: Optional[Sequence[BaseNode]] = None, + ) -> None: + """Insert node.""" + index = index or self.size + node_id = node.node_id + + self.all_nodes[index] = node_id + + if children_nodes is None: + children_nodes = [] + children_ids = [n.node_id for n in children_nodes] + self.node_id_to_children_ids[node_id] = children_ids + + def get_children(self, parent_node: Optional[BaseNode]) -> Dict[int, str]: + """Get children nodes.""" + if parent_node is None: + return self.root_nodes + else: + parent_id = parent_node.node_id + children_ids = self.node_id_to_children_ids[parent_id] + return { + self.node_id_to_index[child_id]: child_id for child_id in children_ids + } + + def insert_under_parent( + self, + node: BaseNode, + parent_node: Optional[BaseNode], + new_index: Optional[int] = None, + ) -> None: + """Insert under parent node.""" + new_index = new_index or self.size + if parent_node is None: + self.root_nodes[new_index] = node.node_id + self.node_id_to_children_ids[node.node_id] = [] + else: + if parent_node.node_id not in self.node_id_to_children_ids: + self.node_id_to_children_ids[parent_node.node_id] = [] + self.node_id_to_children_ids[parent_node.node_id].append(node.node_id) + + self.all_nodes[new_index] = node.node_id + + @classmethod + def get_type(cls) -> IndexStructType: + """Get type.""" + return IndexStructType.TREE + + +@dataclass +class KeywordTable(IndexStruct): + """A table of keywords mapping keywords to text chunks.""" + + table: Dict[str, Set[str]] = field(default_factory=dict) + + def add_node(self, keywords: List[str], node: BaseNode) -> None: + """Add text to table.""" + for keyword in keywords: + if keyword not in self.table: + self.table[keyword] = set() + self.table[keyword].add(node.node_id) + + @property + def node_ids(self) -> Set[str]: + """Get all node ids.""" + return set.union(*self.table.values()) + + @property + def keywords(self) -> Set[str]: + """Get all keywords in the table.""" + return set(self.table.keys()) + + @property + def size(self) -> int: + """Get the size of the table.""" + return len(self.table) + + @classmethod + def get_type(cls) -> IndexStructType: + """Get type.""" + return IndexStructType.KEYWORD_TABLE + + +@dataclass +class IndexList(IndexStruct): + """A list of documents.""" + + nodes: List[str] = field(default_factory=list) + + def add_node(self, node: BaseNode) -> None: + """Add text to table, return current position in list.""" + # don't worry about child indices for now, nodes are all in order + self.nodes.append(node.node_id) + + @classmethod + def get_type(cls) -> IndexStructType: + """Get type.""" + return IndexStructType.LIST + + +@dataclass +class IndexDict(IndexStruct): + """A simple dictionary of documents.""" + + # TODO: slightly deprecated, should likely be a list or set now + # mapping from vector store id to node doc_id + nodes_dict: Dict[str, str] = field(default_factory=dict) + + # TODO: deprecated, not used + # mapping from node doc_id to vector store id + doc_id_dict: Dict[str, List[str]] = field(default_factory=dict) + + # TODO: deprecated, not used + # this should be empty for all other indices + embeddings_dict: Dict[str, List[float]] = field(default_factory=dict) + + def add_node( + self, + node: BaseNode, + text_id: Optional[str] = None, + ) -> str: + """Add text to table, return current position in list.""" + # # don't worry about child indices for now, nodes are all in order + # self.nodes_dict[int_id] = node + vector_id = text_id if text_id is not None else node.node_id + self.nodes_dict[vector_id] = node.node_id + + return vector_id + + def delete(self, doc_id: str) -> None: + """Delete a Node.""" + del self.nodes_dict[doc_id] + + @classmethod + def get_type(cls) -> IndexStructType: + """Get type.""" + return IndexStructType.VECTOR_STORE + + +@dataclass +class KG(IndexStruct): + """A table of keywords mapping keywords to text chunks.""" + + # Unidirectional + + # table of keywords to node ids + table: Dict[str, Set[str]] = field(default_factory=dict) + + # TODO: legacy attribute, remove in future releases + rel_map: Dict[str, List[List[str]]] = field(default_factory=dict) + + # TBD, should support vector store, now we just persist the embedding memory + # maybe chainable abstractions for *_stores could be designed + embedding_dict: Dict[str, List[float]] = field(default_factory=dict) + + @property + def node_ids(self) -> Set[str]: + """Get all node ids.""" + return set.union(*self.table.values()) + + def add_to_embedding_dict(self, triplet_str: str, embedding: List[float]) -> None: + """Add embedding to dict.""" + self.embedding_dict[triplet_str] = embedding + + def add_node(self, keywords: List[str], node: BaseNode) -> None: + """Add text to table.""" + node_id = node.node_id + for keyword in keywords: + keyword = keyword.lower() + if keyword not in self.table: + self.table[keyword] = set() + self.table[keyword].add(node_id) + + def search_node_by_keyword(self, keyword: str) -> List[str]: + """Search for nodes by keyword.""" + if keyword not in self.table: + return [] + return list(self.table[keyword]) + + @classmethod + def get_type(cls) -> IndexStructType: + """Get type.""" + return IndexStructType.KG + + +@dataclass +class EmptyIndexStruct(IndexStruct): + """Empty index.""" + + @classmethod + def get_type(cls) -> IndexStructType: + """Get type.""" + return IndexStructType.EMPTY diff --git a/pilot/graph_engine/index_type.py b/pilot/graph_engine/index_type.py new file mode 100644 index 000000000..939066be9 --- /dev/null +++ b/pilot/graph_engine/index_type.py @@ -0,0 +1,48 @@ +"""IndexStructType class.""" + +from enum import Enum + + +class IndexStructType(str, Enum): + """Index struct type. Identifier for a "type" of index. + + Attributes: + TREE ("tree"): Tree index. See :ref:`Ref-Indices-Tree` for tree indices. + LIST ("list"): Summary index. See :ref:`Ref-Indices-List` for summary indices. + KEYWORD_TABLE ("keyword_table"): Keyword table index. See + :ref:`Ref-Indices-Table` + for keyword table indices. + DICT ("dict"): Faiss Vector Store Index. See + :ref:`Ref-Indices-VectorStore` + for more information on the faiss vector store index. + SIMPLE_DICT ("simple_dict"): Simple Vector Store Index. See + :ref:`Ref-Indices-VectorStore` + for more information on the simple vector store index. + KG ("kg"): Knowledge Graph index. + See :ref:`Ref-Indices-Knowledge-Graph` for KG indices. + DOCUMENT_SUMMARY ("document_summary"): Document Summary Index. + See :ref:`Ref-Indices-Document-Summary` for Summary Indices. + + """ + + # TODO: refactor so these are properties on the base class + + NODE = "node" + TREE = "tree" + LIST = "list" + KEYWORD_TABLE = "keyword_table" + + DICT = "dict" + # simple + SIMPLE_DICT = "simple_dict" + # for KG index + KG = "kg" + SIMPLE_KG = "simple_kg" + NEBULAGRAPH = "nebulagraph" + FALKORDB = "falkordb" + + # EMPTY + EMPTY = "empty" + COMPOSITE = "composite" + + DOCUMENT_SUMMARY = "document_summary" diff --git a/pilot/graph_engine/kv_index.py b/pilot/graph_engine/kv_index.py new file mode 100644 index 000000000..7b44b7d04 --- /dev/null +++ b/pilot/graph_engine/kv_index.py @@ -0,0 +1,74 @@ +from typing import List, Optional +from llama_index.data_structs.data_structs import IndexStruct +from llama_index.storage.index_store.utils import ( + index_struct_to_json, + json_to_index_struct, +) +from llama_index.storage.kvstore.types import BaseKVStore + +DEFAULT_NAMESPACE = "index_store" + + +class KVIndexStore: + """Key-Value Index store. + + Args: + kvstore (BaseKVStore): key-value store + namespace (str): namespace for the index store + + """ + + def __init__(self, kvstore: BaseKVStore, namespace: Optional[str] = None) -> None: + """Init a KVIndexStore.""" + self._kvstore = kvstore + self._namespace = namespace or DEFAULT_NAMESPACE + self._collection = f"{self._namespace}/data" + + def add_index_struct(self, index_struct: IndexStruct) -> None: + """Add an index struct. + + Args: + index_struct (IndexStruct): index struct + + """ + key = index_struct.index_id + data = index_struct_to_json(index_struct) + self._kvstore.put(key, data, collection=self._collection) + + def delete_index_struct(self, key: str) -> None: + """Delete an index struct. + + Args: + key (str): index struct key + + """ + self._kvstore.delete(key, collection=self._collection) + + def get_index_struct( + self, struct_id: Optional[str] = None + ) -> Optional[IndexStruct]: + """Get an index struct. + + Args: + struct_id (Optional[str]): index struct id + + """ + if struct_id is None: + structs = self.index_structs() + assert len(structs) == 1 + return structs[0] + else: + json = self._kvstore.get(struct_id, collection=self._collection) + if json is None: + return None + return json_to_index_struct(json) + + def index_structs(self) -> List[IndexStruct]: + """Get all index structs. + + Returns: + List[IndexStruct]: index structs + + """ + jsons = self._kvstore.get_all(collection=self._collection) + return [json_to_index_struct(json) for json in jsons.values()] diff --git a/pilot/graph_engine/node.py b/pilot/graph_engine/node.py new file mode 100644 index 000000000..6f6d45ae4 --- /dev/null +++ b/pilot/graph_engine/node.py @@ -0,0 +1,569 @@ +"""Base schema for data structures.""" +import json +import textwrap +import uuid +from abc import abstractmethod +from enum import Enum, auto +from hashlib import sha256 +from typing import Any, Dict, List, Optional, Union + +from langchain.schema import Document +from pydantic import BaseModel, Field, root_validator +from typing_extensions import Self + + +DEFAULT_TEXT_NODE_TMPL = "{metadata_str}\n\n{content}" +DEFAULT_METADATA_TMPL = "{key}: {value}" +# NOTE: for pretty printing +TRUNCATE_LENGTH = 350 +WRAP_WIDTH = 70 + + +class BaseComponent(BaseModel): + """Base component object to caputure class names.""" + """reference llama-index""" + + @classmethod + @abstractmethod + def class_name(cls) -> str: + """Get class name.""" + + def to_dict(self, **kwargs: Any) -> Dict[str, Any]: + data = self.dict(**kwargs) + data["class_name"] = self.class_name() + return data + + def to_json(self, **kwargs: Any) -> str: + data = self.to_dict(**kwargs) + return json.dumps(data) + + # TODO: return type here not supported by current mypy version + @classmethod + def from_dict(cls, data: Dict[str, Any], **kwargs: Any) -> Self: # type: ignore + if isinstance(kwargs, dict): + data.update(kwargs) + + data.pop("class_name", None) + return cls(**data) + + @classmethod + def from_json(cls, data_str: str, **kwargs: Any) -> Self: # type: ignore + data = json.loads(data_str) + return cls.from_dict(data, **kwargs) + + +class NodeRelationship(str, Enum): + """Node relationships used in `BaseNode` class. + + Attributes: + SOURCE: The node is the source document. + PREVIOUS: The node is the previous node in the document. + NEXT: The node is the next node in the document. + PARENT: The node is the parent node in the document. + CHILD: The node is a child node in the document. + + """ + + SOURCE = auto() + PREVIOUS = auto() + NEXT = auto() + PARENT = auto() + CHILD = auto() + + +class ObjectType(str, Enum): + TEXT = auto() + IMAGE = auto() + INDEX = auto() + DOCUMENT = auto() + + +class MetadataMode(str, Enum): + ALL = auto() + EMBED = auto() + LLM = auto() + NONE = auto() + + +class RelatedNodeInfo(BaseComponent): + node_id: str + node_type: Optional[ObjectType] = None + metadata: Dict[str, Any] = Field(default_factory=dict) + hash: Optional[str] = None + + @classmethod + def class_name(cls) -> str: + """Get class name.""" + return "RelatedNodeInfo" + + +RelatedNodeType = Union[RelatedNodeInfo, List[RelatedNodeInfo]] + + +# Node classes for indexes +class BaseNode(BaseComponent): + """Base node Object. + + Generic abstract interface for retrievable nodes + + """ + + class Config: + allow_population_by_field_name = True + + id_: str = Field( + default_factory=lambda: str(uuid.uuid4()), description="Unique ID of the node." + ) + embedding: Optional[List[float]] = Field( + default=None, description="Embedding of the node." + ) + + """" + metadata fields + - injected as part of the text shown to LLMs as context + - injected as part of the text for generating embeddings + - used by vector DBs for metadata filtering + + """ + metadata: Dict[str, Any] = Field( + default_factory=dict, + description="A flat dictionary of metadata fields", + alias="extra_info", + ) + excluded_embed_metadata_keys: List[str] = Field( + default_factory=list, + description="Metadata keys that are exluded from text for the embed model.", + ) + excluded_llm_metadata_keys: List[str] = Field( + default_factory=list, + description="Metadata keys that are exluded from text for the LLM.", + ) + relationships: Dict[NodeRelationship, RelatedNodeType] = Field( + default_factory=dict, + description="A mapping of relationships to other node information.", + ) + hash: str = Field(default="", description="Hash of the node content.") + + @classmethod + @abstractmethod + def get_type(cls) -> str: + """Get Object type.""" + + @abstractmethod + def get_content(self, metadata_mode: MetadataMode = MetadataMode.ALL) -> str: + """Get object content.""" + + @abstractmethod + def get_metadata_str(self, mode: MetadataMode = MetadataMode.ALL) -> str: + """Metadata string.""" + + @abstractmethod + def set_content(self, value: Any) -> None: + """Set the content of the node.""" + + @property + def node_id(self) -> str: + return self.id_ + + @node_id.setter + def node_id(self, value: str) -> None: + self.id_ = value + + @property + def source_node(self) -> Optional[RelatedNodeInfo]: + """Source object node. + + Extracted from the relationships field. + + """ + if NodeRelationship.SOURCE not in self.relationships: + return None + + relation = self.relationships[NodeRelationship.SOURCE] + if isinstance(relation, list): + raise ValueError("Source object must be a single RelatedNodeInfo object") + return relation + + @property + def prev_node(self) -> Optional[RelatedNodeInfo]: + """Prev node.""" + if NodeRelationship.PREVIOUS not in self.relationships: + return None + + relation = self.relationships[NodeRelationship.PREVIOUS] + if not isinstance(relation, RelatedNodeInfo): + raise ValueError("Previous object must be a single RelatedNodeInfo object") + return relation + + @property + def next_node(self) -> Optional[RelatedNodeInfo]: + """Next node.""" + if NodeRelationship.NEXT not in self.relationships: + return None + + relation = self.relationships[NodeRelationship.NEXT] + if not isinstance(relation, RelatedNodeInfo): + raise ValueError("Next object must be a single RelatedNodeInfo object") + return relation + + @property + def parent_node(self) -> Optional[RelatedNodeInfo]: + """Parent node.""" + if NodeRelationship.PARENT not in self.relationships: + return None + + relation = self.relationships[NodeRelationship.PARENT] + if not isinstance(relation, RelatedNodeInfo): + raise ValueError("Parent object must be a single RelatedNodeInfo object") + return relation + + @property + def child_nodes(self) -> Optional[List[RelatedNodeInfo]]: + """Child nodes.""" + if NodeRelationship.CHILD not in self.relationships: + return None + + relation = self.relationships[NodeRelationship.CHILD] + if not isinstance(relation, list): + raise ValueError("Child objects must be a list of RelatedNodeInfo objects.") + return relation + + @property + def ref_doc_id(self) -> Optional[str]: + """Deprecated: Get ref doc id.""" + source_node = self.source_node + if source_node is None: + return None + return source_node.node_id + + @property + def extra_info(self) -> Dict[str, Any]: + """TODO: DEPRECATED: Extra info.""" + return self.metadata + + def __str__(self) -> str: + source_text_truncated = truncate_text( + self.get_content().strip(), TRUNCATE_LENGTH + ) + source_text_wrapped = textwrap.fill( + f"Text: {source_text_truncated}\n", width=WRAP_WIDTH + ) + return f"Node ID: {self.node_id}\n{source_text_wrapped}" + + def truncate_text(text: str, max_length: int) -> str: + """Truncate text to a maximum length.""" + if len(text) <= max_length: + return text + return text[: max_length - 3] + "..." + + def get_embedding(self) -> List[float]: + """Get embedding. + + Errors if embedding is None. + + """ + if self.embedding is None: + raise ValueError("embedding not set.") + return self.embedding + + def as_related_node_info(self) -> RelatedNodeInfo: + """Get node as RelatedNodeInfo.""" + return RelatedNodeInfo( + node_id=self.node_id, metadata=self.metadata, hash=self.hash + ) + + +class TextNode(BaseNode): + text: str = Field(default="", description="Text content of the node.") + start_char_idx: Optional[int] = Field( + default=None, description="Start char index of the node." + ) + end_char_idx: Optional[int] = Field( + default=None, description="End char index of the node." + ) + text_template: str = Field( + default=DEFAULT_TEXT_NODE_TMPL, + description=( + "Template for how text is formatted, with {content} and " + "{metadata_str} placeholders." + ), + ) + metadata_template: str = Field( + default=DEFAULT_METADATA_TMPL, + description=( + "Template for how metadata is formatted, with {key} and " + "{value} placeholders." + ), + ) + metadata_seperator: str = Field( + default="\n", + description="Seperator between metadata fields when converting to string.", + ) + + @classmethod + def class_name(cls) -> str: + """Get class name.""" + return "TextNode" + + @root_validator + def _check_hash(cls, values: dict) -> dict: + """Generate a hash to represent the node.""" + text = values.get("text", "") + metadata = values.get("metadata", {}) + doc_identity = str(text) + str(metadata) + values["hash"] = str( + sha256(doc_identity.encode("utf-8", "surrogatepass")).hexdigest() + ) + return values + + @classmethod + def get_type(cls) -> str: + """Get Object type.""" + return ObjectType.TEXT + + def get_content(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str: + """Get object content.""" + metadata_str = self.get_metadata_str(mode=metadata_mode).strip() + if not metadata_str: + return self.text + + return self.text_template.format( + content=self.text, metadata_str=metadata_str + ).strip() + + def get_metadata_str(self, mode: MetadataMode = MetadataMode.ALL) -> str: + """metadata info string.""" + if mode == MetadataMode.NONE: + return "" + + usable_metadata_keys = set(self.metadata.keys()) + if mode == MetadataMode.LLM: + for key in self.excluded_llm_metadata_keys: + if key in usable_metadata_keys: + usable_metadata_keys.remove(key) + elif mode == MetadataMode.EMBED: + for key in self.excluded_embed_metadata_keys: + if key in usable_metadata_keys: + usable_metadata_keys.remove(key) + + return self.metadata_seperator.join( + [ + self.metadata_template.format(key=key, value=str(value)) + for key, value in self.metadata.items() + if key in usable_metadata_keys + ] + ) + + def set_content(self, value: str) -> None: + """Set the content of the node.""" + self.text = value + + def get_node_info(self) -> Dict[str, Any]: + """Get node info.""" + return {"start": self.start_char_idx, "end": self.end_char_idx} + + def get_text(self) -> str: + return self.get_content(metadata_mode=MetadataMode.NONE) + + @property + def node_info(self) -> Dict[str, Any]: + """Deprecated: Get node info.""" + return self.get_node_info() + + +# TODO: legacy backport of old Node class +Node = TextNode + + +class ImageNode(TextNode): + """Node with image.""" + + # TODO: store reference instead of actual image + # base64 encoded image str + image: Optional[str] = None + + @classmethod + def get_type(cls) -> str: + return ObjectType.IMAGE + + @classmethod + def class_name(cls) -> str: + """Get class name.""" + return "ImageNode" + + +class IndexNode(TextNode): + """Node with reference to any object. + + This can include other indices, query engines, retrievers. + + This can also include other nodes (though this is overlapping with `relationships` + on the Node class). + + """ + + index_id: str + + @classmethod + def from_text_node( + cls, + node: TextNode, + index_id: str, + ) -> "IndexNode": + """Create index node from text node.""" + # copy all attributes from text node, add index id + return cls( + **node.dict(), + index_id=index_id, + ) + + @classmethod + def get_type(cls) -> str: + return ObjectType.INDEX + + @classmethod + def class_name(cls) -> str: + """Get class name.""" + return "IndexNode" + + +class NodeWithScore(BaseComponent): + node: BaseNode + score: Optional[float] = None + + def __str__(self) -> str: + return f"{self.node}\nScore: {self.score: 0.3f}\n" + + def get_score(self, raise_error: bool = False) -> float: + """Get score.""" + if self.score is None: + if raise_error: + raise ValueError("Score not set.") + else: + return 0.0 + else: + return self.score + + @classmethod + def class_name(cls) -> str: + """Get class name.""" + return "NodeWithScore" + + ##### pass through methods to BaseNode ##### + @property + def node_id(self) -> str: + return self.node.node_id + + @property + def id_(self) -> str: + return self.node.id_ + + @property + def text(self) -> str: + if isinstance(self.node, TextNode): + return self.node.text + else: + raise ValueError("Node must be a TextNode to get text.") + + @property + def metadata(self) -> Dict[str, Any]: + return self.node.metadata + + @property + def embedding(self) -> Optional[List[float]]: + return self.node.embedding + + def get_text(self) -> str: + if isinstance(self.node, TextNode): + return self.node.get_text() + else: + raise ValueError("Node must be a TextNode to get text.") + + def get_content(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str: + return self.node.get_content(metadata_mode=metadata_mode) + + def get_embedding(self) -> List[float]: + return self.node.get_embedding() + + +# Document Classes for Readers + + +class Document(TextNode): + """Generic interface for a data document. + + This document connects to data sources. + + """ + + # TODO: A lot of backwards compatibility logic here, clean up + id_: str = Field( + default_factory=lambda: str(uuid.uuid4()), + description="Unique ID of the node.", + alias="doc_id", + ) + + _compat_fields = {"doc_id": "id_", "extra_info": "metadata"} + + @classmethod + def get_type(cls) -> str: + """Get Document type.""" + return ObjectType.DOCUMENT + + @property + def doc_id(self) -> str: + """Get document ID.""" + return self.id_ + + def __str__(self) -> str: + source_text_truncated = truncate_text( + self.get_content().strip(), TRUNCATE_LENGTH + ) + source_text_wrapped = textwrap.fill( + f"Text: {source_text_truncated}\n", width=WRAP_WIDTH + ) + return f"Doc ID: {self.doc_id}\n{source_text_wrapped}" + + def get_doc_id(self) -> str: + """TODO: Deprecated: Get document ID.""" + return self.id_ + + def __setattr__(self, name: str, value: object) -> None: + if name in self._compat_fields: + name = self._compat_fields[name] + super().__setattr__(name, value) + + def to_langchain_format(self) -> Document: + """Convert struct to LangChain document format.""" + metadata = self.metadata or {} + return Document(page_content=self.text, metadata=metadata) + + @classmethod + def from_langchain_format(cls, doc: Document) -> "Document": + """Convert struct from LangChain document format.""" + return cls(text=doc.page_content, metadata=doc.metadata) + + @classmethod + def example(cls) -> "Document": + document = Document( + text="", + metadata={"filename": "README.md", "category": "codebase"}, + ) + return document + + @classmethod + def class_name(cls) -> str: + """Get class name.""" + return "Document" + + +class ImageDocument(Document): + """Data document containing an image.""" + + # base64 encoded image str + image: Optional[str] = None + + @classmethod + def class_name(cls) -> str: + """Get class name.""" + return "ImageDocument" diff --git a/pilot/graph_engine/search.py b/pilot/graph_engine/search.py new file mode 100644 index 000000000..8db837278 --- /dev/null +++ b/pilot/graph_engine/search.py @@ -0,0 +1,44 @@ +from abc import ABC, abstractmethod +from enum import Enum + + +class SearchMode(str, Enum): + """Query mode enum for Knowledge Graphs. + + Can be passed as the enum struct, or as the underlying string. + + Attributes: + KEYWORD ("keyword"): Default query mode, using keywords to find triplets. + EMBEDDING ("embedding"): Embedding mode, using embeddings to find + similar triplets. + HYBRID ("hybrid"): Hyrbid mode, combining both keywords and embeddings + to find relevant triplets. + """ + + KEYWORD = "keyword" + EMBEDDING = "embedding" + HYBRID = "hybrid" + + +class BaseSearch(ABC): + """Base Search.""" + + def search(self, query: str): + """Retrieve nodes given query. + + Args: + query (QueryType): Either a query string or + a QueryBundle object. + + """ + # if isinstance(query, str): + return self._search(query) + + @abstractmethod + def _search(self, query: str): + """search nodes given query. + + Implemented by the user. + + """ + pass diff --git a/pilot/scene/chat_knowledge/extract_entity/__init__.py b/pilot/scene/chat_knowledge/extract_entity/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/scene/chat_knowledge/extract_entity/chat.py b/pilot/scene/chat_knowledge/extract_entity/chat.py new file mode 100644 index 000000000..bb52961b5 --- /dev/null +++ b/pilot/scene/chat_knowledge/extract_entity/chat.py @@ -0,0 +1,35 @@ +from typing import Dict + +from pilot.scene.base_chat import BaseChat +from pilot.scene.base import ChatScene +from pilot.configs.config import Config + +from pilot.scene.chat_knowledge.extract_entity.prompt import prompt + +CFG = Config() + + +class ExtractEntity(BaseChat): + chat_scene: str = ChatScene.ExtractEntity.value() + + """extracting entities by llm""" + + def __init__(self, chat_param: Dict): + """ """ + chat_param["chat_mode"] = ChatScene.ExtractEntity + super().__init__( + chat_param=chat_param, + ) + + self.user_input = chat_param["current_user_input"] + self.extract_mode = chat_param["select_param"] + + def generate_input_values(self): + input_values = { + "text": self.user_input, + } + return input_values + + @property + def chat_type(self) -> str: + return ChatScene.ExtractEntity.value diff --git a/pilot/scene/chat_knowledge/extract_entity/out_parser.py b/pilot/scene/chat_knowledge/extract_entity/out_parser.py new file mode 100644 index 000000000..4093e460f --- /dev/null +++ b/pilot/scene/chat_knowledge/extract_entity/out_parser.py @@ -0,0 +1,39 @@ +import json +import logging +from typing import Set + +from pilot.out_parser.base import BaseOutputParser, T +from pilot.configs.config import Config + +CFG = Config() + + +logger = logging.getLogger(__name__) + + +class ExtractEntityParser(BaseOutputParser): + def __init__(self, sep: str, is_stream_out: bool): + super().__init__(sep=sep, is_stream_out=is_stream_out) + + def parse_prompt_response(self, response, max_length: int = 128) -> Set[str]: + lowercase = True + # clean_str = super().parse_prompt_response(response) + print("clean prompt response:", response) + + results = [] + response = response.strip() # Strip newlines from responses. + + if response.startswith("KEYWORDS:"): + response = response[len("KEYWORDS:") :] + + keywords = response.split(",") + for k in keywords: + rk = k + if lowercase: + rk = rk.lower() + results.append(rk.strip()) + + return set(results) + + def parse_view_response(self, speak, data) -> str: + return data diff --git a/pilot/scene/chat_knowledge/extract_entity/prompt.py b/pilot/scene/chat_knowledge/extract_entity/prompt.py new file mode 100644 index 000000000..77349bd28 --- /dev/null +++ b/pilot/scene/chat_knowledge/extract_entity/prompt.py @@ -0,0 +1,52 @@ +import json + +from pilot.prompts.prompt_new import PromptTemplate +from pilot.configs.config import Config +from pilot.scene.base import ChatScene +from pilot.common.schema import SeparatorStyle +from pilot.scene.chat_knowledge.extract_entity.out_parser import ExtractEntityParser + +from pilot.scene.chat_knowledge.extract_triplet.out_parser import ( + ExtractTripleParser, +) + + +CFG = Config() + +PROMPT_SCENE_DEFINE = """""" + +_DEFAULT_TEMPLATE = """ +"A question is provided below. Given the question, extract up to 10 " + "keywords from the text. Focus on extracting the keywords that we can use " + "to best lookup answers to the question. Avoid stopwords.\n" + "Example:" + "Text: Alice is Bob's mother." + "KEYWORDS:Alice,mother,Bob\n" + "---------------------\n" + "{text}\n" + "---------------------\n" + "Provide keywords in the following comma-separated format: 'KEYWORDS: '\n" +""" +PROMPT_RESPONSE = """""" + + +RESPONSE_FORMAT = """""" + + +PROMPT_SEP = SeparatorStyle.SINGLE.value + +PROMPT_NEED_NEED_STREAM_OUT = False + +prompt = PromptTemplate( + template_scene=ChatScene.ExtractEntity.value(), + input_variables=["text"], + response_format="", + template_define=PROMPT_SCENE_DEFINE, + template=_DEFAULT_TEMPLATE + PROMPT_RESPONSE, + stream_out=PROMPT_NEED_NEED_STREAM_OUT, + output_parser=ExtractEntityParser( + sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT + ), +) + +CFG.prompt_template_registry.register(prompt, is_default=True) diff --git a/pilot/scene/chat_knowledge/extract_triplet/__init__.py b/pilot/scene/chat_knowledge/extract_triplet/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/scene/chat_knowledge/extract_triplet/chat.py b/pilot/scene/chat_knowledge/extract_triplet/chat.py new file mode 100644 index 000000000..11fe871ab --- /dev/null +++ b/pilot/scene/chat_knowledge/extract_triplet/chat.py @@ -0,0 +1,35 @@ +from typing import Dict + +from pilot.scene.base_chat import BaseChat +from pilot.scene.base import ChatScene +from pilot.configs.config import Config + +from pilot.scene.chat_knowledge.extract_triplet.prompt import prompt + +CFG = Config() + + +class ExtractTriplet(BaseChat): + chat_scene: str = ChatScene.ExtractTriplet.value() + + """extracting triplets by llm""" + + def __init__(self, chat_param: Dict): + """ """ + chat_param["chat_mode"] = ChatScene.ExtractTriplet + super().__init__( + chat_param=chat_param, + ) + + self.user_input = chat_param["current_user_input"] + self.extract_mode = chat_param["select_param"] + + def generate_input_values(self): + input_values = { + "text": self.user_input, + } + return input_values + + @property + def chat_type(self) -> str: + return ChatScene.ExtractTriplet.value diff --git a/pilot/scene/chat_knowledge/extract_triplet/out_parser.py b/pilot/scene/chat_knowledge/extract_triplet/out_parser.py new file mode 100644 index 000000000..75606bd0f --- /dev/null +++ b/pilot/scene/chat_knowledge/extract_triplet/out_parser.py @@ -0,0 +1,57 @@ +import json +import logging +import re +from typing import List, Tuple + +from pilot.out_parser.base import BaseOutputParser, T +from pilot.configs.config import Config + +CFG = Config() + + +logger = logging.getLogger(__name__) + + +class ExtractTripleParser(BaseOutputParser): + def __init__(self, sep: str, is_stream_out: bool): + super().__init__(sep=sep, is_stream_out=is_stream_out) + + def parse_prompt_response( + self, response, max_length: int = 128 + ) -> List[Tuple[str, str, str]]: + # clean_str = super().parse_prompt_response(response) + print("clean prompt response:", response) + + if response.startswith("Triplets:"): + response = response[len("Triplets:") :] + pattern = r"\([^()]+\)" + response = re.findall(pattern, response) + # response = response.strip().split("\n") + print("parse prompt response:", response) + results = [] + for text in response: + if not text or text[0] != "(" or text[-1] != ")": + # skip empty lines and non-triplets + continue + tokens = text[1:-1].split(",") + if len(tokens) != 3: + continue + + if any(len(s.encode("utf-8")) > max_length for s in tokens): + # We count byte-length instead of len() for UTF-8 chars, + # will skip if any of the tokens are too long. + # This is normally due to a poorly formatted triplet + # extraction, in more serious KG building cases + # we'll need NLP models to better extract triplets. + continue + + subject, predicate, obj = map(str.strip, tokens) + if not subject or not predicate or not obj: + # skip partial triplets + continue + results.append((subject.lower(), predicate.lower(), obj.lower())) + return results + + def parse_view_response(self, speak, data) -> str: + ### tool out data to table view + return data diff --git a/pilot/scene/chat_knowledge/extract_triplet/prompt.py b/pilot/scene/chat_knowledge/extract_triplet/prompt.py new file mode 100644 index 000000000..dd391bce8 --- /dev/null +++ b/pilot/scene/chat_knowledge/extract_triplet/prompt.py @@ -0,0 +1,57 @@ +import json + +from pilot.prompts.prompt_new import PromptTemplate +from pilot.configs.config import Config +from pilot.scene.base import ChatScene +from pilot.common.schema import SeparatorStyle + +from pilot.scene.chat_knowledge.extract_triplet.out_parser import ( + ExtractTripleParser, +) + + +CFG = Config() + +PROMPT_SCENE_DEFINE = """""" + +_DEFAULT_TEMPLATE = """ +"Some text is provided below. Given the text, extract up to 10" + "knowledge triplets in the form of (subject, predicate, object). Avoid stopwords.\n" + "---------------------\n" + "Example:" + "Text: Alice is Bob's mother." + "Triplets:\n(Alice, is mother of, Bob)\n" + "Text: Philz is a coffee shop founded in Berkeley in 1982.\n" + "Triplets:\n" + "(Philz, is, coffee shop)\n" + "(Philz, founded in, Berkeley)\n" + "(Philz, founded in, 1982)\n" + "---------------------\n" + "Text: {text}\n" + "Triplets:\n" + ensure Respond in the following List(Tuple) format: + '(Stephen Curry, plays for, Golden State Warriors)\n(Stephen Curry, known for, shooting skills)\n(Stephen Curry, attended, Davidson College)\n(Stephen Curry, led, team to success)' +""" +PROMPT_RESPONSE = """""" + + +RESPONSE_FORMAT = """""" + + +PROMPT_SEP = SeparatorStyle.SINGLE.value + +PROMPT_NEED_NEED_STREAM_OUT = False + +prompt = PromptTemplate( + template_scene=ChatScene.ExtractTriplet.value(), + input_variables=["text"], + response_format="", + template_define=PROMPT_SCENE_DEFINE, + template=_DEFAULT_TEMPLATE + PROMPT_RESPONSE, + stream_out=PROMPT_NEED_NEED_STREAM_OUT, + output_parser=ExtractTripleParser( + sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT + ), +) + +CFG.prompt_template_registry.register(prompt, is_default=True) diff --git a/pilot/scene/chat_knowledge/v1/chat.py b/pilot/scene/chat_knowledge/v1/chat.py index ebecddd19..c381546f8 100644 --- a/pilot/scene/chat_knowledge/v1/chat.py +++ b/pilot/scene/chat_knowledge/v1/chat.py @@ -88,25 +88,7 @@ class ChatKnowledge(BaseChat): if self.space_context: self.prompt_template.template_define = self.space_context["prompt"]["scene"] self.prompt_template.template = self.space_context["prompt"]["template"] - # docs = self.rag_engine.search(query=self.current_user_input) - # import httpx - # with httpx.Client() as client: - # request = client.build_request( - # "post", - # "http://127.0.0.1/api/knowledge/entities/extract", - # json="application/json", # using json for data to ensure it sends as application/json - # params={"text": self.current_user_input}, - # headers={}, - # ) - # - # response = client.send(request) - # if response.status_code != 200: - # error_msg = f"request /api/knowledge/entities/extract failed, error: {response.text}" - # raise Exception(error_msg) - # docs = response.json() - # import requests - # docs = requests.post("http://127.0.0.1:5000/api/knowledge/entities/extract", headers={}, json={"text": self.current_user_input}) - + docs = self.rag_engine.search(query=self.current_user_input) docs = self.knowledge_embedding_client.similar_search( self.current_user_input, self.top_k ) diff --git a/pilot/server/knowledge/api.py b/pilot/server/knowledge/api.py index e0f31031e..8e5e52b58 100644 --- a/pilot/server/knowledge/api.py +++ b/pilot/server/knowledge/api.py @@ -205,12 +205,6 @@ def similar_query(space_name: str, query_request: KnowledgeQueryRequest): async def entity_extract(request: EntityExtractRequest): logger.info(f"Received params: {request}") try: - # from pilot.graph_engine.graph_factory import RAGGraphFactory - # from pilot.component import ComponentType - # rag_engine = CFG.SYSTEM_APP.get_component( - # ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory - # ).create() - # return Result.succ(await rag_engine.search(request.text)) from pilot.scene.base import ChatScene from pilot.common.chat_util import llm_chat_response_nostream import uuid @@ -222,11 +216,6 @@ async def entity_extract(request: EntityExtractRequest): "model_name": request.model_name, } - # import nest_asyncio - # nest_asyncio.apply() - # loop = asyncio.get_event_loop() - # loop.stop() - # loop = utils.get_or_create_event_loop() res = await llm_chat_response_nostream( ChatScene.ExtractEntity.value(), **{"chat_param": chat_param} )