mirror of
				https://github.com/csunny/DB-GPT.git
				synced 2025-10-30 22:28:48 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			571 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			571 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """Base schema for data structures."""
 | |
| import json
 | |
| import textwrap
 | |
| import uuid
 | |
| from abc import abstractmethod
 | |
| from enum import Enum, auto
 | |
| from hashlib import sha256
 | |
| from typing import Any, Dict, List, Optional, Union
 | |
| 
 | |
| from langchain.schema import Document
 | |
| from pydantic import BaseModel, Field, root_validator
 | |
| from typing_extensions import Self
 | |
| 
 | |
| 
 | |
| DEFAULT_TEXT_NODE_TMPL = "{metadata_str}\n\n{content}"
 | |
| DEFAULT_METADATA_TMPL = "{key}: {value}"
 | |
| # NOTE: for pretty printing
 | |
| TRUNCATE_LENGTH = 350
 | |
| WRAP_WIDTH = 70
 | |
| 
 | |
| 
 | |
| class BaseComponent(BaseModel):
 | |
|     """Base component object to caputure class names."""
 | |
| 
 | |
|     """reference llama-index"""
 | |
| 
 | |
|     @classmethod
 | |
|     @abstractmethod
 | |
|     def class_name(cls) -> str:
 | |
|         """Get class name."""
 | |
| 
 | |
|     def to_dict(self, **kwargs: Any) -> Dict[str, Any]:
 | |
|         data = self.dict(**kwargs)
 | |
|         data["class_name"] = self.class_name()
 | |
|         return data
 | |
| 
 | |
|     def to_json(self, **kwargs: Any) -> str:
 | |
|         data = self.to_dict(**kwargs)
 | |
|         return json.dumps(data)
 | |
| 
 | |
|     # TODO: return type here not supported by current mypy version
 | |
|     @classmethod
 | |
|     def from_dict(cls, data: Dict[str, Any], **kwargs: Any) -> Self:  # type: ignore
 | |
|         if isinstance(kwargs, dict):
 | |
|             data.update(kwargs)
 | |
| 
 | |
|         data.pop("class_name", None)
 | |
|         return cls(**data)
 | |
| 
 | |
|     @classmethod
 | |
|     def from_json(cls, data_str: str, **kwargs: Any) -> Self:  # type: ignore
 | |
|         data = json.loads(data_str)
 | |
|         return cls.from_dict(data, **kwargs)
 | |
| 
 | |
| 
 | |
| class NodeRelationship(str, Enum):
 | |
|     """Node relationships used in `BaseNode` class.
 | |
| 
 | |
|     Attributes:
 | |
|         SOURCE: The node is the source document.
 | |
|         PREVIOUS: The node is the previous node in the document.
 | |
|         NEXT: The node is the next node in the document.
 | |
|         PARENT: The node is the parent node in the document.
 | |
|         CHILD: The node is a child node in the document.
 | |
| 
 | |
|     """
 | |
| 
 | |
|     SOURCE = auto()
 | |
|     PREVIOUS = auto()
 | |
|     NEXT = auto()
 | |
|     PARENT = auto()
 | |
|     CHILD = auto()
 | |
| 
 | |
| 
 | |
| class ObjectType(str, Enum):
 | |
|     TEXT = auto()
 | |
|     IMAGE = auto()
 | |
|     INDEX = auto()
 | |
|     DOCUMENT = auto()
 | |
| 
 | |
| 
 | |
| class MetadataMode(str, Enum):
 | |
|     ALL = auto()
 | |
|     EMBED = auto()
 | |
|     LLM = auto()
 | |
|     NONE = auto()
 | |
| 
 | |
| 
 | |
| class RelatedNodeInfo(BaseComponent):
 | |
|     node_id: str
 | |
|     node_type: Optional[ObjectType] = None
 | |
|     metadata: Dict[str, Any] = Field(default_factory=dict)
 | |
|     hash: Optional[str] = None
 | |
| 
 | |
|     @classmethod
 | |
|     def class_name(cls) -> str:
 | |
|         """Get class name."""
 | |
|         return "RelatedNodeInfo"
 | |
| 
 | |
| 
 | |
| RelatedNodeType = Union[RelatedNodeInfo, List[RelatedNodeInfo]]
 | |
| 
 | |
| 
 | |
| # Node classes for indexes
 | |
| class BaseNode(BaseComponent):
 | |
|     """Base node Object.
 | |
| 
 | |
|     Generic abstract interface for retrievable nodes
 | |
| 
 | |
|     """
 | |
| 
 | |
|     class Config:
 | |
|         allow_population_by_field_name = True
 | |
