mirror of
				https://github.com/csunny/DB-GPT.git
				synced 2025-10-31 06:39:43 +00:00 
			
		
		
		
	Co-authored-by: hzh97 <2976151305@qq.com> Co-authored-by: Fangyin Cheng <staneyffer@gmail.com> Co-authored-by: licunxing <864255598@qq.com>
		
			
				
	
	
		
			220 lines
		
	
	
		
			6.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			220 lines
		
	
	
		
			6.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """Module for ChunkManager."""
 | |
| 
 | |
| from enum import Enum
 | |
| from typing import Any, List, Optional
 | |
| 
 | |
| from dbgpt._private.pydantic import BaseModel, Field
 | |
| from dbgpt.core import Chunk, Document
 | |
| from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
 | |
| from dbgpt.rag.extractor.base import Extractor
 | |
| from dbgpt.rag.knowledge.base import ChunkStrategy, Knowledge
 | |
| from dbgpt.rag.text_splitter import TextSplitter
 | |
| from dbgpt.util.i18n_utils import _
 | |
| 
 | |
| 
 | |
| class SplitterType(str, Enum):
 | |
|     """The type of splitter."""
 | |
| 
 | |
|     LANGCHAIN = "langchain"
 | |
|     LLAMA_INDEX = "llama-index"
 | |
|     USER_DEFINE = "user_define"
 | |
| 
 | |
| 
 | |
| @register_resource(
 | |
|     _("Chunk Parameters"),
 | |
|     "chunk_parameters",
 | |
|     category=ResourceCategory.RAG,
 | |
|     parameters=[
 | |
|         Parameter.build_from(
 | |
|             _("Chunk Strategy"),
 | |
|             "chunk_strategy",
 | |
|             str,
 | |
|             description=_("chunk strategy"),
 | |
|             optional=True,
 | |
|             default=None,
 | |
|         ),
 | |
|         Parameter.build_from(
 | |
|             _("Text Splitter"),
 | |
|             "text_splitter",
 | |
|             TextSplitter,
 | |
|             description=_(
 | |
|                 "Text splitter, if not set, will use the default text splitter."
 | |
|             ),
 | |
|             optional=True,
 | |
|             default=None,
 | |
|         ),
 | |
|         Parameter.build_from(
 | |
|             _("Splitter Type"),
 | |
|             "splitter_type",
 | |
|             str,
 | |
|             description=_("Splitter type"),
 | |
|             optional=True,
 | |
|             default=SplitterType.USER_DEFINE.value,
 | |
|         ),
 | |
|         Parameter.build_from(
 | |
|             _("Chunk Size"),
 | |
|             "chunk_size",
 | |
|             int,
 | |
|             description=_("Chunk size"),
 | |
|             optional=True,
 | |
|             default=512,
 | |
|         ),
 | |
|         Parameter.build_from(
 | |
|             _("Chunk Overlap"),
 | |
|             "chunk_overlap",
 | |
|             int,
 | |
|             description="Chunk overlap",
 | |
|             optional=True,
 | |
|             default=50,
 | |
|         ),
 | |
|         Parameter.build_from(
 | |
|             _("Separator"),
 | |
|             "separator",
 | |
|             str,
 | |
|             description=_("Chunk separator"),
 | |
|             optional=True,
 | |
|             default="\n",
 | |
|         ),
 | |
|         Parameter.build_from(
 | |
|             _("Enable Merge"),
 | |
|             "enable_merge",
 | |
|             bool,
 | |
|             description=_("Enable chunk merge by chunk_size."),
 | |
|             optional=True,
 | |
|             default=False,
 | |
|         ),
 | |
|     ],
 | |
| )
 | |
| class ChunkParameters(BaseModel):
 | |
|     """The parameters for chunking."""
 | |
| 
 | |
|     chunk_strategy: str = Field(
 | |
|         default=None,
 | |
|         description="chunk strategy",
 | |
|     )
 | |
|     text_splitter: Optional[Any] = Field(
 | |
|         default=None,
 | |
|         description="text splitter",
 | |
|     )
 | |
| 
 | |
|     splitter_type: SplitterType = Field(
 | |
|         default=SplitterType.USER_DEFINE,
 | |
|         description="splitter type",
 | |
|     )
 | |
| 
 | |
|     chunk_size: int = Field(
 | |
|         default=512,
 | |
|         description="chunk size",
 | |
|     )
 | |
|     chunk_overlap: int = Field(
 | |
|         default=50,
 | |
|         description="chunk overlap",
 | |
|     )
 | |
|     separator: str = Field(
 | |
|         default="\n",
 | |
|         description="chunk separator",
 | |
|     )
 | |
