mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-13 13:10:29 +00:00
feat: Run AWEL flow in CLI (#1341)
This commit is contained in:
247
dbgpt/client/_cli.py
Normal file
247
dbgpt/client/_cli.py
Normal file
@@ -0,0 +1,247 @@
|
||||
"""CLI for DB-GPT client."""
|
||||
|
||||
import functools
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict
|
||||
|
||||
import click
|
||||
|
||||
from dbgpt.util import get_or_create_event_loop
|
||||
from dbgpt.util.console import CliLogger
|
||||
from dbgpt.util.i18n_utils import _
|
||||
|
||||
from .client import Client
|
||||
from .flow import list_flow
|
||||
|
||||
cl = CliLogger()
|
||||
|
||||
|
||||
def add_base_flow_options(func):
|
||||
"""Add base flow options to the command."""
|
||||
|
||||
@click.option(
|
||||
"-n",
|
||||
"--name",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help=_("The name of the AWEL flow"),
|
||||
)
|
||||
@click.option(
|
||||
"--uid",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help=_("The uid of the AWEL flow"),
|
||||
)
|
||||
@functools.wraps(func)
|
||||
def _wrapper(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return _wrapper
|
||||
|
||||
|
||||
def add_chat_options(func):
|
||||
"""Add chat options to the command."""
|
||||
|
||||
@click.option(
|
||||
"-m",
|
||||
"--messages",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help=_("The messages to run AWEL flow"),
|
||||
)
|
||||
@click.option(
|
||||
"--model",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help=_("The model name of AWEL flow"),
|
||||
)
|
||||
@click.option(
|
||||
"-s",
|
||||
"--stream",
|
||||
type=bool,
|
||||
default=False,
|
||||
required=False,
|
||||
is_flag=True,
|
||||
help=_("Whether use stream mode to run AWEL flow"),
|
||||
)
|
||||
@click.option(
|
||||
"-t",
|
||||
"--temperature",
|
||||
type=float,
|
||||
default=None,
|
||||
required=False,
|
||||
help=_("The temperature to run AWEL flow"),
|
||||
)
|
||||
@click.option(
|
||||
"--max_new_tokens",
|
||||
type=int,
|
||||
default=None,
|
||||
required=False,
|
||||
help=_("The max new tokens to run AWEL flow"),
|
||||
)
|
||||
@click.option(
|
||||
"--conv_uid",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help=_("The conversation id of the AWEL flow"),
|
||||
)
|
||||
@click.option(
|
||||
"-d",
|
||||
"--data",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help=_("The json data to run AWEL flow, if set, will overwrite other options"),
|
||||
)
|
||||
@click.option(
|
||||
"-e",
|
||||
"--extra",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help=_("The extra json data to run AWEL flow."),
|
||||
)
|
||||
@click.option(
|
||||
"-i",
|
||||
"--interactive",
|
||||
type=bool,
|
||||
default=False,
|
||||
required=False,
|
||||
is_flag=True,
|
||||
help=_("Whether use interactive mode to run AWEL flow"),
|
||||
)
|
||||
@functools.wraps(func)
|
||||
def _wrapper(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return _wrapper
|
||||
|
||||
|
||||
@click.command(name="flow")
|
||||
@add_base_flow_options
|
||||
@add_chat_options
|
||||
def run_flow(name: str, uid: str, data: str, interactive: bool, **kwargs):
|
||||
"""Run a AWEL flow."""
|
||||
client = Client()
|
||||
|
||||
loop = get_or_create_event_loop()
|
||||
res = loop.run_until_complete(list_flow(client, name, uid))
|
||||
|
||||
if not res:
|
||||
cl.error("Flow not found with the given name or uid", exit_code=1)
|
||||
if len(res) > 1:
|
||||
cl.error("More than one flow found", exit_code=1)
|
||||
flow = res[0]
|
||||
json_data = _parse_json_data(data, **kwargs)
|
||||
json_data["chat_param"] = flow.uid
|
||||
json_data["chat_mode"] = "chat_flow"
|
||||
stream = "stream" in json_data and str(json_data["stream"]).lower() in ["true", "1"]
|
||||
if stream:
|
||||
loop.run_until_complete(_chat_stream(client, interactive, json_data))
|
||||
else:
|
||||
loop.run_until_complete(_chat(client, interactive, json_data))
|
||||
|
||||
|
||||
def _parse_json_data(data: str, **kwargs):
|
||||
json_data = {}
|
||||
if data:
|
||||
try:
|
||||
json_data = json.loads(data)
|
||||
except Exception as e:
|
||||
cl.error(f"Invalid JSON data: {data}, {e}", exit_code=1)
|
||||
if "extra" in kwargs and kwargs["extra"]:
|
||||
try:
|
||||
extra = json.loads(kwargs["extra"])
|
||||
kwargs["extra"] = extra
|
||||
except Exception as e:
|
||||
cl.error(f"Invalid extra JSON data: {kwargs['extra']}, {e}", exit_code=1)
|
||||
for k, v in kwargs.items():
|
||||
if v is not None and k not in json_data:
|
||||
json_data[k] = v
|
||||
if "model" not in json_data:
|
||||
json_data["model"] = "__empty__model__"
|
||||
return json_data
|
||||
|
||||
|
||||
async def _chat_stream(client: Client, interactive: bool, json_data: Dict[str, Any]):
|
||||
user_input = json_data.get("messages", "")
|
||||
if "conv_uid" not in json_data and interactive:
|
||||
json_data["conv_uid"] = str(uuid.uuid4())
|
||||
first_message = True
|
||||
while True:
|
||||
try:
|
||||
if interactive and not user_input:
|
||||
cl.print("Type 'exit' or 'quit' to exit.")
|
||||
while not user_input:
|
||||
user_input = cl.ask("You")
|
||||
if user_input.lower() in ["exit", "quit", "q"]:
|
||||
break
|
||||
start_time = time.time()
|
||||
json_data["messages"] = user_input
|
||||
if first_message:
|
||||
cl.info("You: " + user_input)
|
||||
cl.info("Chat stream started")
|
||||
cl.debug(f"JSON data: {json.dumps(json_data, ensure_ascii=False)}")
|
||||
full_text = ""
|
||||
cl.print("Bot: ")
|
||||
async for out in client.chat_stream(**json_data):
|
||||
if out.choices:
|
||||
text = out.choices[0].delta.content
|
||||
if text:
|
||||
full_text += text
|
||||
cl.print(text, end="")
|
||||
end_time = time.time()
|
||||
time_cost = round(end_time - start_time, 2)
|
||||
cl.success(f"\n:tada: Chat stream finished, timecost: {time_cost} s")
|
||||
except Exception as e:
|
||||
cl.error(f"Chat stream failed: {e}", exit_code=1)
|
||||
finally:
|
||||
first_message = False
|
||||
if interactive:
|
||||
user_input = ""
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
async def _chat(client: Client, interactive: bool, json_data: Dict[str, Any]):
|
||||
user_input = json_data.get("messages", "")
|
||||
if "conv_uid" not in json_data and interactive:
|
||||
json_data["conv_uid"] = str(uuid.uuid4())
|
||||
first_message = True
|
||||
while True:
|
||||
try:
|
||||
if interactive and not user_input:
|
||||
cl.print("Type 'exit' or 'quit' to exit.")
|
||||
while not user_input:
|
||||
user_input = cl.ask("You")
|
||||
if user_input.lower() in ["exit", "quit", "q"]:
|
||||
break
|
||||
start_time = time.time()
|
||||
json_data["messages"] = user_input
|
||||
if first_message:
|
||||
cl.info("You: " + user_input)
|
||||
|
||||
cl.info("Chat started")
|
||||
cl.debug(f"JSON data: {json.dumps(json_data, ensure_ascii=False)}")
|
||||
res = await client.chat(**json_data)
|
||||
cl.print("Bot: ")
|
||||
if res.choices:
|
||||
text = res.choices[0].message.content
|
||||
cl.markdown(text)
|
||||
time_cost = round(time.time() - start_time, 2)
|
||||
cl.success(f"\n:tada: Chat stream finished, timecost: {time_cost} s")
|
||||
except Exception as e:
|
||||
cl.error(f"Chat failed: {e}", exit_code=1)
|
||||
finally:
|
||||
first_message = False
|
||||
if interactive:
|
||||
user_input = ""
|
||||
else:
|
||||
break
|
@@ -1,7 +1,8 @@
|
||||
"""This module contains the client for the DB-GPT API."""
|
||||
import atexit
|
||||
import json
|
||||
import os
|
||||
from typing import Any, AsyncGenerator, List, Optional, Union
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
@@ -98,6 +99,7 @@ class Client:
|
||||
self._http_client = httpx.AsyncClient(
|
||||
headers=headers, timeout=timeout if timeout else httpx.Timeout(None)
|
||||
)
|
||||
atexit.register(self.close)
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
@@ -113,6 +115,7 @@ class Client:
|
||||
span_id: Optional[str] = None,
|
||||
incremental: bool = True,
|
||||
enable_vis: bool = True,
|
||||
**kwargs,
|
||||
) -> ChatCompletionResponse:
|
||||
"""
|
||||
Chat Completion.
|
||||
@@ -187,6 +190,7 @@ class Client:
|
||||
span_id: Optional[str] = None,
|
||||
incremental: bool = True,
|
||||
enable_vis: bool = True,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[ChatCompletionStreamResponse, None]:
|
||||
"""
|
||||
Chat Stream Completion.
|
||||
@@ -238,10 +242,23 @@ class Client:
|
||||
incremental=incremental,
|
||||
enable_vis=enable_vis,
|
||||
)
|
||||
async for chat_completion_response in self._chat_stream(request.dict()):
|
||||
yield chat_completion_response
|
||||
|
||||
async def _chat_stream(
|
||||
self, data: Dict[str, Any]
|
||||
) -> AsyncGenerator[ChatCompletionStreamResponse, None]:
|
||||
"""Chat Stream Completion.
|
||||
|
||||
Args:
|
||||
data: dict, The data to send to the API.
|
||||
Returns:
|
||||
AsyncGenerator[dict, None]: The chat completion response.
|
||||
"""
|
||||
async with self._http_client.stream(
|
||||
method="POST",
|
||||
url=self._api_url + "/chat/completions",
|
||||
json=request.dict(),
|
||||
json=data,
|
||||
headers={},
|
||||
) as response:
|
||||
if response.status_code == 200:
|
||||
@@ -250,7 +267,11 @@ class Client:
|
||||
if line.strip() == "data: [DONE]":
|
||||
break
|
||||
if line.startswith("data:"):
|
||||
json_data = json.loads(line[len("data: ") :])
|
||||
if line.startswith("data: "):
|
||||
sse_data = line[len("data: ") :]
|
||||
else:
|
||||
sse_data = line[len("data:") :]
|
||||
json_data = json.loads(sse_data)
|
||||
chat_completion_response = ChatCompletionStreamResponse(
|
||||
**json_data
|
||||
)
|
||||
@@ -265,21 +286,20 @@ class Client:
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def get(self, path: str, *args):
|
||||
async def get(self, path: str, *args, **kwargs):
|
||||
"""Get method.
|
||||
|
||||
Args:
|
||||
path: str, The path to get.
|
||||
args: Any, The arguments to pass to the get method.
|
||||
"""
|
||||
try:
|
||||
response = await self._http_client.get(
|
||||
f"{self._api_url}/{CLIENT_SERVE_PATH}{path}",
|
||||
*args,
|
||||
)
|
||||
return response
|
||||
finally:
|
||||
await self._http_client.aclose()
|
||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
response = await self._http_client.get(
|
||||
f"{self._api_url}/{CLIENT_SERVE_PATH}{path}",
|
||||
*args,
|
||||
params=kwargs,
|
||||
)
|
||||
return response
|
||||
|
||||
async def post(self, path: str, args):
|
||||
"""Post method.
|
||||
@@ -288,13 +308,10 @@ class Client:
|
||||
path: str, The path to post.
|
||||
args: Any, The arguments to pass to the post
|
||||
"""
|
||||
try:
|
||||
return await self._http_client.post(
|
||||
f"{self._api_url}/{CLIENT_SERVE_PATH}{path}",
|
||||
json=args,
|
||||
)
|
||||
finally:
|
||||
await self._http_client.aclose()
|
||||
return await self._http_client.post(
|
||||
f"{self._api_url}/{CLIENT_SERVE_PATH}{path}",
|
||||
json=args,
|
||||
)
|
||||
|
||||
async def post_param(self, path: str, args):
|
||||
"""Post method.
|
||||
@@ -303,13 +320,10 @@ class Client:
|
||||
path: str, The path to post.
|
||||
args: Any, The arguments to pass to the post
|
||||
"""
|
||||
try:
|
||||
return await self._http_client.post(
|
||||
f"{self._api_url}/{CLIENT_SERVE_PATH}{path}",
|
||||
params=args,
|
||||
)
|
||||
finally:
|
||||
await self._http_client.aclose()
|
||||
return await self._http_client.post(
|
||||
f"{self._api_url}/{CLIENT_SERVE_PATH}{path}",
|
||||
params=args,
|
||||
)
|
||||
|
||||
async def patch(self, path: str, *args):
|
||||
"""Patch method.
|
||||
@@ -329,26 +343,20 @@ class Client:
|
||||
path: str, The path to put.
|
||||
args: Any, The arguments to pass to the put.
|
||||
"""
|
||||
try:
|
||||
return await self._http_client.put(
|
||||
f"{self._api_url}/{CLIENT_SERVE_PATH}{path}", json=args
|
||||
)
|
||||
finally:
|
||||
await self._http_client.aclose()
|
||||
return await self._http_client.put(
|
||||
f"{self._api_url}/{CLIENT_SERVE_PATH}{path}", json=args
|
||||
)
|
||||
|
||||
async def delete(self, path: str, *args):
|
||||
"""Delete method.
|
||||
|
||||
Args:
|
||||
path: str, The path to delete.
|
||||
args: Any, The arguments to pass to the delete.
|
||||
args: Any, The arguments to pass to delete.
|
||||
"""
|
||||
try:
|
||||
return await self._http_client.delete(
|
||||
f"{self._api_url}/{CLIENT_SERVE_PATH}{path}", *args
|
||||
)
|
||||
finally:
|
||||
await self._http_client.aclose()
|
||||
return await self._http_client.delete(
|
||||
f"{self._api_url}/{CLIENT_SERVE_PATH}{path}", *args
|
||||
)
|
||||
|
||||
async def head(self, path: str, *args):
|
||||
"""Head method.
|
||||
@@ -359,6 +367,18 @@ class Client:
|
||||
"""
|
||||
return self._http_client.head(self._api_url + path, *args)
|
||||
|
||||
def close(self):
|
||||
"""Close the client."""
|
||||
from dbgpt.util import get_or_create_event_loop
|
||||
|
||||
if not self._http_client.is_closed:
|
||||
loop = get_or_create_event_loop()
|
||||
loop.run_until_complete(self._http_client.aclose())
|
||||
|
||||
async def aclose(self):
|
||||
"""Close the client."""
|
||||
await self._http_client.aclose()
|
||||
|
||||
|
||||
def is_valid_url(api_url: Any) -> bool:
|
||||
"""Check if the given URL is valid.
|
||||
|
@@ -93,19 +93,23 @@ async def get_flow(client: Client, flow_id: str) -> FlowPanel:
|
||||
raise ClientException(f"Failed to get flow: {e}")
|
||||
|
||||
|
||||
async def list_flow(client: Client) -> List[FlowPanel]:
|
||||
async def list_flow(
|
||||
client: Client, name: str | None = None, uid: str | None = None
|
||||
) -> List[FlowPanel]:
|
||||
"""
|
||||
List flows.
|
||||
|
||||
Args:
|
||||
client (Client): The dbgpt client.
|
||||
name (str): The name of the flow.
|
||||
uid (str): The uid of the flow.
|
||||
Returns:
|
||||
List[FlowPanel]: The list of flow panels.
|
||||
Raises:
|
||||
ClientException: If the request failed.
|
||||
"""
|
||||
try:
|
||||
res = await client.get("/awel/flows")
|
||||
res = await client.get("/awel/flows", **{"name": name, "uid": uid})
|
||||
result: Result = res.json()
|
||||
if result["success"]:
|
||||
return [FlowPanel(**flow) for flow in result["data"]["items"]]
|
||||
|
Reference in New Issue
Block a user