| 
 | |
|     id_: str = Field(
 | |
|         default_factory=lambda: str(uuid.uuid4()), description="Unique ID of the node."
 | |
|     )
 | |
|     embedding: Optional[List[float]] = Field(
 | |
|         default=None, description="Embedding of the node."
 | |
|     )
 | |
| 
 | |
|     """"
 | |
|     metadata fields
 | |
|     - injected as part of the text shown to LLMs as context
 | |
|     - injected as part of the text for generating embeddings
 | |
|     - used by vector DBs for metadata filtering
 | |
| 
 | |
|     """
 | |
|     metadata: Dict[str, Any] = Field(
 | |
|         default_factory=dict,
 | |
|         description="A flat dictionary of metadata fields",
 | |
|         alias="extra_info",
 | |
|     )
 | |
|     excluded_embed_metadata_keys: List[str] = Field(
 | |
|         default_factory=list,
 | |
|         description="Metadata keys that are exluded from text for the embed model.",
 | |
|     )
 | |
|     excluded_llm_metadata_keys: List[str] = Field(
 | |
|         default_factory=list,
 | |
|         description="Metadata keys that are exluded from text for the LLM.",
 | |
|     )
 | |
|     relationships: Dict[NodeRelationship, RelatedNodeType] = Field(
 | |
|         default_factory=dict,
 | |
|         description="A mapping of relationships to other node information.",
 | |
|     )
 | |
|     hash: str = Field(default="", description="Hash of the node content.")
 | |
| 
 | |
|     @classmethod
 | |
|     @abstractmethod
 | |
|     def get_type(cls) -> str:
 | |
|         """Get Object type."""
 | |
| 
 | |
|     @abstractmethod
 | |
|     def get_content(self, metadata_mode: MetadataMode = MetadataMode.ALL) -> str:
 | |
|         """Get object content."""
 | |
| 
 | |
|     @abstractmethod
 | |
|     def get_metadata_str(self, mode: MetadataMode = MetadataMode.ALL) -> str:
 | |
|         """Metadata string."""
 | |
| 
 | |
|     @abstractmethod
 | |
|     def set_content(self, value: Any) -> None:
 | |
|         """Set the content of the node."""
 | |
| 
 | |
|     @property
 | |
|     def node_id(self) -> str:
 | |
|         return self.id_
 | |
| 
 | |
|     @node_id.setter
 | |
|     def node_id(self, value: str) -> None:
 | |
|         self.id_ = value
 | |
| 
 | |
|     @property
 | |
|     def source_node(self) -> Optional[RelatedNodeInfo]:
 | |
|         """Source object node.
 | |
| 
 | |
|         Extracted from the relationships field.
 | |
| 
 | |
|         """
 | |
|         if NodeRelationship.SOURCE not in self.relationships:
 | |
|             return None
 | |
| 
 | |
|         relation = self.relationships[NodeRelationship.SOURCE]
 | |
|         if isinstance(relation, list):
 | |
|             raise ValueError("Source object must be a single RelatedNodeInfo object")
 | |
|         return relation
 | |
| 
 | |
|     @property
 | |
|     def prev_node(self) -> Optional[RelatedNodeInfo]:
 | |
|         """Prev node."""
 | |
|         if NodeRelationship.PREVIOUS not in self.relationships:
 | |
|             return None
 | |
| 
 | |
|         relation = self.relationships[NodeRelationship.PREVIOUS]
 | |
|         if not isinstance(relation, RelatedNodeInfo):
 | |
|             raise ValueError("Previous object must be a single RelatedNodeInfo object")
 | |
|         return relation
 | |
| 
 | |
|     @property
 | |
|     def next_node(self) -> Optional[RelatedNodeInfo]:
 | |
|         """Next node."""
 | |
|         if NodeRelationship.NEXT not in self.relationships:
 | |
|             return None
 | |
| 
 | |
|         relation = self.relationships[NodeRelationship.NEXT]
 | |
|         if not isinstance(relation, RelatedNodeInfo):
 | |
|             raise ValueError("Next object must be a single RelatedNodeInfo object")
 | |
|         return relation
 | |
| 
 | |
|     @property
 | |
|     def parent_node(self) -> Optional[RelatedNodeInfo]:
 | |
|         """Parent node."""
 | |
|         if NodeRelationship.PARENT not in self.relationships:
 | |
|             return None
 | |
| 
 | |
|         relation = self.relationships[NodeRelationship.PARENT]
 | |
|         if not isinstance(relation, RelatedNodeInfo):
 | |
|             raise ValueError("Parent object must be a single RelatedNodeInfo object")
 | |
|         return relation
 | |
