mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-12 20:53:48 +00:00
feat(core): Multiple ways to run dbgpts (#1734)
This commit is contained in:
@@ -1,22 +1,55 @@
|
||||
"""CLI for DB-GPT client."""
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict
|
||||
from typing import Any, AsyncIterator, Callable, Dict, Tuple, cast
|
||||
|
||||
import click
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.core.awel import DAG, BaseOperator, DAGVar
|
||||
from dbgpt.core.awel.dag.dag_manager import DAGMetadata, _parse_metadata
|
||||
from dbgpt.core.awel.flow.flow_factory import FlowFactory
|
||||
from dbgpt.util import get_or_create_event_loop
|
||||
from dbgpt.util.console import CliLogger
|
||||
from dbgpt.util.i18n_utils import _
|
||||
|
||||
from .client import Client
|
||||
from .flow import list_flow
|
||||
from .flow import run_flow_cmd as client_run_flow_cmd
|
||||
|
||||
cl = CliLogger()
|
||||
|
||||
_LOCAL_MODE: bool | None = False
|
||||
_FILE_PATH: str | None = None
|
||||
|
||||
|
||||
@click.group()
|
||||
@click.option(
|
||||
"--local",
|
||||
required=False,
|
||||
type=bool,
|
||||
default=False,
|
||||
is_flag=True,
|
||||
help="Whether use local mode(run local AWEL file)",
|
||||
)
|
||||
@click.option(
|
||||
"-f",
|
||||
"--file",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help=_("The path of the AWEL flow"),
|
||||
)
|
||||
def flow(local: bool = False, file: str | None = None):
|
||||
"""Run a AWEL flow."""
|
||||
global _LOCAL_MODE, _FILE_PATH
|
||||
_LOCAL_MODE = local
|
||||
_FILE_PATH = file
|
||||
|
||||
|
||||
def add_base_flow_options(func):
|
||||
"""Add base flow options to the command."""
|
||||
@@ -124,32 +157,229 @@ def add_chat_options(func):
|
||||
return _wrapper
|
||||
|
||||
|
||||
@click.command(name="flow")
|
||||
@flow.command(name="chat")
|
||||
@add_base_flow_options
|
||||
@add_chat_options
|
||||
def run_flow(name: str, uid: str, data: str, interactive: bool, **kwargs):
|
||||
def run_flow_chat(name: str, uid: str, data: str, interactive: bool, **kwargs):
|
||||
"""Run a AWEL flow."""
|
||||
json_data = _parse_chat_json_data(data, **kwargs)
|
||||
stream = "stream" in json_data and str(json_data["stream"]).lower() in ["true", "1"]
|
||||
loop = get_or_create_event_loop()
|
||||
if _LOCAL_MODE:
|
||||
_run_flow_chat_local(loop, name, interactive, json_data, stream)
|
||||
return
|
||||
|
||||
client = Client()
|
||||
|
||||
loop = get_or_create_event_loop()
|
||||
res = loop.run_until_complete(list_flow(client, name, uid))
|
||||
# AWEL flow store the python module name now, so we need to replace "-" with "_"
|
||||
new_name = name.replace("-", "_")
|
||||
res = loop.run_until_complete(list_flow(client, new_name, uid))
|
||||
|
||||
if not res:
|
||||
cl.error("Flow not found with the given name or uid", exit_code=1)
|
||||
if len(res) > 1:
|
||||
cl.error("More than one flow found", exit_code=1)
|
||||
flow = res[0]
|
||||
json_data = _parse_json_data(data, **kwargs)
|
||||
json_data["chat_param"] = flow.uid
|
||||
json_data["chat_mode"] = "chat_flow"
|
||||
stream = "stream" in json_data and str(json_data["stream"]).lower() in ["true", "1"]
|
||||
if stream:
|
||||
loop.run_until_complete(_chat_stream(client, interactive, json_data))
|
||||
_run_flow_chat_stream(loop, client, interactive, json_data)
|
||||
else:
|
||||
loop.run_until_complete(_chat(client, interactive, json_data))
|
||||
_run_flow_chat(loop, client, interactive, json_data)
|
||||
|
||||
|
||||
def _parse_json_data(data: str, **kwargs):
|
||||
@flow.command(name="cmd")
|
||||
@add_base_flow_options
|
||||
@click.option(
|
||||
"-d",
|
||||
"--data",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help=_("The json data to run AWEL flow, if set, will overwrite other options"),
|
||||
)
|
||||
@click.option(
|
||||
"--output_key",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help=_(
|
||||
"The output key of the AWEL flow, if set, it will try to get the output by the "
|
||||
"key"
|
||||
),
|
||||
)
|
||||
def run_flow_cmd(
|
||||
name: str, uid: str, data: str | None = None, output_key: str | None = None
|
||||
):
|
||||
"""Run a AWEL flow with command mode."""
|
||||
json_data = _parse_json_data(data)
|
||||
loop = get_or_create_event_loop()
|
||||
|
||||
if _LOCAL_MODE:
|
||||
_run_flow_cmd_local(loop, name, json_data, output_key)
|
||||
else:
|
||||
_run_flow_cmd(loop, name, uid, json_data, output_key)
|
||||
|
||||
|
||||
def _run_flow_cmd_local(
|
||||
loop: asyncio.BaseEventLoop,
|
||||
name: str,
|
||||
data: Dict[str, Any] | None = None,
|
||||
output_key: str | None = None,
|
||||
):
|
||||
from dbgpt.core.awel.util.chat_util import safe_chat_stream_with_dag_task
|
||||
|
||||
end_node, dag, dag_metadata, call_body = _parse_and_check_local_dag(
|
||||
name, _FILE_PATH, data
|
||||
)
|
||||
|
||||
async def _streaming_call():
|
||||
start_time = time.time()
|
||||
try:
|
||||
cl.debug("[~info] Flow started")
|
||||
cl.debug(f"[~info] JSON data: {json.dumps(data, ensure_ascii=False)}")
|
||||
cl.debug("Command output: ")
|
||||
async for out in safe_chat_stream_with_dag_task(
|
||||
end_node, call_body, incremental=True, covert_to_str=True
|
||||
):
|
||||
if not out.success:
|
||||
cl.error(out.text)
|
||||
else:
|
||||
cl.print(out.text, end="")
|
||||
except Exception as e:
|
||||
cl.error(f"Failed to run flow: {e}", exit_code=1)
|
||||
finally:
|
||||
time_cost = round(time.time() - start_time, 2)
|
||||
cl.success(f"\n:tada: Flow finished, timecost: {time_cost} s")
|
||||
|
||||
loop.run_until_complete(_streaming_call())
|
||||
|
||||
|
||||
def _run_flow_cmd(
|
||||
loop: asyncio.BaseEventLoop,
|
||||
name: str | None = None,
|
||||
uid: str | None = None,
|
||||
json_data: Dict[str, Any] | None = None,
|
||||
output_key: str | None = None,
|
||||
):
|
||||
client = Client()
|
||||
|
||||
def _non_streaming_callback(text: str):
|
||||
parsed_text: Any = None
|
||||
if output_key:
|
||||
try:
|
||||
json_out = json.loads(text)
|
||||
parsed_text = json_out.get(output_key)
|
||||
except Exception as e:
|
||||
cl.warning(f"Failed to parse output by key: {output_key}, {e}")
|
||||
if not parsed_text:
|
||||
parsed_text = text
|
||||
cl.markdown(parsed_text)
|
||||
|
||||
def _streaming_callback(text: str):
|
||||
cl.print(text, end="")
|
||||
|
||||
async def _client_run_cmd():
|
||||
cl.debug("[~info] Flow started")
|
||||
cl.debug(f"[~info] JSON data: {json.dumps(json_data, ensure_ascii=False)}")
|
||||
cl.debug("Command output: ")
|
||||
start_time = time.time()
|
||||
# AWEL flow store the python module name now, so we need to replace "-" with "_"
|
||||
new_name = name.replace("-", "_")
|
||||
try:
|
||||
await client_run_flow_cmd(
|
||||
client,
|
||||
new_name,
|
||||
uid,
|
||||
json_data,
|
||||
non_streaming_callback=_non_streaming_callback,
|
||||
streaming_callback=_streaming_callback,
|
||||
)
|
||||
except Exception as e:
|
||||
cl.error(f"Failed to run flow: {e}", exit_code=1)
|
||||
finally:
|
||||
time_cost = round(time.time() - start_time, 2)
|
||||
cl.success(f"\n:tada: Flow finished, timecost: {time_cost} s")
|
||||
|
||||
loop.run_until_complete(_client_run_cmd())
|
||||
|
||||
|
||||
def _parse_and_check_local_dag(
|
||||
name: str,
|
||||
filepath: str | None = None,
|
||||
data: Dict[str, Any] | None = None,
|
||||
) -> Tuple[BaseOperator, DAG, DAGMetadata, Any]:
|
||||
|
||||
dag, dag_metadata = _parse_local_dag(name, filepath)
|
||||
|
||||
return _check_local_dag(dag, dag_metadata, data)
|
||||
|
||||
|
||||
def _check_local_dag(
|
||||
dag: DAG, dag_metadata: DAGMetadata, data: Dict[str, Any] | None = None
|
||||
) -> Tuple[BaseOperator, DAG, DAGMetadata, Any]:
|
||||
from dbgpt.core.awel import HttpTrigger
|
||||
|
||||
leaf_nodes = dag.leaf_nodes
|
||||
if not leaf_nodes:
|
||||
cl.error("No leaf nodes found in the flow", exit_code=1)
|
||||
if len(leaf_nodes) > 1:
|
||||
cl.error("More than one leaf nodes found in the flow", exit_code=1)
|
||||
if not isinstance(leaf_nodes[0], BaseOperator):
|
||||
cl.error("Unsupported leaf node type", exit_code=1)
|
||||
end_node = cast(BaseOperator, leaf_nodes[0])
|
||||
call_body: Any = data
|
||||
trigger_nodes = dag.trigger_nodes
|
||||
if trigger_nodes:
|
||||
if len(trigger_nodes) > 1:
|
||||
cl.error("More than one trigger nodes found in the flow", exit_code=1)
|
||||
trigger = trigger_nodes[0]
|
||||
if isinstance(trigger, HttpTrigger):
|
||||
http_trigger = trigger
|
||||
if http_trigger._req_body and data:
|
||||
call_body = http_trigger._req_body(**data)
|
||||
else:
|
||||
cl.error("Unsupported trigger type", exit_code=1)
|
||||
return end_node, dag, dag_metadata, call_body
|
||||
|
||||
|
||||
def _parse_local_dag(name: str, filepath: str | None = None) -> Tuple[DAG, DAGMetadata]:
|
||||
|
||||
system_app = SystemApp()
|
||||
DAGVar.set_current_system_app(system_app)
|
||||
|
||||
if not filepath:
|
||||
# Load DAG from installed package(dbgpts)
|
||||
from dbgpt.util.dbgpts.loader import (
|
||||
_flow_package_to_flow_panel,
|
||||
_load_flow_package_from_path,
|
||||
)
|
||||
|
||||
flow_panel = _flow_package_to_flow_panel(_load_flow_package_from_path(name))
|
||||
if flow_panel.define_type == "json":
|
||||
factory = FlowFactory()
|
||||
factory.pre_load_requirements(flow_panel)
|
||||
dag = factory.build(flow_panel)
|
||||
else:
|
||||
dag = flow_panel.flow_dag
|
||||
return dag, _parse_metadata(dag)
|
||||
else:
|
||||
from dbgpt.core.awel.dag.loader import _process_file
|
||||
|
||||
dags = _process_file(filepath)
|
||||
if not dags:
|
||||
cl.error("No DAG found in the file", exit_code=1)
|
||||
if len(dags) > 1:
|
||||
dags = [dag for dag in dags if dag.dag_id == name]
|
||||
# Filter by name
|
||||
if len(dags) > 1:
|
||||
cl.error("More than one DAG found in the file", exit_code=1)
|
||||
if not dags:
|
||||
cl.error("No DAG found with the given name", exit_code=1)
|
||||
return dags[0], _parse_metadata(dags[0])
|
||||
|
||||
|
||||
def _parse_chat_json_data(data: str, **kwargs):
|
||||
json_data = {}
|
||||
if data:
|
||||
try:
|
||||
@@ -170,7 +400,100 @@ def _parse_json_data(data: str, **kwargs):
|
||||
return json_data
|
||||
|
||||
|
||||
async def _chat_stream(client: Client, interactive: bool, json_data: Dict[str, Any]):
|
||||
def _parse_json_data(data: str | None) -> Dict[str, Any] | None:
|
||||
if not data:
|
||||
return None
|
||||
try:
|
||||
return json.loads(data)
|
||||
except Exception as e:
|
||||
cl.error(f"Invalid JSON data: {data}, {e}", exit_code=1)
|
||||
# Should not reach here
|
||||
return None
|
||||
|
||||
|
||||
def _run_flow_chat_local(
|
||||
loop: asyncio.BaseEventLoop,
|
||||
name: str,
|
||||
interactive: bool,
|
||||
json_data: Dict[str, Any],
|
||||
stream: bool,
|
||||
):
|
||||
from dbgpt.core.awel.util.chat_util import (
|
||||
parse_single_output,
|
||||
safe_chat_stream_with_dag_task,
|
||||
)
|
||||
|
||||
dag, dag_metadata = _parse_local_dag(name, _FILE_PATH)
|
||||
|
||||
async def _streaming_call(_call_body: Dict[str, Any]):
|
||||
nonlocal dag, dag_metadata
|
||||
|
||||
end_node, dag, dag_metadata, handled_call_body = _check_local_dag(
|
||||
dag, dag_metadata, _call_body
|
||||
)
|
||||
async for out in safe_chat_stream_with_dag_task(
|
||||
end_node, handled_call_body, incremental=True, covert_to_str=True
|
||||
):
|
||||
if not out.success:
|
||||
cl.error(f"Error: {out.text}")
|
||||
raise Exception(out.text)
|
||||
else:
|
||||
yield out.text
|
||||
|
||||
async def _call(_call_body: Dict[str, Any]):
|
||||
nonlocal dag, dag_metadata
|
||||
|
||||
end_node, dag, dag_metadata, handled_call_body = _check_local_dag(
|
||||
dag, dag_metadata, _call_body
|
||||
)
|
||||
res = await end_node.call(handled_call_body)
|
||||
parsed_res = parse_single_output(res, is_sse=False, covert_to_str=True)
|
||||
if not parsed_res.success:
|
||||
raise Exception(parsed_res.text)
|
||||
return parsed_res.text
|
||||
|
||||
if stream:
|
||||
loop.run_until_complete(_chat_stream(_streaming_call, interactive, json_data))
|
||||
else:
|
||||
loop.run_until_complete(_chat(_call, interactive, json_data))
|
||||
|
||||
|
||||
def _run_flow_chat_stream(
|
||||
loop: asyncio.BaseEventLoop,
|
||||
client: Client,
|
||||
interactive: bool,
|
||||
json_data: Dict[str, Any],
|
||||
):
|
||||
async def _streaming_call(_call_body: Dict[str, Any]):
|
||||
async for out in client.chat_stream(**_call_body):
|
||||
if out.choices:
|
||||
text = out.choices[0].delta.content
|
||||
if text:
|
||||
yield text
|
||||
|
||||
loop.run_until_complete(_chat_stream(_streaming_call, interactive, json_data))
|
||||
|
||||
|
||||
def _run_flow_chat(
|
||||
loop: asyncio.BaseEventLoop,
|
||||
client: Client,
|
||||
interactive: bool,
|
||||
json_data: Dict[str, Any],
|
||||
):
|
||||
async def _call(_call_body: Dict[str, Any]):
|
||||
res = await client.chat(**_call_body)
|
||||
if res.choices:
|
||||
text = res.choices[0].message.content
|
||||
return text
|
||||
|
||||
loop.run_until_complete(_chat(_call, interactive, json_data))
|
||||
|
||||
|
||||
async def _chat_stream(
|
||||
streaming_func: Callable[[Dict[str, Any]], AsyncIterator[str]],
|
||||
interactive: bool,
|
||||
json_data: Dict[str, Any],
|
||||
):
|
||||
user_input = json_data.get("messages", "")
|
||||
if "conv_uid" not in json_data and interactive:
|
||||
json_data["conv_uid"] = str(uuid.uuid4())
|
||||
@@ -187,16 +510,14 @@ async def _chat_stream(client: Client, interactive: bool, json_data: Dict[str, A
|
||||
json_data["messages"] = user_input
|
||||
if first_message:
|
||||
cl.info("You: " + user_input)
|
||||
cl.info("Chat stream started")
|
||||
cl.debug(f"JSON data: {json.dumps(json_data, ensure_ascii=False)}")
|
||||
cl.debug("[~info] Chat stream started")
|
||||
cl.debug(f"[~info] JSON data: {json.dumps(json_data, ensure_ascii=False)}")
|
||||
full_text = ""
|
||||
cl.print("Bot: ")
|
||||
async for out in client.chat_stream(**json_data):
|
||||
if out.choices:
|
||||
text = out.choices[0].delta.content
|
||||
if text:
|
||||
full_text += text
|
||||
cl.print(text, end="")
|
||||
async for text in streaming_func(json_data):
|
||||
if text:
|
||||
full_text += text
|
||||
cl.print(text, end="")
|
||||
end_time = time.time()
|
||||
time_cost = round(end_time - start_time, 2)
|
||||
cl.success(f"\n:tada: Chat stream finished, timecost: {time_cost} s")
|
||||
@@ -210,7 +531,11 @@ async def _chat_stream(client: Client, interactive: bool, json_data: Dict[str, A
|
||||
break
|
||||
|
||||
|
||||
async def _chat(client: Client, interactive: bool, json_data: Dict[str, Any]):
|
||||
async def _chat(
|
||||
func: Callable[[Dict[str, Any]], Any],
|
||||
interactive: bool,
|
||||
json_data: Dict[str, Any],
|
||||
):
|
||||
user_input = json_data.get("messages", "")
|
||||
if "conv_uid" not in json_data and interactive:
|
||||
json_data["conv_uid"] = str(uuid.uuid4())
|
||||
@@ -228,17 +553,19 @@ async def _chat(client: Client, interactive: bool, json_data: Dict[str, Any]):
|
||||
if first_message:
|
||||
cl.info("You: " + user_input)
|
||||
|
||||
cl.info("Chat started")
|
||||
cl.debug(f"JSON data: {json.dumps(json_data, ensure_ascii=False)}")
|
||||
res = await client.chat(**json_data)
|
||||
cl.debug("[~info] Chat started")
|
||||
cl.debug(f"[~info] JSON data: {json.dumps(json_data, ensure_ascii=False)}")
|
||||
res = await func(json_data)
|
||||
cl.print("Bot: ")
|
||||
if res.choices:
|
||||
text = res.choices[0].message.content
|
||||
cl.markdown(text)
|
||||
if res:
|
||||
cl.markdown(res)
|
||||
time_cost = round(time.time() - start_time, 2)
|
||||
cl.success(f"\n:tada: Chat stream finished, timecost: {time_cost} s")
|
||||
except Exception as e:
|
||||
cl.error(f"Chat failed: {e}", exit_code=1)
|
||||
import traceback
|
||||
|
||||
messages = traceback.format_exc()
|
||||
cl.error(f"Chat failed: {e}\n, error detail: {messages}", exit_code=1)
|
||||
finally:
|
||||
first_message = False
|
||||
if interactive:
|
||||
|
@@ -1,4 +1,5 @@
|
||||
"""This module contains the client for the DB-GPT API."""
|
||||
|
||||
import atexit
|
||||
import json
|
||||
import os
|
||||
@@ -102,6 +103,15 @@ class Client:
|
||||
)
|
||||
atexit.register(self.close)
|
||||
|
||||
def _base_url(self):
|
||||
parsed_url = urlparse(self._api_url)
|
||||
host = parsed_url.hostname
|
||||
scheme = parsed_url.scheme
|
||||
port = parsed_url.port
|
||||
if port:
|
||||
return f"{scheme}://{host}:{port}"
|
||||
return f"{scheme}://{host}"
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
model: str,
|
||||
|
@@ -1,5 +1,8 @@
|
||||
"""this module contains the flow client functions."""
|
||||
from typing import List
|
||||
|
||||
from typing import Any, Callable, Dict, List
|
||||
|
||||
from httpx import AsyncClient
|
||||
|
||||
from dbgpt.core.awel.flow.flow_factory import FlowPanel
|
||||
from dbgpt.core.schema.api import Result
|
||||
@@ -117,3 +120,181 @@ async def list_flow(
|
||||
raise ClientException(status=result["err_code"], reason=result)
|
||||
except Exception as e:
|
||||
raise ClientException(f"Failed to list flows: {e}")
|
||||
|
||||
|
||||
async def run_flow_cmd(
|
||||
client: Client,
|
||||
name: str | None = None,
|
||||
uid: str | None = None,
|
||||
data: Dict[str, Any] | None = None,
|
||||
non_streaming_callback: Callable[[str], None] | None = None,
|
||||
streaming_callback: Callable[[str], None] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Run flows.
|
||||
|
||||
Args:
|
||||
client (Client): The dbgpt client.
|
||||
name (str): The name of the flow.
|
||||
uid (str): The uid of the flow.
|
||||
data (Dict[str, Any]): The data to run the flow.
|
||||
non_streaming_callback (Callable[[str], None]): The non-streaming callback.
|
||||
streaming_callback (Callable[[str], None]): The streaming callback.
|
||||
Returns:
|
||||
List[FlowPanel]: The list of flow panels.
|
||||
Raises:
|
||||
ClientException: If the request failed.
|
||||
"""
|
||||
try:
|
||||
res = await client.get("/awel/flows", **{"name": name, "uid": uid})
|
||||
result: Result = res.json()
|
||||
if not result["success"]:
|
||||
raise ClientException("Flow not found with the given name or uid")
|
||||
flows = result["data"]["items"]
|
||||
if not flows:
|
||||
raise ClientException("Flow not found with the given name or uid")
|
||||
if len(flows) > 1:
|
||||
raise ClientException("More than one flow found")
|
||||
flow = flows[0]
|
||||
flow_panel = FlowPanel(**flow)
|
||||
metadata = flow.get("metadata")
|
||||
await _run_flow_trigger(
|
||||
client,
|
||||
flow_panel,
|
||||
metadata,
|
||||
data,
|
||||
non_streaming_callback=non_streaming_callback,
|
||||
streaming_callback=streaming_callback,
|
||||
)
|
||||
except Exception as e:
|
||||
raise ClientException(f"Failed to run flows: {e}")
|
||||
|
||||
|
||||
async def _run_flow_trigger(
|
||||
client: Client,
|
||||
flow: FlowPanel,
|
||||
metadata: Dict[str, Any] | None = None,
|
||||
data: Dict[str, Any] | None = None,
|
||||
non_streaming_callback: Callable[[str], None] | None = None,
|
||||
streaming_callback: Callable[[str], None] | None = None,
|
||||
):
|
||||
if not metadata:
|
||||
raise ClientException("No AWEL flow metadata found")
|
||||
if "triggers" not in metadata:
|
||||
raise ClientException("No triggers found in AWEL flow metadata")
|
||||
triggers = metadata["triggers"]
|
||||
if len(triggers) > 1:
|
||||
raise ClientException("More than one trigger found")
|
||||
trigger = triggers[0]
|
||||
sse_output = metadata.get("sse_output", False)
|
||||
streaming_output = metadata.get("streaming_output", False)
|
||||
trigger_type = trigger["trigger_type"]
|
||||
if trigger_type == "http":
|
||||
methods = trigger["methods"]
|
||||
if not methods:
|
||||
method = "GET"
|
||||
else:
|
||||
method = methods[0]
|
||||
path = trigger["path"]
|
||||
base_url = client._base_url()
|
||||
req_url = f"{base_url}{path}"
|
||||
if streaming_output:
|
||||
await _call_stream_request(
|
||||
client._http_client,
|
||||
method,
|
||||
req_url,
|
||||
sse_output,
|
||||
data,
|
||||
streaming_callback,
|
||||
)
|
||||
elif non_streaming_callback:
|
||||
await _call_non_stream_request(
|
||||
client._http_client, method, req_url, data, non_streaming_callback
|
||||
)
|
||||
else:
|
||||
raise ClientException(f"Invalid trigger type: {trigger_type}")
|
||||
|
||||
|
||||
async def _call_non_stream_request(
|
||||
http_client: AsyncClient,
|
||||
method: str,
|
||||
base_url: str,
|
||||
data: Dict[str, Any] | None = None,
|
||||
non_streaming_callback: Callable[[str], None] | None = None,
|
||||
):
|
||||
import httpx
|
||||
|
||||
kwargs: Dict[str, Any] = {"url": base_url, "method": method}
|
||||
if method in ["POST", "PUT"]:
|
||||
kwargs["json"] = data
|
||||
else:
|
||||
kwargs["params"] = data
|
||||
response = await http_client.request(**kwargs)
|
||||
bytes_response_content = await response.aread()
|
||||
if response.status_code != 200:
|
||||
str_error_message = ""
|
||||
error_message = await response.aread()
|
||||
if error_message:
|
||||
str_error_message = error_message.decode("utf-8")
|
||||
raise httpx.RequestError(
|
||||
f"Request failed with status {response.status_code}, error_message: "
|
||||
f"{str_error_message}",
|
||||
request=response.request,
|
||||
)
|
||||
response_content = bytes_response_content.decode("utf-8")
|
||||
if non_streaming_callback:
|
||||
non_streaming_callback(response_content)
|
||||
return response_content
|
||||
|
||||
|
||||
async def _call_stream_request(
|
||||
http_client: AsyncClient,
|
||||
method: str,
|
||||
base_url: str,
|
||||
sse_output: bool,
|
||||
data: Dict[str, Any] | None = None,
|
||||
streaming_callback: Callable[[str], None] | None = None,
|
||||
):
|
||||
full_out = ""
|
||||
async for out in _stream_request(http_client, method, base_url, sse_output, data):
|
||||
if streaming_callback:
|
||||
streaming_callback(out)
|
||||
full_out += out
|
||||
return full_out
|
||||
|
||||
|
||||
async def _stream_request(
|
||||
http_client: AsyncClient,
|
||||
method: str,
|
||||
base_url: str,
|
||||
sse_output: bool,
|
||||
data: Dict[str, Any] | None = None,
|
||||
):
|
||||
import json
|
||||
|
||||
from dbgpt.core.awel.util.chat_util import parse_openai_output
|
||||
|
||||
kwargs: Dict[str, Any] = {"url": base_url, "method": method}
|
||||
if method in ["POST", "PUT"]:
|
||||
kwargs["json"] = data
|
||||
else:
|
||||
kwargs["params"] = data
|
||||
|
||||
async with http_client.stream(**kwargs) as response:
|
||||
if response.status_code == 200:
|
||||
async for line in response.aiter_lines():
|
||||
if not line:
|
||||
continue
|
||||
if sse_output:
|
||||
out = parse_openai_output(line)
|
||||
if not out.success:
|
||||
raise ClientException(f"Failed to parse output: {out.text}")
|
||||
yield out.text
|
||||
else:
|
||||
yield line
|
||||
else:
|
||||
try:
|
||||
error = await response.aread()
|
||||
yield json.loads(error)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
Reference in New Issue
Block a user