mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-23 20:26:15 +00:00
fix:update awel embedding examples and delete unuseful code. (#1073)
This commit is contained in:
parent
cb9c34abb9
commit
3a54d1ef9a
@ -28,11 +28,6 @@ def initialize_components(
|
||||
system_app.register(DefaultExecutorFactory)
|
||||
system_app.register_instance(controller)
|
||||
|
||||
# Register global default RAGGraphFactory
|
||||
# from dbgpt.graph.graph_factory import DefaultRAGGraphFactory
|
||||
|
||||
# system_app.register(DefaultRAGGraphFactory)
|
||||
|
||||
from dbgpt.serve.agent.hub.controller import module_agent
|
||||
|
||||
system_app.register_instance(module_agent)
|
||||
|
@ -497,30 +497,6 @@ class KnowledgeService:
|
||||
res.page = request.page
|
||||
return res
|
||||
|
||||
def async_knowledge_graph(self, chunk_docs, doc):
|
||||
"""async document extract triplets and save into graph db
|
||||
Args:
|
||||
- chunk_docs: List[Document]
|
||||
- doc: KnowledgeDocumentEntity
|
||||
"""
|
||||
logger.info(
|
||||
f"async_knowledge_graph, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to graph store"
|
||||
)
|
||||
try:
|
||||
from dbgpt.rag.graph.graph_factory import RAGGraphFactory
|
||||
|
||||
rag_engine = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory
|
||||
).create()
|
||||
rag_engine.knowledge_graph(chunk_docs)
|
||||
doc.status = SyncStatus.FINISHED.name
|
||||
doc.result = "document build graph success"
|
||||
except Exception as e:
|
||||
doc.status = SyncStatus.FAILED.name
|
||||
doc.result = "document build graph failed" + str(e)
|
||||
logger.error(f"document build graph failed:{doc.doc_name}, {str(e)}")
|
||||
return knowledge_document_dao.update_knowledge_document(doc)
|
||||
|
||||
def async_doc_embedding(self, assembler, chunk_docs, doc):
|
||||
"""async document embedding into vector db
|
||||
Args:
|
||||
|
@ -1,148 +0,0 @@
|
||||
import logging
|
||||
from typing import Any, Callable, List, Optional, Tuple
|
||||
|
||||
from langchain.schema import Document
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from dbgpt.rag.embedding import KnowledgeType
|
||||
from dbgpt.rag.embedding.knowledge_type import get_knowledge_embedding
|
||||
from dbgpt.rag.graph.index_struct import KG
|
||||
from dbgpt.rag.graph.node import TextNode
|
||||
from dbgpt.util import utils
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RAGGraphEngine:
|
||||
"""Knowledge RAG Graph Engine.
|
||||
Build a RAG Graph Client can extract triplets and insert into graph store.
|
||||
Args:
|
||||
knowledge_type (Optional[str]): Default: KnowledgeType.DOCUMENT.value
|
||||
extracting triplets.
|
||||
knowledge_source (Optional[str]):
|
||||
model_name (Optional[str]): llm model name
|
||||
graph_store (Optional[GraphStore]): The graph store to use.refrence:llama-index
|
||||
include_embeddings (bool): Whether to include embeddings in the index.
|
||||
Defaults to False.
|
||||
max_object_length (int): The maximum length of the object in a triplet.
|
||||
Defaults to 128.
|
||||
extract_triplet_fn (Optional[Callable]): The function to use for
|
||||
extracting triplets. Defaults to None.
|
||||
"""
|
||||
|
||||
index_struct_cls = KG
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
knowledge_type: Optional[str] = KnowledgeType.DOCUMENT.value,
|
||||
knowledge_source: Optional[str] = None,
|
||||
text_splitter=None,
|
||||
graph_store=None,
|
||||
index_struct: Optional[KG] = None,
|
||||
model_name: Optional[str] = None,
|
||||
max_triplets_per_chunk: int = 10,
|
||||
include_embeddings: bool = False,
|
||||
max_object_length: int = 128,
|
||||
extract_triplet_fn: Optional[Callable] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize params."""
|
||||
from llama_index.graph_stores import SimpleGraphStore
|
||||
|
||||
# need to set parameters before building index in base class.
|
||||
self.knowledge_source = knowledge_source
|
||||
self.knowledge_type = knowledge_type
|
||||
self.model_name = model_name
|
||||
self.text_splitter = text_splitter
|
||||
self.index_struct = index_struct
|
||||
self.include_embeddings = include_embeddings
|
||||
self.graph_store = graph_store or SimpleGraphStore()
|
||||
# self.graph_store = graph_store
|
||||
self.max_triplets_per_chunk = max_triplets_per_chunk
|
||||
self._max_object_length = max_object_length
|
||||
self._extract_triplet_fn = extract_triplet_fn
|
||||
|
||||
def knowledge_graph(self, docs=None):
|
||||
"""knowledge docs into graph store"""
|
||||
if not docs:
|
||||
if self.text_splitter:
|
||||
self.text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=2000, chunk_overlap=100
|
||||
)
|
||||
knowledge_source = get_knowledge_embedding(
|
||||
knowledge_type=self.knowledge_type,
|
||||
knowledge_source=self.knowledge_source,
|
||||
text_splitter=self.text_splitter,
|
||||
)
|
||||
docs = knowledge_source.read()
|
||||
if self.index_struct is None:
|
||||
self.index_struct = self._build_index_from_docs(docs)
|
||||
|
||||
def _extract_triplets(self, text: str) -> List[Tuple[str, str, str]]:
|
||||
"""Extract triplets from text by function or llm"""
|
||||
if self._extract_triplet_fn is not None:
|
||||
return self._extract_triplet_fn(text)
|
||||
else:
|
||||
return self._llm_extract_triplets(text)
|
||||
|
||||
def _llm_extract_triplets(self, text: str) -> List[Tuple[str, str, str]]:
|
||||
"""Extract triplets from text by llm"""
|
||||
import uuid
|
||||
|
||||
from dbgpt.app.scene import ChatScene
|
||||
from dbgpt.util.chat_util import llm_chat_response_nostream
|
||||
|
||||
chat_param = {
|
||||
"chat_session_id": uuid.uuid1(),
|
||||
"current_user_input": text,
|
||||
"select_param": "triplet",
|
||||
"model_name": self.model_name,
|
||||
}
|
||||
loop = utils.get_or_create_event_loop()
|
||||
triplets = loop.run_until_complete(
|
||||
llm_chat_response_nostream(
|
||||
ChatScene.ExtractTriplet.value(), **{"chat_param": chat_param}
|
||||
)
|
||||
)
|
||||
return triplets
|
||||
|
||||
def _build_index_from_docs(self, documents: List[Document]) -> KG:
|
||||
"""Build the index from nodes.
|
||||
Args:documents:List[Document]
|
||||
"""
|
||||
index_struct = self.index_struct_cls()
|
||||
triplets = []
|
||||
for doc in documents:
|
||||
trips = self._extract_triplets_task([doc], index_struct)
|
||||
triplets.extend(trips)
|
||||
print(triplets)
|
||||
text_node = TextNode(text=doc.page_content, metadata=doc.metadata)
|
||||
for triplet in triplets:
|
||||
subj, _, obj = triplet
|
||||
self.graph_store.upsert_triplet(*triplet)
|
||||
index_struct.add_node([subj, obj], text_node)
|
||||
return index_struct
|
||||
|
||||
def search(self, query):
|
||||
from dbgpt.rag.graph.graph_search import RAGGraphSearch
|
||||
|
||||
graph_search = RAGGraphSearch(graph_engine=self)
|
||||
return graph_search.search(query)
|
||||
|
||||
def _extract_triplets_task(self, docs, index_struct):
|
||||
triple_results = []
|
||||
for doc in docs:
|
||||
import threading
|
||||
|
||||
thread_id = threading.get_ident()
|
||||
print(f"current thread-{thread_id} begin extract triplets task")
|
||||
triplets = self._extract_triplets(doc.page_content)
|
||||
if len(triplets) == 0:
|
||||
triplets = []
|
||||
text_node = TextNode(text=doc.page_content, metadata=doc.metadata)
|
||||
logger.info(f"extracted knowledge triplets: {triplets}")
|
||||
print(
|
||||
f"current thread-{thread_id} end extract triplets tasks, triplets-{triplets}"
|
||||
)
|
||||
triple_results.extend(triplets)
|
||||
return triple_results
|
@ -1,35 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Type
|
||||
|
||||
from dbgpt.component import BaseComponent, ComponentType
|
||||
|
||||
|
||||
class RAGGraphFactory(BaseComponent, ABC):
|
||||
name = ComponentType.RAG_GRAPH_DEFAULT.value
|
||||
|
||||
@abstractmethod
|
||||
def create(self, model_name: str = None, embedding_cls: Type = None):
|
||||
"""Create RAG Graph Engine"""
|
||||
|
||||
|
||||
class DefaultRAGGraphFactory(RAGGraphFactory):
|
||||
def __init__(
|
||||
self, system_app=None, default_model_name: str = None, **kwargs: Any
|
||||
) -> None:
|
||||
super().__init__(system_app=system_app)
|
||||
self._default_model_name = default_model_name
|
||||
self.kwargs = kwargs
|
||||
from dbgpt.rag.graph.graph_engine import RAGGraphEngine
|
||||
|
||||
self.rag_engine = RAGGraphEngine(model_name="proxyllm")
|
||||
|
||||
def init_app(self, system_app):
|
||||
pass
|
||||
|
||||
def create(self, model_name: str = None, rag_cls: Type = None):
|
||||
if not model_name:
|
||||
model_name = self._default_model_name
|
||||
|
||||
return self.rag_engine
|
@ -1,198 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, Callable, Dict, List, Optional, Set
|
||||
|
||||
from langchain.schema import Document
|
||||
|
||||
from dbgpt.rag.graph.node import BaseNode, NodeWithScore, TextNode
|
||||
from dbgpt.rag.graph.search import BaseSearch, SearchMode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_NODE_SCORE = 1000.0
|
||||
GLOBAL_EXPLORE_NODE_LIMIT = 3
|
||||
REL_TEXT_LIMIT = 30
|
||||
|
||||
|
||||
class RAGGraphSearch(BaseSearch):
|
||||
"""RAG Graph Search.
|
||||
|
||||
args:
|
||||
graph RAGGraphEngine.
|
||||
model_name (str): model name
|
||||
(see :ref:`Prompt-Templates`).
|
||||
text_qa_template (Optional[BasePromptTemplate]): A Question Answering Prompt
|
||||
(see :ref:`Prompt-Templates`).
|
||||
max_keywords_per_query (int): Maximum number of keywords to extract from query.
|
||||
num_chunks_per_query (int): Maximum number of text chunks to query.
|
||||
search_mode (Optional[SearchMode]): Specifies whether to use keyowrds, default SearchMode.KEYWORD
|
||||
embeddings, or both to find relevant triplets. Should be one of "keyword",
|
||||
"embedding", or "hybrid".
|
||||
graph_store_query_depth (int): The depth of the graph store query.
|
||||
extract_subject_entities_fn (Optional[Callback]): extract_subject_entities callback.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph_engine,
|
||||
model_name: str = None,
|
||||
max_keywords_per_query: int = 10,
|
||||
num_chunks_per_query: int = 10,
|
||||
search_mode: Optional[SearchMode] = SearchMode.KEYWORD,
|
||||
graph_store_query_depth: int = 2,
|
||||
extract_subject_entities_fn: Optional[Callable] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize params."""
|
||||
from dbgpt.rag.graph.graph_engine import RAGGraphEngine
|
||||
|
||||
self.graph_engine: RAGGraphEngine = graph_engine
|
||||
self.model_name = model_name or self.graph_engine.model_name
|
||||
self._index_struct = self.graph_engine.index_struct
|
||||
self.max_keywords_per_query = max_keywords_per_query
|
||||
self.num_chunks_per_query = num_chunks_per_query
|
||||
self._search_mode = search_mode
|
||||
|
||||
self._graph_store = self.graph_engine.graph_store
|
||||
self.graph_store_query_depth = graph_store_query_depth
|
||||
self._verbose = kwargs.get("verbose", False)
|
||||
refresh_schema = kwargs.get("refresh_schema", False)
|
||||
self.extract_subject_entities_fn = extract_subject_entities_fn
|
||||
self.executor = ThreadPoolExecutor(max_workers=os.cpu_count() * 5)
|
||||
try:
|
||||
self._graph_schema = self._graph_store.get_schema(refresh=refresh_schema)
|
||||
except NotImplementedError:
|
||||
self._graph_schema = ""
|
||||
except Exception as e:
|
||||
logger.warn(f"can not to find graph schema: {e}")
|
||||
self._graph_schema = ""
|
||||
|
||||
async def _extract_subject_entities(self, query_str: str) -> Set[str]:
|
||||
"""extract subject entities."""
|
||||
if self.extract_subject_entities_fn is not None:
|
||||
return await self.extract_subject_entities_fn(query_str)
|
||||
else:
|
||||
return await self._extract_entities_by_llm(query_str)
|
||||
|
||||
async def _extract_entities_by_llm(self, text: str) -> Set[str]:
|
||||
"""extract subject entities from text by llm"""
|
||||
import uuid
|
||||
|
||||
from dbgpt.app.scene import ChatScene
|
||||
from dbgpt.util.chat_util import llm_chat_response_nostream
|
||||
|
||||
chat_param = {
|
||||
"chat_session_id": uuid.uuid1(),
|
||||
"current_user_input": text,
|
||||
"select_param": "entity",
|
||||
"model_name": self.model_name,
|
||||
}
|
||||
# loop = util.get_or_create_event_loop()
|
||||
# entities = loop.run_until_complete(
|
||||
# llm_chat_response_nostream(
|
||||
# ChatScene.ExtractEntity.value(), **{"chat_param": chat_param}
|
||||
# )
|
||||
# )
|
||||
return await llm_chat_response_nostream(
|
||||
ChatScene.ExtractEntity.value(), **{"chat_param": chat_param}
|
||||
)
|
||||
|
||||
async def _search(
|
||||
self,
|
||||
query_str: str,
|
||||
) -> List[Document]:
|
||||
"""Get nodes for response."""
|
||||
node_visited = set()
|
||||
keywords = await self._extract_subject_entities(query_str)
|
||||
print(f"extract entities: {keywords}\n")
|
||||
rel_texts = []
|
||||
cur_rel_map = {}
|
||||
chunk_indices_count: Dict[str, int] = defaultdict(int)
|
||||
if self._search_mode != SearchMode.EMBEDDING:
|
||||
for keyword in keywords:
|
||||
keyword = keyword.lower()
|
||||
subjs = set((keyword,))
|
||||
# node_ids = self._index_struct.search_node_by_keyword(keyword)
|
||||
# for node_id in node_ids[:GLOBAL_EXPLORE_NODE_LIMIT]:
|
||||
# if node_id in node_visited:
|
||||
# continue
|
||||
#
|
||||
# # if self._include_text:
|
||||
# # chunk_indices_count[node_id] += 1
|
||||
#
|
||||
# node_visited.add(node_id)
|
||||
|
||||
rel_map = self._graph_store.get_rel_map(
|
||||
list(subjs), self.graph_store_query_depth
|
||||
)
|
||||
logger.debug(f"rel_map: {rel_map}")
|
||||
|
||||
if not rel_map:
|
||||
continue
|
||||
rel_texts.extend(
|
||||
[
|
||||
str(rel_obj)
|
||||
for rel_objs in rel_map.values()
|
||||
for rel_obj in rel_objs
|
||||
]
|
||||
)
|
||||
cur_rel_map.update(rel_map)
|
||||
|
||||
sorted_nodes_with_scores = []
|
||||
if not rel_texts:
|
||||
logger.info("> No relationships found, returning nodes found by keywords.")
|
||||
if len(sorted_nodes_with_scores) == 0:
|
||||
logger.info("> No nodes found by keywords, returning empty response.")
|
||||
return [Document(page_content="No relationships found.")]
|
||||
|
||||
# add relationships as Node
|
||||
# TODO: make initial text customizable
|
||||
rel_initial_text = (
|
||||
f"The following are knowledge sequence in max depth"
|
||||
f" {self.graph_store_query_depth} "
|
||||
f"in the form of directed graph like:\n"
|
||||
f"`subject -[predicate]->, object, <-[predicate_next_hop]-,"
|
||||
f" object_next_hop ...`"
|
||||
)
|
||||
rel_info = [rel_initial_text] + rel_texts
|
||||
rel_node_info = {
|
||||
"kg_rel_texts": rel_texts,
|
||||
"kg_rel_map": cur_rel_map,
|
||||
}
|
||||
if self._graph_schema != "":
|
||||
rel_node_info["kg_schema"] = {"schema": self._graph_schema}
|
||||
rel_info_text = "\n".join(
|
||||
[
|
||||
str(item)
|
||||
for sublist in rel_info
|
||||
for item in (sublist if isinstance(sublist, list) else [sublist])
|
||||
]
|
||||
)
|
||||
if self._verbose:
|
||||
print(f"KG context:\n{rel_info_text}\n", color="blue")
|
||||
rel_text_node = TextNode(
|
||||
text=rel_info_text,
|
||||
metadata=rel_node_info,
|
||||
excluded_embed_metadata_keys=["kg_rel_map", "kg_rel_texts"],
|
||||
excluded_llm_metadata_keys=["kg_rel_map", "kg_rel_texts"],
|
||||
)
|
||||
# this node is constructed from rel_texts, give high confidence to avoid cutoff
|
||||
sorted_nodes_with_scores.append(
|
||||
NodeWithScore(node=rel_text_node, score=DEFAULT_NODE_SCORE)
|
||||
)
|
||||
docs = [
|
||||
Document(page_content=node.text, metadata=node.metadata)
|
||||
for node in sorted_nodes_with_scores
|
||||
]
|
||||
return docs
|
||||
|
||||
def _get_metadata_for_response(
|
||||
self, nodes: List[BaseNode]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Get metadata for response."""
|
||||
for node in nodes:
|
||||
if node.metadata is None or "kg_rel_map" not in node.metadata:
|
||||
continue
|
||||
return node.metadata
|
||||
raise ValueError("kg_rel_map must be found in at least one Node.")
|
@ -1,258 +0,0 @@
|
||||
"""Data structures.
|
||||
|
||||
Nodes are decoupled from the indices.
|
||||
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Sequence, Set
|
||||
|
||||
from dataclasses_json import DataClassJsonMixin
|
||||
|
||||
from dbgpt.rag.graph.index_type import IndexStructType
|
||||
from dbgpt.rag.graph.node import BaseNode, TextNode
|
||||
|
||||
# TODO: legacy backport of old Node class
|
||||
Node = TextNode
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexStruct(DataClassJsonMixin):
|
||||
"""A base data struct for a LlamaIndex."""
|
||||
|
||||
index_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
summary: Optional[str] = None
|
||||
|
||||
def get_summary(self) -> str:
|
||||
"""Get text summary."""
|
||||
if self.summary is None:
|
||||
raise ValueError("summary field of the index_struct not set.")
|
||||
return self.summary
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_type(cls):
|
||||
"""Get index struct type."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexGraph(IndexStruct):
|
||||
"""A graph representing the tree-structured index."""
|
||||
|
||||
# mapping from index in tree to Node doc id.
|
||||
all_nodes: Dict[int, str] = field(default_factory=dict)
|
||||
root_nodes: Dict[int, str] = field(default_factory=dict)
|
||||
node_id_to_children_ids: Dict[str, List[str]] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def node_id_to_index(self) -> Dict[str, int]:
|
||||
"""Map from node id to index."""
|
||||
return {node_id: index for index, node_id in self.all_nodes.items()}
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
"""Get the size of the graph."""
|
||||
return len(self.all_nodes)
|
||||
|
||||
def get_index(self, node: BaseNode) -> int:
|
||||
"""Get index of node."""
|
||||
return self.node_id_to_index[node.node_id]
|
||||
|
||||
def insert(
|
||||
self,
|
||||
node: BaseNode,
|
||||
index: Optional[int] = None,
|
||||
children_nodes: Optional[Sequence[BaseNode]] = None,
|
||||
) -> None:
|
||||
"""Insert node."""
|
||||
index = index or self.size
|
||||
node_id = node.node_id
|
||||
|
||||
self.all_nodes[index] = node_id
|
||||
|
||||
if children_nodes is None:
|
||||
children_nodes = []
|
||||
children_ids = [n.node_id for n in children_nodes]
|
||||
self.node_id_to_children_ids[node_id] = children_ids
|
||||
|
||||
def get_children(self, parent_node: Optional[BaseNode]) -> Dict[int, str]:
|
||||
"""Get children nodes."""
|
||||
if parent_node is None:
|
||||
return self.root_nodes
|
||||
else:
|
||||
parent_id = parent_node.node_id
|
||||
children_ids = self.node_id_to_children_ids[parent_id]
|
||||
return {
|
||||
self.node_id_to_index[child_id]: child_id for child_id in children_ids
|
||||
}
|
||||
|
||||
def insert_under_parent(
|
||||
self,
|
||||
node: BaseNode,
|
||||
parent_node: Optional[BaseNode],
|
||||
new_index: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Insert under parent node."""
|
||||
new_index = new_index or self.size
|
||||
if parent_node is None:
|
||||
self.root_nodes[new_index] = node.node_id
|
||||
self.node_id_to_children_ids[node.node_id] = []
|
||||
else:
|
||||
if parent_node.node_id not in self.node_id_to_children_ids:
|
||||
self.node_id_to_children_ids[parent_node.node_id] = []
|
||||
self.node_id_to_children_ids[parent_node.node_id].append(node.node_id)
|
||||
|
||||
self.all_nodes[new_index] = node.node_id
|
||||
|
||||
@classmethod
|
||||
def get_type(cls) -> IndexStructType:
|
||||
"""Get type."""
|
||||
return IndexStructType.TREE
|
||||
|
||||
|
||||
@dataclass
|
||||
class KeywordTable(IndexStruct):
|
||||
"""A table of keywords mapping keywords to text chunks."""
|
||||
|
||||
table: Dict[str, Set[str]] = field(default_factory=dict)
|
||||
|
||||
def add_node(self, keywords: List[str], node: BaseNode) -> None:
|
||||
"""Add text to table."""
|
||||
for keyword in keywords:
|
||||
if keyword not in self.table:
|
||||
self.table[keyword] = set()
|
||||
self.table[keyword].add(node.node_id)
|
||||
|
||||
@property
|
||||
def node_ids(self) -> Set[str]:
|
||||
"""Get all node ids."""
|
||||
return set.union(*self.table.values())
|
||||
|
||||
@property
|
||||
def keywords(self) -> Set[str]:
|
||||
"""Get all keywords in the table."""
|
||||
return set(self.table.keys())
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
"""Get the size of the table."""
|
||||
return len(self.table)
|
||||
|
||||
@classmethod
|
||||
def get_type(cls) -> IndexStructType:
|
||||
"""Get type."""
|
||||
return IndexStructType.KEYWORD_TABLE
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexList(IndexStruct):
|
||||
"""A list of documents."""
|
||||
|
||||
nodes: List[str] = field(default_factory=list)
|
||||
|
||||
def add_node(self, node: BaseNode) -> None:
|
||||
"""Add text to table, return current position in list."""
|
||||
# don't worry about child indices for now, nodes are all in order
|
||||
self.nodes.append(node.node_id)
|
||||
|
||||
@classmethod
|
||||
def get_type(cls) -> IndexStructType:
|
||||
"""Get type."""
|
||||
return IndexStructType.LIST
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexDict(IndexStruct):
|
||||
"""A simple dictionary of documents."""
|
||||
|
||||
# TODO: slightly deprecated, should likely be a list or set now
|
||||
# mapping from vector store id to node doc_id
|
||||
nodes_dict: Dict[str, str] = field(default_factory=dict)
|
||||
|
||||
# TODO: deprecated, not used
|
||||
# mapping from node doc_id to vector store id
|
||||
doc_id_dict: Dict[str, List[str]] = field(default_factory=dict)
|
||||
|
||||
# TODO: deprecated, not used
|
||||
# this should be empty for all other indices
|
||||
embeddings_dict: Dict[str, List[float]] = field(default_factory=dict)
|
||||
|
||||
def add_node(
|
||||
self,
|
||||
node: BaseNode,
|
||||
text_id: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Add text to table, return current position in list."""
|
||||
# # don't worry about child indices for now, nodes are all in order
|
||||
# self.nodes_dict[int_id] = node
|
||||
vector_id = text_id if text_id is not None else node.node_id
|
||||
self.nodes_dict[vector_id] = node.node_id
|
||||
|
||||
return vector_id
|
||||
|
||||
def delete(self, doc_id: str) -> None:
|
||||
"""Delete a Node."""
|
||||
del self.nodes_dict[doc_id]
|
||||
|
||||
@classmethod
|
||||
def get_type(cls) -> IndexStructType:
|
||||
"""Get type."""
|
||||
return IndexStructType.VECTOR_STORE
|
||||
|
||||
|
||||
@dataclass
|
||||
class KG(IndexStruct):
|
||||
"""A table of keywords mapping keywords to text chunks."""
|
||||
|
||||
# Unidirectional
|
||||
|
||||
# table of keywords to node ids
|
||||
table: Dict[str, Set[str]] = field(default_factory=dict)
|
||||
|
||||
# TODO: legacy attribute, remove in future releases
|
||||
rel_map: Dict[str, List[List[str]]] = field(default_factory=dict)
|
||||
|
||||
# TBD, should support vector store, now we just persist the embedding memory
|
||||
# maybe chainable abstractions for *_stores could be designed
|
||||
embedding_dict: Dict[str, List[float]] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def node_ids(self) -> Set[str]:
|
||||
"""Get all node ids."""
|
||||
return set.union(*self.table.values())
|
||||
|
||||
def add_to_embedding_dict(self, triplet_str: str, embedding: List[float]) -> None:
|
||||
"""Add embedding to dict."""
|
||||
self.embedding_dict[triplet_str] = embedding
|
||||
|
||||
def add_node(self, keywords: List[str], node: BaseNode) -> None:
|
||||
"""Add text to table."""
|
||||
node_id = node.node_id
|
||||
for keyword in keywords:
|
||||
keyword = keyword.lower()
|
||||
if keyword not in self.table:
|
||||
self.table[keyword] = set()
|
||||
self.table[keyword].add(node_id)
|
||||
|
||||
def search_node_by_keyword(self, keyword: str) -> List[str]:
|
||||
"""Search for nodes by keyword."""
|
||||
if keyword not in self.table:
|
||||
return []
|
||||
return list(self.table[keyword])
|
||||
|
||||
@classmethod
|
||||
def get_type(cls) -> IndexStructType:
|
||||
"""Get type."""
|
||||
return IndexStructType.KG
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmptyIndexStruct(IndexStruct):
|
||||
"""Empty index."""
|
||||
|
||||
@classmethod
|
||||
def get_type(cls) -> IndexStructType:
|
||||
"""Get type."""
|
||||
return IndexStructType.EMPTY
|
@ -1,48 +0,0 @@
|
||||
"""IndexStructType class."""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class IndexStructType(str, Enum):
|
||||
"""Index struct type. Identifier for a "type" of index.
|
||||
|
||||
Attributes:
|
||||
TREE ("tree"): Tree index. See :ref:`Ref-Indices-Tree` for tree indices.
|
||||
LIST ("list"): Summary index. See :ref:`Ref-Indices-List` for summary indices.
|
||||
KEYWORD_TABLE ("keyword_table"): Keyword table index. See
|
||||
:ref:`Ref-Indices-Table`
|
||||
for keyword table indices.
|
||||
DICT ("dict"): Faiss Vector Store Index. See
|
||||
:ref:`Ref-Indices-VectorStore`
|
||||
for more information on the faiss vector store index.
|
||||
SIMPLE_DICT ("simple_dict"): Simple Vector Store Index. See
|
||||
:ref:`Ref-Indices-VectorStore`
|
||||
for more information on the simple vector store index.
|
||||
KG ("kg"): Knowledge Graph index.
|
||||
See :ref:`Ref-Indices-Knowledge-Graph` for KG indices.
|
||||
DOCUMENT_SUMMARY ("document_summary"): Document Summary Index.
|
||||
See :ref:`Ref-Indices-Document-Summary` for Summary Indices.
|
||||
|
||||
"""
|
||||
|
||||
# TODO: refactor so these are properties on the base class
|
||||
|
||||
NODE = "node"
|
||||
TREE = "tree"
|
||||
LIST = "list"
|
||||
KEYWORD_TABLE = "keyword_table"
|
||||
|
||||
DICT = "dict"
|
||||
# simple
|
||||
SIMPLE_DICT = "simple_dict"
|
||||
# for KG index
|
||||
KG = "kg"
|
||||
SIMPLE_KG = "simple_kg"
|
||||
NEBULAGRAPH = "nebulagraph"
|
||||
FALKORDB = "falkordb"
|
||||
|
||||
# EMPTY
|
||||
EMPTY = "empty"
|
||||
COMPOSITE = "composite"
|
||||
|
||||
DOCUMENT_SUMMARY = "document_summary"
|
@ -1,75 +0,0 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from llama_index.data_structs.data_structs import IndexStruct
|
||||
from llama_index.storage.index_store.utils import (
|
||||
index_struct_to_json,
|
||||
json_to_index_struct,
|
||||
)
|
||||
from llama_index.storage.kvstore.types import BaseKVStore
|
||||
|
||||
DEFAULT_NAMESPACE = "index_store"
|
||||
|
||||
|
||||
class KVIndexStore:
|
||||
"""Key-Value Index store.
|
||||
|
||||
Args:
|
||||
kvstore (BaseKVStore): key-value store
|
||||
namespace (str): namespace for the index store
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, kvstore: BaseKVStore, namespace: Optional[str] = None) -> None:
|
||||
"""Init a KVIndexStore."""
|
||||
self._kvstore = kvstore
|
||||
self._namespace = namespace or DEFAULT_NAMESPACE
|
||||
self._collection = f"{self._namespace}/data"
|
||||
|
||||
def add_index_struct(self, index_struct: IndexStruct) -> None:
|
||||
"""Add an index struct.
|
||||
|
||||
Args:
|
||||
index_struct (IndexStruct): index struct
|
||||
|
||||
"""
|
||||
key = index_struct.index_id
|
||||
data = index_struct_to_json(index_struct)
|
||||
self._kvstore.put(key, data, collection=self._collection)
|
||||
|
||||
def delete_index_struct(self, key: str) -> None:
|
||||
"""Delete an index struct.
|
||||
|
||||
Args:
|
||||
key (str): index struct key
|
||||
|
||||
"""
|
||||
self._kvstore.delete(key, collection=self._collection)
|
||||
|
||||
def get_index_struct(
|
||||
self, struct_id: Optional[str] = None
|
||||
) -> Optional[IndexStruct]:
|
||||
"""Get an index struct.
|
||||
|
||||
Args:
|
||||
struct_id (Optional[str]): index struct id
|
||||
|
||||
"""
|
||||
if struct_id is None:
|
||||
structs = self.index_structs()
|
||||
assert len(structs) == 1
|
||||
return structs[0]
|
||||
else:
|
||||
json = self._kvstore.get(struct_id, collection=self._collection)
|
||||
if json is None:
|
||||
return None
|
||||
return json_to_index_struct(json)
|
||||
|
||||
def index_structs(self) -> List[IndexStruct]:
|
||||
"""Get all index structs.
|
||||
|
||||
Returns:
|
||||
List[IndexStruct]: index structs
|
||||
|
||||
"""
|
||||
jsons = self._kvstore.get_all(collection=self._collection)
|
||||
return [json_to_index_struct(json) for json in jsons.values()]
|
@ -1,570 +0,0 @@
|
||||
"""Base schema for data structures."""
|
||||
import json
|
||||
import textwrap
|
||||
import uuid
|
||||
from abc import abstractmethod
|
||||
from enum import Enum, auto
|
||||
from hashlib import sha256
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langchain.schema import Document
|
||||
from typing_extensions import Self
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field, root_validator
|
||||
|
||||
DEFAULT_TEXT_NODE_TMPL = "{metadata_str}\n\n{content}"
|
||||
DEFAULT_METADATA_TMPL = "{key}: {value}"
|
||||
# NOTE: for pretty printing
|
||||
TRUNCATE_LENGTH = 350
|
||||
WRAP_WIDTH = 70
|
||||
|
||||
|
||||
class BaseComponent(BaseModel):
|
||||
"""Base component object to caputure class names."""
|
||||
|
||||
"""reference llama-index"""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def class_name(cls) -> str:
|
||||
"""Get class name."""
|
||||
|
||||
def to_dict(self, **kwargs: Any) -> Dict[str, Any]:
|
||||
data = self.dict(**kwargs)
|
||||
data["class_name"] = self.class_name()
|
||||
return data
|
||||
|
||||
def to_json(self, **kwargs: Any) -> str:
|
||||
data = self.to_dict(**kwargs)
|
||||
return json.dumps(data)
|
||||
|
||||
# TODO: return type here not supported by current mypy version
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any], **kwargs: Any) -> Self: # type: ignore
|
||||
if isinstance(kwargs, dict):
|
||||
data.update(kwargs)
|
||||
|
||||
data.pop("class_name", None)
|
||||
return cls(**data)
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, data_str: str, **kwargs: Any) -> Self: # type: ignore
|
||||
data = json.loads(data_str)
|
||||
return cls.from_dict(data, **kwargs)
|
||||
|
||||
|
||||
class NodeRelationship(str, Enum):
|
||||
"""Node relationships used in `BaseNode` class.
|
||||
|
||||
Attributes:
|
||||
SOURCE: The node is the source document.
|
||||
PREVIOUS: The node is the previous node in the document.
|
||||
NEXT: The node is the next node in the document.
|
||||
PARENT: The node is the parent node in the document.
|
||||
CHILD: The node is a child node in the document.
|
||||
|
||||
"""
|
||||
|
||||
SOURCE = auto()
|
||||
PREVIOUS = auto()
|
||||
NEXT = auto()
|
||||
PARENT = auto()
|
||||
CHILD = auto()
|
||||
|
||||
|
||||
class ObjectType(str, Enum):
|
||||
TEXT = auto()
|
||||
IMAGE = auto()
|
||||
INDEX = auto()
|
||||
DOCUMENT = auto()
|
||||
|
||||
|
||||
class MetadataMode(str, Enum):
|
||||
ALL = auto()
|
||||
EMBED = auto()
|
||||
LLM = auto()
|
||||
NONE = auto()
|
||||
|
||||
|
||||
class RelatedNodeInfo(BaseComponent):
|
||||
node_id: str
|
||||
node_type: Optional[ObjectType] = None
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
hash: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
"""Get class name."""
|
||||
return "RelatedNodeInfo"
|
||||
|
||||
|
||||
RelatedNodeType = Union[RelatedNodeInfo, List[RelatedNodeInfo]]
|
||||
|
||||
|
||||
# Node classes for indexes
|
||||
class BaseNode(BaseComponent):
|
||||
"""Base node Object.
|
||||
|
||||
Generic abstract interface for retrievable nodes
|
||||
|
||||
"""
|
||||
|
||||
class Config:
|
||||
allow_population_by_field_name = True
|
||||
|
||||
id_: str = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()), description="Unique ID of the node."
|
||||
)
|
||||
embedding: Optional[List[float]] = Field(
|
||||
default=None, description="Embedding of the node."
|
||||
)
|
||||
|
||||
""""
|
||||
metadata fields
|
||||
- injected as part of the text shown to LLMs as context
|
||||
- injected as part of the text for generating embeddings
|
||||
- used by vector DBs for metadata filtering
|
||||
|
||||
"""
|
||||
metadata: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="A flat dictionary of metadata fields",
|
||||
alias="extra_info",
|
||||
)
|
||||
excluded_embed_metadata_keys: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="Metadata keys that are exluded from text for the embed model.",
|
||||
)
|
||||
excluded_llm_metadata_keys: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="Metadata keys that are exluded from text for the LLM.",
|
||||
)
|
||||
relationships: Dict[NodeRelationship, RelatedNodeType] = Field(
|
||||
default_factory=dict,
|
||||
description="A mapping of relationships to other node information.",
|
||||
)
|
||||
hash: str = Field(default="", description="Hash of the node content.")
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_type(cls) -> str:
|
||||
"""Get Object type."""
|
||||
|
||||
@abstractmethod
|
||||
def get_content(self, metadata_mode: MetadataMode = MetadataMode.ALL) -> str:
|
||||
"""Get object content."""
|
||||
|
||||
@abstractmethod
|
||||
def get_metadata_str(self, mode: MetadataMode = MetadataMode.ALL) -> str:
|
||||
"""Metadata string."""
|
||||
|
||||
@abstractmethod
|
||||
def set_content(self, value: Any) -> None:
|
||||
"""Set the content of the node."""
|
||||
|
||||
@property
|
||||
def node_id(self) -> str:
|
||||
return self.id_
|
||||
|
||||
@node_id.setter
|
||||
def node_id(self, value: str) -> None:
|
||||
self.id_ = value
|
||||
|
||||
@property
|
||||
def source_node(self) -> Optional[RelatedNodeInfo]:
|
||||
"""Source object node.
|
||||
|
||||
Extracted from the relationships field.
|
||||
|
||||
"""
|
||||
if NodeRelationship.SOURCE not in self.relationships:
|
||||
return None
|
||||
|
||||
relation = self.relationships[NodeRelationship.SOURCE]
|
||||
if isinstance(relation, list):
|
||||
raise ValueError("Source object must be a single RelatedNodeInfo object")
|
||||
return relation
|
||||
|
||||
@property
|
||||
def prev_node(self) -> Optional[RelatedNodeInfo]:
|
||||
"""Prev node."""
|
||||
if NodeRelationship.PREVIOUS not in self.relationships:
|
||||
return None
|
||||
|
||||
relation = self.relationships[NodeRelationship.PREVIOUS]
|
||||
if not isinstance(relation, RelatedNodeInfo):
|
||||
raise ValueError("Previous object must be a single RelatedNodeInfo object")
|
||||
return relation
|
||||
|
||||
@property
|
||||
def next_node(self) -> Optional[RelatedNodeInfo]:
|
||||
"""Next node."""
|
||||
if NodeRelationship.NEXT not in self.relationships:
|
||||
return None
|
||||
|
||||
relation = self.relationships[NodeRelationship.NEXT]
|
||||
if not isinstance(relation, RelatedNodeInfo):
|
||||
raise ValueError("Next object must be a single RelatedNodeInfo object")
|
||||
return relation
|
||||
|
||||
@property
|
||||
def parent_node(self) -> Optional[RelatedNodeInfo]:
|
||||
"""Parent node."""
|
||||
if NodeRelationship.PARENT not in self.relationships:
|
||||
return None
|
||||
|
||||
relation = self.relationships[NodeRelationship.PARENT]
|
||||
if not isinstance(relation, RelatedNodeInfo):
|
||||
raise ValueError("Parent object must be a single RelatedNodeInfo object")
|
||||
return relation
|
||||
|
||||
@property
|
||||
def child_nodes(self) -> Optional[List[RelatedNodeInfo]]:
|
||||
"""Child nodes."""
|
||||
if NodeRelationship.CHILD not in self.relationships:
|
||||
return None
|
||||
|
||||
relation = self.relationships[NodeRelationship.CHILD]
|
||||
if not isinstance(relation, list):
|
||||
raise ValueError("Child objects must be a list of RelatedNodeInfo objects.")
|
||||
return relation
|
||||
|
||||
@property
|
||||
def ref_doc_id(self) -> Optional[str]:
|
||||
"""Deprecated: Get ref doc id."""
|
||||
source_node = self.source_node
|
||||
if source_node is None:
|
||||
return None
|
||||
return source_node.node_id
|
||||
|
||||
@property
|
||||
def extra_info(self) -> Dict[str, Any]:
|
||||
"""TODO: DEPRECATED: Extra info."""
|
||||
return self.metadata
|
||||
|
||||
def __str__(self) -> str:
|
||||
source_text_truncated = truncate_text(
|
||||
self.get_content().strip(), TRUNCATE_LENGTH
|
||||
)
|
||||
source_text_wrapped = textwrap.fill(
|
||||
f"Text: {source_text_truncated}\n", width=WRAP_WIDTH
|
||||
)
|
||||
return f"Node ID: {self.node_id}\n{source_text_wrapped}"
|
||||
|
||||
def truncate_text(text: str, max_length: int) -> str:
|
||||
"""Truncate text to a maximum length."""
|
||||
if len(text) <= max_length:
|
||||
return text
|
||||
return text[: max_length - 3] + "..."
|
||||
|
||||
def get_embedding(self) -> List[float]:
|
||||
"""Get embedding.
|
||||
|
||||
Errors if embedding is None.
|
||||
|
||||
"""
|
||||
if self.embedding is None:
|
||||
raise ValueError("embedding not set.")
|
||||
return self.embedding
|
||||
|
||||
def as_related_node_info(self) -> RelatedNodeInfo:
|
||||
"""Get node as RelatedNodeInfo."""
|
||||
return RelatedNodeInfo(
|
||||
node_id=self.node_id, metadata=self.metadata, hash=self.hash
|
||||
)
|
||||
|
||||
|
||||
class TextNode(BaseNode):
|
||||
text: str = Field(default="", description="Text content of the node.")
|
||||
start_char_idx: Optional[int] = Field(
|
||||
default=None, description="Start char index of the node."
|
||||
)
|
||||
end_char_idx: Optional[int] = Field(
|
||||
default=None, description="End char index of the node."
|
||||
)
|
||||
text_template: str = Field(
|
||||
default=DEFAULT_TEXT_NODE_TMPL,
|
||||
description=(
|
||||
"Template for how text is formatted, with {content} and "
|
||||
"{metadata_str} placeholders."
|
||||
),
|
||||
)
|
||||
metadata_template: str = Field(
|
||||
default=DEFAULT_METADATA_TMPL,
|
||||
description=(
|
||||
"Template for how metadata is formatted, with {key} and "
|
||||
"{value} placeholders."
|
||||
),
|
||||
)
|
||||
metadata_seperator: str = Field(
|
||||
default="\n",
|
||||
description="Seperator between metadata fields when converting to string.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
"""Get class name."""
|
||||
return "TextNode"
|
||||
|
||||
@root_validator
|
||||
def _check_hash(cls, values: dict) -> dict:
|
||||
"""Generate a hash to represent the node."""
|
||||
text = values.get("text", "")
|
||||
metadata = values.get("metadata", {})
|
||||
doc_identity = str(text) + str(metadata)
|
||||
values["hash"] = str(
|
||||
sha256(doc_identity.encode("utf-8", "surrogatepass")).hexdigest()
|
||||
)
|
||||
return values
|
||||
|
||||
@classmethod
|
||||
def get_type(cls) -> str:
|
||||
"""Get Object type."""
|
||||
return ObjectType.TEXT
|
||||
|
||||
def get_content(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str:
|
||||
"""Get object content."""
|
||||
metadata_str = self.get_metadata_str(mode=metadata_mode).strip()
|
||||
if not metadata_str:
|
||||
return self.text
|
||||
|
||||
return self.text_template.format(
|
||||
content=self.text, metadata_str=metadata_str
|
||||
).strip()
|
||||
|
||||
def get_metadata_str(self, mode: MetadataMode = MetadataMode.ALL) -> str:
|
||||
"""metadata info string."""
|
||||
if mode == MetadataMode.NONE:
|
||||
return ""
|
||||
|
||||
usable_metadata_keys = set(self.metadata.keys())
|
||||
if mode == MetadataMode.LLM:
|
||||
for key in self.excluded_llm_metadata_keys:
|
||||
if key in usable_metadata_keys:
|
||||
usable_metadata_keys.remove(key)
|
||||
elif mode == MetadataMode.EMBED:
|
||||
for key in self.excluded_embed_metadata_keys:
|
||||
if key in usable_metadata_keys:
|
||||
usable_metadata_keys.remove(key)
|
||||
|
||||
return self.metadata_seperator.join(
|
||||
[
|
||||
self.metadata_template.format(key=key, value=str(value))
|
||||
for key, value in self.metadata.items()
|
||||
if key in usable_metadata_keys
|
||||
]
|
||||
)
|
||||
|
||||
def set_content(self, value: str) -> None:
|
||||
"""Set the content of the node."""
|
||||
self.text = value
|
||||
|
||||
def get_node_info(self) -> Dict[str, Any]:
|
||||
"""Get node info."""
|
||||
return {"start": self.start_char_idx, "end": self.end_char_idx}
|
||||
|
||||
def get_text(self) -> str:
|
||||
return self.get_content(metadata_mode=MetadataMode.NONE)
|
||||
|
||||
@property
|
||||
def node_info(self) -> Dict[str, Any]:
|
||||
"""Deprecated: Get node info."""
|
||||
return self.get_node_info()
|
||||
|
||||
|
||||
# TODO: legacy backport of old Node class
|
||||
Node = TextNode
|
||||
|
||||
|
||||
class ImageNode(TextNode):
|
||||
"""Node with image."""
|
||||
|
||||
# TODO: store reference instead of actual image
|
||||
# base64 encoded image str
|
||||
image: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def get_type(cls) -> str:
|
||||
return ObjectType.IMAGE
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
"""Get class name."""
|
||||
return "ImageNode"
|
||||
|
||||
|
||||
class IndexNode(TextNode):
|
||||
"""Node with reference to any object.
|
||||
|
||||
This can include other indices, query engines, retrievers.
|
||||
|
||||
This can also include other nodes (though this is overlapping with `relationships`
|
||||
on the Node class).
|
||||
|
||||
"""
|
||||
|
||||
index_id: str
|
||||
|
||||
@classmethod
|
||||
def from_text_node(
|
||||
cls,
|
||||
node: TextNode,
|
||||
index_id: str,
|
||||
) -> "IndexNode":
|
||||
"""Create index node from text node."""
|
||||
# copy all attributes from text node, add index id
|
||||
return cls(
|
||||
**node.dict(),
|
||||
index_id=index_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_type(cls) -> str:
|
||||
return ObjectType.INDEX
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
"""Get class name."""
|
||||
return "IndexNode"
|
||||
|
||||
|
||||
class NodeWithScore(BaseComponent):
|
||||
node: BaseNode
|
||||
score: Optional[float] = None
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.node}\nScore: {self.score: 0.3f}\n"
|
||||
|
||||
def get_score(self, raise_error: bool = False) -> float:
|
||||
"""Get score."""
|
||||
if self.score is None:
|
||||
if raise_error:
|
||||
raise ValueError("Score not set.")
|
||||
else:
|
||||
return 0.0
|
||||
else:
|
||||
return self.score
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
"""Get class name."""
|
||||
return "NodeWithScore"
|
||||
|
||||
##### pass through methods to BaseNode #####
|
||||
@property
|
||||
def node_id(self) -> str:
|
||||
return self.node.node_id
|
||||
|
||||
@property
|
||||
def id_(self) -> str:
|
||||
return self.node.id_
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
if isinstance(self.node, TextNode):
|
||||
return self.node.text
|
||||
else:
|
||||
raise ValueError("Node must be a TextNode to get text.")
|
||||
|
||||
@property
|
||||
def metadata(self) -> Dict[str, Any]:
|
||||
return self.node.metadata
|
||||
|
||||
@property
|
||||
def embedding(self) -> Optional[List[float]]:
|
||||
return self.node.embedding
|
||||
|
||||
def get_text(self) -> str:
|
||||
if isinstance(self.node, TextNode):
|
||||
return self.node.get_text()
|
||||
else:
|
||||
raise ValueError("Node must be a TextNode to get text.")
|
||||
|
||||
def get_content(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str:
|
||||
return self.node.get_content(metadata_mode=metadata_mode)
|
||||
|
||||
def get_embedding(self) -> List[float]:
|
||||
return self.node.get_embedding()
|
||||
|
||||
|
||||
# Document Classes for Readers
|
||||
|
||||
|
||||
class Document(TextNode):
|
||||
"""Generic interface for a data document.
|
||||
|
||||
This document connects to data sources.
|
||||
|
||||
"""
|
||||
|
||||
# TODO: A lot of backwards compatibility logic here, clean up
|
||||
id_: str = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
description="Unique ID of the node.",
|
||||
alias="doc_id",
|
||||
)
|
||||
|
||||
_compat_fields = {"doc_id": "id_", "extra_info": "metadata"}
|
||||
|
||||
@classmethod
|
||||
def get_type(cls) -> str:
|
||||
"""Get Document type."""
|
||||
return ObjectType.DOCUMENT
|
||||
|
||||
@property
|
||||
def doc_id(self) -> str:
|
||||
"""Get document ID."""
|
||||
return self.id_
|
||||
|
||||
def __str__(self) -> str:
|
||||
source_text_truncated = truncate_text(
|
||||
self.get_content().strip(), TRUNCATE_LENGTH
|
||||
)
|
||||
source_text_wrapped = textwrap.fill(
|
||||
f"Text: {source_text_truncated}\n", width=WRAP_WIDTH
|
||||
)
|
||||
return f"Doc ID: {self.doc_id}\n{source_text_wrapped}"
|
||||
|
||||
def get_doc_id(self) -> str:
|
||||
"""TODO: Deprecated: Get document ID."""
|
||||
return self.id_
|
||||
|
||||
def __setattr__(self, name: str, value: object) -> None:
|
||||
if name in self._compat_fields:
|
||||
name = self._compat_fields[name]
|
||||
super().__setattr__(name, value)
|
||||
|
||||
def to_langchain_format(self) -> Document:
|
||||
"""Convert struct to LangChain document format."""
|
||||
metadata = self.metadata or {}
|
||||
return Document(page_content=self.text, metadata=metadata)
|
||||
|
||||
@classmethod
|
||||
def from_langchain_format(cls, doc: Document) -> "Document":
|
||||
"""Convert struct from LangChain document format."""
|
||||
return cls(text=doc.page_content, metadata=doc.metadata)
|
||||
|
||||
@classmethod
|
||||
def example(cls) -> "Document":
|
||||
document = Document(
|
||||
text="",
|
||||
metadata={"filename": "README.md", "category": "codebase"},
|
||||
)
|
||||
return document
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
"""Get class name."""
|
||||
return "Document"
|
||||
|
||||
|
||||
class ImageDocument(Document):
|
||||
"""Data document containing an image."""
|
||||
|
||||
# base64 encoded image str
|
||||
image: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
"""Get class name."""
|
||||
return "ImageDocument"
|
@ -1,44 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class SearchMode(str, Enum):
|
||||
"""Query mode enum for Knowledge Graphs.
|
||||
|
||||
Can be passed as the enum struct, or as the underlying string.
|
||||
|
||||
Attributes:
|
||||
KEYWORD ("keyword"): Default query mode, using keywords to find triplets.
|
||||
EMBEDDING ("embedding"): Embedding mode, using embeddings to find
|
||||
similar triplets.
|
||||
HYBRID ("hybrid"): Hyrbid mode, combining both keywords and embeddings
|
||||
to find relevant triplets.
|
||||
"""
|
||||
|
||||
KEYWORD = "keyword"
|
||||
EMBEDDING = "embedding"
|
||||
HYBRID = "hybrid"
|
||||
|
||||
|
||||
class BaseSearch(ABC):
|
||||
"""Base Search."""
|
||||
|
||||
async def search(self, query: str):
|
||||
"""Retrieve nodes given query.
|
||||
|
||||
Args:
|
||||
query (QueryType): Either a query string or
|
||||
a QueryBundle object.
|
||||
|
||||
"""
|
||||
# if isinstance(query, str):
|
||||
return await self._search(query)
|
||||
|
||||
@abstractmethod
|
||||
async def _search(self, query: str):
|
||||
"""search nodes given query.
|
||||
|
||||
Implemented by the user.
|
||||
|
||||
"""
|
||||
pass
|
@ -1,40 +1,32 @@
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Dict, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH
|
||||
from dbgpt.core.awel import DAG, InputOperator, MapOperator, SimpleCallDataInputSource
|
||||
from dbgpt.rag.chunk import Chunk
|
||||
from dbgpt.core.awel import DAG, HttpTrigger, MapOperator
|
||||
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
|
||||
from dbgpt.rag.knowledge.base import KnowledgeType
|
||||
from dbgpt.rag.operator.knowledge import KnowledgeOperator
|
||||
from dbgpt.serve.rag.operators.embedding import EmbeddingAssemblerOperator
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
"""AWEL: Simple rag embedding operator example
|
||||
|
||||
pre-requirements:
|
||||
set your file path in your example code.
|
||||
|
||||
Examples:
|
||||
pre-requirements:
|
||||
python examples/awel/simple_rag_embedding_example.py
|
||||
..code-block:: shell
|
||||
python examples/awel/simple_rag_embedding_example.py
|
||||
curl --location --request POST 'http://127.0.0.1:5555/api/v1/awel/trigger/examples/rag/embedding' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data-raw '{
|
||||
"url": "https://docs.dbgpt.site/docs/awel"
|
||||
}'
|
||||
"""
|
||||
|
||||
|
||||
def _context_join_fn(context_dict: Dict, chunks: List[Chunk]) -> Dict:
|
||||
"""context Join function for JoinOperator.
|
||||
|
||||
Args:
|
||||
context_dict (Dict): context dict
|
||||
chunks (List[Chunk]): chunks
|
||||
Returns:
|
||||
Dict: context dict
|
||||
"""
|
||||
context_dict["context"] = "\n".join([chunk.content for chunk in chunks])
|
||||
return context_dict
|
||||
|
||||
|
||||
def _create_vector_connector():
|
||||
def _create_vector_connector() -> VectorStoreConnector:
|
||||
"""Create vector connector."""
|
||||
return VectorStoreConnector.from_default(
|
||||
"Chroma",
|
||||
@ -48,6 +40,22 @@ def _create_vector_connector():
|
||||
)
|
||||
|
||||
|
||||
class TriggerReqBody(BaseModel):
|
||||
url: str = Field(..., description="url")
|
||||
|
||||
|
||||
class RequestHandleOperator(MapOperator[TriggerReqBody, Dict]):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def map(self, input_value: TriggerReqBody) -> Dict:
|
||||
params = {
|
||||
"url": input_value.url,
|
||||
}
|
||||
print(f"Receive input value: {input_value}")
|
||||
return params
|
||||
|
||||
|
||||
class ResultOperator(MapOperator):
|
||||
"""The Result Operator."""
|
||||
|
||||
@ -61,26 +69,31 @@ class ResultOperator(MapOperator):
|
||||
|
||||
|
||||
with DAG("simple_sdk_rag_embedding_example") as dag:
|
||||
knowledge_operator = KnowledgeOperator()
|
||||
trigger = HttpTrigger(
|
||||
"/examples/rag/embedding", methods="POST", request_body=TriggerReqBody
|
||||
)
|
||||
request_handle_task = RequestHandleOperator()
|
||||
knowledge_operator = KnowledgeOperator(knowledge_type=KnowledgeType.URL)
|
||||
vector_connector = _create_vector_connector()
|
||||
input_task = InputOperator(input_source=SimpleCallDataInputSource())
|
||||
file_path_parser = MapOperator(map_function=lambda x: x["file_path"])
|
||||
url_parser_operator = MapOperator(map_function=lambda x: x["url"])
|
||||
embedding_operator = EmbeddingAssemblerOperator(
|
||||
vector_store_connector=vector_connector,
|
||||
)
|
||||
output_task = ResultOperator()
|
||||
(
|
||||
input_task
|
||||
>> file_path_parser
|
||||
trigger
|
||||
>> request_handle_task
|
||||
>> url_parser_operator
|
||||
>> knowledge_operator
|
||||
>> embedding_operator
|
||||
>> output_task
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
input_data = {
|
||||
"data": {
|
||||
"file_path": "docs/docs/awel.md",
|
||||
}
|
||||
}
|
||||
output = asyncio.run(output_task.call(call_data=input_data))
|
||||
if dag.leaf_nodes[0].dev_mode:
|
||||
# Development mode, you can run the dag locally for debugging.
|
||||
from dbgpt.core.awel import setup_dev_environment
|
||||
|
||||
setup_dev_environment([dag], port=5555)
|
||||
else:
|
||||
pass
|
||||
|
Loading…
Reference in New Issue
Block a user