DB-GPT/dbgpt/core/awel/tests/conftest.py
Fangyin Cheng 65c875db20
feat(core): Support higher-order operators (#1984)
Co-authored-by: 谨欣 <echo.cmy@antgroup.com>
2024-09-09 10:15:37 +08:00

136 lines
3.8 KiB
Python

from contextlib import asynccontextmanager
from typing import AsyncIterator, List
import pytest
import pytest_asyncio
from dbgpt.component import SystemApp
from ...interface.variables import (
StorageVariables,
StorageVariablesProvider,
VariablesIdentifier,
)
from .. import DAGVar, DefaultWorkflowRunner, InputOperator, SimpleInputSource
from ..task.task_impl import _is_async_iterator
@pytest.fixture
def runner():
return DefaultWorkflowRunner()
def _create_stream(num_nodes) -> List[AsyncIterator[int]]:
iters = []
for _ in range(num_nodes):
async def stream_iter():
for i in range(10):
yield i
stream_iter = stream_iter()
assert _is_async_iterator(stream_iter)
iters.append(stream_iter)
return iters
def _create_stream_from(output_streams: List[List[int]]) -> List[AsyncIterator[int]]:
iters = []
for single_stream in output_streams:
async def stream_iter():
for i in single_stream:
yield i
stream_iter = stream_iter()
assert _is_async_iterator(stream_iter)
iters.append(stream_iter)
return iters
@asynccontextmanager
async def _create_input_node(**kwargs):
num_nodes = kwargs.get("num_nodes")
is_stream = kwargs.get("is_stream", False)
if is_stream:
outputs = kwargs.get("output_streams")
if outputs:
if num_nodes and num_nodes != len(outputs):
raise ValueError(
f"num_nodes {num_nodes} != the length of output_streams {len(outputs)}"
)
outputs = _create_stream_from(outputs)
else:
num_nodes = num_nodes or 1
outputs = _create_stream(num_nodes)
else:
outputs = kwargs.get("outputs", ["Hello."])
nodes = []
for i, output in enumerate(outputs):
print(f"output: {output}")
input_source = SimpleInputSource(output)
input_node = InputOperator(input_source, task_id="input_node_" + str(i))
nodes.append(input_node)
yield nodes
@pytest_asyncio.fixture
async def input_node(request):
param = getattr(request, "param", {})
async with _create_input_node(**param) as input_nodes:
yield input_nodes[0]
@pytest_asyncio.fixture
async def stream_input_node(request):
param = getattr(request, "param", {})
param["is_stream"] = True
async with _create_input_node(**param) as input_nodes:
yield input_nodes[0]
@pytest_asyncio.fixture
async def input_nodes(request):
param = getattr(request, "param", {})
async with _create_input_node(**param) as input_nodes:
yield input_nodes
@pytest_asyncio.fixture
async def stream_input_nodes(request):
param = getattr(request, "param", {})
param["is_stream"] = True
async with _create_input_node(**param) as input_nodes:
yield input_nodes
@asynccontextmanager
async def _create_variables(**kwargs):
sys_app = SystemApp()
DAGVar.set_current_system_app(sys_app)
vp = StorageVariablesProvider(system_app=sys_app)
vars = kwargs.get("vars")
if vars and isinstance(vars, dict):
for param_key, param_var in vars.items():
key = param_var.get("key")
value = param_var.get("value")
value_type = param_var.get("value_type")
category = param_var.get("category", "common")
id = VariablesIdentifier.from_str_identifier(key)
vp.save(
StorageVariables.from_identifier(
id, value, value_type, label="", category=category
)
)
else:
raise ValueError("vars is required.")
yield vp
@pytest_asyncio.fixture
async def variables_provider(request):
param = getattr(request, "param", {})
async with _create_variables(**param) as vp:
yield vp