feat(core): Support RAG chat flow (#1185)

This commit is contained in:
Fangyin Cheng 2024-02-23 11:44:44 +08:00 committed by GitHub
parent 21682575f5
commit e0986198a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 134 additions and 54 deletions

View File

@ -3,6 +3,7 @@
This runner will run the workflow in the current process. This runner will run the workflow in the current process.
""" """
import logging import logging
import traceback
from typing import Any, Dict, List, Optional, Set, cast from typing import Any, Dict, List, Optional, Set, cast
from dbgpt.component import SystemApp from dbgpt.component import SystemApp
@ -143,7 +144,11 @@ class DefaultWorkflowRunner(WorkflowRunner):
) )
_skip_current_downstream_by_node_name(node, skip_nodes, skip_node_ids) _skip_current_downstream_by_node_name(node, skip_nodes, skip_node_ids)
except Exception as e: except Exception as e:
logger.info(f"Run operator {node.node_id} error, error message: {str(e)}") msg = traceback.format_exc()
logger.info(
f"Run operator {type(node)}({node.node_id}) error, error message: "
f"{msg}"
)
task_ctx.set_current_state(TaskState.FAILED) task_ctx.set_current_state(TaskState.FAILED)
raise e raise e

View File

@ -370,22 +370,16 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
return FlowCategory.COMMON return FlowCategory.COMMON
def _is_chat_flow_type(obj: Any, is_class: bool = False) -> bool: def _is_chat_flow_type(output_obj: Any, is_class: bool = False) -> bool:
try:
from dbgpt.model.utils.chatgpt_utils import OpenAIStreamingOutputOperator
except ImportError:
OpenAIStreamingOutputOperator = None
if is_class: if is_class:
return ( return (
obj == str output_obj == str
or obj == CommonLLMHttpResponseBody or output_obj == CommonLLMHttpResponseBody
or (OpenAIStreamingOutputOperator and obj == OpenAIStreamingOutputOperator) or output_obj == ModelOutput
) )
else: else:
chat_types = (str, CommonLLMHttpResponseBody) chat_types = (str, CommonLLMHttpResponseBody)
if OpenAIStreamingOutputOperator: return isinstance(output_obj, chat_types)
chat_types += (OpenAIStreamingOutputOperator,)
return isinstance(obj, chat_types)
async def _chat_with_dag_task( async def _chat_with_dag_task(
@ -439,29 +433,50 @@ async def _chat_with_dag_task(
yield f"data:{full_text}\n\n" yield f"data:{full_text}\n\n"
else: else:
async for output in await task.call_stream(request): async for output in await task.call_stream(request):
str_msg = ""
should_return = False
if isinstance(output, str): if isinstance(output, str):
if output.strip(): if output.strip():
yield output str_msg = output
elif isinstance(output, ModelOutput):
if output.error_code != 0:
str_msg = f"[SERVER_ERROR]{output.text}"
should_return = True
else:
str_msg = output.text
else: else:
yield "data:[SERVER_ERROR]The output is not a stream format\n\n" str_msg = (
return f"[SERVER_ERROR]The output is not a valid format"
f"({type(output)})"
)
should_return = True
if str_msg:
str_msg = str_msg.replace("\n", "\\n")
yield f"data:{str_msg}\n\n"
if should_return:
return
else: else:
result = await task.call(request) result = await task.call(request)
str_msg = ""
if result is None: if result is None:
yield "data:[SERVER_ERROR]The result is None\n\n" str_msg = "[SERVER_ERROR]The result is None!"
elif isinstance(result, str): elif isinstance(result, str):
yield f"data:{result}\n\n" str_msg = result
elif isinstance(result, ModelOutput): elif isinstance(result, ModelOutput):
if result.error_code != 0: if result.error_code != 0:
yield f"data:[SERVER_ERROR]{result.text}\n\n" str_msg = f"[SERVER_ERROR]{result.text}"
else: else:
yield f"data:{result.text}\n\n" str_msg = result.text
elif isinstance(result, CommonLLMHttpResponseBody): elif isinstance(result, CommonLLMHttpResponseBody):
if result.error_code != 0: if result.error_code != 0:
yield f"data:[SERVER_ERROR]{result.text}\n\n" str_msg = f"[SERVER_ERROR]{result.text}"
else: else:
yield f"data:{result.text}\n\n" str_msg = result.text
elif isinstance(result, dict): elif isinstance(result, dict):
yield f"data:{json.dumps(result, ensure_ascii=False)}\n\n" str_msg = json.dumps(result, ensure_ascii=False)
else: else:
yield f"data:[SERVER_ERROR]The result is not a valid format({type(result)})\n\n" str_msg = f"[SERVER_ERROR]The result is not a valid format({type(result)})"
if str_msg:
str_msg = str_msg.replace("\n", "\\n")
yield f"data:{str_msg}\n\n"

View File

@ -339,9 +339,7 @@ class MilvusStore(VectorStoreBase):
self.vector_field = x.name self.vector_field = x.name
_, docs_and_scores = self._search(text, topk) _, docs_and_scores = self._search(text, topk)
if any(score < 0.0 or score > 1.0 for _, score, id in docs_and_scores): if any(score < 0.0 or score > 1.0 for _, score, id in docs_and_scores):
import warnings logger.warning(
warnings.warn(
"similarity score need between" f" 0 and 1, got {docs_and_scores}" "similarity score need between" f" 0 and 1, got {docs_and_scores}"
) )
@ -357,7 +355,7 @@ class MilvusStore(VectorStoreBase):
if score >= score_threshold if score >= score_threshold
] ]
if len(docs_and_scores) == 0: if len(docs_and_scores) == 0:
warnings.warn( logger.warning(
"No relevant docs were retrieved using the relevance score" "No relevant docs were retrieved using the relevance score"
f" threshold {score_threshold}" f" threshold {score_threshold}"
) )

View File

@ -56,17 +56,25 @@ def list_repos():
@click.command(name="add") @click.command(name="add")
@add_tap_options @add_tap_options
@click.option(
"-b",
"--branch",
type=str,
default=None,
required=False,
help="The branch of the repository(Just for git repo)",
)
@click.option( @click.option(
"--url", "--url",
type=str, type=str,
required=True, required=True,
help="The URL of the repo", help="The URL of the repo",
) )
def add_repo(repo: str, url: str): def add_repo(repo: str, branch: str | None, url: str):
"""Add a new repo""" """Add a new repo"""
from .repo import add_repo from .repo import add_repo
add_repo(repo, url) add_repo(repo, url, branch)
@click.command(name="remove") @click.command(name="remove")

View File

@ -63,12 +63,13 @@ def _list_repos_details() -> List[Tuple[str, str]]:
return results return results
def add_repo(repo: str, repo_url: str): def add_repo(repo: str, repo_url: str, branch: str | None = None):
"""Add a new repo """Add a new repo
Args: Args:
repo (str): The name of the repo repo (str): The name of the repo
repo_url (str): The URL of the repo repo_url (str): The URL of the repo
branch (str): The branch of the repo
""" """
exist_repos = list_repos() exist_repos = list_repos()
if repo in exist_repos and repo_url not in DEFAULT_REPO_MAP.values(): if repo in exist_repos and repo_url not in DEFAULT_REPO_MAP.values():
@ -84,7 +85,7 @@ def add_repo(repo: str, repo_url: str):
repo_group_dir = os.path.join(DBGPTS_REPO_HOME, repo_arr[0]) repo_group_dir = os.path.join(DBGPTS_REPO_HOME, repo_arr[0])
os.makedirs(repo_group_dir, exist_ok=True) os.makedirs(repo_group_dir, exist_ok=True)
if repo_url.startswith("http") or repo_url.startswith("git"): if repo_url.startswith("http") or repo_url.startswith("git"):
clone_repo(repo, repo_group_dir, repo_name, repo_url) clone_repo(repo, repo_group_dir, repo_name, repo_url, branch)
elif os.path.isdir(repo_url): elif os.path.isdir(repo_url):
# Create soft link # Create soft link
os.symlink(repo_url, os.path.join(repo_group_dir, repo_name)) os.symlink(repo_url, os.path.join(repo_group_dir, repo_name))
@ -106,7 +107,13 @@ def remove_repo(repo: str):
logger.info(f"Repo '{repo}' removed successfully.") logger.info(f"Repo '{repo}' removed successfully.")
def clone_repo(repo: str, repo_group_dir: str, repo_name: str, repo_url: str): def clone_repo(
repo: str,
repo_group_dir: str,
repo_name: str,
repo_url: str,
branch: str | None = None,
):
"""Clone the specified repo """Clone the specified repo
Args: Args:
@ -114,10 +121,22 @@ def clone_repo(repo: str, repo_group_dir: str, repo_name: str, repo_url: str):
repo_group_dir (str): The directory of the repo group repo_group_dir (str): The directory of the repo group
repo_name (str): The name of the repo repo_name (str): The name of the repo
repo_url (str): The URL of the repo repo_url (str): The URL of the repo
branch (str): The branch of the repo
""" """
os.chdir(repo_group_dir) os.chdir(repo_group_dir)
subprocess.run(["git", "clone", repo_url, repo_name], check=True) clone_command = ["git", "clone", repo_url, repo_name]
logger.info(f"Repo '{repo}' cloned from {repo_url} successfully.")
# If the branch is specified, add it to the clone command
if branch:
clone_command += ["-b", branch]
subprocess.run(clone_command, check=True)
if branch:
click.echo(
f"Repo '{repo}' cloned from {repo_url} with branch '{branch}' successfully."
)
else:
click.echo(f"Repo '{repo}' cloned from {repo_url} successfully.")
def update_repo(repo: str): def update_repo(repo: str):
@ -217,7 +236,7 @@ def _write_install_metadata(name: str, repo: str, install_path: Path):
def check_with_retry( def check_with_retry(
name: str, name: str,
repo: str | None = None, spec_repo: str | None = None,
with_update: bool = False, with_update: bool = False,
is_first: bool = False, is_first: bool = False,
) -> Tuple[str, Path] | None: ) -> Tuple[str, Path] | None:
@ -225,18 +244,17 @@ def check_with_retry(
Args: Args:
name (str): The name of the dbgpt name (str): The name of the dbgpt
repo (str): The name of the repo spec_repo (str): The name of the repo
with_update (bool): Whether to update the repo before installing with_update (bool): Whether to update the repo before installing
is_first (bool): Whether it's the first time to check the dbgpt is_first (bool): Whether it's the first time to check the dbgpt
Returns: Returns:
Tuple[str, Path] | None: The repo and the path of the dbgpt Tuple[str, Path] | None: The repo and the path of the dbgpt
""" """
repos = _list_repos_details() repos = _list_repos_details()
if repo: if spec_repo:
repos = list(filter(lambda x: x[0] == repo, repos)) repos = list(filter(lambda x: x[0] == repo, repos))
if not repos: if not repos:
logger.error(f"The specified repo '{repo}' does not exist.") logger.error(f"The specified repo '{spec_repo}' does not exist.")
return return
if is_first and with_update: if is_first and with_update:
for repo in repos: for repo in repos:
@ -253,7 +271,9 @@ def check_with_retry(
): ):
return repo[0], dbgpt_path return repo[0], dbgpt_path
if is_first: if is_first:
return check_with_retry(name, repo, with_update=with_update, is_first=False) return check_with_retry(
name, spec_repo, with_update=with_update, is_first=False
)
return None return None

View File

@ -3,6 +3,14 @@ import inspect
from functools import wraps from functools import wraps
from typing import Any, get_args, get_origin, get_type_hints from typing import Any, get_args, get_origin, get_type_hints
from typeguard import check_type
def _is_typing(obj):
from typing import _Final # type: ignore
return isinstance(obj, _Final)
def _is_instance_of_generic_type(obj, generic_type): def _is_instance_of_generic_type(obj, generic_type):
"""Check if an object is an instance of a generic type.""" """Check if an object is an instance of a generic type."""
@ -18,18 +26,44 @@ def _is_instance_of_generic_type(obj, generic_type):
return isinstance(obj, origin) return isinstance(obj, origin)
# Check if object matches the generic origin (like list, dict) # Check if object matches the generic origin (like list, dict)
if not isinstance(obj, origin): if not _is_typing(origin):
return False return isinstance(obj, origin)
objs = [obj for _ in range(len(args))]
# For each item in the object, check if it matches the corresponding type argument # For each item in the object, check if it matches the corresponding type argument
for sub_obj, arg in zip(obj, args): for sub_obj, arg in zip(objs, args):
# Skip check if the type argument is Any # Skip check if the type argument is Any
if arg is not Any and not isinstance(sub_obj, arg): if arg is not Any:
return False if _is_typing(arg):
sub_args = get_args(arg)
if (
sub_args
and not _is_typing(sub_args[0])
and not isinstance(sub_obj, sub_args[0])
):
return False
elif not isinstance(sub_obj, arg):
return False
return True return True
def _check_type(obj, t) -> bool:
try:
check_type(obj, t)
return True
except Exception:
return False
def _get_orders(obj, arg_types):
try:
orders = [i for i, t in enumerate(arg_types) if _check_type(obj, t)]
return orders[0] if orders else int(1e8)
except Exception:
return int(1e8)
def _sort_args(func, args, kwargs): def _sort_args(func, args, kwargs):
sig = inspect.signature(func) sig = inspect.signature(func)
type_hints = get_type_hints(func) type_hints = get_type_hints(func)
@ -49,9 +83,7 @@ def _sort_args(func, args, kwargs):
sorted_args = sorted( sorted_args = sorted(
other_args, other_args,
key=lambda x: next( key=lambda x: _get_orders(x, arg_types),
i for i, t in enumerate(arg_types) if _is_instance_of_generic_type(x, t)
),
) )
return (*self_arg, *sorted_args), kwargs return (*self_arg, *sorted_args), kwargs

View File

@ -4,7 +4,7 @@
At first, install dbgpt, and necessary dependencies: At first, install dbgpt, and necessary dependencies:
```python ```shell
pip install dbgpt --upgrade pip install dbgpt --upgrade
pip install openai pip install openai
``` ```
@ -14,7 +14,7 @@ Create a python file `simple_sdk_llm_example_dag.py` and write the following con
```python ```python
from dbgpt.core import BaseOutputParser from dbgpt.core import BaseOutputParser
from dbgpt.core.awel import DAG from dbgpt.core.awel import DAG
from dbgpt.core.operator import ( from dbgpt.core.operators import (
PromptBuilderOperator, PromptBuilderOperator,
RequestBuilderOperator, RequestBuilderOperator,
) )

View File

@ -35,14 +35,14 @@ clone_repositories() {
cd /root && git clone https://github.com/eosphoros-ai/DB-GPT.git cd /root && git clone https://github.com/eosphoros-ai/DB-GPT.git
mkdir -p /root/DB-GPT/models && cd /root/DB-GPT/models mkdir -p /root/DB-GPT/models && cd /root/DB-GPT/models
git clone https://huggingface.co/GanymedeNil/text2vec-large-chinese git clone https://huggingface.co/GanymedeNil/text2vec-large-chinese
git clone https://huggingface.co/THUDM/chatglm2-6b git clone https://huggingface.co/Qwen/Qwen-1_8B-Chat
rm -rf /root/DB-GPT/models/text2vec-large-chinese/.git rm -rf /root/DB-GPT/models/text2vec-large-chinese/.git
rm -rf /root/DB-GPT/models/chatglm2-6b/.git rm -rf /root/DB-GPT/models/Qwen-1_8B-Chat/.git
} }
install_dbgpt_packages() { install_dbgpt_packages() {
conda activate dbgpt && cd /root/DB-GPT && pip install -e ".[default]" conda activate dbgpt && cd /root/DB-GPT && pip install -e ".[default]"
cp .env.template .env && sed -i 's/LLM_MODEL=vicuna-13b-v1.5/LLM_MODEL=chatglm2-6b/' .env cp .env.template .env && sed -i 's/LLM_MODEL=vicuna-13b-v1.5/LLM_MODEL=qwen-1.8b-chat/' .env
} }
clean_up() { clean_up() {

View File

@ -367,6 +367,8 @@ def core_requires():
"python-dotenv==1.0.0", "python-dotenv==1.0.0",
"cachetools", "cachetools",
"pydantic<2,>=1", "pydantic<2,>=1",
# For AWEL type checking
"typeguard",
] ]
# Simple command line dependencies # Simple command line dependencies
setup_spec.extras["cli"] = setup_spec.extras["core"] + [ setup_spec.extras["cli"] = setup_spec.extras["core"] + [