mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-05 11:01:09 +00:00
feat(core): Support RAG chat flow (#1185)
This commit is contained in:
@@ -56,17 +56,25 @@ def list_repos():
|
||||
|
||||
@click.command(name="add")
|
||||
@add_tap_options
|
||||
@click.option(
|
||||
"-b",
|
||||
"--branch",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="The branch of the repository(Just for git repo)",
|
||||
)
|
||||
@click.option(
|
||||
"--url",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The URL of the repo",
|
||||
)
|
||||
def add_repo(repo: str, url: str):
|
||||
def add_repo(repo: str, branch: str | None, url: str):
|
||||
"""Add a new repo"""
|
||||
from .repo import add_repo
|
||||
|
||||
add_repo(repo, url)
|
||||
add_repo(repo, url, branch)
|
||||
|
||||
|
||||
@click.command(name="remove")
|
||||
|
@@ -63,12 +63,13 @@ def _list_repos_details() -> List[Tuple[str, str]]:
|
||||
return results
|
||||
|
||||
|
||||
def add_repo(repo: str, repo_url: str):
|
||||
def add_repo(repo: str, repo_url: str, branch: str | None = None):
|
||||
"""Add a new repo
|
||||
|
||||
Args:
|
||||
repo (str): The name of the repo
|
||||
repo_url (str): The URL of the repo
|
||||
branch (str): The branch of the repo
|
||||
"""
|
||||
exist_repos = list_repos()
|
||||
if repo in exist_repos and repo_url not in DEFAULT_REPO_MAP.values():
|
||||
@@ -84,7 +85,7 @@ def add_repo(repo: str, repo_url: str):
|
||||
repo_group_dir = os.path.join(DBGPTS_REPO_HOME, repo_arr[0])
|
||||
os.makedirs(repo_group_dir, exist_ok=True)
|
||||
if repo_url.startswith("http") or repo_url.startswith("git"):
|
||||
clone_repo(repo, repo_group_dir, repo_name, repo_url)
|
||||
clone_repo(repo, repo_group_dir, repo_name, repo_url, branch)
|
||||
elif os.path.isdir(repo_url):
|
||||
# Create soft link
|
||||
os.symlink(repo_url, os.path.join(repo_group_dir, repo_name))
|
||||
@@ -106,7 +107,13 @@ def remove_repo(repo: str):
|
||||
logger.info(f"Repo '{repo}' removed successfully.")
|
||||
|
||||
|
||||
def clone_repo(repo: str, repo_group_dir: str, repo_name: str, repo_url: str):
|
||||
def clone_repo(
|
||||
repo: str,
|
||||
repo_group_dir: str,
|
||||
repo_name: str,
|
||||
repo_url: str,
|
||||
branch: str | None = None,
|
||||
):
|
||||
"""Clone the specified repo
|
||||
|
||||
Args:
|
||||
@@ -114,10 +121,22 @@ def clone_repo(repo: str, repo_group_dir: str, repo_name: str, repo_url: str):
|
||||
repo_group_dir (str): The directory of the repo group
|
||||
repo_name (str): The name of the repo
|
||||
repo_url (str): The URL of the repo
|
||||
branch (str): The branch of the repo
|
||||
"""
|
||||
os.chdir(repo_group_dir)
|
||||
subprocess.run(["git", "clone", repo_url, repo_name], check=True)
|
||||
logger.info(f"Repo '{repo}' cloned from {repo_url} successfully.")
|
||||
clone_command = ["git", "clone", repo_url, repo_name]
|
||||
|
||||
# If the branch is specified, add it to the clone command
|
||||
if branch:
|
||||
clone_command += ["-b", branch]
|
||||
|
||||
subprocess.run(clone_command, check=True)
|
||||
if branch:
|
||||
click.echo(
|
||||
f"Repo '{repo}' cloned from {repo_url} with branch '{branch}' successfully."
|
||||
)
|
||||
else:
|
||||
click.echo(f"Repo '{repo}' cloned from {repo_url} successfully.")
|
||||
|
||||
|
||||
def update_repo(repo: str):
|
||||
@@ -217,7 +236,7 @@ def _write_install_metadata(name: str, repo: str, install_path: Path):
|
||||
|
||||
def check_with_retry(
|
||||
name: str,
|
||||
repo: str | None = None,
|
||||
spec_repo: str | None = None,
|
||||
with_update: bool = False,
|
||||
is_first: bool = False,
|
||||
) -> Tuple[str, Path] | None:
|
||||
@@ -225,18 +244,17 @@ def check_with_retry(
|
||||
|
||||
Args:
|
||||
name (str): The name of the dbgpt
|
||||
repo (str): The name of the repo
|
||||
spec_repo (str): The name of the repo
|
||||
with_update (bool): Whether to update the repo before installing
|
||||
is_first (bool): Whether it's the first time to check the dbgpt
|
||||
|
||||
Returns:
|
||||
Tuple[str, Path] | None: The repo and the path of the dbgpt
|
||||
"""
|
||||
repos = _list_repos_details()
|
||||
if repo:
|
||||
if spec_repo:
|
||||
repos = list(filter(lambda x: x[0] == repo, repos))
|
||||
if not repos:
|
||||
logger.error(f"The specified repo '{repo}' does not exist.")
|
||||
logger.error(f"The specified repo '{spec_repo}' does not exist.")
|
||||
return
|
||||
if is_first and with_update:
|
||||
for repo in repos:
|
||||
@@ -253,7 +271,9 @@ def check_with_retry(
|
||||
):
|
||||
return repo[0], dbgpt_path
|
||||
if is_first:
|
||||
return check_with_retry(name, repo, with_update=with_update, is_first=False)
|
||||
return check_with_retry(
|
||||
name, spec_repo, with_update=with_update, is_first=False
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
|
@@ -3,6 +3,14 @@ import inspect
|
||||
from functools import wraps
|
||||
from typing import Any, get_args, get_origin, get_type_hints
|
||||
|
||||
from typeguard import check_type
|
||||
|
||||
|
||||
def _is_typing(obj):
|
||||
from typing import _Final # type: ignore
|
||||
|
||||
return isinstance(obj, _Final)
|
||||
|
||||
|
||||
def _is_instance_of_generic_type(obj, generic_type):
|
||||
"""Check if an object is an instance of a generic type."""
|
||||
@@ -18,18 +26,44 @@ def _is_instance_of_generic_type(obj, generic_type):
|
||||
return isinstance(obj, origin)
|
||||
|
||||
# Check if object matches the generic origin (like list, dict)
|
||||
if not isinstance(obj, origin):
|
||||
return False
|
||||
if not _is_typing(origin):
|
||||
return isinstance(obj, origin)
|
||||
|
||||
objs = [obj for _ in range(len(args))]
|
||||
|
||||
# For each item in the object, check if it matches the corresponding type argument
|
||||
for sub_obj, arg in zip(obj, args):
|
||||
for sub_obj, arg in zip(objs, args):
|
||||
# Skip check if the type argument is Any
|
||||
if arg is not Any and not isinstance(sub_obj, arg):
|
||||
return False
|
||||
|
||||
if arg is not Any:
|
||||
if _is_typing(arg):
|
||||
sub_args = get_args(arg)
|
||||
if (
|
||||
sub_args
|
||||
and not _is_typing(sub_args[0])
|
||||
and not isinstance(sub_obj, sub_args[0])
|
||||
):
|
||||
return False
|
||||
elif not isinstance(sub_obj, arg):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _check_type(obj, t) -> bool:
|
||||
try:
|
||||
check_type(obj, t)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _get_orders(obj, arg_types):
|
||||
try:
|
||||
orders = [i for i, t in enumerate(arg_types) if _check_type(obj, t)]
|
||||
return orders[0] if orders else int(1e8)
|
||||
except Exception:
|
||||
return int(1e8)
|
||||
|
||||
|
||||
def _sort_args(func, args, kwargs):
|
||||
sig = inspect.signature(func)
|
||||
type_hints = get_type_hints(func)
|
||||
@@ -49,9 +83,7 @@ def _sort_args(func, args, kwargs):
|
||||
|
||||
sorted_args = sorted(
|
||||
other_args,
|
||||
key=lambda x: next(
|
||||
i for i, t in enumerate(arg_types) if _is_instance_of_generic_type(x, t)
|
||||
),
|
||||
key=lambda x: _get_orders(x, arg_types),
|
||||
)
|
||||
return (*self_arg, *sorted_args), kwargs
|
||||
|
||||
|
Reference in New Issue
Block a user