| 
 | |
|     @property
 | |
|     def child_nodes(self) -> Optional[List[RelatedNodeInfo]]:
 | |
|         """Child nodes."""
 | |
|         if NodeRelationship.CHILD not in self.relationships:
 | |
|             return None
 | |
| 
 | |
|         relation = self.relationships[NodeRelationship.CHILD]
 | |
|         if not isinstance(relation, list):
 | |
|             raise ValueError("Child objects must be a list of RelatedNodeInfo objects.")
 | |
|         return relation
 | |
| 
 | |
|     @property
 | |
|     def ref_doc_id(self) -> Optional[str]:
 | |
|         """Deprecated: Get ref doc id."""
 | |
|         source_node = self.source_node
 | |
|         if source_node is None:
 | |
|             return None
 | |
|         return source_node.node_id
 | |
| 
 | |
|     @property
 | |
|     def extra_info(self) -> Dict[str, Any]:
 | |
|         """TODO: DEPRECATED: Extra info."""
 | |
|         return self.metadata
 | |
| 
 | |
|     def __str__(self) -> str:
 | |
|         source_text_truncated = truncate_text(
 | |
|             self.get_content().strip(), TRUNCATE_LENGTH
 | |
|         )
 | |
|         source_text_wrapped = textwrap.fill(
 | |
|             f"Text: {source_text_truncated}\n", width=WRAP_WIDTH
 | |
|         )
 | |
|         return f"Node ID: {self.node_id}\n{source_text_wrapped}"
 | |
| 
 | |
|     def truncate_text(text: str, max_length: int) -> str:
 | |
|         """Truncate text to a maximum length."""
 | |
|         if len(text) <= max_length:
 | |
|             return text
 | |
|         return text[: max_length - 3] + "..."
 | |
| 
 | |
|     def get_embedding(self) -> List[float]:
 | |
|         """Get embedding.
 | |
| 
 | |
|         Errors if embedding is None.
 | |
| 
 | |
|         """
 | |
|         if self.embedding is None:
 | |
|             raise ValueError("embedding not set.")
 | |
|         return self.embedding
 | |
| 
 | |
|     def as_related_node_info(self) -> RelatedNodeInfo:
 | |
|         """Get node as RelatedNodeInfo."""
 | |
|         return RelatedNodeInfo(
 | |
|             node_id=self.node_id, metadata=self.metadata, hash=self.hash
 | |
|         )
 | |
| 
 | |
| 
 | |
| class TextNode(BaseNode):
 | |
|     text: str = Field(default="", description="Text content of the node.")
 | |
|     start_char_idx: Optional[int] = Field(
 | |
|         default=None, description="Start char index of the node."
 | |
|     )
 | |
|     end_char_idx: Optional[int] = Field(
 | |
|         default=None, description="End char index of the node."
 | |
|     )
 | |
|     text_template: str = Field(
 | |
|         default=DEFAULT_TEXT_NODE_TMPL,
 | |
|         description=(
 | |
|             "Template for how text is formatted, with {content} and "
 | |
|             "{metadata_str} placeholders."
 | |
|         ),
 | |
|     )
 | |
|     metadata_template: str = Field(
 | |
|         default=DEFAULT_METADATA_TMPL,
 | |
|         description=(
 | |
|             "Template for how metadata is formatted, with {key} and "
 | |
|             "{value} placeholders."
 | |
|         ),
 | |
|     )
 | |
|     metadata_seperator: str = Field(
 | |
|         default="\n",
 | |
|         description="Seperator between metadata fields when converting to string.",
 | |
|     )
 | |
| 
 | |
|     @classmethod
 | |
|     def class_name(cls) -> str:
 | |
|         """Get class name."""
 | |
|         return "TextNode"
 | |
| 
 | |
|     @root_validator
 | |
|     def _check_hash(cls, values: dict) -> dict:
 | |
|         """Generate a hash to represent the node."""
 | |
|         text = values.get("text", "")
 | |
|         metadata = values.get("metadata", {})
 | |
|         doc_identity = str(text) + str(metadata)
 | |
|         values["hash"] = str(
 | |
|             sha256(doc_identity.encode("utf-8", "surrogatepass")).hexdigest()
 | |
|         )
 | |
|         return values
 | |
| 
 | |
|     @classmethod
 | |
|     def get_type(cls) -> str:
 | |
|         """Get Object type."""
 | |
|         return ObjectType.TEXT
 | |
| 
 | |
