mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-23 20:26:15 +00:00
feat(core): Support RAG chat flow (#1185)
This commit is contained in:
parent
21682575f5
commit
e0986198a6
@ -3,6 +3,7 @@
|
|||||||
This runner will run the workflow in the current process.
|
This runner will run the workflow in the current process.
|
||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
|
import traceback
|
||||||
from typing import Any, Dict, List, Optional, Set, cast
|
from typing import Any, Dict, List, Optional, Set, cast
|
||||||
|
|
||||||
from dbgpt.component import SystemApp
|
from dbgpt.component import SystemApp
|
||||||
@ -143,7 +144,11 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
|||||||
)
|
)
|
||||||
_skip_current_downstream_by_node_name(node, skip_nodes, skip_node_ids)
|
_skip_current_downstream_by_node_name(node, skip_nodes, skip_node_ids)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.info(f"Run operator {node.node_id} error, error message: {str(e)}")
|
msg = traceback.format_exc()
|
||||||
|
logger.info(
|
||||||
|
f"Run operator {type(node)}({node.node_id}) error, error message: "
|
||||||
|
f"{msg}"
|
||||||
|
)
|
||||||
task_ctx.set_current_state(TaskState.FAILED)
|
task_ctx.set_current_state(TaskState.FAILED)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
@ -370,22 +370,16 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
|||||||
return FlowCategory.COMMON
|
return FlowCategory.COMMON
|
||||||
|
|
||||||
|
|
||||||
def _is_chat_flow_type(obj: Any, is_class: bool = False) -> bool:
|
def _is_chat_flow_type(output_obj: Any, is_class: bool = False) -> bool:
|
||||||
try:
|
|
||||||
from dbgpt.model.utils.chatgpt_utils import OpenAIStreamingOutputOperator
|
|
||||||
except ImportError:
|
|
||||||
OpenAIStreamingOutputOperator = None
|
|
||||||
if is_class:
|
if is_class:
|
||||||
return (
|
return (
|
||||||
obj == str
|
output_obj == str
|
||||||
or obj == CommonLLMHttpResponseBody
|
or output_obj == CommonLLMHttpResponseBody
|
||||||
or (OpenAIStreamingOutputOperator and obj == OpenAIStreamingOutputOperator)
|
or output_obj == ModelOutput
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
chat_types = (str, CommonLLMHttpResponseBody)
|
chat_types = (str, CommonLLMHttpResponseBody)
|
||||||
if OpenAIStreamingOutputOperator:
|
return isinstance(output_obj, chat_types)
|
||||||
chat_types += (OpenAIStreamingOutputOperator,)
|
|
||||||
return isinstance(obj, chat_types)
|
|
||||||
|
|
||||||
|
|
||||||
async def _chat_with_dag_task(
|
async def _chat_with_dag_task(
|
||||||
@ -439,29 +433,50 @@ async def _chat_with_dag_task(
|
|||||||
yield f"data:{full_text}\n\n"
|
yield f"data:{full_text}\n\n"
|
||||||
else:
|
else:
|
||||||
async for output in await task.call_stream(request):
|
async for output in await task.call_stream(request):
|
||||||
|
str_msg = ""
|
||||||
|
should_return = False
|
||||||
if isinstance(output, str):
|
if isinstance(output, str):
|
||||||
if output.strip():
|
if output.strip():
|
||||||
yield output
|
str_msg = output
|
||||||
|
elif isinstance(output, ModelOutput):
|
||||||
|
if output.error_code != 0:
|
||||||
|
str_msg = f"[SERVER_ERROR]{output.text}"
|
||||||
|
should_return = True
|
||||||
|
else:
|
||||||
|
str_msg = output.text
|
||||||
else:
|
else:
|
||||||
yield "data:[SERVER_ERROR]The output is not a stream format\n\n"
|
str_msg = (
|
||||||
return
|
f"[SERVER_ERROR]The output is not a valid format"
|
||||||
|
f"({type(output)})"
|
||||||
|
)
|
||||||
|
should_return = True
|
||||||
|
if str_msg:
|
||||||
|
str_msg = str_msg.replace("\n", "\\n")
|
||||||
|
yield f"data:{str_msg}\n\n"
|
||||||
|
if should_return:
|
||||||
|
return
|
||||||
else:
|
else:
|
||||||
result = await task.call(request)
|
result = await task.call(request)
|
||||||
|
str_msg = ""
|
||||||
if result is None:
|
if result is None:
|
||||||
yield "data:[SERVER_ERROR]The result is None\n\n"
|
str_msg = "[SERVER_ERROR]The result is None!"
|
||||||
elif isinstance(result, str):
|
elif isinstance(result, str):
|
||||||
yield f"data:{result}\n\n"
|
str_msg = result
|
||||||
elif isinstance(result, ModelOutput):
|
elif isinstance(result, ModelOutput):
|
||||||
if result.error_code != 0:
|
if result.error_code != 0:
|
||||||
yield f"data:[SERVER_ERROR]{result.text}\n\n"
|
str_msg = f"[SERVER_ERROR]{result.text}"
|
||||||
else:
|
else:
|
||||||
yield f"data:{result.text}\n\n"
|
str_msg = result.text
|
||||||
elif isinstance(result, CommonLLMHttpResponseBody):
|
elif isinstance(result, CommonLLMHttpResponseBody):
|
||||||
if result.error_code != 0:
|
if result.error_code != 0:
|
||||||
yield f"data:[SERVER_ERROR]{result.text}\n\n"
|
str_msg = f"[SERVER_ERROR]{result.text}"
|
||||||
else:
|
else:
|
||||||
yield f"data:{result.text}\n\n"
|
str_msg = result.text
|
||||||
elif isinstance(result, dict):
|
elif isinstance(result, dict):
|
||||||
yield f"data:{json.dumps(result, ensure_ascii=False)}\n\n"
|
str_msg = json.dumps(result, ensure_ascii=False)
|
||||||
else:
|
else:
|
||||||
yield f"data:[SERVER_ERROR]The result is not a valid format({type(result)})\n\n"
|
str_msg = f"[SERVER_ERROR]The result is not a valid format({type(result)})"
|
||||||
|
|
||||||
|
if str_msg:
|
||||||
|
str_msg = str_msg.replace("\n", "\\n")
|
||||||
|
yield f"data:{str_msg}\n\n"
|
||||||
|
@ -339,9 +339,7 @@ class MilvusStore(VectorStoreBase):
|
|||||||
self.vector_field = x.name
|
self.vector_field = x.name
|
||||||
_, docs_and_scores = self._search(text, topk)
|
_, docs_and_scores = self._search(text, topk)
|
||||||
if any(score < 0.0 or score > 1.0 for _, score, id in docs_and_scores):
|
if any(score < 0.0 or score > 1.0 for _, score, id in docs_and_scores):
|
||||||
import warnings
|
logger.warning(
|
||||||
|
|
||||||
warnings.warn(
|
|
||||||
"similarity score need between" f" 0 and 1, got {docs_and_scores}"
|
"similarity score need between" f" 0 and 1, got {docs_and_scores}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -357,7 +355,7 @@ class MilvusStore(VectorStoreBase):
|
|||||||
if score >= score_threshold
|
if score >= score_threshold
|
||||||
]
|
]
|
||||||
if len(docs_and_scores) == 0:
|
if len(docs_and_scores) == 0:
|
||||||
warnings.warn(
|
logger.warning(
|
||||||
"No relevant docs were retrieved using the relevance score"
|
"No relevant docs were retrieved using the relevance score"
|
||||||
f" threshold {score_threshold}"
|
f" threshold {score_threshold}"
|
||||||
)
|
)
|
||||||
|
@ -56,17 +56,25 @@ def list_repos():
|
|||||||
|
|
||||||
@click.command(name="add")
|
@click.command(name="add")
|
||||||
@add_tap_options
|
@add_tap_options
|
||||||
|
@click.option(
|
||||||
|
"-b",
|
||||||
|
"--branch",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
required=False,
|
||||||
|
help="The branch of the repository(Just for git repo)",
|
||||||
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--url",
|
"--url",
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help="The URL of the repo",
|
help="The URL of the repo",
|
||||||
)
|
)
|
||||||
def add_repo(repo: str, url: str):
|
def add_repo(repo: str, branch: str | None, url: str):
|
||||||
"""Add a new repo"""
|
"""Add a new repo"""
|
||||||
from .repo import add_repo
|
from .repo import add_repo
|
||||||
|
|
||||||
add_repo(repo, url)
|
add_repo(repo, url, branch)
|
||||||
|
|
||||||
|
|
||||||
@click.command(name="remove")
|
@click.command(name="remove")
|
||||||
|
@ -63,12 +63,13 @@ def _list_repos_details() -> List[Tuple[str, str]]:
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
def add_repo(repo: str, repo_url: str):
|
def add_repo(repo: str, repo_url: str, branch: str | None = None):
|
||||||
"""Add a new repo
|
"""Add a new repo
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
repo (str): The name of the repo
|
repo (str): The name of the repo
|
||||||
repo_url (str): The URL of the repo
|
repo_url (str): The URL of the repo
|
||||||
|
branch (str): The branch of the repo
|
||||||
"""
|
"""
|
||||||
exist_repos = list_repos()
|
exist_repos = list_repos()
|
||||||
if repo in exist_repos and repo_url not in DEFAULT_REPO_MAP.values():
|
if repo in exist_repos and repo_url not in DEFAULT_REPO_MAP.values():
|
||||||
@ -84,7 +85,7 @@ def add_repo(repo: str, repo_url: str):
|
|||||||
repo_group_dir = os.path.join(DBGPTS_REPO_HOME, repo_arr[0])
|
repo_group_dir = os.path.join(DBGPTS_REPO_HOME, repo_arr[0])
|
||||||
os.makedirs(repo_group_dir, exist_ok=True)
|
os.makedirs(repo_group_dir, exist_ok=True)
|
||||||
if repo_url.startswith("http") or repo_url.startswith("git"):
|
if repo_url.startswith("http") or repo_url.startswith("git"):
|
||||||
clone_repo(repo, repo_group_dir, repo_name, repo_url)
|
clone_repo(repo, repo_group_dir, repo_name, repo_url, branch)
|
||||||
elif os.path.isdir(repo_url):
|
elif os.path.isdir(repo_url):
|
||||||
# Create soft link
|
# Create soft link
|
||||||
os.symlink(repo_url, os.path.join(repo_group_dir, repo_name))
|
os.symlink(repo_url, os.path.join(repo_group_dir, repo_name))
|
||||||
@ -106,7 +107,13 @@ def remove_repo(repo: str):
|
|||||||
logger.info(f"Repo '{repo}' removed successfully.")
|
logger.info(f"Repo '{repo}' removed successfully.")
|
||||||
|
|
||||||
|
|
||||||
def clone_repo(repo: str, repo_group_dir: str, repo_name: str, repo_url: str):
|
def clone_repo(
|
||||||
|
repo: str,
|
||||||
|
repo_group_dir: str,
|
||||||
|
repo_name: str,
|
||||||
|
repo_url: str,
|
||||||
|
branch: str | None = None,
|
||||||
|
):
|
||||||
"""Clone the specified repo
|
"""Clone the specified repo
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -114,10 +121,22 @@ def clone_repo(repo: str, repo_group_dir: str, repo_name: str, repo_url: str):
|
|||||||
repo_group_dir (str): The directory of the repo group
|
repo_group_dir (str): The directory of the repo group
|
||||||
repo_name (str): The name of the repo
|
repo_name (str): The name of the repo
|
||||||
repo_url (str): The URL of the repo
|
repo_url (str): The URL of the repo
|
||||||
|
branch (str): The branch of the repo
|
||||||
"""
|
"""
|
||||||
os.chdir(repo_group_dir)
|
os.chdir(repo_group_dir)
|
||||||
subprocess.run(["git", "clone", repo_url, repo_name], check=True)
|
clone_command = ["git", "clone", repo_url, repo_name]
|
||||||
logger.info(f"Repo '{repo}' cloned from {repo_url} successfully.")
|
|
||||||
|
# If the branch is specified, add it to the clone command
|
||||||
|
if branch:
|
||||||
|
clone_command += ["-b", branch]
|
||||||
|
|
||||||
|
subprocess.run(clone_command, check=True)
|
||||||
|
if branch:
|
||||||
|
click.echo(
|
||||||
|
f"Repo '{repo}' cloned from {repo_url} with branch '{branch}' successfully."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
click.echo(f"Repo '{repo}' cloned from {repo_url} successfully.")
|
||||||
|
|
||||||
|
|
||||||
def update_repo(repo: str):
|
def update_repo(repo: str):
|
||||||
@ -217,7 +236,7 @@ def _write_install_metadata(name: str, repo: str, install_path: Path):
|
|||||||
|
|
||||||
def check_with_retry(
|
def check_with_retry(
|
||||||
name: str,
|
name: str,
|
||||||
repo: str | None = None,
|
spec_repo: str | None = None,
|
||||||
with_update: bool = False,
|
with_update: bool = False,
|
||||||
is_first: bool = False,
|
is_first: bool = False,
|
||||||
) -> Tuple[str, Path] | None:
|
) -> Tuple[str, Path] | None:
|
||||||
@ -225,18 +244,17 @@ def check_with_retry(
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
name (str): The name of the dbgpt
|
name (str): The name of the dbgpt
|
||||||
repo (str): The name of the repo
|
spec_repo (str): The name of the repo
|
||||||
with_update (bool): Whether to update the repo before installing
|
with_update (bool): Whether to update the repo before installing
|
||||||
is_first (bool): Whether it's the first time to check the dbgpt
|
is_first (bool): Whether it's the first time to check the dbgpt
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[str, Path] | None: The repo and the path of the dbgpt
|
Tuple[str, Path] | None: The repo and the path of the dbgpt
|
||||||
"""
|
"""
|
||||||
repos = _list_repos_details()
|
repos = _list_repos_details()
|
||||||
if repo:
|
if spec_repo:
|
||||||
repos = list(filter(lambda x: x[0] == repo, repos))
|
repos = list(filter(lambda x: x[0] == repo, repos))
|
||||||
if not repos:
|
if not repos:
|
||||||
logger.error(f"The specified repo '{repo}' does not exist.")
|
logger.error(f"The specified repo '{spec_repo}' does not exist.")
|
||||||
return
|
return
|
||||||
if is_first and with_update:
|
if is_first and with_update:
|
||||||
for repo in repos:
|
for repo in repos:
|
||||||
@ -253,7 +271,9 @@ def check_with_retry(
|
|||||||
):
|
):
|
||||||
return repo[0], dbgpt_path
|
return repo[0], dbgpt_path
|
||||||
if is_first:
|
if is_first:
|
||||||
return check_with_retry(name, repo, with_update=with_update, is_first=False)
|
return check_with_retry(
|
||||||
|
name, spec_repo, with_update=with_update, is_first=False
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@ -3,6 +3,14 @@ import inspect
|
|||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, get_args, get_origin, get_type_hints
|
from typing import Any, get_args, get_origin, get_type_hints
|
||||||
|
|
||||||
|
from typeguard import check_type
|
||||||
|
|
||||||
|
|
||||||
|
def _is_typing(obj):
|
||||||
|
from typing import _Final # type: ignore
|
||||||
|
|
||||||
|
return isinstance(obj, _Final)
|
||||||
|
|
||||||
|
|
||||||
def _is_instance_of_generic_type(obj, generic_type):
|
def _is_instance_of_generic_type(obj, generic_type):
|
||||||
"""Check if an object is an instance of a generic type."""
|
"""Check if an object is an instance of a generic type."""
|
||||||
@ -18,18 +26,44 @@ def _is_instance_of_generic_type(obj, generic_type):
|
|||||||
return isinstance(obj, origin)
|
return isinstance(obj, origin)
|
||||||
|
|
||||||
# Check if object matches the generic origin (like list, dict)
|
# Check if object matches the generic origin (like list, dict)
|
||||||
if not isinstance(obj, origin):
|
if not _is_typing(origin):
|
||||||
return False
|
return isinstance(obj, origin)
|
||||||
|
|
||||||
|
objs = [obj for _ in range(len(args))]
|
||||||
|
|
||||||
# For each item in the object, check if it matches the corresponding type argument
|
# For each item in the object, check if it matches the corresponding type argument
|
||||||
for sub_obj, arg in zip(obj, args):
|
for sub_obj, arg in zip(objs, args):
|
||||||
# Skip check if the type argument is Any
|
# Skip check if the type argument is Any
|
||||||
if arg is not Any and not isinstance(sub_obj, arg):
|
if arg is not Any:
|
||||||
return False
|
if _is_typing(arg):
|
||||||
|
sub_args = get_args(arg)
|
||||||
|
if (
|
||||||
|
sub_args
|
||||||
|
and not _is_typing(sub_args[0])
|
||||||
|
and not isinstance(sub_obj, sub_args[0])
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
elif not isinstance(sub_obj, arg):
|
||||||
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _check_type(obj, t) -> bool:
|
||||||
|
try:
|
||||||
|
check_type(obj, t)
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _get_orders(obj, arg_types):
|
||||||
|
try:
|
||||||
|
orders = [i for i, t in enumerate(arg_types) if _check_type(obj, t)]
|
||||||
|
return orders[0] if orders else int(1e8)
|
||||||
|
except Exception:
|
||||||
|
return int(1e8)
|
||||||
|
|
||||||
|
|
||||||
def _sort_args(func, args, kwargs):
|
def _sort_args(func, args, kwargs):
|
||||||
sig = inspect.signature(func)
|
sig = inspect.signature(func)
|
||||||
type_hints = get_type_hints(func)
|
type_hints = get_type_hints(func)
|
||||||
@ -49,9 +83,7 @@ def _sort_args(func, args, kwargs):
|
|||||||
|
|
||||||
sorted_args = sorted(
|
sorted_args = sorted(
|
||||||
other_args,
|
other_args,
|
||||||
key=lambda x: next(
|
key=lambda x: _get_orders(x, arg_types),
|
||||||
i for i, t in enumerate(arg_types) if _is_instance_of_generic_type(x, t)
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
return (*self_arg, *sorted_args), kwargs
|
return (*self_arg, *sorted_args), kwargs
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@
|
|||||||
|
|
||||||
At first, install dbgpt, and necessary dependencies:
|
At first, install dbgpt, and necessary dependencies:
|
||||||
|
|
||||||
```python
|
```shell
|
||||||
pip install dbgpt --upgrade
|
pip install dbgpt --upgrade
|
||||||
pip install openai
|
pip install openai
|
||||||
```
|
```
|
||||||
@ -14,7 +14,7 @@ Create a python file `simple_sdk_llm_example_dag.py` and write the following con
|
|||||||
```python
|
```python
|
||||||
from dbgpt.core import BaseOutputParser
|
from dbgpt.core import BaseOutputParser
|
||||||
from dbgpt.core.awel import DAG
|
from dbgpt.core.awel import DAG
|
||||||
from dbgpt.core.operator import (
|
from dbgpt.core.operators import (
|
||||||
PromptBuilderOperator,
|
PromptBuilderOperator,
|
||||||
RequestBuilderOperator,
|
RequestBuilderOperator,
|
||||||
)
|
)
|
||||||
|
@ -35,14 +35,14 @@ clone_repositories() {
|
|||||||
cd /root && git clone https://github.com/eosphoros-ai/DB-GPT.git
|
cd /root && git clone https://github.com/eosphoros-ai/DB-GPT.git
|
||||||
mkdir -p /root/DB-GPT/models && cd /root/DB-GPT/models
|
mkdir -p /root/DB-GPT/models && cd /root/DB-GPT/models
|
||||||
git clone https://huggingface.co/GanymedeNil/text2vec-large-chinese
|
git clone https://huggingface.co/GanymedeNil/text2vec-large-chinese
|
||||||
git clone https://huggingface.co/THUDM/chatglm2-6b
|
git clone https://huggingface.co/Qwen/Qwen-1_8B-Chat
|
||||||
rm -rf /root/DB-GPT/models/text2vec-large-chinese/.git
|
rm -rf /root/DB-GPT/models/text2vec-large-chinese/.git
|
||||||
rm -rf /root/DB-GPT/models/chatglm2-6b/.git
|
rm -rf /root/DB-GPT/models/Qwen-1_8B-Chat/.git
|
||||||
}
|
}
|
||||||
|
|
||||||
install_dbgpt_packages() {
|
install_dbgpt_packages() {
|
||||||
conda activate dbgpt && cd /root/DB-GPT && pip install -e ".[default]"
|
conda activate dbgpt && cd /root/DB-GPT && pip install -e ".[default]"
|
||||||
cp .env.template .env && sed -i 's/LLM_MODEL=vicuna-13b-v1.5/LLM_MODEL=chatglm2-6b/' .env
|
cp .env.template .env && sed -i 's/LLM_MODEL=vicuna-13b-v1.5/LLM_MODEL=qwen-1.8b-chat/' .env
|
||||||
}
|
}
|
||||||
|
|
||||||
clean_up() {
|
clean_up() {
|
||||||
|
2
setup.py
2
setup.py
@ -367,6 +367,8 @@ def core_requires():
|
|||||||
"python-dotenv==1.0.0",
|
"python-dotenv==1.0.0",
|
||||||
"cachetools",
|
"cachetools",
|
||||||
"pydantic<2,>=1",
|
"pydantic<2,>=1",
|
||||||
|
# For AWEL type checking
|
||||||
|
"typeguard",
|
||||||
]
|
]
|
||||||
# Simple command line dependencies
|
# Simple command line dependencies
|
||||||
setup_spec.extras["cli"] = setup_spec.extras["core"] + [
|
setup_spec.extras["cli"] = setup_spec.extras["core"] + [
|
||||||
|
Loading…
Reference in New Issue
Block a user