mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-02 09:37:03 +00:00
refactor: RAG Refactor (#985)
Co-authored-by: Aralhi <xiaoping0501@gmail.com> Co-authored-by: csunny <cfqsunny@163.com>
This commit is contained in:
57
dbgpt/util/chat_util.py
Normal file
57
dbgpt/util/chat_util.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import asyncio
|
||||
from typing import Coroutine, List, Any
|
||||
|
||||
from dbgpt.app.scene import BaseChat, ChatFactory
|
||||
|
||||
chat_factory = ChatFactory()
|
||||
|
||||
|
||||
async def llm_chat_response_nostream(chat_scene: str, **chat_param):
|
||||
"""llm_chat_response_nostream"""
|
||||
chat: BaseChat = chat_factory.get_implementation(chat_scene, **chat_param)
|
||||
res = await chat.get_llm_response()
|
||||
return res
|
||||
|
||||
|
||||
async def llm_chat_response(chat_scene: str, **chat_param):
|
||||
chat: BaseChat = chat_factory.get_implementation(chat_scene, **chat_param)
|
||||
return chat.stream_call()
|
||||
|
||||
|
||||
async def run_async_tasks(
|
||||
tasks: List[Coroutine],
|
||||
concurrency_limit: int = None,
|
||||
) -> List[Any]:
|
||||
"""Run a list of async tasks."""
|
||||
tasks_to_execute: List[Any] = tasks
|
||||
|
||||
async def _gather() -> List[Any]:
|
||||
if concurrency_limit:
|
||||
semaphore = asyncio.Semaphore(concurrency_limit)
|
||||
|
||||
async def _execute_task(task):
|
||||
async with semaphore:
|
||||
return await task
|
||||
|
||||
# Execute tasks with semaphore limit
|
||||
return await asyncio.gather(
|
||||
*[_execute_task(task) for task in tasks_to_execute]
|
||||
)
|
||||
else:
|
||||
return await asyncio.gather(*tasks_to_execute)
|
||||
|
||||
# outputs: List[Any] = asyncio.run(_gather())
|
||||
return await _gather()
|
||||
|
||||
|
||||
def run_tasks(
|
||||
tasks: List[Coroutine],
|
||||
) -> List[Any]:
|
||||
"""Run a list of async tasks."""
|
||||
tasks_to_execute: List[Any] = tasks
|
||||
|
||||
async def _gather() -> List[Any]:
|
||||
return await asyncio.gather(*tasks_to_execute)
|
||||
|
||||
outputs: List[Any] = asyncio.run(_gather())
|
||||
return outputs
|
@@ -16,7 +16,7 @@ from dbgpt._private.pydantic import Field, PrivateAttr, BaseModel
|
||||
|
||||
from dbgpt.util.global_helper import globals_helper
|
||||
from dbgpt._private.llm_metadata import LLMMetadata
|
||||
from dbgpt.rag.embedding_engine.loader.token_splitter import TokenTextSplitter
|
||||
from dbgpt.rag.text_splitter.token_splitter import TokenTextSplitter
|
||||
|
||||
DEFAULT_PADDING = 5
|
||||
DEFAULT_CHUNK_OVERLAP_RATIO = 0.1
|
||||
@@ -93,6 +93,11 @@ class PromptHelper(BaseModel):
|
||||
separator=separator,
|
||||
)
|
||||
|
||||
def token_count(self, prompt_template: str) -> int:
|
||||
"""Get token count of prompt template."""
|
||||
empty_prompt_txt = get_empty_prompt_txt(prompt_template)
|
||||
return len(self._tokenizer(empty_prompt_txt))
|
||||
|
||||
@classmethod
|
||||
def from_llm_metadata(
|
||||
cls,
|
||||
|
81
dbgpt/util/splitter_utils.py
Normal file
81
dbgpt/util/splitter_utils.py
Normal file
@@ -0,0 +1,81 @@
|
||||
from typing import Callable, List
|
||||
|
||||
|
||||
def split_text_keep_separator(text: str, separator: str) -> List[str]:
|
||||
"""Split text with separator and keep the separator at the end of each split."""
|
||||
parts = text.split(separator)
|
||||
result = [separator + s if i > 0 else s for i, s in enumerate(parts)]
|
||||
return [s for s in result if s]
|
||||
|
||||
|
||||
def split_by_sep(sep: str, keep_sep: bool = True) -> Callable[[str], List[str]]:
|
||||
"""Split text by separator."""
|
||||
if keep_sep:
|
||||
return lambda text: split_text_keep_separator(text, sep)
|
||||
else:
|
||||
return lambda text: text.split(sep)
|
||||
|
||||
|
||||
def split_by_char() -> Callable[[str], List[str]]:
|
||||
"""Split text by character."""
|
||||
return lambda text: list(text)
|
||||
|
||||
|
||||
def split_by_sentence_tokenizer() -> Callable[[str], List[str]]:
|
||||
import os
|
||||
|
||||
import nltk
|
||||
|
||||
from llama_index.utils import get_cache_dir
|
||||
|
||||
cache_dir = get_cache_dir()
|
||||
nltk_data_dir = os.environ.get("NLTK_DATA", cache_dir)
|
||||
|
||||
# update nltk path for nltk so that it finds the data
|
||||
if nltk_data_dir not in nltk.data.path:
|
||||
nltk.data.path.append(nltk_data_dir)
|
||||
|
||||
try:
|
||||
nltk.data.find("tokenizers/punkt")
|
||||
except LookupError:
|
||||
nltk.download("punkt", download_dir=nltk_data_dir)
|
||||
|
||||
tokenizer = nltk.tokenize.PunktSentenceTokenizer()
|
||||
|
||||
# get the spans and then return the sentences
|
||||
# using the start index of each span
|
||||
# instead of using end, use the start of the next span if available
|
||||
def split(text: str) -> List[str]:
|
||||
spans = list(tokenizer.span_tokenize(text))
|
||||
sentences = []
|
||||
for i, span in enumerate(spans):
|
||||
start = span[0]
|
||||
if i < len(spans) - 1:
|
||||
end = spans[i + 1][0]
|
||||
else:
|
||||
end = len(text)
|
||||
sentences.append(text[start:end])
|
||||
|
||||
return sentences
|
||||
|
||||
return split
|
||||
|
||||
|
||||
def split_by_regex(regex: str) -> Callable[[str], List[str]]:
|
||||
"""Split text by regex."""
|
||||
import re
|
||||
|
||||
return lambda text: re.findall(regex, text)
|
||||
|
||||
|
||||
def split_by_phrase_regex() -> Callable[[str], List[str]]:
|
||||
"""Split text by phrase regex.
|
||||
|
||||
This regular expression will split the sentences into phrases,
|
||||
where each phrase is a sequence of one or more non-comma,
|
||||
non-period, and non-semicolon characters, followed by an optional comma,
|
||||
period, or semicolon. The regular expression will also capture the
|
||||
delimiters themselves as separate items in the list of phrases.
|
||||
"""
|
||||
regex = "[^,.;。]+[,.;。]?"
|
||||
return split_by_regex(regex)
|
Reference in New Issue
Block a user