mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-31 15:47:05 +00:00
feat(awel): AWEL supports http trigger and add some AWEL examples (#815)
- AWEL supports http trigger. - Disassemble the KBQA into some atomic operator. - Add some AWEL examples.
This commit is contained in:
commit
2ce77519de
54
examples/awel/simple_chat_dag_example.py
Normal file
54
examples/awel/simple_chat_dag_example.py
Normal file
@ -0,0 +1,54 @@
|
||||
"""AWEL: Simple chat dag example
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
curl -X POST http://127.0.0.1:5000/api/v1/awel/trigger/examples/simple_chat \
|
||||
-H "Content-Type: application/json" -d '{
|
||||
"model": "proxyllm",
|
||||
"user_input": "hello"
|
||||
}'
|
||||
"""
|
||||
from typing import Dict
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from pilot.awel import DAG, HttpTrigger, MapOperator
|
||||
from pilot.scene.base_message import ModelMessage
|
||||
from pilot.model.base import ModelOutput
|
||||
from pilot.model.operator.model_operator import ModelOperator
|
||||
|
||||
|
||||
class TriggerReqBody(BaseModel):
|
||||
model: str = Field(..., description="Model name")
|
||||
user_input: str = Field(..., description="User input")
|
||||
|
||||
|
||||
class RequestHandleOperator(MapOperator[TriggerReqBody, Dict]):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def map(self, input_value: TriggerReqBody) -> Dict:
|
||||
hist = []
|
||||
hist.append(ModelMessage.build_human_message(input_value.user_input))
|
||||
hist = list(h.dict() for h in hist)
|
||||
params = {
|
||||
"prompt": input_value.user_input,
|
||||
"messages": hist,
|
||||
"model": input_value.model,
|
||||
"echo": False,
|
||||
}
|
||||
print(f"Receive input value: {input_value}")
|
||||
return params
|
||||
|
||||
|
||||
with DAG("dbgpt_awel_simple_dag_example") as dag:
|
||||
# Receive http request and trigger dag to run.
|
||||
trigger = HttpTrigger(
|
||||
"/examples/simple_chat", methods="POST", request_body=TriggerReqBody
|
||||
)
|
||||
request_handle_task = RequestHandleOperator()
|
||||
model_task = ModelOperator()
|
||||
# type(out) == ModelOutput
|
||||
model_parse_task = MapOperator(lambda out: out.to_dict())
|
||||
trigger >> request_handle_task >> model_task >> model_parse_task
|
32
examples/awel/simple_dag_example.py
Normal file
32
examples/awel/simple_dag_example.py
Normal file
@ -0,0 +1,32 @@
|
||||
"""AWEL: Simple dag example
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
curl -X GET http://127.0.0.1:5000/api/v1/awel/trigger/examples/hello\?name\=zhangsan
|
||||
|
||||
"""
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from pilot.awel import DAG, HttpTrigger, MapOperator
|
||||
|
||||
|
||||
class TriggerReqBody(BaseModel):
|
||||
name: str = Field(..., description="User name")
|
||||
age: int = Field(18, description="User age")
|
||||
|
||||
|
||||
class RequestHandleOperator(MapOperator[TriggerReqBody, str]):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def map(self, input_value: TriggerReqBody) -> str:
|
||||
print(f"Receive input value: {input_value}")
|
||||
return f"Hello, {input_value.name}, your age is {input_value.age}"
|
||||
|
||||
|
||||
with DAG("simple_dag_example") as dag:
|
||||
trigger = HttpTrigger("/examples/hello", request_body=TriggerReqBody)
|
||||
map_node = RequestHandleOperator()
|
||||
trigger >> map_node
|
70
examples/awel/simple_rag_example.py
Normal file
70
examples/awel/simple_rag_example.py
Normal file
@ -0,0 +1,70 @@
|
||||
"""AWEL: Simple rag example
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
curl -X POST http://127.0.0.1:5000/api/v1/awel/trigger/examples/simple_rag \
|
||||
-H "Content-Type: application/json" -d '{
|
||||
"conv_uid": "36f0e992-8825-11ee-8638-0242ac150003",
|
||||
"model_name": "proxyllm",
|
||||
"chat_mode": "chat_knowledge",
|
||||
"user_input": "What is DB-GPT?",
|
||||
"select_param": "default"
|
||||
}'
|
||||
|
||||
"""
|
||||
|
||||
from pilot.awel import HttpTrigger, DAG, MapOperator
|
||||
from pilot.scene.operator._experimental import (
|
||||
ChatContext,
|
||||
PromptManagerOperator,
|
||||
ChatHistoryStorageOperator,
|
||||
ChatHistoryOperator,
|
||||
EmbeddingEngingOperator,
|
||||
BaseChatOperator,
|
||||
)
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.openapi.api_view_model import ConversationVo
|
||||
from pilot.model.base import ModelOutput
|
||||
from pilot.model.operator.model_operator import ModelOperator
|
||||
|
||||
|
||||
class RequestParseOperator(MapOperator[ConversationVo, ChatContext]):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def map(self, input_value: ConversationVo) -> ChatContext:
|
||||
return ChatContext(
|
||||
current_user_input=input_value.user_input,
|
||||
model_name=input_value.model_name,
|
||||
chat_session_id=input_value.conv_uid,
|
||||
select_param=input_value.select_param,
|
||||
chat_scene=ChatScene.ChatKnowledge,
|
||||
)
|
||||
|
||||
|
||||
with DAG("simple_rag_example") as dag:
|
||||
trigger_task = HttpTrigger(
|
||||
"/examples/simple_rag", methods="POST", request_body=ConversationVo
|
||||
)
|
||||
req_parse_task = RequestParseOperator()
|
||||
prompt_task = PromptManagerOperator()
|
||||
history_storage_task = ChatHistoryStorageOperator()
|
||||
history_task = ChatHistoryOperator()
|
||||
embedding_task = EmbeddingEngingOperator()
|
||||
chat_task = BaseChatOperator()
|
||||
model_task = ModelOperator()
|
||||
output_parser_task = MapOperator(lambda out: out.to_dict()["text"])
|
||||
|
||||
(
|
||||
trigger_task
|
||||
>> req_parse_task
|
||||
>> prompt_task
|
||||
>> history_storage_task
|
||||
>> history_task
|
||||
>> embedding_task
|
||||
>> chat_task
|
||||
>> model_task
|
||||
>> output_parser_task
|
||||
)
|
@ -1,8 +1,17 @@
|
||||
"""Agentic Workflow Expression Language (AWEL)"""
|
||||
"""Agentic Workflow Expression Language (AWEL)
|
||||
|
||||
Note:
|
||||
|
||||
AWEL is still an experimental feature and only opens the lowest level API.
|
||||
The stability of this API cannot be guaranteed at present.
|
||||
|
||||
"""
|
||||
|
||||
from pilot.component import SystemApp
|
||||
|
||||
from .dag.base import DAGContext, DAG
|
||||
|
||||
from .operator.base import BaseOperator, WorkflowRunner, initialize_awel
|
||||
from .operator.base import BaseOperator, WorkflowRunner
|
||||
from .operator.common_operator import (
|
||||
JoinOperator,
|
||||
ReduceStreamOperator,
|
||||
@ -28,6 +37,7 @@ from .task.task_impl import (
|
||||
SimpleStreamTaskOutput,
|
||||
_is_async_iterator,
|
||||
)
|
||||
from .trigger.http_trigger import HttpTrigger
|
||||
from .runner.local_runner import DefaultWorkflowRunner
|
||||
|
||||
__all__ = [
|
||||
@ -57,4 +67,21 @@ __all__ = [
|
||||
"StreamifyAbsOperator",
|
||||
"UnstreamifyAbsOperator",
|
||||
"TransformStreamAbsOperator",
|
||||
"HttpTrigger",
|
||||
]
|
||||
|
||||
|
||||
def initialize_awel(system_app: SystemApp, dag_filepath: str):
|
||||
from .dag.dag_manager import DAGManager
|
||||
from .dag.base import DAGVar
|
||||
from .trigger.trigger_manager import DefaultTriggerManager
|
||||
from .operator.base import initialize_runner
|
||||
|
||||
DAGVar.set_current_system_app(system_app)
|
||||
|
||||
system_app.register(DefaultTriggerManager)
|
||||
dag_manager = DAGManager(system_app, dag_filepath)
|
||||
system_app.register_instance(dag_manager)
|
||||
initialize_runner(DefaultWorkflowRunner())
|
||||
# Load all dags
|
||||
dag_manager.load_dags()
|
||||
|
7
pilot/awel/base.py
Normal file
7
pilot/awel/base.py
Normal file
@ -0,0 +1,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class Trigger(ABC):
|
||||
@abstractmethod
|
||||
async def trigger(self) -> None:
|
||||
"""Trigger the workflow or a specific operation in the workflow."""
|
@ -1,14 +1,20 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Dict, List, Sequence, Union, Any
|
||||
from typing import Optional, Dict, List, Sequence, Union, Any, Set
|
||||
import uuid
|
||||
import contextvars
|
||||
import threading
|
||||
import asyncio
|
||||
import logging
|
||||
from collections import deque
|
||||
from functools import cache
|
||||
from concurrent.futures import Executor
|
||||
|
||||
from pilot.component import SystemApp
|
||||
from ..resource.base import ResourceGroup
|
||||
from ..task.base import TaskContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DependencyType = Union["DependencyMixin", Sequence["DependencyMixin"]]
|
||||
|
||||
|
||||
@ -96,6 +102,8 @@ class DependencyMixin(ABC):
|
||||
class DAGVar:
|
||||
_thread_local = threading.local()
|
||||
_async_local = contextvars.ContextVar("current_dag_stack", default=deque())
|
||||
_system_app: SystemApp = None
|
||||
_executor: Executor = None
|
||||
|
||||
@classmethod
|
||||
def enter_dag(cls, dag) -> None:
|
||||
@ -138,18 +146,48 @@ class DAGVar:
|
||||
return cls._thread_local.current_dag_stack[-1]
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_current_system_app(cls) -> SystemApp:
|
||||
if not cls._system_app:
|
||||
raise RuntimeError("System APP not set for DAGVar")
|
||||
return cls._system_app
|
||||
|
||||
@classmethod
|
||||
def set_current_system_app(cls, system_app: SystemApp) -> None:
|
||||
if cls._system_app:
|
||||
logger.warn("System APP has already set, nothing to do")
|
||||
else:
|
||||
cls._system_app = system_app
|
||||
|
||||
@classmethod
|
||||
def get_executor(cls) -> Executor:
|
||||
return cls._executor
|
||||
|
||||
@classmethod
|
||||
def set_executor(cls, executor: Executor) -> None:
|
||||
cls._executor = executor
|
||||
|
||||
|
||||
class DAGNode(DependencyMixin, ABC):
|
||||
resource_group: Optional[ResourceGroup] = None
|
||||
"""The resource group of current DAGNode"""
|
||||
|
||||
def __init__(
|
||||
self, dag: Optional["DAG"] = None, node_id: str = None, node_name: str = None
|
||||
self,
|
||||
dag: Optional["DAG"] = None,
|
||||
node_id: Optional[str] = None,
|
||||
node_name: Optional[str] = None,
|
||||
system_app: Optional[SystemApp] = None,
|
||||
executor: Optional[Executor] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._upstream: List["DAGNode"] = []
|
||||
self._downstream: List["DAGNode"] = []
|
||||
self._dag: Optional["DAG"] = dag or DAGVar.get_current_dag()
|
||||
self._system_app: Optional[SystemApp] = (
|
||||
system_app or DAGVar.get_current_system_app()
|
||||
)
|
||||
self._executor: Optional[Executor] = executor or DAGVar.get_executor()
|
||||
if not node_id and self._dag:
|
||||
node_id = self._dag._new_node_id()
|
||||
self._node_id: str = node_id
|
||||
@ -159,6 +197,10 @@ class DAGNode(DependencyMixin, ABC):
|
||||
def node_id(self) -> str:
|
||||
return self._node_id
|
||||
|
||||
@property
|
||||
def system_app(self) -> SystemApp:
|
||||
return self._system_app
|
||||
|
||||
def set_node_id(self, node_id: str) -> None:
|
||||
self._node_id = node_id
|
||||
|
||||
@ -178,7 +220,7 @@ class DAGNode(DependencyMixin, ABC):
|
||||
return self._node_name
|
||||
|
||||
@property
|
||||
def dag(self) -> "DAGNode":
|
||||
def dag(self) -> "DAG":
|
||||
return self._dag
|
||||
|
||||
def set_upstream(self, nodes: DependencyType) -> "DAGNode":
|
||||
@ -254,17 +296,69 @@ class DAG:
|
||||
def __init__(
|
||||
self, dag_id: str, resource_group: Optional[ResourceGroup] = None
|
||||
) -> None:
|
||||
self._dag_id = dag_id
|
||||
self.node_map: Dict[str, DAGNode] = {}
|
||||
self._root_nodes: Set[DAGNode] = None
|
||||
self._leaf_nodes: Set[DAGNode] = None
|
||||
self._trigger_nodes: Set[DAGNode] = None
|
||||
|
||||
def _append_node(self, node: DAGNode) -> None:
|
||||
self.node_map[node.node_id] = node
|
||||
# clear cached nodes
|
||||
self._root_nodes = None
|
||||
self._leaf_nodes = None
|
||||
|
||||
def _new_node_id(self) -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
@property
|
||||
def dag_id(self) -> str:
|
||||
return self._dag_id
|
||||
|
||||
def _build(self) -> None:
|
||||
from ..operator.common_operator import TriggerOperator
|
||||
|
||||
nodes = set()
|
||||
for _, node in self.node_map.items():
|
||||
nodes = nodes.union(_get_nodes(node))
|
||||
self._root_nodes = list(set(filter(lambda x: not x.upstream, nodes)))
|
||||
self._leaf_nodes = list(set(filter(lambda x: not x.downstream, nodes)))
|
||||
self._trigger_nodes = list(
|
||||
set(filter(lambda x: isinstance(x, TriggerOperator), nodes))
|
||||
)
|
||||
|
||||
@property
|
||||
def root_nodes(self) -> List[DAGNode]:
|
||||
if not self._root_nodes:
|
||||
self._build()
|
||||
return self._root_nodes
|
||||
|
||||
@property
|
||||
def leaf_nodes(self) -> List[DAGNode]:
|
||||
if not self._leaf_nodes:
|
||||
self._build()
|
||||
return self._leaf_nodes
|
||||
|
||||
@property
|
||||
def trigger_nodes(self):
|
||||
if not self._trigger_nodes:
|
||||
self._build()
|
||||
return self._trigger_nodes
|
||||
|
||||
def __enter__(self):
|
||||
DAGVar.enter_dag(self)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
DAGVar.exit_dag()
|
||||
|
||||
|
||||
def _get_nodes(node: DAGNode, is_upstream: Optional[bool] = True) -> set[DAGNode]:
|
||||
nodes = set()
|
||||
if not node:
|
||||
return nodes
|
||||
nodes.add(node)
|
||||
stream_nodes = node.upstream if is_upstream else node.downstream
|
||||
for node in stream_nodes:
|
||||
nodes = nodes.union(_get_nodes(node, is_upstream))
|
||||
return nodes
|
||||
|
42
pilot/awel/dag/dag_manager.py
Normal file
42
pilot/awel/dag/dag_manager.py
Normal file
@ -0,0 +1,42 @@
|
||||
from typing import Dict, Optional
|
||||
import logging
|
||||
from pilot.component import BaseComponent, ComponentType, SystemApp
|
||||
from .loader import DAGLoader, LocalFileDAGLoader
|
||||
from .base import DAG
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DAGManager(BaseComponent):
|
||||
name = ComponentType.AWEL_DAG_MANAGER
|
||||
|
||||
def __init__(self, system_app: SystemApp, dag_filepath: str):
|
||||
super().__init__(system_app)
|
||||
self.dag_loader = LocalFileDAGLoader(dag_filepath)
|
||||
self.system_app = system_app
|
||||
self.dag_map: Dict[str, DAG] = {}
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
self.system_app = system_app
|
||||
|
||||
def load_dags(self):
|
||||
dags = self.dag_loader.load_dags()
|
||||
triggers = []
|
||||
for dag in dags:
|
||||
dag_id = dag.dag_id
|
||||
if dag_id in self.dag_map:
|
||||
raise ValueError(f"Load DAG error, DAG ID {dag_id} has already exist")
|
||||
triggers += dag.trigger_nodes
|
||||
from ..trigger.trigger_manager import DefaultTriggerManager
|
||||
|
||||
trigger_manager: DefaultTriggerManager = self.system_app.get_component(
|
||||
ComponentType.AWEL_TRIGGER_MANAGER,
|
||||
DefaultTriggerManager,
|
||||
default_component=None,
|
||||
)
|
||||
if trigger_manager:
|
||||
for trigger in triggers:
|
||||
trigger_manager.register_trigger(trigger)
|
||||
trigger_manager.after_register()
|
||||
else:
|
||||
logger.warn("No trigger manager, not register dag trigger")
|
93
pilot/awel/dag/loader.py
Normal file
93
pilot/awel/dag/loader.py
Normal file
@ -0,0 +1,93 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
import os
|
||||
import hashlib
|
||||
import sys
|
||||
import logging
|
||||
import traceback
|
||||
|
||||
from .base import DAG
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DAGLoader(ABC):
|
||||
@abstractmethod
|
||||
def load_dags(self) -> List[DAG]:
|
||||
"""Load dags"""
|
||||
|
||||
|
||||
class LocalFileDAGLoader(DAGLoader):
|
||||
def __init__(self, filepath: str) -> None:
|
||||
super().__init__()
|
||||
self._filepath = filepath
|
||||
|
||||
def load_dags(self) -> List[DAG]:
|
||||
if not os.path.exists(self._filepath):
|
||||
return []
|
||||
if os.path.isdir(self._filepath):
|
||||
return _process_directory(self._filepath)
|
||||
else:
|
||||
return _process_file(self._filepath)
|
||||
|
||||
|
||||
def _process_directory(directory: str) -> List[DAG]:
|
||||
dags = []
|
||||
for file in os.listdir(directory):
|
||||
if file.endswith(".py"):
|
||||
filepath = os.path.join(directory, file)
|
||||
dags += _process_file(filepath)
|
||||
return dags
|
||||
|
||||
|
||||
def _process_file(filepath) -> List[DAG]:
|
||||
mods = _load_modules_from_file(filepath)
|
||||
results = _process_modules(mods)
|
||||
return results
|
||||
|
||||
|
||||
def _load_modules_from_file(filepath: str):
|
||||
import importlib
|
||||
import importlib.machinery
|
||||
import importlib.util
|
||||
|
||||
logger.info(f"Importing {filepath}")
|
||||
|
||||
org_mod_name, _ = os.path.splitext(os.path.split(filepath)[-1])
|
||||
path_hash = hashlib.sha1(filepath.encode("utf-8")).hexdigest()
|
||||
mod_name = f"unusual_prefix_{path_hash}_{org_mod_name}"
|
||||
|
||||
if mod_name in sys.modules:
|
||||
del sys.modules[mod_name]
|
||||
|
||||
def parse(mod_name, filepath):
|
||||
try:
|
||||
loader = importlib.machinery.SourceFileLoader(mod_name, filepath)
|
||||
spec = importlib.util.spec_from_loader(mod_name, loader)
|
||||
new_module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[spec.name] = new_module
|
||||
loader.exec_module(new_module)
|
||||
return [new_module]
|
||||
except Exception as e:
|
||||
msg = traceback.format_exc()
|
||||
logger.error(f"Failed to import: {filepath}, error message: {msg}")
|
||||
# TODO save error message
|
||||
return []
|
||||
|
||||
return parse(mod_name, filepath)
|
||||
|
||||
|
||||
def _process_modules(mods) -> List[DAG]:
|
||||
top_level_dags = (
|
||||
(o, m) for m in mods for o in m.__dict__.values() if isinstance(o, DAG)
|
||||
)
|
||||
found_dags = []
|
||||
for dag, mod in top_level_dags:
|
||||
try:
|
||||
# TODO validate dag params
|
||||
logger.info(f"Found dag {dag} from mod {mod} and model file {mod.__file__}")
|
||||
found_dags.append(dag)
|
||||
except Exception:
|
||||
msg = traceback.format_exc()
|
||||
logger.error(f"Failed to dag file, error message: {msg}")
|
||||
return found_dags
|
@ -14,6 +14,13 @@ from typing import (
|
||||
)
|
||||
import functools
|
||||
from inspect import signature
|
||||
from pilot.component import SystemApp, ComponentType
|
||||
from pilot.utils.executor_utils import (
|
||||
ExecutorFactory,
|
||||
DefaultExecutorFactory,
|
||||
blocking_func_to_async,
|
||||
BlockingFunction,
|
||||
)
|
||||
|
||||
from ..dag.base import DAGNode, DAGContext, DAGVar, DAG
|
||||
from ..task.base import (
|
||||
@ -67,6 +74,19 @@ class BaseOperatorMeta(ABCMeta):
|
||||
def apply_defaults(self: "BaseOperator", *args: Any, **kwargs: Any) -> Any:
|
||||
dag: Optional[DAG] = kwargs.get("dag") or DAGVar.get_current_dag()
|
||||
task_id: Optional[str] = kwargs.get("task_id")
|
||||
system_app: Optional[SystemApp] = (
|
||||
kwargs.get("system_app") or DAGVar.get_current_system_app()
|
||||
)
|
||||
executor = kwargs.get("executor") or DAGVar.get_executor()
|
||||
if not executor:
|
||||
if system_app:
|
||||
executor = system_app.get_component(
|
||||
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
|
||||
).create()
|
||||
else:
|
||||
executor = DefaultExecutorFactory().create()
|
||||
DAGVar.set_executor(executor)
|
||||
|
||||
if not task_id and dag:
|
||||
task_id = dag._new_node_id()
|
||||
runner: Optional[WorkflowRunner] = kwargs.get("runner") or default_runner
|
||||
@ -80,6 +100,10 @@ class BaseOperatorMeta(ABCMeta):
|
||||
kwargs["task_id"] = task_id
|
||||
if not kwargs.get("runner"):
|
||||
kwargs["runner"] = runner
|
||||
if not kwargs.get("system_app"):
|
||||
kwargs["system_app"] = system_app
|
||||
if not kwargs.get("executor"):
|
||||
kwargs["executor"] = executor
|
||||
real_obj = func(self, *args, **kwargs)
|
||||
return real_obj
|
||||
|
||||
@ -171,7 +195,12 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
out_ctx = await self._runner.execute_workflow(self, call_data)
|
||||
return out_ctx.current_task_context.task_output.output_stream
|
||||
|
||||
async def blocking_func_to_async(
|
||||
self, func: BlockingFunction, *args, **kwargs
|
||||
) -> Any:
|
||||
return await blocking_func_to_async(self._executor, func, *args, **kwargs)
|
||||
|
||||
def initialize_awel(runner: WorkflowRunner):
|
||||
|
||||
def initialize_runner(runner: WorkflowRunner):
|
||||
global default_runner
|
||||
default_runner = runner
|
||||
|
@ -237,3 +237,10 @@ class InputOperator(BaseOperator, Generic[OUT]):
|
||||
task_output = await self._input_source.read(curr_task_ctx)
|
||||
curr_task_ctx.set_task_output(task_output)
|
||||
return task_output
|
||||
|
||||
|
||||
class TriggerOperator(InputOperator, Generic[OUT]):
|
||||
def __init__(self, **kwargs) -> None:
|
||||
from ..task.task_impl import SimpleCallDataInputSource
|
||||
|
||||
super().__init__(input_source=SimpleCallDataInputSource(), **kwargs)
|
||||
|
@ -3,7 +3,7 @@ import logging
|
||||
|
||||
from ..dag.base import DAGContext
|
||||
from ..operator.base import WorkflowRunner, BaseOperator, CALL_DATA
|
||||
from ..operator.common_operator import BranchOperator, JoinOperator
|
||||
from ..operator.common_operator import BranchOperator, JoinOperator, TriggerOperator
|
||||
from ..task.base import TaskContext, TaskState
|
||||
from ..task.task_impl import DefaultInputContext, DefaultTaskContext, SimpleTaskOutput
|
||||
from .job_manager import JobManager
|
||||
@ -67,7 +67,7 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
||||
node_outputs[node.node_id] = task_ctx
|
||||
return
|
||||
try:
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"Begin run operator, node id: {node.node_id}, node name: {node.node_name}, cls: {node}"
|
||||
)
|
||||
await node._run(dag_ctx)
|
||||
@ -76,7 +76,7 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
||||
|
||||
if isinstance(node, BranchOperator):
|
||||
skip_nodes = task_ctx.metadata.get("skip_node_names", [])
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"Current is branch operator, skip node names: {skip_nodes}"
|
||||
)
|
||||
_skip_current_downstream_by_node_name(node, skip_nodes, skip_node_ids)
|
||||
|
15
pilot/awel/trigger/base.py
Normal file
15
pilot/awel/trigger/base.py
Normal file
@ -0,0 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from ..operator.base import BaseOperator
|
||||
from ..operator.common_operator import TriggerOperator
|
||||
from ..dag.base import DAGContext
|
||||
from ..task.base import TaskOutput
|
||||
|
||||
|
||||
class Trigger(TriggerOperator, ABC):
|
||||
@abstractmethod
|
||||
async def trigger(self, end_operator: "BaseOperator") -> None:
|
||||
"""Trigger the workflow or a specific operation in the workflow."""
|
117
pilot/awel/trigger/http_trigger.py
Normal file
117
pilot/awel/trigger/http_trigger.py
Normal file
@ -0,0 +1,117 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Union, Type, List, TYPE_CHECKING, Optional, Any, Dict
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from pydantic import BaseModel
|
||||
import logging
|
||||
|
||||
from .base import Trigger
|
||||
from ..operator.base import BaseOperator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import APIRouter, FastAPI
|
||||
|
||||
RequestBody = Union[Request, Type[BaseModel], str]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HttpTrigger(Trigger):
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: str,
|
||||
methods: Optional[Union[str, List[str]]] = "GET",
|
||||
request_body: Optional[RequestBody] = None,
|
||||
streaming_response: Optional[bool] = False,
|
||||
response_model: Optional[Type] = None,
|
||||
response_headers: Optional[Dict[str, str]] = None,
|
||||
response_media_type: Optional[str] = None,
|
||||
status_code: Optional[int] = 200,
|
||||
router_tags: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
if not endpoint.startswith("/"):
|
||||
endpoint = "/" + endpoint
|
||||
self._endpoint = endpoint
|
||||
self._methods = methods
|
||||
self._req_body = request_body
|
||||
self._streaming_response = streaming_response
|
||||
self._response_model = response_model
|
||||
self._status_code = status_code
|
||||
self._router_tags = router_tags
|
||||
self._response_headers = response_headers
|
||||
self._response_media_type = response_media_type
|
||||
self._end_node: BaseOperator = None
|
||||
|
||||
async def trigger(self) -> None:
|
||||
pass
|
||||
|
||||
def mount_to_router(self, router: "APIRouter") -> None:
|
||||
from fastapi import Depends
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
methods = self._methods if isinstance(self._methods, list) else [self._methods]
|
||||
|
||||
def create_route_function(name):
|
||||
async def _request_body_dependency(request: Request):
|
||||
return await _parse_request_body(request, self._req_body)
|
||||
|
||||
async def route_function(body: Any = Depends(_request_body_dependency)):
|
||||
end_node = self.dag.leaf_nodes
|
||||
if len(end_node) != 1:
|
||||
raise ValueError("HttpTrigger just support one leaf node in dag")
|
||||
end_node = end_node[0]
|
||||
if not self._streaming_response:
|
||||
return await end_node.call(call_data={"data": body})
|
||||
else:
|
||||
headers = self._response_headers
|
||||
media_type = (
|
||||
self._response_media_type
|
||||
if self._response_media_type
|
||||
else "text/event-stream"
|
||||
)
|
||||
if not headers:
|
||||
headers = {
|
||||
"Content-Type": "text/event-stream",
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"Transfer-Encoding": "chunked",
|
||||
}
|
||||
return StreamingResponse(
|
||||
end_node.call_stream(call_data={"data": body}),
|
||||
headers=headers,
|
||||
media_type=media_type,
|
||||
)
|
||||
|
||||
route_function.__name__ = name
|
||||
return route_function
|
||||
|
||||
function_name = f"dynamic_route_{self._endpoint.replace('/', '_')}"
|
||||
dynamic_route_function = create_route_function(function_name)
|
||||
logger.info(
|
||||
f"mount router function {dynamic_route_function}({function_name}), endpoint: {self._endpoint}, methods: {methods}"
|
||||
)
|
||||
|
||||
router.api_route(
|
||||
self._endpoint,
|
||||
methods=methods,
|
||||
response_model=self._response_model,
|
||||
status_code=self._status_code,
|
||||
tags=self._router_tags,
|
||||
)(dynamic_route_function)
|
||||
|
||||
|
||||
async def _parse_request_body(
|
||||
request: Request, request_body_cls: Optional[Type[BaseModel]]
|
||||
):
|
||||
if not request_body_cls:
|
||||
return None
|
||||
if request.method == "POST":
|
||||
json_data = await request.json()
|
||||
return request_body_cls(**json_data)
|
||||
elif request.method == "GET":
|
||||
return request_body_cls(**request.query_params)
|
||||
else:
|
||||
return request
|
74
pilot/awel/trigger/trigger_manager.py
Normal file
74
pilot/awel/trigger/trigger_manager.py
Normal file
@ -0,0 +1,74 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, TYPE_CHECKING, Optional
|
||||
import logging
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import APIRouter
|
||||
|
||||
from pilot.component import SystemApp, BaseComponent, ComponentType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TriggerManager(ABC):
|
||||
@abstractmethod
|
||||
def register_trigger(self, trigger: Any) -> None:
|
||||
""" "Register a trigger to current manager"""
|
||||
|
||||
|
||||
class HttpTriggerManager(TriggerManager):
|
||||
def __init__(
|
||||
self,
|
||||
router: Optional["APIRouter"] = None,
|
||||
router_prefix: Optional[str] = "/api/v1/awel/trigger",
|
||||
) -> None:
|
||||
if not router:
|
||||
from fastapi import APIRouter
|
||||
|
||||
router = APIRouter()
|
||||
self._router_prefix = router_prefix
|
||||
self._router = router
|
||||
self._trigger_map = {}
|
||||
|
||||
def register_trigger(self, trigger: Any) -> None:
|
||||
from .http_trigger import HttpTrigger
|
||||
|
||||
if not isinstance(trigger, HttpTrigger):
|
||||
raise ValueError(f"Current trigger {trigger} not an object of HttpTrigger")
|
||||
trigger: HttpTrigger = trigger
|
||||
trigger_id = trigger.node_id
|
||||
if trigger_id not in self._trigger_map:
|
||||
trigger.mount_to_router(self._router)
|
||||
self._trigger_map[trigger_id] = trigger
|
||||
|
||||
def _init_app(self, system_app: SystemApp):
|
||||
logger.info(
|
||||
f"Include router {self._router} to prefix path {self._router_prefix}"
|
||||
)
|
||||
system_app.app.include_router(
|
||||
self._router, prefix=self._router_prefix, tags=["AWEL"]
|
||||
)
|
||||
|
||||
|
||||
class DefaultTriggerManager(TriggerManager, BaseComponent):
|
||||
name = ComponentType.AWEL_TRIGGER_MANAGER
|
||||
|
||||
def __init__(self, system_app: SystemApp | None = None):
|
||||
self.system_app = system_app
|
||||
self.http_trigger = HttpTriggerManager()
|
||||
super().__init__(None)
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
self.system_app = system_app
|
||||
|
||||
def register_trigger(self, trigger: Any) -> None:
|
||||
from .http_trigger import HttpTrigger
|
||||
|
||||
if isinstance(trigger, HttpTrigger):
|
||||
logger.info(f"Register trigger {trigger}")
|
||||
self.http_trigger.register_trigger(trigger)
|
||||
else:
|
||||
raise ValueError(f"Unsupport trigger: {trigger}")
|
||||
|
||||
def after_register(self) -> None:
|
||||
self.http_trigger._init_app(self.system_app)
|
@ -54,6 +54,8 @@ class ComponentType(str, Enum):
|
||||
TRACER = "dbgpt_tracer"
|
||||
TRACER_SPAN_STORAGE = "dbgpt_tracer_span_storage"
|
||||
RAG_GRAPH_DEFAULT = "dbgpt_rag_engine_default"
|
||||
AWEL_TRIGGER_MANAGER = "dbgpt_awel_trigger_manager"
|
||||
AWEL_DAG_MANAGER = "dbgpt_awel_dag_manager"
|
||||
|
||||
|
||||
class BaseComponent(LifeCycle, ABC):
|
||||
|
@ -16,6 +16,7 @@ DATA_DIR = os.path.join(PILOT_PATH, "data")
|
||||
PLUGINS_DIR = os.path.join(ROOT_PATH, "plugins")
|
||||
FONT_DIR = os.path.join(PILOT_PATH, "fonts")
|
||||
MODEL_DISK_CACHE_DIR = os.path.join(DATA_DIR, "model_cache")
|
||||
_DAG_DEFINITION_DIR = os.path.join(ROOT_PATH, "examples/awel")
|
||||
|
||||
current_directory = os.getcwd()
|
||||
|
||||
|
@ -47,7 +47,7 @@ class DbHistoryMemory(BaseChatHistoryMemory):
|
||||
logger.error("init create conversation log error!" + str(e))
|
||||
|
||||
def append(self, once_message: OnceConversation) -> None:
|
||||
logger.info(f"db history append: {once_message}")
|
||||
logger.debug(f"db history append: {once_message}")
|
||||
chat_history: ChatHistoryEntity = self.chat_history_dao.get_by_uid(
|
||||
self.chat_seesion_id
|
||||
)
|
||||
|
@ -7,9 +7,10 @@ from pilot.awel import (
|
||||
MapOperator,
|
||||
TransformStreamAbsOperator,
|
||||
)
|
||||
from pilot.component import ComponentType
|
||||
from pilot.awel.operator.base import BaseOperator
|
||||
from pilot.model.base import ModelOutput
|
||||
from pilot.model.cluster import WorkerManager
|
||||
from pilot.model.cluster import WorkerManager, WorkerManagerFactory
|
||||
from pilot.cache import LLMCacheClient, CacheManager, LLMCacheKey, LLMCacheValue
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -29,7 +30,7 @@ class ModelStreamOperator(StreamifyAbsOperator[Dict, ModelOutput]):
|
||||
streamify: Asynchronously processes a stream of inputs, yielding model outputs.
|
||||
"""
|
||||
|
||||
def __init__(self, worker_manager: WorkerManager, **kwargs) -> None:
|
||||
def __init__(self, worker_manager: WorkerManager = None, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.worker_manager = worker_manager
|
||||
|
||||
@ -42,6 +43,10 @@ class ModelStreamOperator(StreamifyAbsOperator[Dict, ModelOutput]):
|
||||
Returns:
|
||||
AsyncIterator[ModelOutput]: An asynchronous iterator of model outputs.
|
||||
"""
|
||||
if not self.worker_manager:
|
||||
self.worker_manager = self.system_app.get_component(
|
||||
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||
).create()
|
||||
async for out in self.worker_manager.generate_stream(input_value):
|
||||
yield out
|
||||
|
||||
@ -57,9 +62,9 @@ class ModelOperator(MapOperator[Dict, ModelOutput]):
|
||||
map: Asynchronously processes a single input and returns the model output.
|
||||
"""
|
||||
|
||||
def __init__(self, worker_manager: WorkerManager, **kwargs) -> None:
|
||||
self.worker_manager = worker_manager
|
||||
def __init__(self, worker_manager: WorkerManager = None, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.worker_manager = worker_manager
|
||||
|
||||
async def map(self, input_value: Dict) -> ModelOutput:
|
||||
"""Process a single input and return the model output.
|
||||
@ -70,6 +75,10 @@ class ModelOperator(MapOperator[Dict, ModelOutput]):
|
||||
Returns:
|
||||
ModelOutput: The output from the model.
|
||||
"""
|
||||
if not self.worker_manager:
|
||||
self.worker_manager = self.system_app.get_component(
|
||||
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||
).create()
|
||||
return await self.worker_manager.generate(input_value)
|
||||
|
||||
|
||||
|
@ -143,9 +143,7 @@ def _build_request(model: ProxyModel, params):
|
||||
proxyllm_backend = proxyllm_backend or "gpt-3.5-turbo"
|
||||
payloads["model"] = proxyllm_backend
|
||||
|
||||
logger.info(
|
||||
f"Send request to real model {proxyllm_backend}, openai_params: {openai_params}"
|
||||
)
|
||||
logger.info(f"Send request to real model {proxyllm_backend}")
|
||||
return history, payloads
|
||||
|
||||
|
||||
|
@ -68,7 +68,7 @@ class BaseChat(ABC):
|
||||
CFG.prompt_template_registry.get_prompt_template(
|
||||
self.chat_mode.value(),
|
||||
language=CFG.LANGUAGE,
|
||||
model_name=CFG.LLM_MODEL,
|
||||
model_name=self.llm_model,
|
||||
proxyllm_backend=CFG.PROXYLLM_BACKEND,
|
||||
)
|
||||
)
|
||||
@ -141,13 +141,7 @@ class BaseChat(ABC):
|
||||
return speak_to_user
|
||||
|
||||
async def __call_base(self):
|
||||
import inspect
|
||||
|
||||
input_values = (
|
||||
await self.generate_input_values()
|
||||
if inspect.isawaitable(self.generate_input_values())
|
||||
else self.generate_input_values()
|
||||
)
|
||||
input_values = await self.generate_input_values()
|
||||
### Chat sequence advance
|
||||
self.current_message.chat_order = len(self.history_message) + 1
|
||||
self.current_message.add_user_message(self.current_user_input)
|
||||
@ -379,16 +373,18 @@ class BaseChat(ABC):
|
||||
if self.prompt_template.template_define:
|
||||
text += self.prompt_template.template_define + self.prompt_template.sep
|
||||
### Load prompt
|
||||
text += self.__load_system_message()
|
||||
text += _load_system_message(self.current_message, self.prompt_template)
|
||||
|
||||
### Load examples
|
||||
text += self.__load_example_messages()
|
||||
text += _load_example_messages(self.prompt_template)
|
||||
|
||||
### Load History
|
||||
text += self.__load_history_messages()
|
||||
text += _load_history_messages(
|
||||
self.prompt_template, self.history_message, self.chat_retention_rounds
|
||||
)
|
||||
|
||||
### Load User Input
|
||||
text += self.__load_user_message()
|
||||
text += _load_user_message(self.current_message, self.prompt_template)
|
||||
return text
|
||||
|
||||
def generate_llm_messages(self) -> List[ModelMessage]:
|
||||
@ -406,137 +402,26 @@ class BaseChat(ABC):
|
||||
)
|
||||
)
|
||||
### Load prompt
|
||||
messages += self.__load_system_message(str_message=False)
|
||||
messages += _load_system_message(
|
||||
self.current_message, self.prompt_template, str_message=False
|
||||
)
|
||||
### Load examples
|
||||
messages += self.__load_example_messages(str_message=False)
|
||||
messages += _load_example_messages(self.prompt_template, str_message=False)
|
||||
|
||||
### Load History
|
||||
messages += self.__load_history_messages(str_message=False)
|
||||
messages += _load_history_messages(
|
||||
self.prompt_template,
|
||||
self.history_message,
|
||||
self.chat_retention_rounds,
|
||||
str_message=False,
|
||||
)
|
||||
|
||||
### Load User Input
|
||||
messages += self.__load_user_message(str_message=False)
|
||||
messages += _load_user_message(
|
||||
self.current_message, self.prompt_template, str_message=False
|
||||
)
|
||||
return messages
|
||||
|
||||
def __load_system_message(self, str_message: bool = True):
|
||||
system_convs = self.current_message.get_system_conv()
|
||||
system_text = ""
|
||||
system_messages = []
|
||||
for system_conv in system_convs:
|
||||
system_text += (
|
||||
system_conv.type + ":" + system_conv.content + self.prompt_template.sep
|
||||
)
|
||||
system_messages.append(
|
||||
ModelMessage(role=system_conv.type, content=system_conv.content)
|
||||
)
|
||||
return system_text if str_message else system_messages
|
||||
|
||||
def __load_user_message(self, str_message: bool = True):
|
||||
user_conv = self.current_message.get_user_conv()
|
||||
user_messages = []
|
||||
if user_conv:
|
||||
user_text = (
|
||||
user_conv.type + ":" + user_conv.content + self.prompt_template.sep
|
||||
)
|
||||
user_messages.append(
|
||||
ModelMessage(role=user_conv.type, content=user_conv.content)
|
||||
)
|
||||
return user_text if str_message else user_messages
|
||||
else:
|
||||
raise ValueError("Hi! What do you want to talk about?")
|
||||
|
||||
def __load_example_messages(self, str_message: bool = True):
|
||||
example_text = ""
|
||||
example_messages = []
|
||||
if self.prompt_template.example_selector:
|
||||
for round_conv in self.prompt_template.example_selector.examples():
|
||||
for round_message in round_conv["messages"]:
|
||||
if not round_message["type"] in [
|
||||
ModelMessageRoleType.VIEW,
|
||||
ModelMessageRoleType.SYSTEM,
|
||||
]:
|
||||
message_type = round_message["type"]
|
||||
message_content = round_message["data"]["content"]
|
||||
example_text += (
|
||||
message_type
|
||||
+ ":"
|
||||
+ message_content
|
||||
+ self.prompt_template.sep
|
||||
)
|
||||
example_messages.append(
|
||||
ModelMessage(role=message_type, content=message_content)
|
||||
)
|
||||
return example_text if str_message else example_messages
|
||||
|
||||
def __load_history_messages(self, str_message: bool = True):
|
||||
history_text = ""
|
||||
history_messages = []
|
||||
if self.prompt_template.need_historical_messages:
|
||||
if self.history_message:
|
||||
logger.info(
|
||||
f"There are already {len(self.history_message)} rounds of conversations! Will use {self.chat_retention_rounds} rounds of content as history!"
|
||||
)
|
||||
if len(self.history_message) > self.chat_retention_rounds:
|
||||
for first_message in self.history_message[0]["messages"]:
|
||||
if not first_message["type"] in [
|
||||
ModelMessageRoleType.VIEW,
|
||||
ModelMessageRoleType.SYSTEM,
|
||||
]:
|
||||
message_type = first_message["type"]
|
||||
message_content = first_message["data"]["content"]
|
||||
history_text += (
|
||||
message_type
|
||||
+ ":"
|
||||
+ message_content
|
||||
+ self.prompt_template.sep
|
||||
)
|
||||
history_messages.append(
|
||||
ModelMessage(role=message_type, content=message_content)
|
||||
)
|
||||
if self.chat_retention_rounds > 1:
|
||||
index = self.chat_retention_rounds - 1
|
||||
for round_conv in self.history_message[-index:]:
|
||||
for round_message in round_conv["messages"]:
|
||||
if not round_message["type"] in [
|
||||
ModelMessageRoleType.VIEW,
|
||||
ModelMessageRoleType.SYSTEM,
|
||||
]:
|
||||
message_type = round_message["type"]
|
||||
message_content = round_message["data"]["content"]
|
||||
history_text += (
|
||||
message_type
|
||||
+ ":"
|
||||
+ message_content
|
||||
+ self.prompt_template.sep
|
||||
)
|
||||
history_messages.append(
|
||||
ModelMessage(
|
||||
role=message_type, content=message_content
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
### user all history
|
||||
for conversation in self.history_message:
|
||||
for message in conversation["messages"]:
|
||||
### histroy message not have promot and view info
|
||||
if not message["type"] in [
|
||||
ModelMessageRoleType.VIEW,
|
||||
ModelMessageRoleType.SYSTEM,
|
||||
]:
|
||||
message_type = message["type"]
|
||||
message_content = message["data"]["content"]
|
||||
history_text += (
|
||||
message_type
|
||||
+ ":"
|
||||
+ message_content
|
||||
+ self.prompt_template.sep
|
||||
)
|
||||
history_messages.append(
|
||||
ModelMessage(role=message_type, content=message_content)
|
||||
)
|
||||
|
||||
return history_text if str_message else history_messages
|
||||
|
||||
def current_ai_response(self) -> str:
|
||||
for message in self.current_message.messages:
|
||||
if message.type == "view":
|
||||
@ -656,3 +541,127 @@ def _build_model_operator(
|
||||
cache_check_branch_node >> cached_node >> join_node
|
||||
|
||||
return join_node
|
||||
|
||||
|
||||
def _load_system_message(
|
||||
current_message: OnceConversation,
|
||||
prompt_template: PromptTemplate,
|
||||
str_message: bool = True,
|
||||
):
|
||||
system_convs = current_message.get_system_conv()
|
||||
system_text = ""
|
||||
system_messages = []
|
||||
for system_conv in system_convs:
|
||||
system_text += (
|
||||
system_conv.type + ":" + system_conv.content + prompt_template.sep
|
||||
)
|
||||
system_messages.append(
|
||||
ModelMessage(role=system_conv.type, content=system_conv.content)
|
||||
)
|
||||
return system_text if str_message else system_messages
|
||||
|
||||
|
||||
def _load_user_message(
|
||||
current_message: OnceConversation,
|
||||
prompt_template: PromptTemplate,
|
||||
str_message: bool = True,
|
||||
):
|
||||
user_conv = current_message.get_user_conv()
|
||||
user_messages = []
|
||||
if user_conv:
|
||||
user_text = user_conv.type + ":" + user_conv.content + prompt_template.sep
|
||||
user_messages.append(
|
||||
ModelMessage(role=user_conv.type, content=user_conv.content)
|
||||
)
|
||||
return user_text if str_message else user_messages
|
||||
else:
|
||||
raise ValueError("Hi! What do you want to talk about?")
|
||||
|
||||
|
||||
def _load_example_messages(prompt_template: PromptTemplate, str_message: bool = True):
|
||||
example_text = ""
|
||||
example_messages = []
|
||||
if prompt_template.example_selector:
|
||||
for round_conv in prompt_template.example_selector.examples():
|
||||
for round_message in round_conv["messages"]:
|
||||
if not round_message["type"] in [
|
||||
ModelMessageRoleType.VIEW,
|
||||
ModelMessageRoleType.SYSTEM,
|
||||
]:
|
||||
message_type = round_message["type"]
|
||||
message_content = round_message["data"]["content"]
|
||||
example_text += (
|
||||
message_type + ":" + message_content + prompt_template.sep
|
||||
)
|
||||
example_messages.append(
|
||||
ModelMessage(role=message_type, content=message_content)
|
||||
)
|
||||
return example_text if str_message else example_messages
|
||||
|
||||
|
||||
def _load_history_messages(
|
||||
prompt_template: PromptTemplate,
|
||||
history_message: List[OnceConversation],
|
||||
chat_retention_rounds: int,
|
||||
str_message: bool = True,
|
||||
):
|
||||
history_text = ""
|
||||
history_messages = []
|
||||
if prompt_template.need_historical_messages:
|
||||
if history_message:
|
||||
logger.info(
|
||||
f"There are already {len(history_message)} rounds of conversations! Will use {chat_retention_rounds} rounds of content as history!"
|
||||
)
|
||||
if len(history_message) > chat_retention_rounds:
|
||||
for first_message in history_message[0]["messages"]:
|
||||
if not first_message["type"] in [
|
||||
ModelMessageRoleType.VIEW,
|
||||
ModelMessageRoleType.SYSTEM,
|
||||
]:
|
||||
message_type = first_message["type"]
|
||||
message_content = first_message["data"]["content"]
|
||||
history_text += (
|
||||
message_type + ":" + message_content + prompt_template.sep
|
||||
)
|
||||
history_messages.append(
|
||||
ModelMessage(role=message_type, content=message_content)
|
||||
)
|
||||
if chat_retention_rounds > 1:
|
||||
index = chat_retention_rounds - 1
|
||||
for round_conv in history_message[-index:]:
|
||||
for round_message in round_conv["messages"]:
|
||||
if not round_message["type"] in [
|
||||
ModelMessageRoleType.VIEW,
|
||||
ModelMessageRoleType.SYSTEM,
|
||||
]:
|
||||
message_type = round_message["type"]
|
||||
message_content = round_message["data"]["content"]
|
||||
history_text += (
|
||||
message_type
|
||||
+ ":"
|
||||
+ message_content
|
||||
+ prompt_template.sep
|
||||
)
|
||||
history_messages.append(
|
||||
ModelMessage(role=message_type, content=message_content)
|
||||
)
|
||||
|
||||
else:
|
||||
### user all history
|
||||
for conversation in history_message:
|
||||
for message in conversation["messages"]:
|
||||
### histroy message not have promot and view info
|
||||
if not message["type"] in [
|
||||
ModelMessageRoleType.VIEW,
|
||||
ModelMessageRoleType.SYSTEM,
|
||||
]:
|
||||
message_type = message["type"]
|
||||
message_content = message["data"]["content"]
|
||||
history_text += (
|
||||
message_type + ":" + message_content + prompt_template.sep
|
||||
)
|
||||
history_messages.append(
|
||||
ModelMessage(role=message_type, content=message_content)
|
||||
)
|
||||
|
||||
return history_text if str_message else history_messages
|
||||
|
@ -117,6 +117,10 @@ class ModelMessage(BaseModel):
|
||||
def to_dict_list(messages: List["ModelMessage"]) -> List[Dict[str, str]]:
|
||||
return list(map(lambda m: m.dict(), messages))
|
||||
|
||||
@staticmethod
|
||||
def build_human_message(content: str) -> "ModelMessage":
|
||||
return ModelMessage(role=ModelMessageRoleType.HUMAN, content=content)
|
||||
|
||||
|
||||
class Generation(BaseModel):
|
||||
"""Output of a single generation."""
|
||||
|
@ -6,7 +6,6 @@ import re
|
||||
import sqlparse
|
||||
import pandas as pd
|
||||
import chardet
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from pyparsing import (
|
||||
CaselessKeyword,
|
||||
@ -27,6 +26,8 @@ from pyparsing import (
|
||||
from pilot.common.pd_utils import csv_colunm_foramt
|
||||
from pilot.common.string_utils import is_chinese_include_number
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def excel_colunm_format(old_name: str) -> str:
|
||||
new_column = old_name.strip()
|
||||
@ -263,7 +264,7 @@ class ExcelReader:
|
||||
file_name = os.path.basename(file_path)
|
||||
self.file_name_without_extension = os.path.splitext(file_name)[0]
|
||||
encoding, confidence = detect_encoding(file_path)
|
||||
logging.error(f"Detected Encoding: {encoding} (Confidence: {confidence})")
|
||||
logger.error(f"Detected Encoding: {encoding} (Confidence: {confidence})")
|
||||
self.excel_file_name = file_name
|
||||
self.extension = os.path.splitext(file_name)[1]
|
||||
# read excel file
|
||||
@ -323,7 +324,7 @@ class ExcelReader:
|
||||
colunms.append(descrip[0])
|
||||
return colunms, results.fetchall()
|
||||
except Exception as e:
|
||||
logging.error("excel sql run error!", e)
|
||||
logger.error(f"excel sql run error!, {str(e)}")
|
||||
raise ValueError(f"Data Query Exception!\\nSQL[{sql}].\\nError:{str(e)}")
|
||||
|
||||
def get_df_by_sql_ex(self, sql):
|
||||
|
@ -37,7 +37,7 @@ class DbChatOutputParser(BaseOutputParser):
|
||||
|
||||
def parse_prompt_response(self, model_out_text):
|
||||
clean_str = super().parse_prompt_response(model_out_text)
|
||||
logging.info("clean prompt response:", clean_str)
|
||||
logger.info(f"clean prompt response: {clean_str}")
|
||||
# Compatible with community pure sql output model
|
||||
if self.is_sql_statement(clean_str):
|
||||
return SqlAction(clean_str, "")
|
||||
@ -51,7 +51,7 @@ class DbChatOutputParser(BaseOutputParser):
|
||||
thoughts = response[key]
|
||||
return SqlAction(sql, thoughts)
|
||||
except Exception as e:
|
||||
logging.error("json load faild")
|
||||
logger.error("json load faild")
|
||||
return SqlAction("", clean_str)
|
||||
|
||||
def parse_view_response(self, speak, data, prompt_response) -> str:
|
||||
|
@ -24,7 +24,7 @@ class ExtractEntity(BaseChat):
|
||||
self.user_input = chat_param["current_user_input"]
|
||||
self.extract_mode = chat_param["select_param"]
|
||||
|
||||
def generate_input_values(self):
|
||||
async def generate_input_values(self):
|
||||
input_values = {
|
||||
"text": self.user_input,
|
||||
}
|
||||
|
@ -24,7 +24,7 @@ class ExtractTriplet(BaseChat):
|
||||
self.user_input = chat_param["current_user_input"]
|
||||
self.extract_mode = chat_param["select_param"]
|
||||
|
||||
def generate_input_values(self):
|
||||
async def generate_input_values(self):
|
||||
input_values = {
|
||||
"text": self.user_input,
|
||||
}
|
||||
|
@ -23,7 +23,7 @@ class ExtractRefineSummary(BaseChat):
|
||||
|
||||
self.existing_answer = chat_param["select_param"]
|
||||
|
||||
def generate_input_values(self):
|
||||
async def generate_input_values(self):
|
||||
input_values = {
|
||||
# "context": self.user_input,
|
||||
"existing_answer": self.existing_answer,
|
||||
|
@ -23,7 +23,7 @@ class ExtractSummary(BaseChat):
|
||||
|
||||
self.user_input = chat_param["select_param"]
|
||||
|
||||
def generate_input_values(self):
|
||||
async def generate_input_values(self):
|
||||
input_values = {
|
||||
"context": self.user_input,
|
||||
}
|
||||
|
@ -104,7 +104,7 @@ class ChatKnowledge(BaseChat):
|
||||
self.current_user_input,
|
||||
self.top_k,
|
||||
)
|
||||
self.sources = self.merge_by_key(
|
||||
self.sources = _merge_by_key(
|
||||
list(map(lambda doc: doc.metadata, docs)), "source"
|
||||
)
|
||||
|
||||
@ -149,29 +149,6 @@ class ChatKnowledge(BaseChat):
|
||||
)
|
||||
return html
|
||||
|
||||
def merge_by_key(self, data, key):
|
||||
result = {}
|
||||
for item in data:
|
||||
if item.get(key):
|
||||
item_key = os.path.basename(item.get(key))
|
||||
if item_key in result:
|
||||
if "pages" in result[item_key] and "page" in item:
|
||||
result[item_key]["pages"].append(str(item["page"]))
|
||||
elif "page" in item:
|
||||
result[item_key]["pages"] = [
|
||||
result[item_key]["pages"],
|
||||
str(item["page"]),
|
||||
]
|
||||
else:
|
||||
if "page" in item:
|
||||
result[item_key] = {
|
||||
"source": item_key,
|
||||
"pages": [str(item["page"])],
|
||||
}
|
||||
else:
|
||||
result[item_key] = {"source": item_key}
|
||||
return list(result.values())
|
||||
|
||||
@property
|
||||
def chat_type(self) -> str:
|
||||
return ChatScene.ChatKnowledge.value()
|
||||
@ -179,3 +156,27 @@ class ChatKnowledge(BaseChat):
|
||||
def get_space_context(self, space_name):
|
||||
service = KnowledgeService()
|
||||
return service.get_space_context(space_name)
|
||||
|
||||
|
||||
def _merge_by_key(data, key):
|
||||
result = {}
|
||||
for item in data:
|
||||
if item.get(key):
|
||||
item_key = os.path.basename(item.get(key))
|
||||
if item_key in result:
|
||||
if "pages" in result[item_key] and "page" in item:
|
||||
result[item_key]["pages"].append(str(item["page"]))
|
||||
elif "page" in item:
|
||||
result[item_key]["pages"] = [
|
||||
result[item_key]["pages"],
|
||||
str(item["page"]),
|
||||
]
|
||||
else:
|
||||
if "page" in item:
|
||||
result[item_key] = {
|
||||
"source": item_key,
|
||||
"pages": [str(item["page"])],
|
||||
}
|
||||
else:
|
||||
result[item_key] = {"source": item_key}
|
||||
return list(result.values())
|
||||
|
0
pilot/scene/operator/__init__.py
Normal file
0
pilot/scene/operator/__init__.py
Normal file
255
pilot/scene/operator/_experimental.py
Normal file
255
pilot/scene/operator/_experimental.py
Normal file
@ -0,0 +1,255 @@
|
||||
from typing import Dict, Optional, List, Any
|
||||
from dataclasses import dataclass
|
||||
import datetime
|
||||
import os
|
||||
from pilot.awel import MapOperator
|
||||
from pilot.prompts.prompt_new import PromptTemplate
|
||||
from pilot.configs.config import Config
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.scene.message import OnceConversation
|
||||
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
||||
|
||||
|
||||
from pilot.memory.chat_history.base import BaseChatHistoryMemory
|
||||
from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory
|
||||
|
||||
# TODO move global config
|
||||
CFG = Config()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatContext:
|
||||
current_user_input: str
|
||||
model_name: Optional[str]
|
||||
chat_session_id: Optional[str] = None
|
||||
select_param: Optional[str] = None
|
||||
chat_scene: Optional[ChatScene] = ChatScene.ChatNormal
|
||||
prompt_template: Optional[PromptTemplate] = None
|
||||
chat_retention_rounds: Optional[int] = 0
|
||||
history_storage: Optional[BaseChatHistoryMemory] = None
|
||||
history_manager: Optional["ChatHistoryManager"] = None
|
||||
# The input values for prompt template
|
||||
input_values: Optional[Dict] = None
|
||||
echo: Optional[bool] = False
|
||||
|
||||
def build_model_payload(self) -> Dict:
|
||||
if not self.input_values:
|
||||
raise ValueError("The input value can't be empty")
|
||||
llm_messages = self.history_manager._new_chat(self.input_values)
|
||||
return {
|
||||
"model": self.model_name,
|
||||
"prompt": "",
|
||||
"messages": llm_messages,
|
||||
"temperature": float(self.prompt_template.temperature),
|
||||
"max_new_tokens": int(self.prompt_template.max_new_tokens),
|
||||
"echo": self.echo,
|
||||
}
|
||||
|
||||
|
||||
class ChatHistoryManager:
|
||||
def __init__(
|
||||
self,
|
||||
chat_ctx: ChatContext,
|
||||
prompt_template: PromptTemplate,
|
||||
history_storage: BaseChatHistoryMemory,
|
||||
chat_retention_rounds: Optional[int] = 0,
|
||||
) -> None:
|
||||
self._chat_ctx = chat_ctx
|
||||
self.chat_retention_rounds = chat_retention_rounds
|
||||
self.current_message: OnceConversation = OnceConversation(
|
||||
chat_ctx.chat_scene.value()
|
||||
)
|
||||
self.prompt_template = prompt_template
|
||||
self.history_storage: BaseChatHistoryMemory = history_storage
|
||||
self.history_message: List[OnceConversation] = history_storage.messages()
|
||||
self.current_message.model_name = chat_ctx.model_name
|
||||
if chat_ctx.select_param:
|
||||
if len(chat_ctx.chat_scene.param_types()) > 0:
|
||||
self.current_message.param_type = chat_ctx.chat_scene.param_types()[0]
|
||||
self.current_message.param_value = chat_ctx.select_param
|
||||
|
||||
def _new_chat(self, input_values: Dict) -> List[ModelMessage]:
|
||||
self.current_message.chat_order = len(self.history_message) + 1
|
||||
self.current_message.add_user_message(self._chat_ctx.current_user_input)
|
||||
self.current_message.start_date = datetime.datetime.now().strftime(
|
||||
"%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
self.current_message.tokens = 0
|
||||
if self.prompt_template.template:
|
||||
current_prompt = self.prompt_template.format(**input_values)
|
||||
self.current_message.add_system_message(current_prompt)
|
||||
return self._generate_llm_messages()
|
||||
|
||||
def _generate_llm_messages(self) -> List[ModelMessage]:
|
||||
from pilot.scene.base_chat import (
|
||||
_load_system_message,
|
||||
_load_example_messages,
|
||||
_load_history_messages,
|
||||
_load_user_message,
|
||||
)
|
||||
|
||||
messages = []
|
||||
### Load scene setting or character definition as system message
|
||||
if self.prompt_template.template_define:
|
||||
messages.append(
|
||||
ModelMessage(
|
||||
role=ModelMessageRoleType.SYSTEM,
|
||||
content=self.prompt_template.template_define,
|
||||
)
|
||||
)
|
||||
### Load prompt
|
||||
messages += _load_system_message(
|
||||
self.current_message, self.prompt_template, str_message=False
|
||||
)
|
||||
### Load examples
|
||||
messages += _load_example_messages(self.prompt_template, str_message=False)
|
||||
|
||||
### Load History
|
||||
messages += _load_history_messages(
|
||||
self.prompt_template,
|
||||
self.history_message,
|
||||
self.chat_retention_rounds,
|
||||
str_message=False,
|
||||
)
|
||||
|
||||
### Load User Input
|
||||
messages += _load_user_message(
|
||||
self.current_message, self.prompt_template, str_message=False
|
||||
)
|
||||
return messages
|
||||
|
||||
|
||||
class PromptManagerOperator(MapOperator[ChatContext, ChatContext]):
|
||||
def __init__(self, prompt_template: PromptTemplate = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._prompt_template = prompt_template
|
||||
|
||||
async def map(self, input_value: ChatContext) -> ChatContext:
|
||||
if not self._prompt_template:
|
||||
self._prompt_template: PromptTemplate = (
|
||||
CFG.prompt_template_registry.get_prompt_template(
|
||||
input_value.chat_scene.value(),
|
||||
language=CFG.LANGUAGE,
|
||||
model_name=input_value.model_name,
|
||||
proxyllm_backend=CFG.PROXYLLM_BACKEND,
|
||||
)
|
||||
)
|
||||
input_value.prompt_template = self._prompt_template
|
||||
return input_value
|
||||
|
||||
|
||||
class ChatHistoryStorageOperator(MapOperator[ChatContext, ChatContext]):
|
||||
def __init__(self, history: BaseChatHistoryMemory = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._history = history
|
||||
|
||||
async def map(self, input_value: ChatContext) -> ChatContext:
|
||||
if self._history:
|
||||
return self._history
|
||||
chat_history_fac = ChatHistory()
|
||||
input_value.history_storage = chat_history_fac.get_store_instance(
|
||||
input_value.chat_session_id
|
||||
)
|
||||
return input_value
|
||||
|
||||
|
||||
class ChatHistoryOperator(MapOperator[ChatContext, ChatContext]):
|
||||
def __init__(self, history: BaseChatHistoryMemory = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._history = history
|
||||
|
||||
async def map(self, input_value: ChatContext) -> ChatContext:
|
||||
history_storage = self._history or input_value.history_storage
|
||||
if not history_storage:
|
||||
from pilot.memory.chat_history.store_type.mem_history import (
|
||||
MemHistoryMemory,
|
||||
)
|
||||
|
||||
history_storage = MemHistoryMemory(input_value.chat_session_id)
|
||||
input_value.history_storage = history_storage
|
||||
input_value.history_manager = ChatHistoryManager(
|
||||
input_value,
|
||||
input_value.prompt_template,
|
||||
history_storage,
|
||||
input_value.chat_retention_rounds,
|
||||
)
|
||||
return input_value
|
||||
|
||||
|
||||
class EmbeddingEngingOperator(MapOperator[ChatContext, ChatContext]):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def map(self, input_value: ChatContext) -> ChatContext:
|
||||
from pilot.configs.model_config import EMBEDDING_MODEL_CONFIG
|
||||
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
|
||||
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
|
||||
from pilot.scene.chat_knowledge.v1.chat import _merge_by_key
|
||||
|
||||
# TODO, decompose the current operator into some atomic operators
|
||||
knowledge_space = input_value.select_param
|
||||
vector_store_config = {
|
||||
"vector_store_name": knowledge_space,
|
||||
"vector_store_type": CFG.VECTOR_STORE_TYPE,
|
||||
}
|
||||
embedding_factory = self.system_app.get_component(
|
||||
"embedding_factory", EmbeddingFactory
|
||||
)
|
||||
knowledge_embedding_client = EmbeddingEngine(
|
||||
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
vector_store_config=vector_store_config,
|
||||
embedding_factory=embedding_factory,
|
||||
)
|
||||
space_context = await self._get_space_context(knowledge_space)
|
||||
top_k = (
|
||||
CFG.KNOWLEDGE_SEARCH_TOP_SIZE
|
||||
if space_context is None
|
||||
else int(space_context["embedding"]["topk"])
|
||||
)
|
||||
max_token = (
|
||||
CFG.KNOWLEDGE_SEARCH_MAX_TOKEN
|
||||
if space_context is None or space_context.get("prompt") is None
|
||||
else int(space_context["prompt"]["max_token"])
|
||||
)
|
||||
input_value.prompt_template.template_is_strict = False
|
||||
if space_context and space_context.get("prompt"):
|
||||
input_value.prompt_template.template_define = space_context["prompt"][
|
||||
"scene"
|
||||
]
|
||||
input_value.prompt_template.template = space_context["prompt"]["template"]
|
||||
|
||||
docs = await self.blocking_func_to_async(
|
||||
knowledge_embedding_client.similar_search,
|
||||
input_value.current_user_input,
|
||||
top_k,
|
||||
)
|
||||
sources = _merge_by_key(list(map(lambda doc: doc.metadata, docs)), "source")
|
||||
if not docs or len(docs) == 0:
|
||||
print("no relevant docs to retrieve")
|
||||
context = "no relevant docs to retrieve"
|
||||
else:
|
||||
context = [d.page_content for d in docs]
|
||||
context = context[:max_token]
|
||||
relations = list(
|
||||
set([os.path.basename(str(d.metadata.get("source", ""))) for d in docs])
|
||||
)
|
||||
input_value.input_values = {
|
||||
"context": context,
|
||||
"question": input_value.current_user_input,
|
||||
"relations": relations,
|
||||
}
|
||||
return input_value
|
||||
|
||||
async def _get_space_context(self, space_name):
|
||||
from pilot.server.knowledge.service import KnowledgeService
|
||||
|
||||
service = KnowledgeService()
|
||||
return await self.blocking_func_to_async(service.get_space_context, space_name)
|
||||
|
||||
|
||||
class BaseChatOperator(MapOperator[ChatContext, Dict]):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def map(self, input_value: ChatContext) -> Dict:
|
||||
return input_value.build_model_payload()
|
@ -45,6 +45,7 @@ def initialize_components(
|
||||
param, system_app, embedding_model_name, embedding_model_path
|
||||
)
|
||||
_initialize_model_cache(system_app)
|
||||
_initialize_awel(system_app)
|
||||
|
||||
|
||||
def _initialize_embedding_model(
|
||||
@ -149,3 +150,10 @@ def _initialize_model_cache(system_app: SystemApp):
|
||||
max_memory_mb = CFG.MODEL_CACHE_MAX_MEMORY_MB or 256
|
||||
persist_dir = CFG.MODEL_CACHE_STORAGE_DISK_DIR or MODEL_DISK_CACHE_DIR
|
||||
initialize_cache(system_app, storage_type, max_memory_mb, persist_dir)
|
||||
|
||||
|
||||
def _initialize_awel(system_app: SystemApp):
|
||||
from pilot.awel import initialize_awel
|
||||
from pilot.configs.model_config import _DAG_DEFINITION_DIR
|
||||
|
||||
initialize_awel(system_app, _DAG_DEFINITION_DIR)
|
||||
|
@ -1,9 +1,14 @@
|
||||
import argparse
|
||||
import os
|
||||
from dataclasses import dataclass, fields, MISSING, asdict, field, is_dataclass
|
||||
from typing import Any, List, Optional, Type, Union, Callable, Dict
|
||||
from typing import Any, List, Optional, Type, Union, Callable, Dict, TYPE_CHECKING
|
||||
from collections import OrderedDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic import BaseModel
|
||||
|
||||
MISSING_DEFAULT_VALUE = "__MISSING_DEFAULT_VALUE__"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParameterDescription:
|
||||
@ -613,6 +618,64 @@ def _get_dict_from_obj(obj, default_value=None) -> Optional[Dict]:
|
||||
return default_value
|
||||
|
||||
|
||||
def _get_base_model_descriptions(model_cls: "BaseModel") -> List[ParameterDescription]:
|
||||
import pydantic
|
||||
|
||||
version = int(pydantic.VERSION.split(".")[0])
|
||||
schema = model_cls.model_json_schema() if version >= 2 else model_cls.schema()
|
||||
required_fields = set(schema.get("required", []))
|
||||
param_descs = []
|
||||
for field_name, field_schema in schema.get("properties", {}).items():
|
||||
field = model_cls.model_fields[field_name]
|
||||
param_type = field_schema.get("type")
|
||||
if not param_type and "anyOf" in field_schema:
|
||||
for any_of in field_schema["anyOf"]:
|
||||
if any_of["type"] != "null":
|
||||
param_type = any_of["type"]
|
||||
break
|
||||
if version >= 2:
|
||||
default_value = (
|
||||
field.default
|
||||
if hasattr(field, "default")
|
||||
and str(field.default) != "PydanticUndefined"
|
||||
else None
|
||||
)
|
||||
else:
|
||||
default_value = (
|
||||
field.default
|
||||
if not field.allow_none
|
||||
else (
|
||||
field.default_factory() if callable(field.default_factory) else None
|
||||
)
|
||||
)
|
||||
description = field_schema.get("description", "")
|
||||
is_required = field_name in required_fields
|
||||
valid_values = None
|
||||
ext_metadata = None
|
||||
if hasattr(field, "field_info"):
|
||||
valid_values = (
|
||||
list(field.field_info.choices)
|
||||
if hasattr(field.field_info, "choices")
|
||||
else None
|
||||
)
|
||||
ext_metadata = (
|
||||
field.field_info.extra if hasattr(field.field_info, "extra") else None
|
||||
)
|
||||
param_class = (f"{model_cls.__module__}.{model_cls.__name__}",)
|
||||
param_desc = ParameterDescription(
|
||||
param_class=param_class,
|
||||
param_name=field_name,
|
||||
param_type=param_type,
|
||||
default_value=default_value,
|
||||
description=description,
|
||||
required=is_required,
|
||||
valid_values=valid_values,
|
||||
ext_metadata=ext_metadata,
|
||||
)
|
||||
param_descs.append(param_desc)
|
||||
return param_descs
|
||||
|
||||
|
||||
class _SimpleArgParser:
|
||||
def __init__(self, *args):
|
||||
self.params = {arg.replace("_", "-"): None for arg in args}
|
||||
|
Loading…
Reference in New Issue
Block a user