|     def get_content(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str:
 | |
|         """Get object content."""
 | |
|         metadata_str = self.get_metadata_str(mode=metadata_mode).strip()
 | |
|         if not metadata_str:
 | |
|             return self.text
 | |
| 
 | |
|         return self.text_template.format(
 | |
|             content=self.text, metadata_str=metadata_str
 | |
|         ).strip()
 | |
| 
 | |
|     def get_metadata_str(self, mode: MetadataMode = MetadataMode.ALL) -> str:
 | |
|         """metadata info string."""
 | |
|         if mode == MetadataMode.NONE:
 | |
|             return ""
 | |
| 
 | |
|         usable_metadata_keys = set(self.metadata.keys())
 | |
|         if mode == MetadataMode.LLM:
 | |
|             for key in self.excluded_llm_metadata_keys:
 | |
|                 if key in usable_metadata_keys:
 | |
|                     usable_metadata_keys.remove(key)
 | |
|         elif mode == MetadataMode.EMBED:
 | |
|             for key in self.excluded_embed_metadata_keys:
 | |
|                 if key in usable_metadata_keys:
 | |
|                     usable_metadata_keys.remove(key)
 | |
| 
 | |
|         return self.metadata_seperator.join(
 | |
|             [
 | |
|                 self.metadata_template.format(key=key, value=str(value))
 | |
|                 for key, value in self.metadata.items()
 | |
|                 if key in usable_metadata_keys
 | |
|             ]
 | |
|         )
 | |
| 
 | |
|     def set_content(self, value: str) -> None:
 | |
|         """Set the content of the node."""
 | |
|         self.text = value
 | |
| 
 | |
|     def get_node_info(self) -> Dict[str, Any]:
 | |
|         """Get node info."""
 | |
|         return {"start": self.start_char_idx, "end": self.end_char_idx}
 | |
| 
 | |
|     def get_text(self) -> str:
 | |
|         return self.get_content(metadata_mode=MetadataMode.NONE)
 | |
| 
 | |
|     @property
 | |
|     def node_info(self) -> Dict[str, Any]:
 | |
|         """Deprecated: Get node info."""
 | |
|         return self.get_node_info()
 | |
| 
 | |
| 
 | |
| # TODO: legacy backport of old Node class
 | |
| Node = TextNode
 | |
| 
 | |
| 
 | |
| class ImageNode(TextNode):
 | |
|     """Node with image."""
 | |
| 
 | |
|     # TODO: store reference instead of actual image
 | |
|     # base64 encoded image str
 | |
|     image: Optional[str] = None
 | |
| 
 | |
|     @classmethod
 | |
|     def get_type(cls) -> str:
 | |
|         return ObjectType.IMAGE
 | |
| 
 | |
|     @classmethod
 | |
|     def class_name(cls) -> str:
 | |
|         """Get class name."""
 | |
|         return "ImageNode"
 | |
| 
 | |
| 
 | |
| class IndexNode(TextNode):
 | |
|     """Node with reference to any object.
 | |
| 
 | |
|     This can include other indices, query engines, retrievers.
 | |
| 
 | |
|     This can also include other nodes (though this is overlapping with `relationships`
 | |
|     on the Node class).
 | |
| 
 | |
|     """
 | |
| 
 | |
|     index_id: str
 | |
| 
 | |
|     @classmethod
 | |
|     def from_text_node(
 | |
|         cls,
 | |
|         node: TextNode,
 | |
|         index_id: str,
 | |
|     ) -> "IndexNode":
 | |
|         """Create index node from text node."""
 | |
|         # copy all attributes from text node, add index id
 | |
|         return cls(
 | |
|             **node.dict(),
 | |
|             index_id=index_id,
 | |
|         )
 | |
| 
 | |
|     @classmethod
 | |
|     def get_type(cls) -> str:
 | |
|         return ObjectType.INDEX
 | |
| 
 | |
|     @classmethod
 | |
|     def class_name(cls) -> str:
 | |
|         """Get class name."""
 | |
|         return "IndexNode"
 | |
| 
 | |
| 
 | |
| class NodeWithScore(BaseComponent):
 | |
|     node: BaseNode
 | |
|     score: Optional[float] = None
 | |
| 
 | |
|     def __str__(self) -> str:
 | |
|         return f"{self.node}\nScore: {self.score: 0.3f}\n"
 | |
| 
 | |
|     def get_score(self, raise_error: bool = False) -> float:
 | |
|         """Get score."""
 | |
|         if self.score is None:
 | |
|             if raise_error:
 | |
|                 raise ValueError("Score not set.")
 | |
|             else:
 | |
|                 return 0.0
 | |
|         else:
 | |
|             return self.score
 | |
| 
 | |
|     @classmethod
 | |
|     def class_name(cls) -> str:
 | |
|         """Get class name."""
 | |
|         return "NodeWithScore"
 | |
| 
 | |
|     ##### pass through methods to BaseNode #####
 | |
|     @property
 | |
|     def node_id(self) -> str:
 | |
|         return self.node.node_id
 | |
| 
 | |
|     @property
 | |
|     def id_(self) -> str:
 | |
|         return self.node.id_
 | |
| 
 | |
|     @property
 | |
|     def text(self) -> str:
 | |
|         if isinstance(self.node, TextNode):
 | |
|             return self.node.text
 | |
|         else:
 | |
|             raise ValueError("Node must be a TextNode to get text.")
 | |
| 
 | |
|     @property
 | |
|     def metadata(self) -> Dict[str, Any]:
 | |
|         return self.node.metadata
 | |
| 
 | |
|     @property
 | |
|     def embedding(self) -> Optional[List[float]]:
 | |
|         return self.node.embedding
 | |
| 
 | |
|     def get_text(self) -> str:
 | |
|         if isinstance(self.node, TextNode):
 | |
|             return self.node.get_text()
 | |
|         else:
 | |
|             raise ValueError("Node must be a TextNode to get text.")
 | |
| 
 | |
|     def get_content(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str:
 | |
|         return self.node.get_content(metadata_mode=metadata_mode)
 | |
| 
 | |
|     def get_embedding(self) -> List[float]:
 | |
|         return self.node.get_embedding()
 | |
| 
 | |
| 
 | |
| # Document Classes for Readers
 | |
| 
 | |
| 
 | |
| class Document(TextNode):
 | |
|     """Generic interface for a data document.
 | |
| 
 | |
|     This document connects to data sources.
 | |
| 
 | |
|     """
 | |
| 
 | |
|     # TODO: A lot of backwards compatibility logic here, clean up
 | |
|     id_: str = Field(
 | |
|         default_factory=lambda: str(uuid.uuid4()),
 | |
|         description="Unique ID of the node.",
 | |
|         alias="doc_id",
 | |
|     )
 | |
| 
 | |
|     _compat_fields = {"doc_id": "id_", "extra_info": "metadata"}
 | |
| 
 | |
|     @classmethod
 | |
|     def get_type(cls) -> str:
 | |
|         """Get Document type."""
 | |
|         return ObjectType.DOCUMENT
 | |
| 
 | |
|     @property
 | |
|     def doc_id(self) -> str:
 | |
|         """Get document ID."""
 | |
|         return self.id_
 | |
| 
 | |
|     def __str__(self) -> str:
 | |
|         source_text_truncated = truncate_text(
 | |
|             self.get_content().strip(), TRUNCATE_LENGTH
 | |
|         )
 | |
|         source_text_wrapped = textwrap.fill(
 | |
|             f"Text: {source_text_truncated}\n", width=WRAP_WIDTH
 | |
|         )
 | |
|         return f"Doc ID: {self.doc_id}\n{source_text_wrapped}"
 | |
| 
 | |
|     def get_doc_id(self) -> str:
 | |
|         """TODO: Deprecated: Get document ID."""
 | |
|         return self.id_
 | |
| 
 | |
|     def __setattr__(self, name: str, value: object) -> None:
 | |
|         if name in self._compat_fields:
 | |
|             name = self._compat_fields[name]
 | |
|         super().__setattr__(name, value)
 | |
| 
 | |
|     def to_langchain_format(self) -> Document:
 | |
|         """Convert struct to LangChain document format."""
 | |
|         metadata = self.metadata or {}
 | |
|         return Document(page_content=self.text, metadata=metadata)
 | |
| 
 | |
|     @classmethod
 | |
|     def from_langchain_format(cls, doc: Document) -> "Document":
 | |
|         """Convert struct from LangChain document format."""
 | |
|         return cls(text=doc.page_content, metadata=doc.metadata)
 | |
| 
 | |
|     @classmethod
 | |
|     def example(cls) -> "Document":
 | |
|         document = Document(
 | |
|             text="",
 | |
|             metadata={"filename": "README.md", "category": "codebase"},
 | |
|         )
 | |
|         return document
 | |
| 
 | |
|     @classmethod
 | |
|     def class_name(cls) -> str:
 | |
|         """Get class name."""
 | |
|         return "Document"
 | |
| 
 | |
| 
 | |
| class ImageDocument(Document):
 | |
|     """Data document containing an image."""
 | |
| 
 | |
|     # base64 encoded image str
 | |
|     image: Optional[str] = None
 | |
| 
 | |
|     @classmethod
 | |
|     def class_name(cls) -> str:
 | |
|         """Get class name."""
 | |
|         return "ImageDocument"
 |