|     enable_merge: bool = Field(
 | |
|         default=None,
 | |
|         description="enable chunk merge by chunk_size.",
 | |
|     )
 | |
| 
 | |
| 
 | |
| class ChunkManager:
 | |
|     """Manager for chunks."""
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         knowledge: Knowledge,
 | |
|         chunk_parameter: Optional[ChunkParameters] = None,
 | |
|         extractor: Optional[Extractor] = None,
 | |
|     ):
 | |
|         """Create a new ChunkManager with the given knowledge.
 | |
| 
 | |
|         Args:
 | |
|             knowledge: (Knowledge) Knowledge datasource.
 | |
|             chunk_parameter: (Optional[ChunkParameter]) Chunk parameter.
 | |
|             extractor: (Optional[Extractor]) Extractor to use for summarization.
 | |
|         """
 | |
|         self._knowledge = knowledge
 | |
| 
 | |
|         self._extractor = extractor
 | |
|         self._chunk_parameters = chunk_parameter or ChunkParameters()
 | |
|         self._chunk_strategy = (
 | |
|             chunk_parameter.chunk_strategy
 | |
|             if chunk_parameter and chunk_parameter.chunk_strategy
 | |
|             else self._knowledge.default_chunk_strategy().name
 | |
|         )
 | |
|         self._text_splitter = self._chunk_parameters.text_splitter
 | |
|         self._splitter_type = self._chunk_parameters.splitter_type
 | |
| 
 | |
|     def split(self, documents: List[Document]) -> List[Chunk]:
 | |
|         """Split a document into chunks."""
 | |
|         text_splitter = self._select_text_splitter()
 | |
|         if SplitterType.LANGCHAIN == self._splitter_type:
 | |
|             documents = text_splitter.split_documents(documents)
 | |
|             return [Chunk.langchain2chunk(document) for document in documents]
 | |
|         elif SplitterType.LLAMA_INDEX == self._splitter_type:
 | |
|             nodes = text_splitter.split_documents(documents)
 | |
|             return [Chunk.llamaindex2chunk(node) for node in nodes]
 | |
|         else:
 | |
|             return text_splitter.split_documents(documents)
 | |
| 
 | |
|     def split_with_summary(
 | |
|         self, document: Any, chunk_strategy: ChunkStrategy
 | |
|     ) -> List[Chunk]:
 | |
|         """Split a document into chunks and summary."""
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     def extract(self, chunks: List[Chunk]) -> None:
 | |
|         """Extract metadata from chunks."""
 | |
|         if self._extractor:
 | |
|             self._extractor.extract(chunks)
 | |
| 
 | |
|     @property
 | |
|     def chunk_parameters(self) -> ChunkParameters:
 | |
|         """Get chunk parameters."""
 | |
|         return self._chunk_parameters
 | |
| 
 | |
|     def set_text_splitter(
 | |
|         self,
 | |
|         text_splitter: TextSplitter,
 | |
|         splitter_type: SplitterType = SplitterType.LANGCHAIN,
 | |
|     ) -> None:
 | |
|         """Add text splitter."""
 | |
|         self._text_splitter = text_splitter
 | |
|         self._splitter_type = splitter_type
 | |
| 
 | |
|     def get_text_splitter(
 | |
|         self,
 | |
|     ) -> TextSplitter:
 | |
|         """Return text splitter."""
 | |
|         return self._select_text_splitter()
 | |
| 
 | |
|     def _select_text_splitter(
 | |
|         self,
 | |
|     ) -> TextSplitter:
 | |
|         """Select text splitter by chunk strategy."""
 | |
|         if self._text_splitter:
 | |
|             return self._text_splitter
 | |
|         if not self._chunk_strategy or self._chunk_strategy == "Automatic":
 | |
|             self._chunk_strategy = self._knowledge.default_chunk_strategy().name
 | |
|         if self._chunk_strategy not in [
 | |
|             support_chunk_strategy.name
 | |
|             for support_chunk_strategy in self._knowledge.support_chunk_strategy()
 | |
|         ]:
 | |
|             current_type = self._knowledge.type().value
 | |
|             if self._knowledge.document_type():
 | |
|                 current_type = self._knowledge.document_type().value
 | |
|             raise ValueError(
 | |
|                 f"{current_type} knowledge not supported chunk strategy "
 | |
|                 f"{self._chunk_strategy} "
 | |
|             )
 | |
|         strategy = ChunkStrategy[self._chunk_strategy]
 | |
|         return strategy.match(
 | |
|             chunk_size=self._chunk_parameters.chunk_size,
 | |
|             chunk_overlap=self._chunk_parameters.chunk_overlap,
 | |
|             separator=self._chunk_parameters.separator,
 | |
|             enable_merge=self._chunk_parameters.enable_merge,
 | |
|         )
 |