mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-05 10:29:36 +00:00
136 lines
3.8 KiB
Python
136 lines
3.8 KiB
Python
from contextlib import asynccontextmanager
|
|
from typing import AsyncIterator, List
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
|
|
from dbgpt.component import SystemApp
|
|
|
|
from ...interface.variables import (
|
|
StorageVariables,
|
|
StorageVariablesProvider,
|
|
VariablesIdentifier,
|
|
)
|
|
from .. import DAGVar, DefaultWorkflowRunner, InputOperator, SimpleInputSource
|
|
from ..task.task_impl import _is_async_iterator
|
|
|
|
|
|
@pytest.fixture
|
|
def runner():
|
|
return DefaultWorkflowRunner()
|
|
|
|
|
|
def _create_stream(num_nodes) -> List[AsyncIterator[int]]:
|
|
iters = []
|
|
for _ in range(num_nodes):
|
|
|
|
async def stream_iter():
|
|
for i in range(10):
|
|
yield i
|
|
|
|
stream_iter = stream_iter()
|
|
assert _is_async_iterator(stream_iter)
|
|
iters.append(stream_iter)
|
|
return iters
|
|
|
|
|
|
def _create_stream_from(output_streams: List[List[int]]) -> List[AsyncIterator[int]]:
|
|
iters = []
|
|
for single_stream in output_streams:
|
|
|
|
async def stream_iter():
|
|
for i in single_stream:
|
|
yield i
|
|
|
|
stream_iter = stream_iter()
|
|
assert _is_async_iterator(stream_iter)
|
|
iters.append(stream_iter)
|
|
return iters
|
|
|
|
|
|
@asynccontextmanager
|
|
async def _create_input_node(**kwargs):
|
|
num_nodes = kwargs.get("num_nodes")
|
|
is_stream = kwargs.get("is_stream", False)
|
|
if is_stream:
|
|
outputs = kwargs.get("output_streams")
|
|
if outputs:
|
|
if num_nodes and num_nodes != len(outputs):
|
|
raise ValueError(
|
|
f"num_nodes {num_nodes} != the length of output_streams {len(outputs)}"
|
|
)
|
|
outputs = _create_stream_from(outputs)
|
|
else:
|
|
num_nodes = num_nodes or 1
|
|
outputs = _create_stream(num_nodes)
|
|
else:
|
|
outputs = kwargs.get("outputs", ["Hello."])
|
|
nodes = []
|
|
for i, output in enumerate(outputs):
|
|
print(f"output: {output}")
|
|
input_source = SimpleInputSource(output)
|
|
input_node = InputOperator(input_source, task_id="input_node_" + str(i))
|
|
nodes.append(input_node)
|
|
yield nodes
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def input_node(request):
|
|
param = getattr(request, "param", {})
|
|
async with _create_input_node(**param) as input_nodes:
|
|
yield input_nodes[0]
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def stream_input_node(request):
|
|
param = getattr(request, "param", {})
|
|
param["is_stream"] = True
|
|
async with _create_input_node(**param) as input_nodes:
|
|
yield input_nodes[0]
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def input_nodes(request):
|
|
param = getattr(request, "param", {})
|
|
async with _create_input_node(**param) as input_nodes:
|
|
yield input_nodes
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def stream_input_nodes(request):
|
|
param = getattr(request, "param", {})
|
|
param["is_stream"] = True
|
|
async with _create_input_node(**param) as input_nodes:
|
|
yield input_nodes
|
|
|
|
|
|
@asynccontextmanager
|
|
async def _create_variables(**kwargs):
|
|
sys_app = SystemApp()
|
|
DAGVar.set_current_system_app(sys_app)
|
|
vp = StorageVariablesProvider(system_app=sys_app)
|
|
vars = kwargs.get("vars")
|
|
if vars and isinstance(vars, dict):
|
|
for param_key, param_var in vars.items():
|
|
key = param_var.get("key")
|
|
value = param_var.get("value")
|
|
value_type = param_var.get("value_type")
|
|
category = param_var.get("category", "common")
|
|
id = VariablesIdentifier.from_str_identifier(key)
|
|
vp.save(
|
|
StorageVariables.from_identifier(
|
|
id, value, value_type, label="", category=category
|
|
)
|
|
)
|
|
else:
|
|
raise ValueError("vars is required.")
|
|
|
|
yield vp
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def variables_provider(request):
|
|
param = getattr(request, "param", {})
|
|
async with _create_variables(**param) as vp:
|
|
yield vp
|