mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-29 06:47:30 +00:00
fix:update awel embedding examples and delete unuseful code. (#1073)
This commit is contained in:
parent
cb9c34abb9
commit
3a54d1ef9a
@ -28,11 +28,6 @@ def initialize_components(
|
|||||||
system_app.register(DefaultExecutorFactory)
|
system_app.register(DefaultExecutorFactory)
|
||||||
system_app.register_instance(controller)
|
system_app.register_instance(controller)
|
||||||
|
|
||||||
# Register global default RAGGraphFactory
|
|
||||||
# from dbgpt.graph.graph_factory import DefaultRAGGraphFactory
|
|
||||||
|
|
||||||
# system_app.register(DefaultRAGGraphFactory)
|
|
||||||
|
|
||||||
from dbgpt.serve.agent.hub.controller import module_agent
|
from dbgpt.serve.agent.hub.controller import module_agent
|
||||||
|
|
||||||
system_app.register_instance(module_agent)
|
system_app.register_instance(module_agent)
|
||||||
|
@ -497,30 +497,6 @@ class KnowledgeService:
|
|||||||
res.page = request.page
|
res.page = request.page
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def async_knowledge_graph(self, chunk_docs, doc):
|
|
||||||
"""async document extract triplets and save into graph db
|
|
||||||
Args:
|
|
||||||
- chunk_docs: List[Document]
|
|
||||||
- doc: KnowledgeDocumentEntity
|
|
||||||
"""
|
|
||||||
logger.info(
|
|
||||||
f"async_knowledge_graph, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to graph store"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
from dbgpt.rag.graph.graph_factory import RAGGraphFactory
|
|
||||||
|
|
||||||
rag_engine = CFG.SYSTEM_APP.get_component(
|
|
||||||
ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory
|
|
||||||
).create()
|
|
||||||
rag_engine.knowledge_graph(chunk_docs)
|
|
||||||
doc.status = SyncStatus.FINISHED.name
|
|
||||||
doc.result = "document build graph success"
|
|
||||||
except Exception as e:
|
|
||||||
doc.status = SyncStatus.FAILED.name
|
|
||||||
doc.result = "document build graph failed" + str(e)
|
|
||||||
logger.error(f"document build graph failed:{doc.doc_name}, {str(e)}")
|
|
||||||
return knowledge_document_dao.update_knowledge_document(doc)
|
|
||||||
|
|
||||||
def async_doc_embedding(self, assembler, chunk_docs, doc):
|
def async_doc_embedding(self, assembler, chunk_docs, doc):
|
||||||
"""async document embedding into vector db
|
"""async document embedding into vector db
|
||||||
Args:
|
Args:
|
||||||
|
@ -1,148 +0,0 @@
|
|||||||
import logging
|
|
||||||
from typing import Any, Callable, List, Optional, Tuple
|
|
||||||
|
|
||||||
from langchain.schema import Document
|
|
||||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
||||||
|
|
||||||
from dbgpt.rag.embedding import KnowledgeType
|
|
||||||
from dbgpt.rag.embedding.knowledge_type import get_knowledge_embedding
|
|
||||||
from dbgpt.rag.graph.index_struct import KG
|
|
||||||
from dbgpt.rag.graph.node import TextNode
|
|
||||||
from dbgpt.util import utils
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class RAGGraphEngine:
|
|
||||||
"""Knowledge RAG Graph Engine.
|
|
||||||
Build a RAG Graph Client can extract triplets and insert into graph store.
|
|
||||||
Args:
|
|
||||||
knowledge_type (Optional[str]): Default: KnowledgeType.DOCUMENT.value
|
|
||||||
extracting triplets.
|
|
||||||
knowledge_source (Optional[str]):
|
|
||||||
model_name (Optional[str]): llm model name
|
|
||||||
graph_store (Optional[GraphStore]): The graph store to use.refrence:llama-index
|
|
||||||
include_embeddings (bool): Whether to include embeddings in the index.
|
|
||||||
Defaults to False.
|
|
||||||
max_object_length (int): The maximum length of the object in a triplet.
|
|
||||||
Defaults to 128.
|
|
||||||
extract_triplet_fn (Optional[Callable]): The function to use for
|
|
||||||
extracting triplets. Defaults to None.
|
|
||||||
"""
|
|
||||||
|
|
||||||
index_struct_cls = KG
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
knowledge_type: Optional[str] = KnowledgeType.DOCUMENT.value,
|
|
||||||
knowledge_source: Optional[str] = None,
|
|
||||||
text_splitter=None,
|
|
||||||
graph_store=None,
|
|
||||||
index_struct: Optional[KG] = None,
|
|
||||||
model_name: Optional[str] = None,
|
|
||||||
max_triplets_per_chunk: int = 10,
|
|
||||||
include_embeddings: bool = False,
|
|
||||||
max_object_length: int = 128,
|
|
||||||
extract_triplet_fn: Optional[Callable] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> None:
|
|
||||||
"""Initialize params."""
|
|
||||||
from llama_index.graph_stores import SimpleGraphStore
|
|
||||||
|
|
||||||
# need to set parameters before building index in base class.
|
|
||||||
self.knowledge_source = knowledge_source
|
|
||||||
self.knowledge_type = knowledge_type
|
|
||||||
self.model_name = model_name
|
|
||||||
self.text_splitter = text_splitter
|
|
||||||
self.index_struct = index_struct
|
|
||||||
self.include_embeddings = include_embeddings
|
|
||||||
self.graph_store = graph_store or SimpleGraphStore()
|
|
||||||
# self.graph_store = graph_store
|
|
||||||
self.max_triplets_per_chunk = max_triplets_per_chunk
|
|
||||||
self._max_object_length = max_object_length
|
|
||||||
self._extract_triplet_fn = extract_triplet_fn
|
|
||||||
|
|
||||||
def knowledge_graph(self, docs=None):
|
|
||||||
"""knowledge docs into graph store"""
|
|
||||||
if not docs:
|
|
||||||
if self.text_splitter:
|
|
||||||
self.text_splitter = RecursiveCharacterTextSplitter(
|
|
||||||
chunk_size=2000, chunk_overlap=100
|
|
||||||
)
|
|
||||||
knowledge_source = get_knowledge_embedding(
|
|
||||||
knowledge_type=self.knowledge_type,
|
|
||||||
knowledge_source=self.knowledge_source,
|
|
||||||
text_splitter=self.text_splitter,
|
|
||||||
)
|
|
||||||
docs = knowledge_source.read()
|
|
||||||
if self.index_struct is None:
|
|
||||||
self.index_struct = self._build_index_from_docs(docs)
|
|
||||||
|
|
||||||
def _extract_triplets(self, text: str) -> List[Tuple[str, str, str]]:
|
|
||||||
"""Extract triplets from text by function or llm"""
|
|
||||||
if self._extract_triplet_fn is not None:
|
|
||||||
return self._extract_triplet_fn(text)
|
|
||||||
else:
|
|
||||||
return self._llm_extract_triplets(text)
|
|
||||||
|
|
||||||
def _llm_extract_triplets(self, text: str) -> List[Tuple[str, str, str]]:
|
|
||||||
"""Extract triplets from text by llm"""
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
from dbgpt.app.scene import ChatScene
|
|
||||||
from dbgpt.util.chat_util import llm_chat_response_nostream
|
|
||||||
|
|
||||||
chat_param = {
|
|
||||||
"chat_session_id": uuid.uuid1(),
|
|
||||||
"current_user_input": text,
|
|
||||||
"select_param": "triplet",
|
|
||||||
"model_name": self.model_name,
|
|
||||||
}
|
|
||||||
loop = utils.get_or_create_event_loop()
|
|
||||||
triplets = loop.run_until_complete(
|
|
||||||
llm_chat_response_nostream(
|
|
||||||
ChatScene.ExtractTriplet.value(), **{"chat_param": chat_param}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return triplets
|
|
||||||
|
|
||||||
def _build_index_from_docs(self, documents: List[Document]) -> KG:
|
|
||||||
"""Build the index from nodes.
|
|
||||||
Args:documents:List[Document]
|
|
||||||
"""
|
|
||||||
index_struct = self.index_struct_cls()
|
|
||||||
triplets = []
|
|
||||||
for doc in documents:
|
|
||||||
trips = self._extract_triplets_task([doc], index_struct)
|
|
||||||
triplets.extend(trips)
|
|
||||||
print(triplets)
|
|
||||||
text_node = TextNode(text=doc.page_content, metadata=doc.metadata)
|
|
||||||
for triplet in triplets:
|
|
||||||
subj, _, obj = triplet
|
|
||||||
self.graph_store.upsert_triplet(*triplet)
|
|
||||||
index_struct.add_node([subj, obj], text_node)
|
|
||||||
return index_struct
|
|
||||||
|
|
||||||
def search(self, query):
|
|
||||||
from dbgpt.rag.graph.graph_search import RAGGraphSearch
|
|
||||||
|
|
||||||
graph_search = RAGGraphSearch(graph_engine=self)
|
|
||||||
return graph_search.search(query)
|
|
||||||
|
|
||||||
def _extract_triplets_task(self, docs, index_struct):
|
|
||||||
triple_results = []
|
|
||||||
for doc in docs:
|
|
||||||
import threading
|
|
||||||
|
|
||||||
thread_id = threading.get_ident()
|
|
||||||
print(f"current thread-{thread_id} begin extract triplets task")
|
|
||||||
triplets = self._extract_triplets(doc.page_content)
|
|
||||||
if len(triplets) == 0:
|
|
||||||
triplets = []
|
|
||||||
text_node = TextNode(text=doc.page_content, metadata=doc.metadata)
|
|
||||||
logger.info(f"extracted knowledge triplets: {triplets}")
|
|
||||||
print(
|
|
||||||
f"current thread-{thread_id} end extract triplets tasks, triplets-{triplets}"
|
|
||||||
)
|
|
||||||
triple_results.extend(triplets)
|
|
||||||
return triple_results
|
|
@ -1,35 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Any, Type
|
|
||||||
|
|
||||||
from dbgpt.component import BaseComponent, ComponentType
|
|
||||||
|
|
||||||
|
|
||||||
class RAGGraphFactory(BaseComponent, ABC):
|
|
||||||
name = ComponentType.RAG_GRAPH_DEFAULT.value
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def create(self, model_name: str = None, embedding_cls: Type = None):
|
|
||||||
"""Create RAG Graph Engine"""
|
|
||||||
|
|
||||||
|
|
||||||
class DefaultRAGGraphFactory(RAGGraphFactory):
|
|
||||||
def __init__(
|
|
||||||
self, system_app=None, default_model_name: str = None, **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
super().__init__(system_app=system_app)
|
|
||||||
self._default_model_name = default_model_name
|
|
||||||
self.kwargs = kwargs
|
|
||||||
from dbgpt.rag.graph.graph_engine import RAGGraphEngine
|
|
||||||
|
|
||||||
self.rag_engine = RAGGraphEngine(model_name="proxyllm")
|
|
||||||
|
|
||||||
def init_app(self, system_app):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def create(self, model_name: str = None, rag_cls: Type = None):
|
|
||||||
if not model_name:
|
|
||||||
model_name = self._default_model_name
|
|
||||||
|
|
||||||
return self.rag_engine
|
|
@ -1,198 +0,0 @@
|
|||||||
import logging
|
|
||||||
import os
|
|
||||||
from collections import defaultdict
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
from typing import Any, Callable, Dict, List, Optional, Set
|
|
||||||
|
|
||||||
from langchain.schema import Document
|
|
||||||
|
|
||||||
from dbgpt.rag.graph.node import BaseNode, NodeWithScore, TextNode
|
|
||||||
from dbgpt.rag.graph.search import BaseSearch, SearchMode
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
DEFAULT_NODE_SCORE = 1000.0
|
|
||||||
GLOBAL_EXPLORE_NODE_LIMIT = 3
|
|
||||||
REL_TEXT_LIMIT = 30
|
|
||||||
|
|
||||||
|
|
||||||
class RAGGraphSearch(BaseSearch):
|
|
||||||
"""RAG Graph Search.
|
|
||||||
|
|
||||||
args:
|
|
||||||
graph RAGGraphEngine.
|
|
||||||
model_name (str): model name
|
|
||||||
(see :ref:`Prompt-Templates`).
|
|
||||||
text_qa_template (Optional[BasePromptTemplate]): A Question Answering Prompt
|
|
||||||
(see :ref:`Prompt-Templates`).
|
|
||||||
max_keywords_per_query (int): Maximum number of keywords to extract from query.
|
|
||||||
num_chunks_per_query (int): Maximum number of text chunks to query.
|
|
||||||
search_mode (Optional[SearchMode]): Specifies whether to use keyowrds, default SearchMode.KEYWORD
|
|
||||||
embeddings, or both to find relevant triplets. Should be one of "keyword",
|
|
||||||
"embedding", or "hybrid".
|
|
||||||
graph_store_query_depth (int): The depth of the graph store query.
|
|
||||||
extract_subject_entities_fn (Optional[Callback]): extract_subject_entities callback.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
graph_engine,
|
|
||||||
model_name: str = None,
|
|
||||||
max_keywords_per_query: int = 10,
|
|
||||||
num_chunks_per_query: int = 10,
|
|
||||||
search_mode: Optional[SearchMode] = SearchMode.KEYWORD,
|
|
||||||
graph_store_query_depth: int = 2,
|
|
||||||
extract_subject_entities_fn: Optional[Callable] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> None:
|
|
||||||
"""Initialize params."""
|
|
||||||
from dbgpt.rag.graph.graph_engine import RAGGraphEngine
|
|
||||||
|
|
||||||
self.graph_engine: RAGGraphEngine = graph_engine
|
|
||||||
self.model_name = model_name or self.graph_engine.model_name
|
|
||||||
self._index_struct = self.graph_engine.index_struct
|
|
||||||
self.max_keywords_per_query = max_keywords_per_query
|
|
||||||
self.num_chunks_per_query = num_chunks_per_query
|
|
||||||
self._search_mode = search_mode
|
|
||||||
|
|
||||||
self._graph_store = self.graph_engine.graph_store
|
|
||||||
self.graph_store_query_depth = graph_store_query_depth
|
|
||||||
self._verbose = kwargs.get("verbose", False)
|
|
||||||
refresh_schema = kwargs.get("refresh_schema", False)
|
|
||||||
self.extract_subject_entities_fn = extract_subject_entities_fn
|
|
||||||
self.executor = ThreadPoolExecutor(max_workers=os.cpu_count() * 5)
|
|
||||||
try:
|
|
||||||
self._graph_schema = self._graph_store.get_schema(refresh=refresh_schema)
|
|
||||||
except NotImplementedError:
|
|
||||||
self._graph_schema = ""
|
|
||||||
except Exception as e:
|
|
||||||
logger.warn(f"can not to find graph schema: {e}")
|
|
||||||
self._graph_schema = ""
|
|
||||||
|
|
||||||
async def _extract_subject_entities(self, query_str: str) -> Set[str]:
|
|
||||||
"""extract subject entities."""
|
|
||||||
if self.extract_subject_entities_fn is not None:
|
|
||||||
return await self.extract_subject_entities_fn(query_str)
|
|
||||||
else:
|
|
||||||
return await self._extract_entities_by_llm(query_str)
|
|
||||||
|
|
||||||
async def _extract_entities_by_llm(self, text: str) -> Set[str]:
|
|
||||||
"""extract subject entities from text by llm"""
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
from dbgpt.app.scene import ChatScene
|
|
||||||
from dbgpt.util.chat_util import llm_chat_response_nostream
|
|
||||||
|
|
||||||
chat_param = {
|
|
||||||
"chat_session_id": uuid.uuid1(),
|
|
||||||
"current_user_input": text,
|
|
||||||
"select_param": "entity",
|
|
||||||
"model_name": self.model_name,
|
|
||||||
}
|
|
||||||
# loop = util.get_or_create_event_loop()
|
|
||||||
# entities = loop.run_until_complete(
|
|
||||||
# llm_chat_response_nostream(
|
|
||||||
# ChatScene.ExtractEntity.value(), **{"chat_param": chat_param}
|
|
||||||
# )
|
|
||||||
# )
|
|
||||||
return await llm_chat_response_nostream(
|
|
||||||
ChatScene.ExtractEntity.value(), **{"chat_param": chat_param}
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _search(
|
|
||||||
self,
|
|
||||||
query_str: str,
|
|
||||||
) -> List[Document]:
|
|
||||||
"""Get nodes for response."""
|
|
||||||
node_visited = set()
|
|
||||||
keywords = await self._extract_subject_entities(query_str)
|
|
||||||
print(f"extract entities: {keywords}\n")
|
|
||||||
rel_texts = []
|
|
||||||
cur_rel_map = {}
|
|
||||||
chunk_indices_count: Dict[str, int] = defaultdict(int)
|
|
||||||
if self._search_mode != SearchMode.EMBEDDING:
|
|
||||||
for keyword in keywords:
|
|
||||||
keyword = keyword.lower()
|
|
||||||
subjs = set((keyword,))
|
|
||||||
# node_ids = self._index_struct.search_node_by_keyword(keyword)
|
|
||||||
# for node_id in node_ids[:GLOBAL_EXPLORE_NODE_LIMIT]:
|
|
||||||
# if node_id in node_visited:
|
|
||||||
# continue
|
|
||||||
#
|
|
||||||
# # if self._include_text:
|
|
||||||
# # chunk_indices_count[node_id] += 1
|
|
||||||
#
|
|
||||||
# node_visited.add(node_id)
|
|
||||||
|
|
||||||
rel_map = self._graph_store.get_rel_map(
|
|
||||||
list(subjs), self.graph_store_query_depth
|
|
||||||
)
|
|
||||||
logger.debug(f"rel_map: {rel_map}")
|
|
||||||
|
|
||||||
if not rel_map:
|
|
||||||
continue
|
|
||||||
rel_texts.extend(
|
|
||||||
[
|
|
||||||
str(rel_obj)
|
|
||||||
for rel_objs in rel_map.values()
|
|
||||||
for rel_obj in rel_objs
|
|
||||||
]
|
|
||||||
)
|
|
||||||
cur_rel_map.update(rel_map)
|
|
||||||
|
|
||||||
sorted_nodes_with_scores = []
|
|
||||||
if not rel_texts:
|
|
||||||
logger.info("> No relationships found, returning nodes found by keywords.")
|
|
||||||
if len(sorted_nodes_with_scores) == 0:
|
|
||||||
logger.info("> No nodes found by keywords, returning empty response.")
|
|
||||||
return [Document(page_content="No relationships found.")]
|
|
||||||
|
|
||||||
# add relationships as Node
|
|
||||||
# TODO: make initial text customizable
|
|
||||||
rel_initial_text = (
|
|
||||||
f"The following are knowledge sequence in max depth"
|
|
||||||
f" {self.graph_store_query_depth} "
|
|
||||||
f"in the form of directed graph like:\n"
|
|
||||||
f"`subject -[predicate]->, object, <-[predicate_next_hop]-,"
|
|
||||||
f" object_next_hop ...`"
|
|
||||||
)
|
|
||||||
rel_info = [rel_initial_text] + rel_texts
|
|
||||||
rel_node_info = {
|
|
||||||
"kg_rel_texts": rel_texts,
|
|
||||||
"kg_rel_map": cur_rel_map,
|
|
||||||
}
|
|
||||||
if self._graph_schema != "":
|
|
||||||
rel_node_info["kg_schema"] = {"schema": self._graph_schema}
|
|
||||||
rel_info_text = "\n".join(
|
|
||||||
[
|
|
||||||
str(item)
|
|
||||||
for sublist in rel_info
|
|
||||||
for item in (sublist if isinstance(sublist, list) else [sublist])
|
|
||||||
]
|
|
||||||
)
|
|
||||||
if self._verbose:
|
|
||||||
print(f"KG context:\n{rel_info_text}\n", color="blue")
|
|
||||||
rel_text_node = TextNode(
|
|
||||||
text=rel_info_text,
|
|
||||||
metadata=rel_node_info,
|
|
||||||
excluded_embed_metadata_keys=["kg_rel_map", "kg_rel_texts"],
|
|
||||||
excluded_llm_metadata_keys=["kg_rel_map", "kg_rel_texts"],
|
|
||||||
)
|
|
||||||
# this node is constructed from rel_texts, give high confidence to avoid cutoff
|
|
||||||
sorted_nodes_with_scores.append(
|
|
||||||
NodeWithScore(node=rel_text_node, score=DEFAULT_NODE_SCORE)
|
|
||||||
)
|
|
||||||
docs = [
|
|
||||||
Document(page_content=node.text, metadata=node.metadata)
|
|
||||||
for node in sorted_nodes_with_scores
|
|
||||||
]
|
|
||||||
return docs
|
|
||||||
|
|
||||||
def _get_metadata_for_response(
|
|
||||||
self, nodes: List[BaseNode]
|
|
||||||
) -> Optional[Dict[str, Any]]:
|
|
||||||
"""Get metadata for response."""
|
|
||||||
for node in nodes:
|
|
||||||
if node.metadata is None or "kg_rel_map" not in node.metadata:
|
|
||||||
continue
|
|
||||||
return node.metadata
|
|
||||||
raise ValueError("kg_rel_map must be found in at least one Node.")
|
|
@ -1,258 +0,0 @@
|
|||||||
"""Data structures.
|
|
||||||
|
|
||||||
Nodes are decoupled from the indices.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
import uuid
|
|
||||||
from abc import abstractmethod
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Dict, List, Optional, Sequence, Set
|
|
||||||
|
|
||||||
from dataclasses_json import DataClassJsonMixin
|
|
||||||
|
|
||||||
from dbgpt.rag.graph.index_type import IndexStructType
|
|
||||||
from dbgpt.rag.graph.node import BaseNode, TextNode
|
|
||||||
|
|
||||||
# TODO: legacy backport of old Node class
|
|
||||||
Node = TextNode
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class IndexStruct(DataClassJsonMixin):
|
|
||||||
"""A base data struct for a LlamaIndex."""
|
|
||||||
|
|
||||||
index_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
|
||||||
summary: Optional[str] = None
|
|
||||||
|
|
||||||
def get_summary(self) -> str:
|
|
||||||
"""Get text summary."""
|
|
||||||
if self.summary is None:
|
|
||||||
raise ValueError("summary field of the index_struct not set.")
|
|
||||||
return self.summary
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@abstractmethod
|
|
||||||
def get_type(cls):
|
|
||||||
"""Get index struct type."""
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class IndexGraph(IndexStruct):
|
|
||||||
"""A graph representing the tree-structured index."""
|
|
||||||
|
|
||||||
# mapping from index in tree to Node doc id.
|
|
||||||
all_nodes: Dict[int, str] = field(default_factory=dict)
|
|
||||||
root_nodes: Dict[int, str] = field(default_factory=dict)
|
|
||||||
node_id_to_children_ids: Dict[str, List[str]] = field(default_factory=dict)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def node_id_to_index(self) -> Dict[str, int]:
|
|
||||||
"""Map from node id to index."""
|
|
||||||
return {node_id: index for index, node_id in self.all_nodes.items()}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def size(self) -> int:
|
|
||||||
"""Get the size of the graph."""
|
|
||||||
return len(self.all_nodes)
|
|
||||||
|
|
||||||
def get_index(self, node: BaseNode) -> int:
|
|
||||||
"""Get index of node."""
|
|
||||||
return self.node_id_to_index[node.node_id]
|
|
||||||
|
|
||||||
def insert(
|
|
||||||
self,
|
|
||||||
node: BaseNode,
|
|
||||||
index: Optional[int] = None,
|
|
||||||
children_nodes: Optional[Sequence[BaseNode]] = None,
|
|
||||||
) -> None:
|
|
||||||
"""Insert node."""
|
|
||||||
index = index or self.size
|
|
||||||
node_id = node.node_id
|
|
||||||
|
|
||||||
self.all_nodes[index] = node_id
|
|
||||||
|
|
||||||
if children_nodes is None:
|
|
||||||
children_nodes = []
|
|
||||||
children_ids = [n.node_id for n in children_nodes]
|
|
||||||
self.node_id_to_children_ids[node_id] = children_ids
|
|
||||||
|
|
||||||
def get_children(self, parent_node: Optional[BaseNode]) -> Dict[int, str]:
|
|
||||||
"""Get children nodes."""
|
|
||||||
if parent_node is None:
|
|
||||||
return self.root_nodes
|
|
||||||
else:
|
|
||||||
parent_id = parent_node.node_id
|
|
||||||
children_ids = self.node_id_to_children_ids[parent_id]
|
|
||||||
return {
|
|
||||||
self.node_id_to_index[child_id]: child_id for child_id in children_ids
|
|
||||||
}
|
|
||||||
|
|
||||||
def insert_under_parent(
|
|
||||||
self,
|
|
||||||
node: BaseNode,
|
|
||||||
parent_node: Optional[BaseNode],
|
|
||||||
new_index: Optional[int] = None,
|
|
||||||
) -> None:
|
|
||||||
"""Insert under parent node."""
|
|
||||||
new_index = new_index or self.size
|
|
||||||
if parent_node is None:
|
|
||||||
self.root_nodes[new_index] = node.node_id
|
|
||||||
self.node_id_to_children_ids[node.node_id] = []
|
|
||||||
else:
|
|
||||||
if parent_node.node_id not in self.node_id_to_children_ids:
|
|
||||||
self.node_id_to_children_ids[parent_node.node_id] = []
|
|
||||||
self.node_id_to_children_ids[parent_node.node_id].append(node.node_id)
|
|
||||||
|
|
||||||
self.all_nodes[new_index] = node.node_id
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_type(cls) -> IndexStructType:
|
|
||||||
"""Get type."""
|
|
||||||
return IndexStructType.TREE
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class KeywordTable(IndexStruct):
|
|
||||||
"""A table of keywords mapping keywords to text chunks."""
|
|
||||||
|
|
||||||
table: Dict[str, Set[str]] = field(default_factory=dict)
|
|
||||||
|
|
||||||
def add_node(self, keywords: List[str], node: BaseNode) -> None:
|
|
||||||
"""Add text to table."""
|
|
||||||
for keyword in keywords:
|
|
||||||
if keyword not in self.table:
|
|
||||||
self.table[keyword] = set()
|
|
||||||
self.table[keyword].add(node.node_id)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def node_ids(self) -> Set[str]:
|
|
||||||
"""Get all node ids."""
|
|
||||||
return set.union(*self.table.values())
|
|
||||||
|
|
||||||
@property
|
|
||||||
def keywords(self) -> Set[str]:
|
|
||||||
"""Get all keywords in the table."""
|
|
||||||
return set(self.table.keys())
|
|
||||||
|
|
||||||
@property
|
|
||||||
def size(self) -> int:
|
|
||||||
"""Get the size of the table."""
|
|
||||||
return len(self.table)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_type(cls) -> IndexStructType:
|
|
||||||
"""Get type."""
|
|
||||||
return IndexStructType.KEYWORD_TABLE
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class IndexList(IndexStruct):
|
|
||||||
"""A list of documents."""
|
|
||||||
|
|
||||||
nodes: List[str] = field(default_factory=list)
|
|
||||||
|
|
||||||
def add_node(self, node: BaseNode) -> None:
|
|
||||||
"""Add text to table, return current position in list."""
|
|
||||||
# don't worry about child indices for now, nodes are all in order
|
|
||||||
self.nodes.append(node.node_id)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_type(cls) -> IndexStructType:
|
|
||||||
"""Get type."""
|
|
||||||
return IndexStructType.LIST
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class IndexDict(IndexStruct):
|
|
||||||
"""A simple dictionary of documents."""
|
|
||||||
|
|
||||||
# TODO: slightly deprecated, should likely be a list or set now
|
|
||||||
# mapping from vector store id to node doc_id
|
|
||||||
nodes_dict: Dict[str, str] = field(default_factory=dict)
|
|
||||||
|
|
||||||
# TODO: deprecated, not used
|
|
||||||
# mapping from node doc_id to vector store id
|
|
||||||
doc_id_dict: Dict[str, List[str]] = field(default_factory=dict)
|
|
||||||
|
|
||||||
# TODO: deprecated, not used
|
|
||||||
# this should be empty for all other indices
|
|
||||||
embeddings_dict: Dict[str, List[float]] = field(default_factory=dict)
|
|
||||||
|
|
||||||
def add_node(
|
|
||||||
self,
|
|
||||||
node: BaseNode,
|
|
||||||
text_id: Optional[str] = None,
|
|
||||||
) -> str:
|
|
||||||
"""Add text to table, return current position in list."""
|
|
||||||
# # don't worry about child indices for now, nodes are all in order
|
|
||||||
# self.nodes_dict[int_id] = node
|
|
||||||
vector_id = text_id if text_id is not None else node.node_id
|
|
||||||
self.nodes_dict[vector_id] = node.node_id
|
|
||||||
|
|
||||||
return vector_id
|
|
||||||
|
|
||||||
def delete(self, doc_id: str) -> None:
|
|
||||||
"""Delete a Node."""
|
|
||||||
del self.nodes_dict[doc_id]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_type(cls) -> IndexStructType:
|
|
||||||
"""Get type."""
|
|
||||||
return IndexStructType.VECTOR_STORE
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class KG(IndexStruct):
|
|
||||||
"""A table of keywords mapping keywords to text chunks."""
|
|
||||||
|
|
||||||
# Unidirectional
|
|
||||||
|
|
||||||
# table of keywords to node ids
|
|
||||||
table: Dict[str, Set[str]] = field(default_factory=dict)
|
|
||||||
|
|
||||||
# TODO: legacy attribute, remove in future releases
|
|
||||||
rel_map: Dict[str, List[List[str]]] = field(default_factory=dict)
|
|
||||||
|
|
||||||
# TBD, should support vector store, now we just persist the embedding memory
|
|
||||||
# maybe chainable abstractions for *_stores could be designed
|
|
||||||
embedding_dict: Dict[str, List[float]] = field(default_factory=dict)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def node_ids(self) -> Set[str]:
|
|
||||||
"""Get all node ids."""
|
|
||||||
return set.union(*self.table.values())
|
|
||||||
|
|
||||||
def add_to_embedding_dict(self, triplet_str: str, embedding: List[float]) -> None:
|
|
||||||
"""Add embedding to dict."""
|
|
||||||
self.embedding_dict[triplet_str] = embedding
|
|
||||||
|
|
||||||
def add_node(self, keywords: List[str], node: BaseNode) -> None:
|
|
||||||
"""Add text to table."""
|
|
||||||
node_id = node.node_id
|
|
||||||
for keyword in keywords:
|
|
||||||
keyword = keyword.lower()
|
|
||||||
if keyword not in self.table:
|
|
||||||
self.table[keyword] = set()
|
|
||||||
self.table[keyword].add(node_id)
|
|
||||||
|
|
||||||
def search_node_by_keyword(self, keyword: str) -> List[str]:
|
|
||||||
"""Search for nodes by keyword."""
|
|
||||||
if keyword not in self.table:
|
|
||||||
return []
|
|
||||||
return list(self.table[keyword])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_type(cls) -> IndexStructType:
|
|
||||||
"""Get type."""
|
|
||||||
return IndexStructType.KG
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class EmptyIndexStruct(IndexStruct):
|
|
||||||
"""Empty index."""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_type(cls) -> IndexStructType:
|
|
||||||
"""Get type."""
|
|
||||||
return IndexStructType.EMPTY
|
|
@ -1,48 +0,0 @@
|
|||||||
"""IndexStructType class."""
|
|
||||||
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
|
|
||||||
class IndexStructType(str, Enum):
|
|
||||||
"""Index struct type. Identifier for a "type" of index.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
TREE ("tree"): Tree index. See :ref:`Ref-Indices-Tree` for tree indices.
|
|
||||||
LIST ("list"): Summary index. See :ref:`Ref-Indices-List` for summary indices.
|
|
||||||
KEYWORD_TABLE ("keyword_table"): Keyword table index. See
|
|
||||||
:ref:`Ref-Indices-Table`
|
|
||||||
for keyword table indices.
|
|
||||||
DICT ("dict"): Faiss Vector Store Index. See
|
|
||||||
:ref:`Ref-Indices-VectorStore`
|
|
||||||
for more information on the faiss vector store index.
|
|
||||||
SIMPLE_DICT ("simple_dict"): Simple Vector Store Index. See
|
|
||||||
:ref:`Ref-Indices-VectorStore`
|
|
||||||
for more information on the simple vector store index.
|
|
||||||
KG ("kg"): Knowledge Graph index.
|
|
||||||
See :ref:`Ref-Indices-Knowledge-Graph` for KG indices.
|
|
||||||
DOCUMENT_SUMMARY ("document_summary"): Document Summary Index.
|
|
||||||
See :ref:`Ref-Indices-Document-Summary` for Summary Indices.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
# TODO: refactor so these are properties on the base class
|
|
||||||
|
|
||||||
NODE = "node"
|
|
||||||
TREE = "tree"
|
|
||||||
LIST = "list"
|
|
||||||
KEYWORD_TABLE = "keyword_table"
|
|
||||||
|
|
||||||
DICT = "dict"
|
|
||||||
# simple
|
|
||||||
SIMPLE_DICT = "simple_dict"
|
|
||||||
# for KG index
|
|
||||||
KG = "kg"
|
|
||||||
SIMPLE_KG = "simple_kg"
|
|
||||||
NEBULAGRAPH = "nebulagraph"
|
|
||||||
FALKORDB = "falkordb"
|
|
||||||
|
|
||||||
# EMPTY
|
|
||||||
EMPTY = "empty"
|
|
||||||
COMPOSITE = "composite"
|
|
||||||
|
|
||||||
DOCUMENT_SUMMARY = "document_summary"
|
|
@ -1,75 +0,0 @@
|
|||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from llama_index.data_structs.data_structs import IndexStruct
|
|
||||||
from llama_index.storage.index_store.utils import (
|
|
||||||
index_struct_to_json,
|
|
||||||
json_to_index_struct,
|
|
||||||
)
|
|
||||||
from llama_index.storage.kvstore.types import BaseKVStore
|
|
||||||
|
|
||||||
DEFAULT_NAMESPACE = "index_store"
|
|
||||||
|
|
||||||
|
|
||||||
class KVIndexStore:
|
|
||||||
"""Key-Value Index store.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
kvstore (BaseKVStore): key-value store
|
|
||||||
namespace (str): namespace for the index store
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, kvstore: BaseKVStore, namespace: Optional[str] = None) -> None:
|
|
||||||
"""Init a KVIndexStore."""
|
|
||||||
self._kvstore = kvstore
|
|
||||||
self._namespace = namespace or DEFAULT_NAMESPACE
|
|
||||||
self._collection = f"{self._namespace}/data"
|
|
||||||
|
|
||||||
def add_index_struct(self, index_struct: IndexStruct) -> None:
|
|
||||||
"""Add an index struct.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
index_struct (IndexStruct): index struct
|
|
||||||
|
|
||||||
"""
|
|
||||||
key = index_struct.index_id
|
|
||||||
data = index_struct_to_json(index_struct)
|
|
||||||
self._kvstore.put(key, data, collection=self._collection)
|
|
||||||
|
|
||||||
def delete_index_struct(self, key: str) -> None:
|
|
||||||
"""Delete an index struct.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
key (str): index struct key
|
|
||||||
|
|
||||||
"""
|
|
||||||
self._kvstore.delete(key, collection=self._collection)
|
|
||||||
|
|
||||||
def get_index_struct(
|
|
||||||
self, struct_id: Optional[str] = None
|
|
||||||
) -> Optional[IndexStruct]:
|
|
||||||
"""Get an index struct.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
struct_id (Optional[str]): index struct id
|
|
||||||
|
|
||||||
"""
|
|
||||||
if struct_id is None:
|
|
||||||
structs = self.index_structs()
|
|
||||||
assert len(structs) == 1
|
|
||||||
return structs[0]
|
|
||||||
else:
|
|
||||||
json = self._kvstore.get(struct_id, collection=self._collection)
|
|
||||||
if json is None:
|
|
||||||
return None
|
|
||||||
return json_to_index_struct(json)
|
|
||||||
|
|
||||||
def index_structs(self) -> List[IndexStruct]:
|
|
||||||
"""Get all index structs.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[IndexStruct]: index structs
|
|
||||||
|
|
||||||
"""
|
|
||||||
jsons = self._kvstore.get_all(collection=self._collection)
|
|
||||||
return [json_to_index_struct(json) for json in jsons.values()]
|
|
@ -1,570 +0,0 @@
|
|||||||
"""Base schema for data structures."""
|
|
||||||
import json
|
|
||||||
import textwrap
|
|
||||||
import uuid
|
|
||||||
from abc import abstractmethod
|
|
||||||
from enum import Enum, auto
|
|
||||||
from hashlib import sha256
|
|
||||||
from typing import Any, Dict, List, Optional, Union
|
|
||||||
|
|
||||||
from langchain.schema import Document
|
|
||||||
from typing_extensions import Self
|
|
||||||
|
|
||||||
from dbgpt._private.pydantic import BaseModel, Field, root_validator
|
|
||||||
|
|
||||||
DEFAULT_TEXT_NODE_TMPL = "{metadata_str}\n\n{content}"
|
|
||||||
DEFAULT_METADATA_TMPL = "{key}: {value}"
|
|
||||||
# NOTE: for pretty printing
|
|
||||||
TRUNCATE_LENGTH = 350
|
|
||||||
WRAP_WIDTH = 70
|
|
||||||
|
|
||||||
|
|
||||||
class BaseComponent(BaseModel):
|
|
||||||
"""Base component object to caputure class names."""
|
|
||||||
|
|
||||||
"""reference llama-index"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@abstractmethod
|
|
||||||
def class_name(cls) -> str:
|
|
||||||
"""Get class name."""
|
|
||||||
|
|
||||||
def to_dict(self, **kwargs: Any) -> Dict[str, Any]:
|
|
||||||
data = self.dict(**kwargs)
|
|
||||||
data["class_name"] = self.class_name()
|
|
||||||
return data
|
|
||||||
|
|
||||||
def to_json(self, **kwargs: Any) -> str:
|
|
||||||
data = self.to_dict(**kwargs)
|
|
||||||
return json.dumps(data)
|
|
||||||
|
|
||||||
# TODO: return type here not supported by current mypy version
|
|
||||||
@classmethod
|
|
||||||
def from_dict(cls, data: Dict[str, Any], **kwargs: Any) -> Self: # type: ignore
|
|
||||||
if isinstance(kwargs, dict):
|
|
||||||
data.update(kwargs)
|
|
||||||
|
|
||||||
data.pop("class_name", None)
|
|
||||||
return cls(**data)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_json(cls, data_str: str, **kwargs: Any) -> Self: # type: ignore
|
|
||||||
data = json.loads(data_str)
|
|
||||||
return cls.from_dict(data, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class NodeRelationship(str, Enum):
|
|
||||||
"""Node relationships used in `BaseNode` class.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
SOURCE: The node is the source document.
|
|
||||||
PREVIOUS: The node is the previous node in the document.
|
|
||||||
NEXT: The node is the next node in the document.
|
|
||||||
PARENT: The node is the parent node in the document.
|
|
||||||
CHILD: The node is a child node in the document.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
SOURCE = auto()
|
|
||||||
PREVIOUS = auto()
|
|
||||||
NEXT = auto()
|
|
||||||
PARENT = auto()
|
|
||||||
CHILD = auto()
|
|
||||||
|
|
||||||
|
|
||||||
class ObjectType(str, Enum):
|
|
||||||
TEXT = auto()
|
|
||||||
IMAGE = auto()
|
|
||||||
INDEX = auto()
|
|
||||||
DOCUMENT = auto()
|
|
||||||
|
|
||||||
|
|
||||||
class MetadataMode(str, Enum):
|
|
||||||
ALL = auto()
|
|
||||||
EMBED = auto()
|
|
||||||
LLM = auto()
|
|
||||||
NONE = auto()
|
|
||||||
|
|
||||||
|
|
||||||
class RelatedNodeInfo(BaseComponent):
|
|
||||||
node_id: str
|
|
||||||
node_type: Optional[ObjectType] = None
|
|
||||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
|
||||||
hash: Optional[str] = None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def class_name(cls) -> str:
|
|
||||||
"""Get class name."""
|
|
||||||
return "RelatedNodeInfo"
|
|
||||||
|
|
||||||
|
|
||||||
RelatedNodeType = Union[RelatedNodeInfo, List[RelatedNodeInfo]]
|
|
||||||
|
|
||||||
|
|
||||||
# Node classes for indexes
|
|
||||||
class BaseNode(BaseComponent):
|
|
||||||
"""Base node Object.
|
|
||||||
|
|
||||||
Generic abstract interface for retrievable nodes
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
allow_population_by_field_name = True
|
|
||||||
|
|
||||||
id_: str = Field(
|
|
||||||
default_factory=lambda: str(uuid.uuid4()), description="Unique ID of the node."
|
|
||||||
)
|
|
||||||
embedding: Optional[List[float]] = Field(
|
|
||||||
default=None, description="Embedding of the node."
|
|
||||||
)
|
|
||||||
|
|
||||||
""""
|
|
||||||
metadata fields
|
|
||||||
- injected as part of the text shown to LLMs as context
|
|
||||||
- injected as part of the text for generating embeddings
|
|
||||||
- used by vector DBs for metadata filtering
|
|
||||||
|
|
||||||
"""
|
|
||||||
metadata: Dict[str, Any] = Field(
|
|
||||||
default_factory=dict,
|
|
||||||
description="A flat dictionary of metadata fields",
|
|
||||||
alias="extra_info",
|
|
||||||
)
|
|
||||||
excluded_embed_metadata_keys: List[str] = Field(
|
|
||||||
default_factory=list,
|
|
||||||
description="Metadata keys that are exluded from text for the embed model.",
|
|
||||||
)
|
|
||||||
excluded_llm_metadata_keys: List[str] = Field(
|
|
||||||
default_factory=list,
|
|
||||||
description="Metadata keys that are exluded from text for the LLM.",
|
|
||||||
)
|
|
||||||
relationships: Dict[NodeRelationship, RelatedNodeType] = Field(
|
|
||||||
default_factory=dict,
|
|
||||||
description="A mapping of relationships to other node information.",
|
|
||||||
)
|
|
||||||
hash: str = Field(default="", description="Hash of the node content.")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@abstractmethod
|
|
||||||
def get_type(cls) -> str:
|
|
||||||
"""Get Object type."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_content(self, metadata_mode: MetadataMode = MetadataMode.ALL) -> str:
|
|
||||||
"""Get object content."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_metadata_str(self, mode: MetadataMode = MetadataMode.ALL) -> str:
|
|
||||||
"""Metadata string."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def set_content(self, value: Any) -> None:
|
|
||||||
"""Set the content of the node."""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def node_id(self) -> str:
|
|
||||||
return self.id_
|
|
||||||
|
|
||||||
@node_id.setter
|
|
||||||
def node_id(self, value: str) -> None:
|
|
||||||
self.id_ = value
|
|
||||||
|
|
||||||
@property
|
|
||||||
def source_node(self) -> Optional[RelatedNodeInfo]:
|
|
||||||
"""Source object node.
|
|
||||||
|
|
||||||
Extracted from the relationships field.
|
|
||||||
|
|
||||||
"""
|
|
||||||
if NodeRelationship.SOURCE not in self.relationships:
|
|
||||||
return None
|
|
||||||
|
|
||||||
relation = self.relationships[NodeRelationship.SOURCE]
|
|
||||||
if isinstance(relation, list):
|
|
||||||
raise ValueError("Source object must be a single RelatedNodeInfo object")
|
|
||||||
return relation
|
|
||||||
|
|
||||||
@property
|
|
||||||
def prev_node(self) -> Optional[RelatedNodeInfo]:
|
|
||||||
"""Prev node."""
|
|
||||||
if NodeRelationship.PREVIOUS not in self.relationships:
|
|
||||||
return None
|
|
||||||
|
|
||||||
relation = self.relationships[NodeRelationship.PREVIOUS]
|
|
||||||
if not isinstance(relation, RelatedNodeInfo):
|
|
||||||
raise ValueError("Previous object must be a single RelatedNodeInfo object")
|
|
||||||
return relation
|
|
||||||
|
|
||||||
@property
|
|
||||||
def next_node(self) -> Optional[RelatedNodeInfo]:
|
|
||||||
"""Next node."""
|
|
||||||
if NodeRelationship.NEXT not in self.relationships:
|
|
||||||
return None
|
|
||||||
|
|
||||||
relation = self.relationships[NodeRelationship.NEXT]
|
|
||||||
if not isinstance(relation, RelatedNodeInfo):
|
|
||||||
raise ValueError("Next object must be a single RelatedNodeInfo object")
|
|
||||||
return relation
|
|
||||||
|
|
||||||
@property
|
|
||||||
def parent_node(self) -> Optional[RelatedNodeInfo]:
|
|
||||||
"""Parent node."""
|
|
||||||
if NodeRelationship.PARENT not in self.relationships:
|
|
||||||
return None
|
|
||||||
|
|
||||||
relation = self.relationships[NodeRelationship.PARENT]
|
|
||||||
if not isinstance(relation, RelatedNodeInfo):
|
|
||||||
raise ValueError("Parent object must be a single RelatedNodeInfo object")
|
|
||||||
return relation
|
|
||||||
|
|
||||||
@property
|
|
||||||
def child_nodes(self) -> Optional[List[RelatedNodeInfo]]:
|
|
||||||
"""Child nodes."""
|
|
||||||
if NodeRelationship.CHILD not in self.relationships:
|
|
||||||
return None
|
|
||||||
|
|
||||||
relation = self.relationships[NodeRelationship.CHILD]
|
|
||||||
if not isinstance(relation, list):
|
|
||||||
raise ValueError("Child objects must be a list of RelatedNodeInfo objects.")
|
|
||||||
return relation
|
|
||||||
|
|
||||||
@property
|
|
||||||
def ref_doc_id(self) -> Optional[str]:
|
|
||||||
"""Deprecated: Get ref doc id."""
|
|
||||||
source_node = self.source_node
|
|
||||||
if source_node is None:
|
|
||||||
return None
|
|
||||||
return source_node.node_id
|
|
||||||
|
|
||||||
@property
|
|
||||||
def extra_info(self) -> Dict[str, Any]:
|
|
||||||
"""TODO: DEPRECATED: Extra info."""
|
|
||||||
return self.metadata
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
source_text_truncated = truncate_text(
|
|
||||||
self.get_content().strip(), TRUNCATE_LENGTH
|
|
||||||
)
|
|
||||||
source_text_wrapped = textwrap.fill(
|
|
||||||
f"Text: {source_text_truncated}\n", width=WRAP_WIDTH
|
|
||||||
)
|
|
||||||
return f"Node ID: {self.node_id}\n{source_text_wrapped}"
|
|
||||||
|
|
||||||
def truncate_text(text: str, max_length: int) -> str:
|
|
||||||
"""Truncate text to a maximum length."""
|
|
||||||
if len(text) <= max_length:
|
|
||||||
return text
|
|
||||||
return text[: max_length - 3] + "..."
|
|
||||||
|
|
||||||
def get_embedding(self) -> List[float]:
|
|
||||||
"""Get embedding.
|
|
||||||
|
|
||||||
Errors if embedding is None.
|
|
||||||
|
|
||||||
"""
|
|
||||||
if self.embedding is None:
|
|
||||||
raise ValueError("embedding not set.")
|
|
||||||
return self.embedding
|
|
||||||
|
|
||||||
def as_related_node_info(self) -> RelatedNodeInfo:
|
|
||||||
"""Get node as RelatedNodeInfo."""
|
|
||||||
return RelatedNodeInfo(
|
|
||||||
node_id=self.node_id, metadata=self.metadata, hash=self.hash
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TextNode(BaseNode):
|
|
||||||
text: str = Field(default="", description="Text content of the node.")
|
|
||||||
start_char_idx: Optional[int] = Field(
|
|
||||||
default=None, description="Start char index of the node."
|
|
||||||
)
|
|
||||||
end_char_idx: Optional[int] = Field(
|
|
||||||
default=None, description="End char index of the node."
|
|
||||||
)
|
|
||||||
text_template: str = Field(
|
|
||||||
default=DEFAULT_TEXT_NODE_TMPL,
|
|
||||||
description=(
|
|
||||||
"Template for how text is formatted, with {content} and "
|
|
||||||
"{metadata_str} placeholders."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
metadata_template: str = Field(
|
|
||||||
default=DEFAULT_METADATA_TMPL,
|
|
||||||
description=(
|
|
||||||
"Template for how metadata is formatted, with {key} and "
|
|
||||||
"{value} placeholders."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
metadata_seperator: str = Field(
|
|
||||||
default="\n",
|
|
||||||
description="Seperator between metadata fields when converting to string.",
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def class_name(cls) -> str:
|
|
||||||
"""Get class name."""
|
|
||||||
return "TextNode"
|
|
||||||
|
|
||||||
@root_validator
|
|
||||||
def _check_hash(cls, values: dict) -> dict:
|
|
||||||
"""Generate a hash to represent the node."""
|
|
||||||
text = values.get("text", "")
|
|
||||||
metadata = values.get("metadata", {})
|
|
||||||
doc_identity = str(text) + str(metadata)
|
|
||||||
values["hash"] = str(
|
|
||||||
sha256(doc_identity.encode("utf-8", "surrogatepass")).hexdigest()
|
|
||||||
)
|
|
||||||
return values
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_type(cls) -> str:
|
|
||||||
"""Get Object type."""
|
|
||||||
return ObjectType.TEXT
|
|
||||||
|
|
||||||
def get_content(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str:
|
|
||||||
"""Get object content."""
|
|
||||||
metadata_str = self.get_metadata_str(mode=metadata_mode).strip()
|
|
||||||
if not metadata_str:
|
|
||||||
return self.text
|
|
||||||
|
|
||||||
return self.text_template.format(
|
|
||||||
content=self.text, metadata_str=metadata_str
|
|
||||||
).strip()
|
|
||||||
|
|
||||||
def get_metadata_str(self, mode: MetadataMode = MetadataMode.ALL) -> str:
|
|
||||||
"""metadata info string."""
|
|
||||||
if mode == MetadataMode.NONE:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
usable_metadata_keys = set(self.metadata.keys())
|
|
||||||
if mode == MetadataMode.LLM:
|
|
||||||
for key in self.excluded_llm_metadata_keys:
|
|
||||||
if key in usable_metadata_keys:
|
|
||||||
usable_metadata_keys.remove(key)
|
|
||||||
elif mode == MetadataMode.EMBED:
|
|
||||||
for key in self.excluded_embed_metadata_keys:
|
|
||||||
if key in usable_metadata_keys:
|
|
||||||
usable_metadata_keys.remove(key)
|
|
||||||
|
|
||||||
return self.metadata_seperator.join(
|
|
||||||
[
|
|
||||||
self.metadata_template.format(key=key, value=str(value))
|
|
||||||
for key, value in self.metadata.items()
|
|
||||||
if key in usable_metadata_keys
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
def set_content(self, value: str) -> None:
|
|
||||||
"""Set the content of the node."""
|
|
||||||
self.text = value
|
|
||||||
|
|
||||||
def get_node_info(self) -> Dict[str, Any]:
|
|
||||||
"""Get node info."""
|
|
||||||
return {"start": self.start_char_idx, "end": self.end_char_idx}
|
|
||||||
|
|
||||||
def get_text(self) -> str:
|
|
||||||
return self.get_content(metadata_mode=MetadataMode.NONE)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def node_info(self) -> Dict[str, Any]:
|
|
||||||
"""Deprecated: Get node info."""
|
|
||||||
return self.get_node_info()
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: legacy backport of old Node class
|
|
||||||
Node = TextNode
|
|
||||||
|
|
||||||
|
|
||||||
class ImageNode(TextNode):
|
|
||||||
"""Node with image."""
|
|
||||||
|
|
||||||
# TODO: store reference instead of actual image
|
|
||||||
# base64 encoded image str
|
|
||||||
image: Optional[str] = None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_type(cls) -> str:
|
|
||||||
return ObjectType.IMAGE
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def class_name(cls) -> str:
|
|
||||||
"""Get class name."""
|
|
||||||
return "ImageNode"
|
|
||||||
|
|
||||||
|
|
||||||
class IndexNode(TextNode):
|
|
||||||
"""Node with reference to any object.
|
|
||||||
|
|
||||||
This can include other indices, query engines, retrievers.
|
|
||||||
|
|
||||||
This can also include other nodes (though this is overlapping with `relationships`
|
|
||||||
on the Node class).
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
index_id: str
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_text_node(
|
|
||||||
cls,
|
|
||||||
node: TextNode,
|
|
||||||
index_id: str,
|
|
||||||
) -> "IndexNode":
|
|
||||||
"""Create index node from text node."""
|
|
||||||
# copy all attributes from text node, add index id
|
|
||||||
return cls(
|
|
||||||
**node.dict(),
|
|
||||||
index_id=index_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_type(cls) -> str:
|
|
||||||
return ObjectType.INDEX
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def class_name(cls) -> str:
|
|
||||||
"""Get class name."""
|
|
||||||
return "IndexNode"
|
|
||||||
|
|
||||||
|
|
||||||
class NodeWithScore(BaseComponent):
|
|
||||||
node: BaseNode
|
|
||||||
score: Optional[float] = None
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return f"{self.node}\nScore: {self.score: 0.3f}\n"
|
|
||||||
|
|
||||||
def get_score(self, raise_error: bool = False) -> float:
|
|
||||||
"""Get score."""
|
|
||||||
if self.score is None:
|
|
||||||
if raise_error:
|
|
||||||
raise ValueError("Score not set.")
|
|
||||||
else:
|
|
||||||
return 0.0
|
|
||||||
else:
|
|
||||||
return self.score
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def class_name(cls) -> str:
|
|
||||||
"""Get class name."""
|
|
||||||
return "NodeWithScore"
|
|
||||||
|
|
||||||
##### pass through methods to BaseNode #####
|
|
||||||
@property
|
|
||||||
def node_id(self) -> str:
|
|
||||||
return self.node.node_id
|
|
||||||
|
|
||||||
@property
|
|
||||||
def id_(self) -> str:
|
|
||||||
return self.node.id_
|
|
||||||
|
|
||||||
@property
|
|
||||||
def text(self) -> str:
|
|
||||||
if isinstance(self.node, TextNode):
|
|
||||||
return self.node.text
|
|
||||||
else:
|
|
||||||
raise ValueError("Node must be a TextNode to get text.")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def metadata(self) -> Dict[str, Any]:
|
|
||||||
return self.node.metadata
|
|
||||||
|
|
||||||
@property
|
|
||||||
def embedding(self) -> Optional[List[float]]:
|
|
||||||
return self.node.embedding
|
|
||||||
|
|
||||||
def get_text(self) -> str:
|
|
||||||
if isinstance(self.node, TextNode):
|
|
||||||
return self.node.get_text()
|
|
||||||
else:
|
|
||||||
raise ValueError("Node must be a TextNode to get text.")
|
|
||||||
|
|
||||||
def get_content(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str:
|
|
||||||
return self.node.get_content(metadata_mode=metadata_mode)
|
|
||||||
|
|
||||||
def get_embedding(self) -> List[float]:
|
|
||||||
return self.node.get_embedding()
|
|
||||||
|
|
||||||
|
|
||||||
# Document Classes for Readers
|
|
||||||
|
|
||||||
|
|
||||||
class Document(TextNode):
|
|
||||||
"""Generic interface for a data document.
|
|
||||||
|
|
||||||
This document connects to data sources.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
# TODO: A lot of backwards compatibility logic here, clean up
|
|
||||||
id_: str = Field(
|
|
||||||
default_factory=lambda: str(uuid.uuid4()),
|
|
||||||
description="Unique ID of the node.",
|
|
||||||
alias="doc_id",
|
|
||||||
)
|
|
||||||
|
|
||||||
_compat_fields = {"doc_id": "id_", "extra_info": "metadata"}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_type(cls) -> str:
|
|
||||||
"""Get Document type."""
|
|
||||||
return ObjectType.DOCUMENT
|
|
||||||
|
|
||||||
@property
|
|
||||||
def doc_id(self) -> str:
|
|
||||||
"""Get document ID."""
|
|
||||||
return self.id_
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
source_text_truncated = truncate_text(
|
|
||||||
self.get_content().strip(), TRUNCATE_LENGTH
|
|
||||||
)
|
|
||||||
source_text_wrapped = textwrap.fill(
|
|
||||||
f"Text: {source_text_truncated}\n", width=WRAP_WIDTH
|
|
||||||
)
|
|
||||||
return f"Doc ID: {self.doc_id}\n{source_text_wrapped}"
|
|
||||||
|
|
||||||
def get_doc_id(self) -> str:
|
|
||||||
"""TODO: Deprecated: Get document ID."""
|
|
||||||
return self.id_
|
|
||||||
|
|
||||||
def __setattr__(self, name: str, value: object) -> None:
|
|
||||||
if name in self._compat_fields:
|
|
||||||
name = self._compat_fields[name]
|
|
||||||
super().__setattr__(name, value)
|
|
||||||
|
|
||||||
def to_langchain_format(self) -> Document:
|
|
||||||
"""Convert struct to LangChain document format."""
|
|
||||||
metadata = self.metadata or {}
|
|
||||||
return Document(page_content=self.text, metadata=metadata)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_langchain_format(cls, doc: Document) -> "Document":
|
|
||||||
"""Convert struct from LangChain document format."""
|
|
||||||
return cls(text=doc.page_content, metadata=doc.metadata)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def example(cls) -> "Document":
|
|
||||||
document = Document(
|
|
||||||
text="",
|
|
||||||
metadata={"filename": "README.md", "category": "codebase"},
|
|
||||||
)
|
|
||||||
return document
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def class_name(cls) -> str:
|
|
||||||
"""Get class name."""
|
|
||||||
return "Document"
|
|
||||||
|
|
||||||
|
|
||||||
class ImageDocument(Document):
|
|
||||||
"""Data document containing an image."""
|
|
||||||
|
|
||||||
# base64 encoded image str
|
|
||||||
image: Optional[str] = None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def class_name(cls) -> str:
|
|
||||||
"""Get class name."""
|
|
||||||
return "ImageDocument"
|
|
@ -1,44 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
|
|
||||||
class SearchMode(str, Enum):
|
|
||||||
"""Query mode enum for Knowledge Graphs.
|
|
||||||
|
|
||||||
Can be passed as the enum struct, or as the underlying string.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
KEYWORD ("keyword"): Default query mode, using keywords to find triplets.
|
|
||||||
EMBEDDING ("embedding"): Embedding mode, using embeddings to find
|
|
||||||
similar triplets.
|
|
||||||
HYBRID ("hybrid"): Hyrbid mode, combining both keywords and embeddings
|
|
||||||
to find relevant triplets.
|
|
||||||
"""
|
|
||||||
|
|
||||||
KEYWORD = "keyword"
|
|
||||||
EMBEDDING = "embedding"
|
|
||||||
HYBRID = "hybrid"
|
|
||||||
|
|
||||||
|
|
||||||
class BaseSearch(ABC):
|
|
||||||
"""Base Search."""
|
|
||||||
|
|
||||||
async def search(self, query: str):
|
|
||||||
"""Retrieve nodes given query.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query (QueryType): Either a query string or
|
|
||||||
a QueryBundle object.
|
|
||||||
|
|
||||||
"""
|
|
||||||
# if isinstance(query, str):
|
|
||||||
return await self._search(query)
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def _search(self, query: str):
|
|
||||||
"""search nodes given query.
|
|
||||||
|
|
||||||
Implemented by the user.
|
|
||||||
|
|
||||||
"""
|
|
||||||
pass
|
|
@ -1,11 +1,12 @@
|
|||||||
import asyncio
|
|
||||||
import os
|
import os
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH
|
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH
|
||||||
from dbgpt.core.awel import DAG, InputOperator, MapOperator, SimpleCallDataInputSource
|
from dbgpt.core.awel import DAG, HttpTrigger, MapOperator
|
||||||
from dbgpt.rag.chunk import Chunk
|
|
||||||
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
|
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
|
||||||
|
from dbgpt.rag.knowledge.base import KnowledgeType
|
||||||
from dbgpt.rag.operator.knowledge import KnowledgeOperator
|
from dbgpt.rag.operator.knowledge import KnowledgeOperator
|
||||||
from dbgpt.serve.rag.operators.embedding import EmbeddingAssemblerOperator
|
from dbgpt.serve.rag.operators.embedding import EmbeddingAssemblerOperator
|
||||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||||
@ -13,28 +14,19 @@ from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
|||||||
|
|
||||||
"""AWEL: Simple rag embedding operator example
|
"""AWEL: Simple rag embedding operator example
|
||||||
|
|
||||||
pre-requirements:
|
|
||||||
set your file path in your example code.
|
|
||||||
Examples:
|
Examples:
|
||||||
|
pre-requirements:
|
||||||
|
python examples/awel/simple_rag_embedding_example.py
|
||||||
..code-block:: shell
|
..code-block:: shell
|
||||||
python examples/awel/simple_rag_embedding_example.py
|
curl --location --request POST 'http://127.0.0.1:5555/api/v1/awel/trigger/examples/rag/embedding' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data-raw '{
|
||||||
|
"url": "https://docs.dbgpt.site/docs/awel"
|
||||||
|
}'
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def _context_join_fn(context_dict: Dict, chunks: List[Chunk]) -> Dict:
|
def _create_vector_connector() -> VectorStoreConnector:
|
||||||
"""context Join function for JoinOperator.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
context_dict (Dict): context dict
|
|
||||||
chunks (List[Chunk]): chunks
|
|
||||||
Returns:
|
|
||||||
Dict: context dict
|
|
||||||
"""
|
|
||||||
context_dict["context"] = "\n".join([chunk.content for chunk in chunks])
|
|
||||||
return context_dict
|
|
||||||
|
|
||||||
|
|
||||||
def _create_vector_connector():
|
|
||||||
"""Create vector connector."""
|
"""Create vector connector."""
|
||||||
return VectorStoreConnector.from_default(
|
return VectorStoreConnector.from_default(
|
||||||
"Chroma",
|
"Chroma",
|
||||||
@ -48,6 +40,22 @@ def _create_vector_connector():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerReqBody(BaseModel):
|
||||||
|
url: str = Field(..., description="url")
|
||||||
|
|
||||||
|
|
||||||
|
class RequestHandleOperator(MapOperator[TriggerReqBody, Dict]):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
async def map(self, input_value: TriggerReqBody) -> Dict:
|
||||||
|
params = {
|
||||||
|
"url": input_value.url,
|
||||||
|
}
|
||||||
|
print(f"Receive input value: {input_value}")
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
class ResultOperator(MapOperator):
|
class ResultOperator(MapOperator):
|
||||||
"""The Result Operator."""
|
"""The Result Operator."""
|
||||||
|
|
||||||
@ -61,26 +69,31 @@ class ResultOperator(MapOperator):
|
|||||||
|
|
||||||
|
|
||||||
with DAG("simple_sdk_rag_embedding_example") as dag:
|
with DAG("simple_sdk_rag_embedding_example") as dag:
|
||||||
knowledge_operator = KnowledgeOperator()
|
trigger = HttpTrigger(
|
||||||
|
"/examples/rag/embedding", methods="POST", request_body=TriggerReqBody
|
||||||
|
)
|
||||||
|
request_handle_task = RequestHandleOperator()
|
||||||
|
knowledge_operator = KnowledgeOperator(knowledge_type=KnowledgeType.URL)
|
||||||
vector_connector = _create_vector_connector()
|
vector_connector = _create_vector_connector()
|
||||||
input_task = InputOperator(input_source=SimpleCallDataInputSource())
|
url_parser_operator = MapOperator(map_function=lambda x: x["url"])
|
||||||
file_path_parser = MapOperator(map_function=lambda x: x["file_path"])
|
|
||||||
embedding_operator = EmbeddingAssemblerOperator(
|
embedding_operator = EmbeddingAssemblerOperator(
|
||||||
vector_store_connector=vector_connector,
|
vector_store_connector=vector_connector,
|
||||||
)
|
)
|
||||||
output_task = ResultOperator()
|
output_task = ResultOperator()
|
||||||
(
|
(
|
||||||
input_task
|
trigger
|
||||||
>> file_path_parser
|
>> request_handle_task
|
||||||
|
>> url_parser_operator
|
||||||
>> knowledge_operator
|
>> knowledge_operator
|
||||||
>> embedding_operator
|
>> embedding_operator
|
||||||
>> output_task
|
>> output_task
|
||||||
)
|
)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
input_data = {
|
if dag.leaf_nodes[0].dev_mode:
|
||||||
"data": {
|
# Development mode, you can run the dag locally for debugging.
|
||||||
"file_path": "docs/docs/awel.md",
|
from dbgpt.core.awel import setup_dev_environment
|
||||||
}
|
|
||||||
}
|
setup_dev_environment([dag], port=5555)
|
||||||
output = asyncio.run(output_task.call(call_data=input_data))
|
else:
|
||||||
|
pass
|
||||||
|
Loading…
Reference in New Issue
Block a user