feat: Run AWEL flow in CLI (#1341)

This commit is contained in:
Fangyin Cheng
2024-03-27 12:50:05 +08:00
committed by GitHub
parent 340a9fbc35
commit 3a7a2cbbb8
42 changed files with 1454 additions and 422 deletions

247
dbgpt/client/_cli.py Normal file
View File

@@ -0,0 +1,247 @@
"""CLI for DB-GPT client."""
import functools
import json
import time
import uuid
from typing import Any, Dict
import click
from dbgpt.util import get_or_create_event_loop
from dbgpt.util.console import CliLogger
from dbgpt.util.i18n_utils import _
from .client import Client
from .flow import list_flow
cl = CliLogger()
def add_base_flow_options(func):
"""Add base flow options to the command."""
@click.option(
"-n",
"--name",
type=str,
default=None,
required=False,
help=_("The name of the AWEL flow"),
)
@click.option(
"--uid",
type=str,
default=None,
required=False,
help=_("The uid of the AWEL flow"),
)
@functools.wraps(func)
def _wrapper(*args, **kwargs):
return func(*args, **kwargs)
return _wrapper
def add_chat_options(func):
"""Add chat options to the command."""
@click.option(
"-m",
"--messages",
type=str,
default=None,
required=False,
help=_("The messages to run AWEL flow"),
)
@click.option(
"--model",
type=str,
default=None,
required=False,
help=_("The model name of AWEL flow"),
)
@click.option(
"-s",
"--stream",
type=bool,
default=False,
required=False,
is_flag=True,
help=_("Whether use stream mode to run AWEL flow"),
)
@click.option(
"-t",
"--temperature",
type=float,
default=None,
required=False,
help=_("The temperature to run AWEL flow"),
)
@click.option(
"--max_new_tokens",
type=int,
default=None,
required=False,
help=_("The max new tokens to run AWEL flow"),
)
@click.option(
"--conv_uid",
type=str,
default=None,
required=False,
help=_("The conversation id of the AWEL flow"),
)
@click.option(
"-d",
"--data",
type=str,
default=None,
required=False,
help=_("The json data to run AWEL flow, if set, will overwrite other options"),
)
@click.option(
"-e",
"--extra",
type=str,
default=None,
required=False,
help=_("The extra json data to run AWEL flow."),
)
@click.option(
"-i",
"--interactive",
type=bool,
default=False,
required=False,
is_flag=True,
help=_("Whether use interactive mode to run AWEL flow"),
)
@functools.wraps(func)
def _wrapper(*args, **kwargs):
return func(*args, **kwargs)
return _wrapper
@click.command(name="flow")
@add_base_flow_options
@add_chat_options
def run_flow(name: str, uid: str, data: str, interactive: bool, **kwargs):
"""Run a AWEL flow."""
client = Client()
loop = get_or_create_event_loop()
res = loop.run_until_complete(list_flow(client, name, uid))
if not res:
cl.error("Flow not found with the given name or uid", exit_code=1)
if len(res) > 1:
cl.error("More than one flow found", exit_code=1)
flow = res[0]
json_data = _parse_json_data(data, **kwargs)
json_data["chat_param"] = flow.uid
json_data["chat_mode"] = "chat_flow"
stream = "stream" in json_data and str(json_data["stream"]).lower() in ["true", "1"]
if stream:
loop.run_until_complete(_chat_stream(client, interactive, json_data))
else:
loop.run_until_complete(_chat(client, interactive, json_data))
def _parse_json_data(data: str, **kwargs):
json_data = {}
if data:
try:
json_data = json.loads(data)
except Exception as e:
cl.error(f"Invalid JSON data: {data}, {e}", exit_code=1)
if "extra" in kwargs and kwargs["extra"]:
try:
extra = json.loads(kwargs["extra"])
kwargs["extra"] = extra
except Exception as e:
cl.error(f"Invalid extra JSON data: {kwargs['extra']}, {e}", exit_code=1)
for k, v in kwargs.items():
if v is not None and k not in json_data:
json_data[k] = v
if "model" not in json_data:
json_data["model"] = "__empty__model__"
return json_data
async def _chat_stream(client: Client, interactive: bool, json_data: Dict[str, Any]):
user_input = json_data.get("messages", "")
if "conv_uid" not in json_data and interactive:
json_data["conv_uid"] = str(uuid.uuid4())
first_message = True
while True:
try:
if interactive and not user_input:
cl.print("Type 'exit' or 'quit' to exit.")
while not user_input:
user_input = cl.ask("You")
if user_input.lower() in ["exit", "quit", "q"]:
break
start_time = time.time()
json_data["messages"] = user_input
if first_message:
cl.info("You: " + user_input)
cl.info("Chat stream started")
cl.debug(f"JSON data: {json.dumps(json_data, ensure_ascii=False)}")
full_text = ""
cl.print("Bot: ")
async for out in client.chat_stream(**json_data):
if out.choices:
text = out.choices[0].delta.content
if text:
full_text += text
cl.print(text, end="")
end_time = time.time()
time_cost = round(end_time - start_time, 2)
cl.success(f"\n:tada: Chat stream finished, timecost: {time_cost} s")
except Exception as e:
cl.error(f"Chat stream failed: {e}", exit_code=1)
finally:
first_message = False
if interactive:
user_input = ""
else:
break
async def _chat(client: Client, interactive: bool, json_data: Dict[str, Any]):
user_input = json_data.get("messages", "")
if "conv_uid" not in json_data and interactive:
json_data["conv_uid"] = str(uuid.uuid4())
first_message = True
while True:
try:
if interactive and not user_input:
cl.print("Type 'exit' or 'quit' to exit.")
while not user_input:
user_input = cl.ask("You")
if user_input.lower() in ["exit", "quit", "q"]:
break
start_time = time.time()
json_data["messages"] = user_input
if first_message:
cl.info("You: " + user_input)
cl.info("Chat started")
cl.debug(f"JSON data: {json.dumps(json_data, ensure_ascii=False)}")
res = await client.chat(**json_data)
cl.print("Bot: ")
if res.choices:
text = res.choices[0].message.content
cl.markdown(text)
time_cost = round(time.time() - start_time, 2)
cl.success(f"\n:tada: Chat stream finished, timecost: {time_cost} s")
except Exception as e:
cl.error(f"Chat failed: {e}", exit_code=1)
finally:
first_message = False
if interactive:
user_input = ""
else:
break

View File

@@ -1,7 +1,8 @@
"""This module contains the client for the DB-GPT API."""
import atexit
import json
import os
from typing import Any, AsyncGenerator, List, Optional, Union
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
from urllib.parse import urlparse
import httpx
@@ -98,6 +99,7 @@ class Client:
self._http_client = httpx.AsyncClient(
headers=headers, timeout=timeout if timeout else httpx.Timeout(None)
)
atexit.register(self.close)
async def chat(
self,
@@ -113,6 +115,7 @@ class Client:
span_id: Optional[str] = None,
incremental: bool = True,
enable_vis: bool = True,
**kwargs,
) -> ChatCompletionResponse:
"""
Chat Completion.
@@ -187,6 +190,7 @@ class Client:
span_id: Optional[str] = None,
incremental: bool = True,
enable_vis: bool = True,
**kwargs,
) -> AsyncGenerator[ChatCompletionStreamResponse, None]:
"""
Chat Stream Completion.
@@ -238,10 +242,23 @@ class Client:
incremental=incremental,
enable_vis=enable_vis,
)
async for chat_completion_response in self._chat_stream(request.dict()):
yield chat_completion_response
async def _chat_stream(
self, data: Dict[str, Any]
) -> AsyncGenerator[ChatCompletionStreamResponse, None]:
"""Chat Stream Completion.
Args:
data: dict, The data to send to the API.
Returns:
AsyncGenerator[dict, None]: The chat completion response.
"""
async with self._http_client.stream(
method="POST",
url=self._api_url + "/chat/completions",
json=request.dict(),
json=data,
headers={},
) as response:
if response.status_code == 200:
@@ -250,7 +267,11 @@ class Client:
if line.strip() == "data: [DONE]":
break
if line.startswith("data:"):
json_data = json.loads(line[len("data: ") :])
if line.startswith("data: "):
sse_data = line[len("data: ") :]
else:
sse_data = line[len("data:") :]
json_data = json.loads(sse_data)
chat_completion_response = ChatCompletionStreamResponse(
**json_data
)
@@ -265,21 +286,20 @@ class Client:
except Exception as e:
raise e
async def get(self, path: str, *args):
async def get(self, path: str, *args, **kwargs):
"""Get method.
Args:
path: str, The path to get.
args: Any, The arguments to pass to the get method.
"""
try:
response = await self._http_client.get(
f"{self._api_url}/{CLIENT_SERVE_PATH}{path}",
*args,
)
return response
finally:
await self._http_client.aclose()
kwargs = {k: v for k, v in kwargs.items() if v is not None}
response = await self._http_client.get(
f"{self._api_url}/{CLIENT_SERVE_PATH}{path}",
*args,
params=kwargs,
)
return response
async def post(self, path: str, args):
"""Post method.
@@ -288,13 +308,10 @@ class Client:
path: str, The path to post.
args: Any, The arguments to pass to the post
"""
try:
return await self._http_client.post(
f"{self._api_url}/{CLIENT_SERVE_PATH}{path}",
json=args,
)
finally:
await self._http_client.aclose()
return await self._http_client.post(
f"{self._api_url}/{CLIENT_SERVE_PATH}{path}",
json=args,
)
async def post_param(self, path: str, args):
"""Post method.
@@ -303,13 +320,10 @@ class Client:
path: str, The path to post.
args: Any, The arguments to pass to the post
"""
try:
return await self._http_client.post(
f"{self._api_url}/{CLIENT_SERVE_PATH}{path}",
params=args,
)
finally:
await self._http_client.aclose()
return await self._http_client.post(
f"{self._api_url}/{CLIENT_SERVE_PATH}{path}",
params=args,
)
async def patch(self, path: str, *args):
"""Patch method.
@@ -329,26 +343,20 @@ class Client:
path: str, The path to put.
args: Any, The arguments to pass to the put.
"""
try:
return await self._http_client.put(
f"{self._api_url}/{CLIENT_SERVE_PATH}{path}", json=args
)
finally:
await self._http_client.aclose()
return await self._http_client.put(
f"{self._api_url}/{CLIENT_SERVE_PATH}{path}", json=args
)
async def delete(self, path: str, *args):
"""Delete method.
Args:
path: str, The path to delete.
args: Any, The arguments to pass to the delete.
args: Any, The arguments to pass to delete.
"""
try:
return await self._http_client.delete(
f"{self._api_url}/{CLIENT_SERVE_PATH}{path}", *args
)
finally:
await self._http_client.aclose()
return await self._http_client.delete(
f"{self._api_url}/{CLIENT_SERVE_PATH}{path}", *args
)
async def head(self, path: str, *args):
"""Head method.
@@ -359,6 +367,18 @@ class Client:
"""
return self._http_client.head(self._api_url + path, *args)
def close(self):
"""Close the client."""
from dbgpt.util import get_or_create_event_loop
if not self._http_client.is_closed:
loop = get_or_create_event_loop()
loop.run_until_complete(self._http_client.aclose())
async def aclose(self):
"""Close the client."""
await self._http_client.aclose()
def is_valid_url(api_url: Any) -> bool:
"""Check if the given URL is valid.

View File

@@ -93,19 +93,23 @@ async def get_flow(client: Client, flow_id: str) -> FlowPanel:
raise ClientException(f"Failed to get flow: {e}")
async def list_flow(client: Client) -> List[FlowPanel]:
async def list_flow(
client: Client, name: str | None = None, uid: str | None = None
) -> List[FlowPanel]:
"""
List flows.
Args:
client (Client): The dbgpt client.
name (str): The name of the flow.
uid (str): The uid of the flow.
Returns:
List[FlowPanel]: The list of flow panels.
Raises:
ClientException: If the request failed.
"""
try:
res = await client.get("/awel/flows")
res = await client.get("/awel/flows", **{"name": name, "uid": uid})
result: Result = res.json()
if result["success"]:
return [FlowPanel(**flow) for flow in result["data"]["items"]]