feat: Run AWEL flow in CLI (#1341)

This commit is contained in:
Fangyin Cheng
2024-03-27 12:50:05 +08:00
committed by GitHub
parent 340a9fbc35
commit 3a7a2cbbb8
42 changed files with 1454 additions and 422 deletions

View File

@@ -1,7 +1,8 @@
"""This module contains the client for the DB-GPT API."""
import atexit
import json
import os
from typing import Any, AsyncGenerator, List, Optional, Union
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
from urllib.parse import urlparse
import httpx
@@ -98,6 +99,7 @@ class Client:
self._http_client = httpx.AsyncClient(
headers=headers, timeout=timeout if timeout else httpx.Timeout(None)
)
atexit.register(self.close)
async def chat(
self,
@@ -113,6 +115,7 @@ class Client:
span_id: Optional[str] = None,
incremental: bool = True,
enable_vis: bool = True,
**kwargs,
) -> ChatCompletionResponse:
"""
Chat Completion.
@@ -187,6 +190,7 @@ class Client:
span_id: Optional[str] = None,
incremental: bool = True,
enable_vis: bool = True,
**kwargs,
) -> AsyncGenerator[ChatCompletionStreamResponse, None]:
"""
Chat Stream Completion.
@@ -238,10 +242,23 @@ class Client:
incremental=incremental,
enable_vis=enable_vis,
)
async for chat_completion_response in self._chat_stream(request.dict()):
yield chat_completion_response
async def _chat_stream(
self, data: Dict[str, Any]
) -> AsyncGenerator[ChatCompletionStreamResponse, None]:
"""Chat Stream Completion.
Args:
data: dict, The data to send to the API.
Returns:
AsyncGenerator[dict, None]: The chat completion response.
"""
async with self._http_client.stream(
method="POST",
url=self._api_url + "/chat/completions",
json=request.dict(),
json=data,
headers={},
) as response:
if response.status_code == 200:
@@ -250,7 +267,11 @@ class Client:
if line.strip() == "data: [DONE]":
break
if line.startswith("data:"):
json_data = json.loads(line[len("data: ") :])
if line.startswith("data: "):
sse_data = line[len("data: ") :]
else:
sse_data = line[len("data:") :]
json_data = json.loads(sse_data)
chat_completion_response = ChatCompletionStreamResponse(
**json_data
)
@@ -265,21 +286,20 @@ class Client:
except Exception as e:
raise e
async def get(self, path: str, *args):
async def get(self, path: str, *args, **kwargs):
"""Get method.
Args:
path: str, The path to get.
args: Any, The arguments to pass to the get method.
"""
try:
response = await self._http_client.get(
f"{self._api_url}/{CLIENT_SERVE_PATH}{path}",
*args,
)
return response
finally:
await self._http_client.aclose()
kwargs = {k: v for k, v in kwargs.items() if v is not None}
response = await self._http_client.get(
f"{self._api_url}/{CLIENT_SERVE_PATH}{path}",
*args,
params=kwargs,
)
return response
async def post(self, path: str, args):
"""Post method.
@@ -288,13 +308,10 @@ class Client:
path: str, The path to post.
args: Any, The arguments to pass to the post
"""
try:
return await self._http_client.post(
f"{self._api_url}/{CLIENT_SERVE_PATH}{path}",
json=args,
)
finally:
await self._http_client.aclose()
return await self._http_client.post(
f"{self._api_url}/{CLIENT_SERVE_PATH}{path}",
json=args,
)
async def post_param(self, path: str, args):
"""Post method.
@@ -303,13 +320,10 @@ class Client:
path: str, The path to post.
args: Any, The arguments to pass to the post
"""
try:
return await self._http_client.post(
f"{self._api_url}/{CLIENT_SERVE_PATH}{path}",
params=args,
)
finally:
await self._http_client.aclose()
return await self._http_client.post(
f"{self._api_url}/{CLIENT_SERVE_PATH}{path}",
params=args,
)
async def patch(self, path: str, *args):
"""Patch method.
@@ -329,26 +343,20 @@ class Client:
path: str, The path to put.
args: Any, The arguments to pass to the put.
"""
try:
return await self._http_client.put(
f"{self._api_url}/{CLIENT_SERVE_PATH}{path}", json=args
)
finally:
await self._http_client.aclose()
return await self._http_client.put(
f"{self._api_url}/{CLIENT_SERVE_PATH}{path}", json=args
)
async def delete(self, path: str, *args):
"""Delete method.
Args:
path: str, The path to delete.
args: Any, The arguments to pass to the delete.
args: Any, The arguments to pass to delete.
"""
try:
return await self._http_client.delete(
f"{self._api_url}/{CLIENT_SERVE_PATH}{path}", *args
)
finally:
await self._http_client.aclose()
return await self._http_client.delete(
f"{self._api_url}/{CLIENT_SERVE_PATH}{path}", *args
)
async def head(self, path: str, *args):
"""Head method.
@@ -359,6 +367,18 @@ class Client:
"""
return self._http_client.head(self._api_url + path, *args)
def close(self):
"""Close the client."""
from dbgpt.util import get_or_create_event_loop
if not self._http_client.is_closed:
loop = get_or_create_event_loop()
loop.run_until_complete(self._http_client.aclose())
async def aclose(self):
"""Close the client."""
await self._http_client.aclose()
def is_valid_url(api_url: Any) -> bool:
"""Check if the given URL is valid.