fix:update awel embedding examples and delete unuseful code. (#1073)

This commit is contained in:
Aries-ckt 2024-01-15 23:22:52 +08:00 committed by GitHub
parent cb9c34abb9
commit 3a54d1ef9a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 45 additions and 1437 deletions

View File

@ -28,11 +28,6 @@ def initialize_components(
system_app.register(DefaultExecutorFactory)
system_app.register_instance(controller)
# Register global default RAGGraphFactory
# from dbgpt.graph.graph_factory import DefaultRAGGraphFactory
# system_app.register(DefaultRAGGraphFactory)
from dbgpt.serve.agent.hub.controller import module_agent
system_app.register_instance(module_agent)

View File

@ -497,30 +497,6 @@ class KnowledgeService:
res.page = request.page
return res
def async_knowledge_graph(self, chunk_docs, doc):
"""async document extract triplets and save into graph db
Args:
- chunk_docs: List[Document]
- doc: KnowledgeDocumentEntity
"""
logger.info(
f"async_knowledge_graph, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to graph store"
)
try:
from dbgpt.rag.graph.graph_factory import RAGGraphFactory
rag_engine = CFG.SYSTEM_APP.get_component(
ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory
).create()
rag_engine.knowledge_graph(chunk_docs)
doc.status = SyncStatus.FINISHED.name
doc.result = "document build graph success"
except Exception as e:
doc.status = SyncStatus.FAILED.name
doc.result = "document build graph failed" + str(e)
logger.error(f"document build graph failed:{doc.doc_name}, {str(e)}")
return knowledge_document_dao.update_knowledge_document(doc)
def async_doc_embedding(self, assembler, chunk_docs, doc):
"""async document embedding into vector db
Args:

View File

@ -1,148 +0,0 @@
import logging
from typing import Any, Callable, List, Optional, Tuple
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from dbgpt.rag.embedding import KnowledgeType
from dbgpt.rag.embedding.knowledge_type import get_knowledge_embedding
from dbgpt.rag.graph.index_struct import KG
from dbgpt.rag.graph.node import TextNode
from dbgpt.util import utils
logger = logging.getLogger(__name__)
class RAGGraphEngine:
"""Knowledge RAG Graph Engine.
Build a RAG Graph Client can extract triplets and insert into graph store.
Args:
knowledge_type (Optional[str]): Default: KnowledgeType.DOCUMENT.value
extracting triplets.
knowledge_source (Optional[str]):
model_name (Optional[str]): llm model name
graph_store (Optional[GraphStore]): The graph store to use.refrence:llama-index
include_embeddings (bool): Whether to include embeddings in the index.
Defaults to False.
max_object_length (int): The maximum length of the object in a triplet.
Defaults to 128.
extract_triplet_fn (Optional[Callable]): The function to use for
extracting triplets. Defaults to None.
"""
index_struct_cls = KG
def __init__(
self,
knowledge_type: Optional[str] = KnowledgeType.DOCUMENT.value,
knowledge_source: Optional[str] = None,
text_splitter=None,
graph_store=None,
index_struct: Optional[KG] = None,
model_name: Optional[str] = None,
max_triplets_per_chunk: int = 10,
include_embeddings: bool = False,
max_object_length: int = 128,
extract_triplet_fn: Optional[Callable] = None,
**kwargs: Any,
) -> None:
"""Initialize params."""
from llama_index.graph_stores import SimpleGraphStore
# need to set parameters before building index in base class.
self.knowledge_source = knowledge_source
self.knowledge_type = knowledge_type
self.model_name = model_name
self.text_splitter = text_splitter
self.index_struct = index_struct
self.include_embeddings = include_embeddings
self.graph_store = graph_store or SimpleGraphStore()
# self.graph_store = graph_store
self.max_triplets_per_chunk = max_triplets_per_chunk
self._max_object_length = max_object_length
self._extract_triplet_fn = extract_triplet_fn
def knowledge_graph(self, docs=None):
"""knowledge docs into graph store"""
if not docs:
if self.text_splitter:
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=2000, chunk_overlap=100
)
knowledge_source = get_knowledge_embedding(
knowledge_type=self.knowledge_type,
knowledge_source=self.knowledge_source,
text_splitter=self.text_splitter,
)
docs = knowledge_source.read()
if self.index_struct is None:
self.index_struct = self._build_index_from_docs(docs)
def _extract_triplets(self, text: str) -> List[Tuple[str, str, str]]:
"""Extract triplets from text by function or llm"""
if self._extract_triplet_fn is not None:
return self._extract_triplet_fn(text)
else:
return self._llm_extract_triplets(text)
def _llm_extract_triplets(self, text: str) -> List[Tuple[str, str, str]]:
"""Extract triplets from text by llm"""
import uuid
from dbgpt.app.scene import ChatScene
from dbgpt.util.chat_util import llm_chat_response_nostream
chat_param = {
"chat_session_id": uuid.uuid1(),
"current_user_input": text,
"select_param": "triplet",
"model_name": self.model_name,
}
loop = utils.get_or_create_event_loop()
triplets = loop.run_until_complete(
llm_chat_response_nostream(
ChatScene.ExtractTriplet.value(), **{"chat_param": chat_param}
)
)
return triplets
def _build_index_from_docs(self, documents: List[Document]) -> KG:
"""Build the index from nodes.
Args:documents:List[Document]
"""
index_struct = self.index_struct_cls()
triplets = []
for doc in documents:
trips = self._extract_triplets_task([doc], index_struct)
triplets.extend(trips)
print(triplets)
text_node = TextNode(text=doc.page_content, metadata=doc.metadata)
for triplet in triplets:
subj, _, obj = triplet
self.graph_store.upsert_triplet(*triplet)
index_struct.add_node([subj, obj], text_node)
return index_struct
def search(self, query):
from dbgpt.rag.graph.graph_search import RAGGraphSearch
graph_search = RAGGraphSearch(graph_engine=self)
return graph_search.search(query)
def _extract_triplets_task(self, docs, index_struct):
triple_results = []
for doc in docs:
import threading
thread_id = threading.get_ident()
print(f"current thread-{thread_id} begin extract triplets task")
triplets = self._extract_triplets(doc.page_content)
if len(triplets) == 0:
triplets = []
text_node = TextNode(text=doc.page_content, metadata=doc.metadata)
logger.info(f"extracted knowledge triplets: {triplets}")
print(
f"current thread-{thread_id} end extract triplets tasks, triplets-{triplets}"
)
triple_results.extend(triplets)
return triple_results

View File

@ -1,35 +0,0 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Type
from dbgpt.component import BaseComponent, ComponentType
class RAGGraphFactory(BaseComponent, ABC):
name = ComponentType.RAG_GRAPH_DEFAULT.value
@abstractmethod
def create(self, model_name: str = None, embedding_cls: Type = None):
"""Create RAG Graph Engine"""
class DefaultRAGGraphFactory(RAGGraphFactory):
def __init__(
self, system_app=None, default_model_name: str = None, **kwargs: Any
) -> None:
super().__init__(system_app=system_app)
self._default_model_name = default_model_name
self.kwargs = kwargs
from dbgpt.rag.graph.graph_engine import RAGGraphEngine
self.rag_engine = RAGGraphEngine(model_name="proxyllm")
def init_app(self, system_app):
pass
def create(self, model_name: str = None, rag_cls: Type = None):
if not model_name:
model_name = self._default_model_name
return self.rag_engine

View File

@ -1,198 +0,0 @@
import logging
import os
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Callable, Dict, List, Optional, Set
from langchain.schema import Document
from dbgpt.rag.graph.node import BaseNode, NodeWithScore, TextNode
from dbgpt.rag.graph.search import BaseSearch, SearchMode
logger = logging.getLogger(__name__)
DEFAULT_NODE_SCORE = 1000.0
GLOBAL_EXPLORE_NODE_LIMIT = 3
REL_TEXT_LIMIT = 30
class RAGGraphSearch(BaseSearch):
"""RAG Graph Search.
args:
graph RAGGraphEngine.
model_name (str): model name
(see :ref:`Prompt-Templates`).
text_qa_template (Optional[BasePromptTemplate]): A Question Answering Prompt
(see :ref:`Prompt-Templates`).
max_keywords_per_query (int): Maximum number of keywords to extract from query.
num_chunks_per_query (int): Maximum number of text chunks to query.
search_mode (Optional[SearchMode]): Specifies whether to use keyowrds, default SearchMode.KEYWORD
embeddings, or both to find relevant triplets. Should be one of "keyword",
"embedding", or "hybrid".
graph_store_query_depth (int): The depth of the graph store query.
extract_subject_entities_fn (Optional[Callback]): extract_subject_entities callback.
"""
def __init__(
self,
graph_engine,
model_name: str = None,
max_keywords_per_query: int = 10,
num_chunks_per_query: int = 10,
search_mode: Optional[SearchMode] = SearchMode.KEYWORD,
graph_store_query_depth: int = 2,
extract_subject_entities_fn: Optional[Callable] = None,
**kwargs: Any,
) -> None:
"""Initialize params."""
from dbgpt.rag.graph.graph_engine import RAGGraphEngine
self.graph_engine: RAGGraphEngine = graph_engine
self.model_name = model_name or self.graph_engine.model_name
self._index_struct = self.graph_engine.index_struct
self.max_keywords_per_query = max_keywords_per_query
self.num_chunks_per_query = num_chunks_per_query
self._search_mode = search_mode
self._graph_store = self.graph_engine.graph_store
self.graph_store_query_depth = graph_store_query_depth
self._verbose = kwargs.get("verbose", False)
refresh_schema = kwargs.get("refresh_schema", False)
self.extract_subject_entities_fn = extract_subject_entities_fn
self.executor = ThreadPoolExecutor(max_workers=os.cpu_count() * 5)
try:
self._graph_schema = self._graph_store.get_schema(refresh=refresh_schema)
except NotImplementedError:
self._graph_schema = ""
except Exception as e:
logger.warn(f"can not to find graph schema: {e}")
self._graph_schema = ""
async def _extract_subject_entities(self, query_str: str) -> Set[str]:
"""extract subject entities."""
if self.extract_subject_entities_fn is not None:
return await self.extract_subject_entities_fn(query_str)
else:
return await self._extract_entities_by_llm(query_str)
async def _extract_entities_by_llm(self, text: str) -> Set[str]:
"""extract subject entities from text by llm"""
import uuid
from dbgpt.app.scene import ChatScene
from dbgpt.util.chat_util import llm_chat_response_nostream
chat_param = {
"chat_session_id": uuid.uuid1(),
"current_user_input": text,
"select_param": "entity",
"model_name": self.model_name,
}
# loop = util.get_or_create_event_loop()
# entities = loop.run_until_complete(
# llm_chat_response_nostream(
# ChatScene.ExtractEntity.value(), **{"chat_param": chat_param}
# )
# )
return await llm_chat_response_nostream(
ChatScene.ExtractEntity.value(), **{"chat_param": chat_param}
)
async def _search(
self,
query_str: str,
) -> List[Document]:
"""Get nodes for response."""
node_visited = set()
keywords = await self._extract_subject_entities(query_str)
print(f"extract entities: {keywords}\n")
rel_texts = []
cur_rel_map = {}
chunk_indices_count: Dict[str, int] = defaultdict(int)
if self._search_mode != SearchMode.EMBEDDING:
for keyword in keywords:
keyword = keyword.lower()
subjs = set((keyword,))
# node_ids = self._index_struct.search_node_by_keyword(keyword)
# for node_id in node_ids[:GLOBAL_EXPLORE_NODE_LIMIT]:
# if node_id in node_visited:
# continue
#
# # if self._include_text:
# # chunk_indices_count[node_id] += 1
#
# node_visited.add(node_id)
rel_map = self._graph_store.get_rel_map(
list(subjs), self.graph_store_query_depth
)
logger.debug(f"rel_map: {rel_map}")
if not rel_map:
continue
rel_texts.extend(
[
str(rel_obj)
for rel_objs in rel_map.values()
for rel_obj in rel_objs
]
)
cur_rel_map.update(rel_map)
sorted_nodes_with_scores = []
if not rel_texts:
logger.info("> No relationships found, returning nodes found by keywords.")
if len(sorted_nodes_with_scores) == 0:
logger.info("> No nodes found by keywords, returning empty response.")
return [Document(page_content="No relationships found.")]
# add relationships as Node
# TODO: make initial text customizable
rel_initial_text = (
f"The following are knowledge sequence in max depth"
f" {self.graph_store_query_depth} "
f"in the form of directed graph like:\n"
f"`subject -[predicate]->, object, <-[predicate_next_hop]-,"
f" object_next_hop ...`"
)
rel_info = [rel_initial_text] + rel_texts
rel_node_info = {
"kg_rel_texts": rel_texts,
"kg_rel_map": cur_rel_map,
}
if self._graph_schema != "":
rel_node_info["kg_schema"] = {"schema": self._graph_schema}
rel_info_text = "\n".join(
[
str(item)
for sublist in rel_info
for item in (sublist if isinstance(sublist, list) else [sublist])
]
)
if self._verbose:
print(f"KG context:\n{rel_info_text}\n", color="blue")
rel_text_node = TextNode(
text=rel_info_text,
metadata=rel_node_info,
excluded_embed_metadata_keys=["kg_rel_map", "kg_rel_texts"],
excluded_llm_metadata_keys=["kg_rel_map", "kg_rel_texts"],
)
# this node is constructed from rel_texts, give high confidence to avoid cutoff
sorted_nodes_with_scores.append(
NodeWithScore(node=rel_text_node, score=DEFAULT_NODE_SCORE)
)
docs = [
Document(page_content=node.text, metadata=node.metadata)
for node in sorted_nodes_with_scores
]
return docs
def _get_metadata_for_response(
self, nodes: List[BaseNode]
) -> Optional[Dict[str, Any]]:
"""Get metadata for response."""
for node in nodes:
if node.metadata is None or "kg_rel_map" not in node.metadata:
continue
return node.metadata
raise ValueError("kg_rel_map must be found in at least one Node.")

View File

@ -1,258 +0,0 @@
"""Data structures.
Nodes are decoupled from the indices.
"""
import uuid
from abc import abstractmethod
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Sequence, Set
from dataclasses_json import DataClassJsonMixin
from dbgpt.rag.graph.index_type import IndexStructType
from dbgpt.rag.graph.node import BaseNode, TextNode
# TODO: legacy backport of old Node class
Node = TextNode
@dataclass
class IndexStruct(DataClassJsonMixin):
"""A base data struct for a LlamaIndex."""
index_id: str = field(default_factory=lambda: str(uuid.uuid4()))
summary: Optional[str] = None
def get_summary(self) -> str:
"""Get text summary."""
if self.summary is None:
raise ValueError("summary field of the index_struct not set.")
return self.summary
@classmethod
@abstractmethod
def get_type(cls):
"""Get index struct type."""
@dataclass
class IndexGraph(IndexStruct):
"""A graph representing the tree-structured index."""
# mapping from index in tree to Node doc id.
all_nodes: Dict[int, str] = field(default_factory=dict)
root_nodes: Dict[int, str] = field(default_factory=dict)
node_id_to_children_ids: Dict[str, List[str]] = field(default_factory=dict)
@property
def node_id_to_index(self) -> Dict[str, int]:
"""Map from node id to index."""
return {node_id: index for index, node_id in self.all_nodes.items()}
@property
def size(self) -> int:
"""Get the size of the graph."""
return len(self.all_nodes)
def get_index(self, node: BaseNode) -> int:
"""Get index of node."""
return self.node_id_to_index[node.node_id]
def insert(
self,
node: BaseNode,
index: Optional[int] = None,
children_nodes: Optional[Sequence[BaseNode]] = None,
) -> None:
"""Insert node."""
index = index or self.size
node_id = node.node_id
self.all_nodes[index] = node_id
if children_nodes is None:
children_nodes = []
children_ids = [n.node_id for n in children_nodes]
self.node_id_to_children_ids[node_id] = children_ids
def get_children(self, parent_node: Optional[BaseNode]) -> Dict[int, str]:
"""Get children nodes."""
if parent_node is None:
return self.root_nodes
else:
parent_id = parent_node.node_id
children_ids = self.node_id_to_children_ids[parent_id]
return {
self.node_id_to_index[child_id]: child_id for child_id in children_ids
}
def insert_under_parent(
self,
node: BaseNode,
parent_node: Optional[BaseNode],
new_index: Optional[int] = None,
) -> None:
"""Insert under parent node."""
new_index = new_index or self.size
if parent_node is None:
self.root_nodes[new_index] = node.node_id
self.node_id_to_children_ids[node.node_id] = []
else:
if parent_node.node_id not in self.node_id_to_children_ids:
self.node_id_to_children_ids[parent_node.node_id] = []
self.node_id_to_children_ids[parent_node.node_id].append(node.node_id)
self.all_nodes[new_index] = node.node_id
@classmethod
def get_type(cls) -> IndexStructType:
"""Get type."""
return IndexStructType.TREE
@dataclass
class KeywordTable(IndexStruct):
"""A table of keywords mapping keywords to text chunks."""
table: Dict[str, Set[str]] = field(default_factory=dict)
def add_node(self, keywords: List[str], node: BaseNode) -> None:
"""Add text to table."""
for keyword in keywords:
if keyword not in self.table:
self.table[keyword] = set()
self.table[keyword].add(node.node_id)
@property
def node_ids(self) -> Set[str]:
"""Get all node ids."""
return set.union(*self.table.values())
@property
def keywords(self) -> Set[str]:
"""Get all keywords in the table."""
return set(self.table.keys())
@property
def size(self) -> int:
"""Get the size of the table."""
return len(self.table)
@classmethod
def get_type(cls) -> IndexStructType:
"""Get type."""
return IndexStructType.KEYWORD_TABLE
@dataclass
class IndexList(IndexStruct):
"""A list of documents."""
nodes: List[str] = field(default_factory=list)
def add_node(self, node: BaseNode) -> None:
"""Add text to table, return current position in list."""
# don't worry about child indices for now, nodes are all in order
self.nodes.append(node.node_id)
@classmethod
def get_type(cls) -> IndexStructType:
"""Get type."""
return IndexStructType.LIST
@dataclass
class IndexDict(IndexStruct):
"""A simple dictionary of documents."""
# TODO: slightly deprecated, should likely be a list or set now
# mapping from vector store id to node doc_id
nodes_dict: Dict[str, str] = field(default_factory=dict)
# TODO: deprecated, not used
# mapping from node doc_id to vector store id
doc_id_dict: Dict[str, List[str]] = field(default_factory=dict)
# TODO: deprecated, not used
# this should be empty for all other indices
embeddings_dict: Dict[str, List[float]] = field(default_factory=dict)
def add_node(
self,
node: BaseNode,
text_id: Optional[str] = None,
) -> str:
"""Add text to table, return current position in list."""
# # don't worry about child indices for now, nodes are all in order
# self.nodes_dict[int_id] = node
vector_id = text_id if text_id is not None else node.node_id
self.nodes_dict[vector_id] = node.node_id
return vector_id
def delete(self, doc_id: str) -> None:
"""Delete a Node."""
del self.nodes_dict[doc_id]
@classmethod
def get_type(cls) -> IndexStructType:
"""Get type."""
return IndexStructType.VECTOR_STORE
@dataclass
class KG(IndexStruct):
"""A table of keywords mapping keywords to text chunks."""
# Unidirectional
# table of keywords to node ids
table: Dict[str, Set[str]] = field(default_factory=dict)
# TODO: legacy attribute, remove in future releases
rel_map: Dict[str, List[List[str]]] = field(default_factory=dict)
# TBD, should support vector store, now we just persist the embedding memory
# maybe chainable abstractions for *_stores could be designed
embedding_dict: Dict[str, List[float]] = field(default_factory=dict)
@property
def node_ids(self) -> Set[str]:
"""Get all node ids."""
return set.union(*self.table.values())
def add_to_embedding_dict(self, triplet_str: str, embedding: List[float]) -> None:
"""Add embedding to dict."""
self.embedding_dict[triplet_str] = embedding
def add_node(self, keywords: List[str], node: BaseNode) -> None:
"""Add text to table."""
node_id = node.node_id
for keyword in keywords:
keyword = keyword.lower()
if keyword not in self.table:
self.table[keyword] = set()
self.table[keyword].add(node_id)
def search_node_by_keyword(self, keyword: str) -> List[str]:
"""Search for nodes by keyword."""
if keyword not in self.table:
return []
return list(self.table[keyword])
@classmethod
def get_type(cls) -> IndexStructType:
"""Get type."""
return IndexStructType.KG
@dataclass
class EmptyIndexStruct(IndexStruct):
"""Empty index."""
@classmethod
def get_type(cls) -> IndexStructType:
"""Get type."""
return IndexStructType.EMPTY

View File

@ -1,48 +0,0 @@
"""IndexStructType class."""
from enum import Enum
class IndexStructType(str, Enum):
"""Index struct type. Identifier for a "type" of index.
Attributes:
TREE ("tree"): Tree index. See :ref:`Ref-Indices-Tree` for tree indices.
LIST ("list"): Summary index. See :ref:`Ref-Indices-List` for summary indices.
KEYWORD_TABLE ("keyword_table"): Keyword table index. See
:ref:`Ref-Indices-Table`
for keyword table indices.
DICT ("dict"): Faiss Vector Store Index. See
:ref:`Ref-Indices-VectorStore`
for more information on the faiss vector store index.
SIMPLE_DICT ("simple_dict"): Simple Vector Store Index. See
:ref:`Ref-Indices-VectorStore`
for more information on the simple vector store index.
KG ("kg"): Knowledge Graph index.
See :ref:`Ref-Indices-Knowledge-Graph` for KG indices.
DOCUMENT_SUMMARY ("document_summary"): Document Summary Index.
See :ref:`Ref-Indices-Document-Summary` for Summary Indices.
"""
# TODO: refactor so these are properties on the base class
NODE = "node"
TREE = "tree"
LIST = "list"
KEYWORD_TABLE = "keyword_table"
DICT = "dict"
# simple
SIMPLE_DICT = "simple_dict"
# for KG index
KG = "kg"
SIMPLE_KG = "simple_kg"
NEBULAGRAPH = "nebulagraph"
FALKORDB = "falkordb"
# EMPTY
EMPTY = "empty"
COMPOSITE = "composite"
DOCUMENT_SUMMARY = "document_summary"

View File

@ -1,75 +0,0 @@
from typing import List, Optional
from llama_index.data_structs.data_structs import IndexStruct
from llama_index.storage.index_store.utils import (
index_struct_to_json,
json_to_index_struct,
)
from llama_index.storage.kvstore.types import BaseKVStore
DEFAULT_NAMESPACE = "index_store"
class KVIndexStore:
"""Key-Value Index store.
Args:
kvstore (BaseKVStore): key-value store
namespace (str): namespace for the index store
"""
def __init__(self, kvstore: BaseKVStore, namespace: Optional[str] = None) -> None:
"""Init a KVIndexStore."""
self._kvstore = kvstore
self._namespace = namespace or DEFAULT_NAMESPACE
self._collection = f"{self._namespace}/data"
def add_index_struct(self, index_struct: IndexStruct) -> None:
"""Add an index struct.
Args:
index_struct (IndexStruct): index struct
"""
key = index_struct.index_id
data = index_struct_to_json(index_struct)
self._kvstore.put(key, data, collection=self._collection)
def delete_index_struct(self, key: str) -> None:
"""Delete an index struct.
Args:
key (str): index struct key
"""
self._kvstore.delete(key, collection=self._collection)
def get_index_struct(
self, struct_id: Optional[str] = None
) -> Optional[IndexStruct]:
"""Get an index struct.
Args:
struct_id (Optional[str]): index struct id
"""
if struct_id is None:
structs = self.index_structs()
assert len(structs) == 1
return structs[0]
else:
json = self._kvstore.get(struct_id, collection=self._collection)
if json is None:
return None
return json_to_index_struct(json)
def index_structs(self) -> List[IndexStruct]:
"""Get all index structs.
Returns:
List[IndexStruct]: index structs
"""
jsons = self._kvstore.get_all(collection=self._collection)
return [json_to_index_struct(json) for json in jsons.values()]

View File

@ -1,570 +0,0 @@
"""Base schema for data structures."""
import json
import textwrap
import uuid
from abc import abstractmethod
from enum import Enum, auto
from hashlib import sha256
from typing import Any, Dict, List, Optional, Union
from langchain.schema import Document
from typing_extensions import Self
from dbgpt._private.pydantic import BaseModel, Field, root_validator
DEFAULT_TEXT_NODE_TMPL = "{metadata_str}\n\n{content}"
DEFAULT_METADATA_TMPL = "{key}: {value}"
# NOTE: for pretty printing
TRUNCATE_LENGTH = 350
WRAP_WIDTH = 70
class BaseComponent(BaseModel):
"""Base component object to caputure class names."""
"""reference llama-index"""
@classmethod
@abstractmethod
def class_name(cls) -> str:
"""Get class name."""
def to_dict(self, **kwargs: Any) -> Dict[str, Any]:
data = self.dict(**kwargs)
data["class_name"] = self.class_name()
return data
def to_json(self, **kwargs: Any) -> str:
data = self.to_dict(**kwargs)
return json.dumps(data)
# TODO: return type here not supported by current mypy version
@classmethod
def from_dict(cls, data: Dict[str, Any], **kwargs: Any) -> Self: # type: ignore
if isinstance(kwargs, dict):
data.update(kwargs)
data.pop("class_name", None)
return cls(**data)
@classmethod
def from_json(cls, data_str: str, **kwargs: Any) -> Self: # type: ignore
data = json.loads(data_str)
return cls.from_dict(data, **kwargs)
class NodeRelationship(str, Enum):
"""Node relationships used in `BaseNode` class.
Attributes:
SOURCE: The node is the source document.
PREVIOUS: The node is the previous node in the document.
NEXT: The node is the next node in the document.
PARENT: The node is the parent node in the document.
CHILD: The node is a child node in the document.
"""
SOURCE = auto()
PREVIOUS = auto()
NEXT = auto()
PARENT = auto()
CHILD = auto()
class ObjectType(str, Enum):
TEXT = auto()
IMAGE = auto()
INDEX = auto()
DOCUMENT = auto()
class MetadataMode(str, Enum):
ALL = auto()
EMBED = auto()
LLM = auto()
NONE = auto()
class RelatedNodeInfo(BaseComponent):
node_id: str
node_type: Optional[ObjectType] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
hash: Optional[str] = None
@classmethod
def class_name(cls) -> str:
"""Get class name."""
return "RelatedNodeInfo"
RelatedNodeType = Union[RelatedNodeInfo, List[RelatedNodeInfo]]
# Node classes for indexes
class BaseNode(BaseComponent):
"""Base node Object.
Generic abstract interface for retrievable nodes
"""
class Config:
allow_population_by_field_name = True
id_: str = Field(
default_factory=lambda: str(uuid.uuid4()), description="Unique ID of the node."
)
embedding: Optional[List[float]] = Field(
default=None, description="Embedding of the node."
)
""""
metadata fields
- injected as part of the text shown to LLMs as context
- injected as part of the text for generating embeddings
- used by vector DBs for metadata filtering
"""
metadata: Dict[str, Any] = Field(
default_factory=dict,
description="A flat dictionary of metadata fields",
alias="extra_info",
)
excluded_embed_metadata_keys: List[str] = Field(
default_factory=list,
description="Metadata keys that are exluded from text for the embed model.",
)
excluded_llm_metadata_keys: List[str] = Field(
default_factory=list,
description="Metadata keys that are exluded from text for the LLM.",
)
relationships: Dict[NodeRelationship, RelatedNodeType] = Field(
default_factory=dict,
description="A mapping of relationships to other node information.",
)
hash: str = Field(default="", description="Hash of the node content.")
@classmethod
@abstractmethod
def get_type(cls) -> str:
"""Get Object type."""
@abstractmethod
def get_content(self, metadata_mode: MetadataMode = MetadataMode.ALL) -> str:
"""Get object content."""
@abstractmethod
def get_metadata_str(self, mode: MetadataMode = MetadataMode.ALL) -> str:
"""Metadata string."""
@abstractmethod
def set_content(self, value: Any) -> None:
"""Set the content of the node."""
@property
def node_id(self) -> str:
return self.id_
@node_id.setter
def node_id(self, value: str) -> None:
self.id_ = value
@property
def source_node(self) -> Optional[RelatedNodeInfo]:
"""Source object node.
Extracted from the relationships field.
"""
if NodeRelationship.SOURCE not in self.relationships:
return None
relation = self.relationships[NodeRelationship.SOURCE]
if isinstance(relation, list):
raise ValueError("Source object must be a single RelatedNodeInfo object")
return relation
@property
def prev_node(self) -> Optional[RelatedNodeInfo]:
"""Prev node."""
if NodeRelationship.PREVIOUS not in self.relationships:
return None
relation = self.relationships[NodeRelationship.PREVIOUS]
if not isinstance(relation, RelatedNodeInfo):
raise ValueError("Previous object must be a single RelatedNodeInfo object")
return relation
@property
def next_node(self) -> Optional[RelatedNodeInfo]:
"""Next node."""
if NodeRelationship.NEXT not in self.relationships:
return None
relation = self.relationships[NodeRelationship.NEXT]
if not isinstance(relation, RelatedNodeInfo):
raise ValueError("Next object must be a single RelatedNodeInfo object")
return relation
@property
def parent_node(self) -> Optional[RelatedNodeInfo]:
"""Parent node."""
if NodeRelationship.PARENT not in self.relationships:
return None
relation = self.relationships[NodeRelationship.PARENT]
if not isinstance(relation, RelatedNodeInfo):
raise ValueError("Parent object must be a single RelatedNodeInfo object")
return relation
@property
def child_nodes(self) -> Optional[List[RelatedNodeInfo]]:
"""Child nodes."""
if NodeRelationship.CHILD not in self.relationships:
return None
relation = self.relationships[NodeRelationship.CHILD]
if not isinstance(relation, list):
raise ValueError("Child objects must be a list of RelatedNodeInfo objects.")
return relation
@property
def ref_doc_id(self) -> Optional[str]:
"""Deprecated: Get ref doc id."""
source_node = self.source_node
if source_node is None:
return None
return source_node.node_id
@property
def extra_info(self) -> Dict[str, Any]:
"""TODO: DEPRECATED: Extra info."""
return self.metadata
def __str__(self) -> str:
source_text_truncated = truncate_text(
self.get_content().strip(), TRUNCATE_LENGTH
)
source_text_wrapped = textwrap.fill(
f"Text: {source_text_truncated}\n", width=WRAP_WIDTH
)
return f"Node ID: {self.node_id}\n{source_text_wrapped}"
def truncate_text(text: str, max_length: int) -> str:
"""Truncate text to a maximum length."""
if len(text) <= max_length:
return text
return text[: max_length - 3] + "..."
def get_embedding(self) -> List[float]:
"""Get embedding.
Errors if embedding is None.
"""
if self.embedding is None:
raise ValueError("embedding not set.")
return self.embedding
def as_related_node_info(self) -> RelatedNodeInfo:
"""Get node as RelatedNodeInfo."""
return RelatedNodeInfo(
node_id=self.node_id, metadata=self.metadata, hash=self.hash
)
class TextNode(BaseNode):
text: str = Field(default="", description="Text content of the node.")
start_char_idx: Optional[int] = Field(
default=None, description="Start char index of the node."
)
end_char_idx: Optional[int] = Field(
default=None, description="End char index of the node."
)
text_template: str = Field(
default=DEFAULT_TEXT_NODE_TMPL,
description=(
"Template for how text is formatted, with {content} and "
"{metadata_str} placeholders."
),
)
metadata_template: str = Field(
default=DEFAULT_METADATA_TMPL,
description=(
"Template for how metadata is formatted, with {key} and "
"{value} placeholders."
),
)
metadata_seperator: str = Field(
default="\n",
description="Seperator between metadata fields when converting to string.",
)
@classmethod
def class_name(cls) -> str:
"""Get class name."""
return "TextNode"
@root_validator
def _check_hash(cls, values: dict) -> dict:
"""Generate a hash to represent the node."""
text = values.get("text", "")
metadata = values.get("metadata", {})
doc_identity = str(text) + str(metadata)
values["hash"] = str(
sha256(doc_identity.encode("utf-8", "surrogatepass")).hexdigest()
)
return values
@classmethod
def get_type(cls) -> str:
"""Get Object type."""
return ObjectType.TEXT
def get_content(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str:
"""Get object content."""
metadata_str = self.get_metadata_str(mode=metadata_mode).strip()
if not metadata_str:
return self.text
return self.text_template.format(
content=self.text, metadata_str=metadata_str
).strip()
def get_metadata_str(self, mode: MetadataMode = MetadataMode.ALL) -> str:
"""metadata info string."""
if mode == MetadataMode.NONE:
return ""
usable_metadata_keys = set(self.metadata.keys())
if mode == MetadataMode.LLM:
for key in self.excluded_llm_metadata_keys:
if key in usable_metadata_keys:
usable_metadata_keys.remove(key)
elif mode == MetadataMode.EMBED:
for key in self.excluded_embed_metadata_keys:
if key in usable_metadata_keys:
usable_metadata_keys.remove(key)
return self.metadata_seperator.join(
[
self.metadata_template.format(key=key, value=str(value))
for key, value in self.metadata.items()
if key in usable_metadata_keys
]
)
def set_content(self, value: str) -> None:
"""Set the content of the node."""
self.text = value
def get_node_info(self) -> Dict[str, Any]:
"""Get node info."""
return {"start": self.start_char_idx, "end": self.end_char_idx}
def get_text(self) -> str:
return self.get_content(metadata_mode=MetadataMode.NONE)
@property
def node_info(self) -> Dict[str, Any]:
"""Deprecated: Get node info."""
return self.get_node_info()
# TODO: legacy backport of old Node class
Node = TextNode
class ImageNode(TextNode):
"""Node with image."""
# TODO: store reference instead of actual image
# base64 encoded image str
image: Optional[str] = None
@classmethod
def get_type(cls) -> str:
return ObjectType.IMAGE
@classmethod
def class_name(cls) -> str:
"""Get class name."""
return "ImageNode"
class IndexNode(TextNode):
"""Node with reference to any object.
This can include other indices, query engines, retrievers.
This can also include other nodes (though this is overlapping with `relationships`
on the Node class).
"""
index_id: str
@classmethod
def from_text_node(
cls,
node: TextNode,
index_id: str,
) -> "IndexNode":
"""Create index node from text node."""
# copy all attributes from text node, add index id
return cls(
**node.dict(),
index_id=index_id,
)
@classmethod
def get_type(cls) -> str:
return ObjectType.INDEX
@classmethod
def class_name(cls) -> str:
"""Get class name."""
return "IndexNode"
class NodeWithScore(BaseComponent):
node: BaseNode
score: Optional[float] = None
def __str__(self) -> str:
return f"{self.node}\nScore: {self.score: 0.3f}\n"
def get_score(self, raise_error: bool = False) -> float:
"""Get score."""
if self.score is None:
if raise_error:
raise ValueError("Score not set.")
else:
return 0.0
else:
return self.score
@classmethod
def class_name(cls) -> str:
"""Get class name."""
return "NodeWithScore"
##### pass through methods to BaseNode #####
@property
def node_id(self) -> str:
return self.node.node_id
@property
def id_(self) -> str:
return self.node.id_
@property
def text(self) -> str:
if isinstance(self.node, TextNode):
return self.node.text
else:
raise ValueError("Node must be a TextNode to get text.")
@property
def metadata(self) -> Dict[str, Any]:
return self.node.metadata
@property
def embedding(self) -> Optional[List[float]]:
return self.node.embedding
def get_text(self) -> str:
if isinstance(self.node, TextNode):
return self.node.get_text()
else:
raise ValueError("Node must be a TextNode to get text.")
def get_content(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str:
return self.node.get_content(metadata_mode=metadata_mode)
def get_embedding(self) -> List[float]:
return self.node.get_embedding()
# Document Classes for Readers
class Document(TextNode):
"""Generic interface for a data document.
This document connects to data sources.
"""
# TODO: A lot of backwards compatibility logic here, clean up
id_: str = Field(
default_factory=lambda: str(uuid.uuid4()),
description="Unique ID of the node.",
alias="doc_id",
)
_compat_fields = {"doc_id": "id_", "extra_info": "metadata"}
@classmethod
def get_type(cls) -> str:
"""Get Document type."""
return ObjectType.DOCUMENT
@property
def doc_id(self) -> str:
"""Get document ID."""
return self.id_
def __str__(self) -> str:
source_text_truncated = truncate_text(
self.get_content().strip(), TRUNCATE_LENGTH
)
source_text_wrapped = textwrap.fill(
f"Text: {source_text_truncated}\n", width=WRAP_WIDTH
)
return f"Doc ID: {self.doc_id}\n{source_text_wrapped}"
def get_doc_id(self) -> str:
"""TODO: Deprecated: Get document ID."""
return self.id_
def __setattr__(self, name: str, value: object) -> None:
if name in self._compat_fields:
name = self._compat_fields[name]
super().__setattr__(name, value)
def to_langchain_format(self) -> Document:
"""Convert struct to LangChain document format."""
metadata = self.metadata or {}
return Document(page_content=self.text, metadata=metadata)
@classmethod
def from_langchain_format(cls, doc: Document) -> "Document":
"""Convert struct from LangChain document format."""
return cls(text=doc.page_content, metadata=doc.metadata)
@classmethod
def example(cls) -> "Document":
document = Document(
text="",
metadata={"filename": "README.md", "category": "codebase"},
)
return document
@classmethod
def class_name(cls) -> str:
"""Get class name."""
return "Document"
class ImageDocument(Document):
"""Data document containing an image."""
# base64 encoded image str
image: Optional[str] = None
@classmethod
def class_name(cls) -> str:
"""Get class name."""
return "ImageDocument"

View File

@ -1,44 +0,0 @@
from abc import ABC, abstractmethod
from enum import Enum
class SearchMode(str, Enum):
"""Query mode enum for Knowledge Graphs.
Can be passed as the enum struct, or as the underlying string.
Attributes:
KEYWORD ("keyword"): Default query mode, using keywords to find triplets.
EMBEDDING ("embedding"): Embedding mode, using embeddings to find
similar triplets.
HYBRID ("hybrid"): Hyrbid mode, combining both keywords and embeddings
to find relevant triplets.
"""
KEYWORD = "keyword"
EMBEDDING = "embedding"
HYBRID = "hybrid"
class BaseSearch(ABC):
"""Base Search."""
async def search(self, query: str):
"""Retrieve nodes given query.
Args:
query (QueryType): Either a query string or
a QueryBundle object.
"""
# if isinstance(query, str):
return await self._search(query)
@abstractmethod
async def _search(self, query: str):
"""search nodes given query.
Implemented by the user.
"""
pass

View File

@ -1,40 +1,32 @@
import asyncio
import os
from typing import Dict, List
from pydantic import BaseModel, Field
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH
from dbgpt.core.awel import DAG, InputOperator, MapOperator, SimpleCallDataInputSource
from dbgpt.rag.chunk import Chunk
from dbgpt.core.awel import DAG, HttpTrigger, MapOperator
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
from dbgpt.rag.knowledge.base import KnowledgeType
from dbgpt.rag.operator.knowledge import KnowledgeOperator
from dbgpt.serve.rag.operators.embedding import EmbeddingAssemblerOperator
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
from dbgpt.storage.vector_store.connector import VectorStoreConnector
"""AWEL: Simple rag embedding operator example
pre-requirements:
set your file path in your example code.
Examples:
pre-requirements:
python examples/awel/simple_rag_embedding_example.py
..code-block:: shell
python examples/awel/simple_rag_embedding_example.py
curl --location --request POST 'http://127.0.0.1:5555/api/v1/awel/trigger/examples/rag/embedding' \
--header 'Content-Type: application/json' \
--data-raw '{
"url": "https://docs.dbgpt.site/docs/awel"
}'
"""
def _context_join_fn(context_dict: Dict, chunks: List[Chunk]) -> Dict:
"""context Join function for JoinOperator.
Args:
context_dict (Dict): context dict
chunks (List[Chunk]): chunks
Returns:
Dict: context dict
"""
context_dict["context"] = "\n".join([chunk.content for chunk in chunks])
return context_dict
def _create_vector_connector():
def _create_vector_connector() -> VectorStoreConnector:
"""Create vector connector."""
return VectorStoreConnector.from_default(
"Chroma",
@ -48,6 +40,22 @@ def _create_vector_connector():
)
class TriggerReqBody(BaseModel):
url: str = Field(..., description="url")
class RequestHandleOperator(MapOperator[TriggerReqBody, Dict]):
def __init__(self, **kwargs):
super().__init__(**kwargs)
async def map(self, input_value: TriggerReqBody) -> Dict:
params = {
"url": input_value.url,
}
print(f"Receive input value: {input_value}")
return params
class ResultOperator(MapOperator):
"""The Result Operator."""
@ -61,26 +69,31 @@ class ResultOperator(MapOperator):
with DAG("simple_sdk_rag_embedding_example") as dag:
knowledge_operator = KnowledgeOperator()
trigger = HttpTrigger(
"/examples/rag/embedding", methods="POST", request_body=TriggerReqBody
)
request_handle_task = RequestHandleOperator()
knowledge_operator = KnowledgeOperator(knowledge_type=KnowledgeType.URL)
vector_connector = _create_vector_connector()
input_task = InputOperator(input_source=SimpleCallDataInputSource())
file_path_parser = MapOperator(map_function=lambda x: x["file_path"])
url_parser_operator = MapOperator(map_function=lambda x: x["url"])
embedding_operator = EmbeddingAssemblerOperator(
vector_store_connector=vector_connector,
)
output_task = ResultOperator()
(
input_task
>> file_path_parser
trigger
>> request_handle_task
>> url_parser_operator
>> knowledge_operator
>> embedding_operator
>> output_task
)
if __name__ == "__main__":
input_data = {
"data": {
"file_path": "docs/docs/awel.md",
}
}
output = asyncio.run(output_task.call(call_data=input_data))
if dag.leaf_nodes[0].dev_mode:
# Development mode, you can run the dag locally for debugging.
from dbgpt.core.awel import setup_dev_environment
setup_dev_environment([dag], port=5555)
else:
pass