mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-22 11:51:42 +00:00
chore: Add pylint for DB-GPT core lib (#1076)
This commit is contained in:
parent
3a54d1ef9a
commit
40c853575a
40
.flake8
Normal file
40
.flake8
Normal file
@ -0,0 +1,40 @@
|
||||
[flake8]
|
||||
exclude =
|
||||
.eggs/
|
||||
build/
|
||||
*/tests/*
|
||||
*_private
|
||||
max-line-length = 88
|
||||
inline-quotes = "
|
||||
ignore =
|
||||
C408
|
||||
C417
|
||||
E121
|
||||
E123
|
||||
E126
|
||||
E203
|
||||
E226
|
||||
E24
|
||||
E704
|
||||
W503
|
||||
W504
|
||||
W605
|
||||
I
|
||||
N
|
||||
B001
|
||||
B002
|
||||
B003
|
||||
B004
|
||||
B005
|
||||
B007
|
||||
B008
|
||||
B009
|
||||
B010
|
||||
B011
|
||||
B012
|
||||
B013
|
||||
B014
|
||||
B015
|
||||
B016
|
||||
B017
|
||||
avoid-escape = no
|
13
.isort.cfg
Normal file
13
.isort.cfg
Normal file
@ -0,0 +1,13 @@
|
||||
[settings]
|
||||
# This is to make isort compatible with Black. See
|
||||
# https://black.readthedocs.io/en/stable/the_black_code_style.html#how-black-wraps-lines.
|
||||
line_length=88
|
||||
profile=black
|
||||
multi_line_output=3
|
||||
include_trailing_comma=True
|
||||
use_parentheses=True
|
||||
float_to_top=True
|
||||
filter_files=True
|
||||
|
||||
skip_glob=examples/notebook/*
|
||||
sections=FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,LOCALFOLDER,AFTERRAY
|
20
.mypy.ini
Normal file
20
.mypy.ini
Normal file
@ -0,0 +1,20 @@
|
||||
[mypy]
|
||||
exclude = /tests/
|
||||
# plugins = pydantic.mypy
|
||||
|
||||
[mypy-graphviz.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-cachetools.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-coloredlogs.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-termcolor.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-pydantic.*]
|
||||
strict_optional = False
|
||||
ignore_missing_imports = True
|
||||
follow_imports = skip
|
@ -20,4 +20,22 @@ repos:
|
||||
stages: [commit]
|
||||
pass_filenames: false
|
||||
args: []
|
||||
- id: python-test-doc
|
||||
name: Python Doc Test
|
||||
entry: make test-doc
|
||||
language: system
|
||||
exclude: '^dbgpt/app/static/|^web/'
|
||||
types: [python]
|
||||
stages: [commit]
|
||||
pass_filenames: false
|
||||
args: []
|
||||
- id: python-lint-mypy
|
||||
name: Python Lint mypy
|
||||
entry: make mypy
|
||||
language: system
|
||||
exclude: '^dbgpt/app/static/|^web/'
|
||||
types: [python]
|
||||
stages: [commit]
|
||||
pass_filenames: false
|
||||
args: []
|
||||
|
||||
|
19
Makefile
19
Makefile
@ -14,7 +14,7 @@ setup: $(VENV)/bin/activate
|
||||
|
||||
$(VENV)/bin/activate: $(VENV)/.venv-timestamp
|
||||
|
||||
$(VENV)/.venv-timestamp: setup.py
|
||||
$(VENV)/.venv-timestamp: setup.py requirements
|
||||
# Create new virtual environment if setup.py has changed
|
||||
python3 -m venv $(VENV)
|
||||
$(VENV_BIN)/pip install --upgrade pip
|
||||
@ -46,15 +46,14 @@ fmt: setup ## Format Python code
|
||||
# $(VENV_BIN)/blackdoc .
|
||||
$(VENV_BIN)/blackdoc dbgpt
|
||||
$(VENV_BIN)/blackdoc examples
|
||||
# TODO: Type checking of Python code.
|
||||
# https://github.com/python/mypy
|
||||
# $(VENV_BIN)/mypy dbgpt
|
||||
# TODO: uUse flake8 to enforce Python style guide.
|
||||
# TODO: Use flake8 to enforce Python style guide.
|
||||
# https://flake8.pycqa.org/en/latest/
|
||||
# $(VENV_BIN)/flake8 dbgpt
|
||||
$(VENV_BIN)/flake8 dbgpt/core/
|
||||
# TODO: More package checks with flake8.
|
||||
|
||||
|
||||
.PHONY: pre-commit
|
||||
pre-commit: fmt test ## Run formatting and unit tests before committing
|
||||
pre-commit: fmt test test-doc mypy ## Run formatting and unit tests before committing
|
||||
|
||||
test: $(VENV)/.testenv ## Run unit tests
|
||||
$(VENV_BIN)/pytest dbgpt
|
||||
@ -64,6 +63,12 @@ test-doc: $(VENV)/.testenv ## Run doctests
|
||||
# -k "not test_" skips tests that are not doctests.
|
||||
$(VENV_BIN)/pytest --doctest-modules -k "not test_" dbgpt/core
|
||||
|
||||
.PHONY: mypy
|
||||
mypy: $(VENV)/.testenv ## Run mypy checks
|
||||
# https://github.com/python/mypy
|
||||
$(VENV_BIN)/mypy --config-file .mypy.ini dbgpt/core/
|
||||
# TODO: More package checks with mypy.
|
||||
|
||||
.PHONY: coverage
|
||||
coverage: setup ## Run tests and report coverage
|
||||
$(VENV_BIN)/pytest dbgpt --cov=dbgpt
|
||||
|
@ -1,22 +1,22 @@
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.ticker as mtick
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
from matplotlib.font_manager import FontManager
|
||||
from pandas import DataFrame
|
||||
|
||||
from dbgpt.configs.model_config import PILOT_PATH
|
||||
from dbgpt.util.string_utils import is_scientific_notation
|
||||
|
||||
from ...command_mange import command
|
||||
|
||||
matplotlib.use("Agg")
|
||||
import logging
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.ticker as mtick
|
||||
from matplotlib.font_manager import FontManager
|
||||
|
||||
from dbgpt.configs.model_config import PILOT_PATH
|
||||
from dbgpt.util.string_utils import is_scientific_notation
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -192,7 +192,8 @@ def _create_mysql_database(db_name: str, db_url: str, try_to_create_db: bool = F
|
||||
with engine_no_db.connect() as conn:
|
||||
conn.execute(
|
||||
DDL(
|
||||
f"CREATE DATABASE {db_name} CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci"
|
||||
f"CREATE DATABASE {db_name} CHARACTER SET utf8mb4 COLLATE "
|
||||
f"utf8mb4_unicode_ci"
|
||||
)
|
||||
)
|
||||
logger.info(f"Database {db_name} successfully created")
|
||||
@ -218,26 +219,31 @@ class WebServerParameters(BaseParameters):
|
||||
controller_addr: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The Model controller address to connect. If None, read model controller address from environment key `MODEL_SERVER`."
|
||||
"help": "The Model controller address to connect. If None, read model "
|
||||
"controller address from environment key `MODEL_SERVER`."
|
||||
},
|
||||
)
|
||||
model_name: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The default model name to use. If None, read model name from environment key `LLM_MODEL`.",
|
||||
"help": "The default model name to use. If None, read model name from "
|
||||
"environment key `LLM_MODEL`.",
|
||||
"tags": "fixed",
|
||||
},
|
||||
)
|
||||
share: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to create a publicly shareable link for the interface. Creates an SSH tunnel to make your UI accessible from anywhere. "
|
||||
"help": "Whether to create a publicly shareable link for the interface. "
|
||||
"Creates an SSH tunnel to make your UI accessible from anywhere. "
|
||||
},
|
||||
)
|
||||
remote_embedding: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to enable remote embedding models. If it is True, you need to start a embedding model through `dbgpt start worker --worker_type text2vec --model_name xxx --model_path xxx`"
|
||||
"help": "Whether to enable remote embedding models. If it is True, you need"
|
||||
" to start a embedding model through `dbgpt start worker --worker_type "
|
||||
"text2vec --model_name xxx --model_path xxx`"
|
||||
},
|
||||
)
|
||||
log_level: Optional[str] = field(
|
||||
@ -286,3 +292,10 @@ class WebServerParameters(BaseParameters):
|
||||
"help": "The directories to search awel files, split by `,`",
|
||||
},
|
||||
)
|
||||
default_thread_pool_size: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The default thread pool size, If None, "
|
||||
"use default config of python thread pool",
|
||||
},
|
||||
)
|
||||
|
@ -25,7 +25,9 @@ def initialize_components(
|
||||
from dbgpt.model.cluster.controller.controller import controller
|
||||
|
||||
# Register global default executor factory first
|
||||
system_app.register(DefaultExecutorFactory)
|
||||
system_app.register(
|
||||
DefaultExecutorFactory, max_workers=param.default_thread_pool_size
|
||||
)
|
||||
system_app.register_instance(controller)
|
||||
|
||||
from dbgpt.serve.agent.hub.controller import module_agent
|
||||
|
@ -3,8 +3,6 @@ import os
|
||||
import sys
|
||||
from typing import List
|
||||
|
||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(ROOT_PATH)
|
||||
from fastapi import FastAPI
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
@ -41,6 +39,10 @@ from dbgpt.util.utils import (
|
||||
setup_logging,
|
||||
)
|
||||
|
||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(ROOT_PATH)
|
||||
|
||||
|
||||
static_file_path = os.path.join(ROOT_PATH, "dbgpt", "app/static")
|
||||
|
||||
CFG = Config()
|
||||
|
@ -5,6 +5,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
from prettytable import PrettyTable
|
||||
|
||||
from dbgpt.app.knowledge.request.request import (
|
||||
ChunkQueryRequest,
|
||||
@ -193,9 +194,6 @@ def knowledge_init(
|
||||
return
|
||||
|
||||
|
||||
from prettytable import PrettyTable
|
||||
|
||||
|
||||
class _KnowledgeVisualizer:
|
||||
def __init__(self, api_address: str, out_format: str):
|
||||
self.client = KnowledgeApiClient(api_address)
|
||||
|
@ -4,13 +4,14 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(ROOT_PATH)
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG, LLM_MODEL_CONFIG
|
||||
from dbgpt.model.cluster import run_worker_manager
|
||||
|
||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(ROOT_PATH)
|
||||
|
||||
|
||||
CFG = Config()
|
||||
|
||||
model_path = LLM_MODEL_CONFIG.get(CFG.LLM_MODEL)
|
||||
|
@ -313,8 +313,9 @@ class BaseChat(ABC):
|
||||
)
|
||||
### store current conversation
|
||||
span.end(metadata={"error": str(e)})
|
||||
# self.memory.append(self.current_message)
|
||||
self.current_message.end_current_round()
|
||||
await blocking_func_to_async(
|
||||
self._executor, self.current_message.end_current_round
|
||||
)
|
||||
|
||||
async def nostream_call(self):
|
||||
payload = await self._build_model_request()
|
||||
@ -381,8 +382,9 @@ class BaseChat(ABC):
|
||||
)
|
||||
span.end(metadata={"error": str(e)})
|
||||
### store dialogue
|
||||
# self.memory.append(self.current_message)
|
||||
self.current_message.end_current_round()
|
||||
await blocking_func_to_async(
|
||||
self._executor, self.current_message.end_current_round
|
||||
)
|
||||
return self.current_ai_response()
|
||||
|
||||
async def get_llm_response(self):
|
||||
|
@ -104,28 +104,37 @@ class BaseComponent(LifeCycle, ABC):
|
||||
|
||||
@classmethod
|
||||
def get_instance(
|
||||
cls,
|
||||
cls: Type[T],
|
||||
system_app: SystemApp,
|
||||
default_component=_EMPTY_DEFAULT_COMPONENT,
|
||||
or_register_component: Type[BaseComponent] = None,
|
||||
or_register_component: Optional[Type[T]] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> BaseComponent:
|
||||
) -> T:
|
||||
"""Get the current component instance.
|
||||
|
||||
Args:
|
||||
system_app (SystemApp): The system app
|
||||
default_component : The default component instance if not retrieve by name
|
||||
or_register_component (Type[BaseComponent]): The new component to register if not retrieve by name
|
||||
or_register_component (Type[T]): The new component to register if not retrieve by name
|
||||
|
||||
Returns:
|
||||
BaseComponent: The component instance
|
||||
T: The component instance
|
||||
"""
|
||||
# Check for keyword argument conflicts
|
||||
if "default_component" in kwargs:
|
||||
raise ValueError(
|
||||
"default_component argument given in both fixed and **kwargs"
|
||||
)
|
||||
if "or_register_component" in kwargs:
|
||||
raise ValueError(
|
||||
"or_register_component argument given in both fixed and **kwargs"
|
||||
)
|
||||
kwargs["default_component"] = default_component
|
||||
kwargs["or_register_component"] = or_register_component
|
||||
return system_app.get_component(
|
||||
cls.name,
|
||||
cls,
|
||||
default_component=default_component,
|
||||
or_register_component=or_register_component,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
@ -159,11 +168,11 @@ class SystemApp(LifeCycle):
|
||||
"""Returns the internal AppConfig."""
|
||||
return self._app_config
|
||||
|
||||
def register(self, component: Type[BaseComponent], *args, **kwargs) -> T:
|
||||
def register(self, component: Type[T], *args, **kwargs) -> T:
|
||||
"""Register a new component by its type.
|
||||
|
||||
Args:
|
||||
component (Type[BaseComponent]): The component class to register
|
||||
component (Type[T]): The component class to register
|
||||
|
||||
Returns:
|
||||
T: The instance of registered component
|
||||
@ -198,7 +207,7 @@ class SystemApp(LifeCycle):
|
||||
name: Union[str, ComponentType],
|
||||
component_type: Type[T],
|
||||
default_component=_EMPTY_DEFAULT_COMPONENT,
|
||||
or_register_component: Type[BaseComponent] = None,
|
||||
or_register_component: Optional[Type[T]] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> T:
|
||||
@ -208,7 +217,7 @@ class SystemApp(LifeCycle):
|
||||
name (Union[str, ComponentType]): Component name
|
||||
component_type (Type[T]): The type of current retrieve component
|
||||
default_component : The default component instance if not retrieve by name
|
||||
or_register_component (Type[BaseComponent]): The new component to register if not retrieve by name
|
||||
or_register_component (Type[T]): The new component to register if not retrieve by name
|
||||
|
||||
Returns:
|
||||
T: The instance retrieved by component name
|
||||
|
@ -1,11 +1,13 @@
|
||||
from dbgpt.core.interface.cache import (
|
||||
"""The core module contains the core interfaces and classes for dbgpt."""
|
||||
|
||||
from dbgpt.core.interface.cache import ( # noqa: F401
|
||||
CacheClient,
|
||||
CacheConfig,
|
||||
CacheKey,
|
||||
CachePolicy,
|
||||
CacheValue,
|
||||
)
|
||||
from dbgpt.core.interface.llm import (
|
||||
from dbgpt.core.interface.llm import ( # noqa: F401
|
||||
DefaultMessageConverter,
|
||||
LLMClient,
|
||||
MessageConverter,
|
||||
@ -16,7 +18,7 @@ from dbgpt.core.interface.llm import (
|
||||
ModelRequest,
|
||||
ModelRequestContext,
|
||||
)
|
||||
from dbgpt.core.interface.message import (
|
||||
from dbgpt.core.interface.message import ( # noqa: F401
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ConversationIdentifier,
|
||||
@ -29,8 +31,11 @@ from dbgpt.core.interface.message import (
|
||||
StorageConversation,
|
||||
SystemMessage,
|
||||
)
|
||||
from dbgpt.core.interface.output_parser import BaseOutputParser, SQLOutputParser
|
||||
from dbgpt.core.interface.prompt import (
|
||||
from dbgpt.core.interface.output_parser import ( # noqa: F401
|
||||
BaseOutputParser,
|
||||
SQLOutputParser,
|
||||
)
|
||||
from dbgpt.core.interface.prompt import ( # noqa: F401
|
||||
BasePromptTemplate,
|
||||
ChatPromptTemplate,
|
||||
HumanPromptTemplate,
|
||||
@ -40,8 +45,8 @@ from dbgpt.core.interface.prompt import (
|
||||
StoragePromptTemplate,
|
||||
SystemPromptTemplate,
|
||||
)
|
||||
from dbgpt.core.interface.serialization import Serializable, Serializer
|
||||
from dbgpt.core.interface.storage import (
|
||||
from dbgpt.core.interface.serialization import Serializable, Serializer # noqa: F401
|
||||
from dbgpt.core.interface.storage import ( # noqa: F401
|
||||
DefaultStorageItemAdapter,
|
||||
InMemoryStorage,
|
||||
QuerySpec,
|
||||
|
@ -1,27 +1,34 @@
|
||||
"""Example selector base class"""
|
||||
|
||||
from abc import ABC
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
|
||||
|
||||
class ExampleType(Enum):
|
||||
"""Example type"""
|
||||
|
||||
ONE_SHOT = "one_shot"
|
||||
FEW_SHOT = "few_shot"
|
||||
|
||||
|
||||
class ExampleSelector(BaseModel, ABC):
|
||||
"""Example selector base class"""
|
||||
|
||||
examples_record: List[dict]
|
||||
use_example: bool = False
|
||||
type: str = ExampleType.ONE_SHOT.value
|
||||
|
||||
def examples(self, count: int = 2):
|
||||
"""Return examples"""
|
||||
if ExampleType.ONE_SHOT.value == self.type:
|
||||
return self.__one_show_context()
|
||||
return self.__one_shot_context()
|
||||
else:
|
||||
return self.__few_shot_context(count)
|
||||
|
||||
def __few_shot_context(self, count: int = 2) -> List[dict]:
|
||||
def __few_shot_context(self, count: int = 2) -> Optional[List[dict]]:
|
||||
"""
|
||||
Use 2 or more examples, default 2
|
||||
Returns: example text
|
||||
@ -31,14 +38,14 @@ class ExampleSelector(BaseModel, ABC):
|
||||
return need_use
|
||||
return None
|
||||
|
||||
def __one_show_context(self) -> dict:
|
||||
def __one_shot_context(self) -> Optional[dict]:
|
||||
"""
|
||||
Use one examples
|
||||
Returns:
|
||||
|
||||
"""
|
||||
if self.use_example:
|
||||
need_use = self.examples_record[:1]
|
||||
need_use = self.examples_record[-1]
|
||||
return need_use
|
||||
|
||||
return None
|
||||
|
@ -1,8 +1,12 @@
|
||||
"""Prompt template registry.
|
||||
|
||||
This module is deprecated. we will remove it in the future.
|
||||
"""
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
_DEFAULT_MODEL_KEY = "___default_prompt_template_model_key__"
|
||||
_DEFUALT_LANGUAGE_KEY = "___default_prompt_template_language_key__"
|
||||
@ -14,15 +18,15 @@ class PromptTemplateRegistry:
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.registry = defaultdict(dict)
|
||||
self.registry = defaultdict(dict) # type: ignore
|
||||
|
||||
def register(
|
||||
self,
|
||||
prompt_template,
|
||||
language: str = "en",
|
||||
is_default: bool = False,
|
||||
model_names: List[str] = None,
|
||||
scene_name: str = None,
|
||||
model_names: Optional[List[str]] = None,
|
||||
scene_name: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Register prompt template with scene name, language
|
||||
registry dict format:
|
||||
@ -43,7 +47,7 @@ class PromptTemplateRegistry:
|
||||
if not scene_name:
|
||||
raise ValueError("Prompt template scene name cannot be empty")
|
||||
if not model_names:
|
||||
model_names: List[str] = [_DEFAULT_MODEL_KEY]
|
||||
model_names = [_DEFAULT_MODEL_KEY]
|
||||
scene_registry = self.registry[scene_name]
|
||||
_register_scene_prompt_template(
|
||||
scene_registry, prompt_template, language, model_names
|
||||
@ -64,7 +68,7 @@ class PromptTemplateRegistry:
|
||||
scene_name: str,
|
||||
language: str,
|
||||
model_name: str,
|
||||
proxyllm_backend: str = None,
|
||||
proxyllm_backend: Optional[str] = None,
|
||||
):
|
||||
"""Get prompt template with scene name, language and model name
|
||||
proxyllm_backend: see CFG.PROXYLLM_BACKEND
|
||||
|
@ -1,9 +1,10 @@
|
||||
"""Agentic Workflow Expression Language (AWEL)
|
||||
"""Agentic Workflow Expression Language (AWEL).
|
||||
|
||||
Note:
|
||||
|
||||
AWEL is still an experimental feature and only opens the lowest level API.
|
||||
The stability of this API cannot be guaranteed at present.
|
||||
Agentic Workflow Expression Language(AWEL) is a set of intelligent agent workflow
|
||||
expression language specially designed for large model application development. It
|
||||
provides great functionality and flexibility. Through the AWEL API, you can focus on
|
||||
the development of business logic for LLMs applications without paying attention to
|
||||
cumbersome model and environment details.
|
||||
|
||||
"""
|
||||
|
||||
@ -71,10 +72,12 @@ __all__ = [
|
||||
"TransformStreamAbsOperator",
|
||||
"HttpTrigger",
|
||||
"setup_dev_environment",
|
||||
"_is_async_iterator",
|
||||
]
|
||||
|
||||
|
||||
def initialize_awel(system_app: SystemApp, dag_dirs: List[str]):
|
||||
"""Initialize AWEL."""
|
||||
from .dag.base import DAGVar
|
||||
from .dag.dag_manager import DAGManager
|
||||
from .operator.base import initialize_runner
|
||||
@ -92,13 +95,13 @@ def initialize_awel(system_app: SystemApp, dag_dirs: List[str]):
|
||||
|
||||
def setup_dev_environment(
|
||||
dags: List[DAG],
|
||||
host: Optional[str] = "127.0.0.1",
|
||||
port: Optional[int] = 5555,
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 5555,
|
||||
logging_level: Optional[str] = None,
|
||||
logger_filename: Optional[str] = None,
|
||||
show_dag_graph: Optional[bool] = True,
|
||||
) -> None:
|
||||
"""Setup a development environment for AWEL.
|
||||
"""Run AWEL in development environment.
|
||||
|
||||
Just using in development environment, not production environment.
|
||||
|
||||
@ -107,9 +110,11 @@ def setup_dev_environment(
|
||||
host (Optional[str], optional): The host. Defaults to "127.0.0.1"
|
||||
port (Optional[int], optional): The port. Defaults to 5555.
|
||||
logging_level (Optional[str], optional): The logging level. Defaults to None.
|
||||
logger_filename (Optional[str], optional): The logger filename. Defaults to None.
|
||||
show_dag_graph (Optional[bool], optional): Whether show the DAG graph. Defaults to True.
|
||||
If True, the DAG graph will be saved to a file and open it automatically.
|
||||
logger_filename (Optional[str], optional): The logger filename.
|
||||
Defaults to None.
|
||||
show_dag_graph (Optional[bool], optional): Whether show the DAG graph.
|
||||
Defaults to True. If True, the DAG graph will be saved to a file and open
|
||||
it automatically.
|
||||
"""
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
@ -138,7 +143,9 @@ def setup_dev_environment(
|
||||
logger.info(f"Visualize DAG {str(dag)} to {dag_graph_file}")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Visualize DAG {str(dag)} failed: {e}, if your system has no graphviz, you can install it by `pip install graphviz` or `sudo apt install graphviz`"
|
||||
f"Visualize DAG {str(dag)} failed: {e}, if your system has no "
|
||||
f"graphviz, you can install it by `pip install graphviz` or "
|
||||
f"`sudo apt install graphviz`"
|
||||
)
|
||||
for trigger in dag.trigger_nodes:
|
||||
trigger_manager.register_trigger(trigger)
|
||||
|
@ -1,7 +1,10 @@
|
||||
"""Base classes for AWEL."""
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class Trigger(ABC):
|
||||
"""Base class for trigger."""
|
||||
|
||||
@abstractmethod
|
||||
async def trigger(self) -> None:
|
||||
"""Trigger the workflow or a specific operation in the workflow."""
|
||||
|
@ -0,0 +1 @@
|
||||
"""The module of DAGs."""
|
@ -1,3 +1,7 @@
|
||||
"""The base module of DAG.
|
||||
|
||||
DAG is the core component of AWEL, it is used to define the relationship between tasks.
|
||||
"""
|
||||
import asyncio
|
||||
import contextvars
|
||||
import logging
|
||||
@ -6,7 +10,7 @@ import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
from concurrent.futures import Executor
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Union, cast
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
|
||||
@ -27,86 +31,108 @@ def _is_async_context():
|
||||
|
||||
|
||||
class DependencyMixin(ABC):
|
||||
"""The mixin class for DAGNode.
|
||||
|
||||
This class defines the interface for setting upstream and downstream nodes.
|
||||
|
||||
And it also implements the operator << and >> for setting upstream
|
||||
and downstream nodes.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def set_upstream(self, nodes: DependencyType) -> "DependencyMixin":
|
||||
def set_upstream(self, nodes: DependencyType) -> None:
|
||||
"""Set one or more upstream nodes for this node.
|
||||
|
||||
Args:
|
||||
nodes (DependencyType): Upstream nodes to be set to current node.
|
||||
|
||||
Returns:
|
||||
DependencyMixin: Returns self to allow method chaining.
|
||||
|
||||
Raises:
|
||||
ValueError: If no upstream nodes are provided or if an argument is not a DependencyMixin.
|
||||
ValueError: If no upstream nodes are provided or if an argument is
|
||||
not a DependencyMixin.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def set_downstream(self, nodes: DependencyType) -> "DependencyMixin":
|
||||
def set_downstream(self, nodes: DependencyType) -> None:
|
||||
"""Set one or more downstream nodes for this node.
|
||||
|
||||
Args:
|
||||
nodes (DependencyType): Downstream nodes to be set to current node.
|
||||
|
||||
Returns:
|
||||
DependencyMixin: Returns self to allow method chaining.
|
||||
|
||||
Raises:
|
||||
ValueError: If no downstream nodes are provided or if an argument is not a DependencyMixin.
|
||||
ValueError: If no downstream nodes are provided or if an argument is
|
||||
not a DependencyMixin.
|
||||
"""
|
||||
|
||||
def __lshift__(self, nodes: DependencyType) -> DependencyType:
|
||||
"""Implements self << nodes
|
||||
"""Set upstream nodes for current node.
|
||||
|
||||
Implements: self << nodes.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
.. code-block:: python
|
||||
# means node.set_upstream(input_node)
|
||||
node << input_node
|
||||
# means node2.set_upstream([input_node])
|
||||
node2 << [input_node]
|
||||
|
||||
# means node.set_upstream(input_node)
|
||||
node << input_node
|
||||
|
||||
# means node2.set_upstream([input_node])
|
||||
node2 << [input_node]
|
||||
"""
|
||||
self.set_upstream(nodes)
|
||||
return nodes
|
||||
|
||||
def __rshift__(self, nodes: DependencyType) -> DependencyType:
|
||||
"""Implements self >> nodes
|
||||
"""Set downstream nodes for current node.
|
||||
|
||||
Example:
|
||||
Implements: self >> nodes.
|
||||
|
||||
.. code-block:: python
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
# means node.set_downstream(next_node)
|
||||
node >> next_node
|
||||
# means node.set_downstream(next_node)
|
||||
node >> next_node
|
||||
|
||||
# means node2.set_downstream([next_node])
|
||||
node2 >> [next_node]
|
||||
# means node2.set_downstream([next_node])
|
||||
node2 >> [next_node]
|
||||
|
||||
"""
|
||||
self.set_downstream(nodes)
|
||||
return nodes
|
||||
|
||||
def __rrshift__(self, nodes: DependencyType) -> "DependencyMixin":
|
||||
"""Implements [node] >> self"""
|
||||
"""Set upstream nodes for current node.
|
||||
|
||||
Implements: [node] >> self
|
||||
"""
|
||||
self.__lshift__(nodes)
|
||||
return self
|
||||
|
||||
def __rlshift__(self, nodes: DependencyType) -> "DependencyMixin":
|
||||
"""Implements [node] << self"""
|
||||
"""Set downstream nodes for current node.
|
||||
|
||||
Implements: [node] << self
|
||||
"""
|
||||
self.__rshift__(nodes)
|
||||
return self
|
||||
|
||||
|
||||
class DAGVar:
|
||||
"""The DAGVar is used to store the current DAG context."""
|
||||
|
||||
_thread_local = threading.local()
|
||||
_async_local = contextvars.ContextVar("current_dag_stack", default=deque())
|
||||
_system_app: SystemApp = None
|
||||
_executor: Executor = None
|
||||
_async_local: contextvars.ContextVar = contextvars.ContextVar(
|
||||
"current_dag_stack", default=deque()
|
||||
)
|
||||
_system_app: Optional[SystemApp] = None
|
||||
# The executor for current DAG, this is used run some sync tasks in async DAG
|
||||
_executor: Optional[Executor] = None
|
||||
|
||||
@classmethod
|
||||
def enter_dag(cls, dag) -> None:
|
||||
"""Enter a DAG context.
|
||||
|
||||
Args:
|
||||
dag (DAG): The DAG to enter
|
||||
"""
|
||||
is_async = _is_async_context()
|
||||
if is_async:
|
||||
stack = cls._async_local.get()
|
||||
@ -119,6 +145,7 @@ class DAGVar:
|
||||
|
||||
@classmethod
|
||||
def exit_dag(cls) -> None:
|
||||
"""Exit a DAG context."""
|
||||
is_async = _is_async_context()
|
||||
if is_async:
|
||||
stack = cls._async_local.get()
|
||||
@ -134,6 +161,11 @@ class DAGVar:
|
||||
|
||||
@classmethod
|
||||
def get_current_dag(cls) -> Optional["DAG"]:
|
||||
"""Get the current DAG.
|
||||
|
||||
Returns:
|
||||
Optional[DAG]: The current DAG
|
||||
"""
|
||||
is_async = _is_async_context()
|
||||
if is_async:
|
||||
stack = cls._async_local.get()
|
||||
@ -147,36 +179,56 @@ class DAGVar:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_current_system_app(cls) -> SystemApp:
|
||||
def get_current_system_app(cls) -> Optional[SystemApp]:
|
||||
"""Get the current system app.
|
||||
|
||||
Returns:
|
||||
Optional[SystemApp]: The current system app
|
||||
"""
|
||||
# if not cls._system_app:
|
||||
# raise RuntimeError("System APP not set for DAGVar")
|
||||
return cls._system_app
|
||||
|
||||
@classmethod
|
||||
def set_current_system_app(cls, system_app: SystemApp) -> None:
|
||||
"""Set the current system app.
|
||||
|
||||
Args:
|
||||
system_app (SystemApp): The system app to set
|
||||
"""
|
||||
if cls._system_app:
|
||||
logger.warn("System APP has already set, nothing to do")
|
||||
logger.warning("System APP has already set, nothing to do")
|
||||
else:
|
||||
cls._system_app = system_app
|
||||
|
||||
@classmethod
|
||||
def get_executor(cls) -> Executor:
|
||||
def get_executor(cls) -> Optional[Executor]:
|
||||
"""Get the current executor.
|
||||
|
||||
Returns:
|
||||
Optional[Executor]: The current executor
|
||||
"""
|
||||
return cls._executor
|
||||
|
||||
@classmethod
|
||||
def set_executor(cls, executor: Executor) -> None:
|
||||
"""Set the current executor.
|
||||
|
||||
Args:
|
||||
executor (Executor): The executor to set
|
||||
"""
|
||||
cls._executor = executor
|
||||
|
||||
|
||||
class DAGLifecycle:
|
||||
"""The lifecycle of DAG"""
|
||||
"""The lifecycle of DAG."""
|
||||
|
||||
async def before_dag_run(self):
|
||||
"""The callback before DAG run"""
|
||||
"""Execute before DAG run."""
|
||||
pass
|
||||
|
||||
async def after_dag_end(self):
|
||||
"""The callback after DAG end,
|
||||
"""Execute after DAG end.
|
||||
|
||||
This method may be called multiple times, please make sure it is idempotent.
|
||||
"""
|
||||
@ -184,6 +236,8 @@ class DAGLifecycle:
|
||||
|
||||
|
||||
class DAGNode(DAGLifecycle, DependencyMixin, ABC):
|
||||
"""The base class of DAGNode."""
|
||||
|
||||
resource_group: Optional[ResourceGroup] = None
|
||||
"""The resource group of current DAGNode"""
|
||||
|
||||
@ -196,6 +250,17 @@ class DAGNode(DAGLifecycle, DependencyMixin, ABC):
|
||||
executor: Optional[Executor] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Initialize a DAGNode.
|
||||
|
||||
Args:
|
||||
dag (Optional["DAG"], optional): The DAG to add this node to.
|
||||
Defaults to None.
|
||||
node_id (Optional[str], optional): The node id. Defaults to None.
|
||||
node_name (Optional[str], optional): The node name. Defaults to None.
|
||||
system_app (Optional[SystemApp], optional): The system app.
|
||||
Defaults to None.
|
||||
executor (Optional[Executor], optional): The executor. Defaults to None.
|
||||
"""
|
||||
super().__init__()
|
||||
self._upstream: List["DAGNode"] = []
|
||||
self._downstream: List["DAGNode"] = []
|
||||
@ -206,24 +271,28 @@ class DAGNode(DAGLifecycle, DependencyMixin, ABC):
|
||||
self._executor: Optional[Executor] = executor or DAGVar.get_executor()
|
||||
if not node_id and self._dag:
|
||||
node_id = self._dag._new_node_id()
|
||||
self._node_id: str = node_id
|
||||
self._node_name: str = node_name
|
||||
self._node_id: Optional[str] = node_id
|
||||
self._node_name: Optional[str] = node_name
|
||||
|
||||
@property
|
||||
def node_id(self) -> str:
|
||||
"""Return the node id of current DAGNode."""
|
||||
if not self._node_id:
|
||||
raise ValueError("Node id not set for current DAGNode")
|
||||
return self._node_id
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def dev_mode(self) -> bool:
|
||||
"""Whether current DAGNode is in dev mode"""
|
||||
"""Whether current DAGNode is in dev mode."""
|
||||
|
||||
@property
|
||||
def system_app(self) -> SystemApp:
|
||||
def system_app(self) -> Optional[SystemApp]:
|
||||
"""Return the system app of current DAGNode."""
|
||||
return self._system_app
|
||||
|
||||
def set_system_app(self, system_app: SystemApp) -> None:
|
||||
"""Set system app for current DAGNode
|
||||
"""Set system app for current DAGNode.
|
||||
|
||||
Args:
|
||||
system_app (SystemApp): The system app
|
||||
@ -231,50 +300,97 @@ class DAGNode(DAGLifecycle, DependencyMixin, ABC):
|
||||
self._system_app = system_app
|
||||
|
||||
def set_node_id(self, node_id: str) -> None:
|
||||
"""Set node id for current DAGNode.
|
||||
|
||||
Args:
|
||||
node_id (str): The node id
|
||||
"""
|
||||
self._node_id = node_id
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Return the hash value of current DAGNode.
|
||||
|
||||
If the node_id is not None, return the hash value of node_id.
|
||||
"""
|
||||
if self.node_id:
|
||||
return hash(self.node_id)
|
||||
else:
|
||||
return super().__hash__()
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Return whether the current DAGNode is equal to other DAGNode."""
|
||||
if not isinstance(other, DAGNode):
|
||||
return False
|
||||
return self.node_id == other.node_id
|
||||
|
||||
@property
|
||||
def node_name(self) -> str:
|
||||
def node_name(self) -> Optional[str]:
|
||||
"""Return the node name of current DAGNode.
|
||||
|
||||
Returns:
|
||||
Optional[str]: The node name of current DAGNode
|
||||
"""
|
||||
return self._node_name
|
||||
|
||||
@property
|
||||
def dag(self) -> "DAG":
|
||||
def dag(self) -> Optional["DAG"]:
|
||||
"""Return the DAG of current DAGNode.
|
||||
|
||||
Returns:
|
||||
Optional["DAG"]: The DAG of current DAGNode
|
||||
"""
|
||||
return self._dag
|
||||
|
||||
def set_upstream(self, nodes: DependencyType) -> "DAGNode":
|
||||
def set_upstream(self, nodes: DependencyType) -> None:
|
||||
"""Set upstream nodes for current node.
|
||||
|
||||
Args:
|
||||
nodes (DependencyType): Upstream nodes to be set to current node.
|
||||
"""
|
||||
self.set_dependency(nodes)
|
||||
|
||||
def set_downstream(self, nodes: DependencyType) -> "DAGNode":
|
||||
def set_downstream(self, nodes: DependencyType) -> None:
|
||||
"""Set downstream nodes for current node.
|
||||
|
||||
Args:
|
||||
nodes (DependencyType): Downstream nodes to be set to current node.
|
||||
"""
|
||||
self.set_dependency(nodes, is_upstream=False)
|
||||
|
||||
@property
|
||||
def upstream(self) -> List["DAGNode"]:
|
||||
"""Return the upstream nodes of current DAGNode.
|
||||
|
||||
Returns:
|
||||
List["DAGNode"]: The upstream nodes of current DAGNode
|
||||
"""
|
||||
return self._upstream
|
||||
|
||||
@property
|
||||
def downstream(self) -> List["DAGNode"]:
|
||||
"""Return the downstream nodes of current DAGNode.
|
||||
|
||||
Returns:
|
||||
List["DAGNode"]: The downstream nodes of current DAGNode
|
||||
"""
|
||||
return self._downstream
|
||||
|
||||
def set_dependency(self, nodes: DependencyType, is_upstream: bool = True) -> None:
|
||||
"""Set dependency for current node.
|
||||
|
||||
Args:
|
||||
nodes (DependencyType): The nodes to set dependency to current node.
|
||||
is_upstream (bool, optional): Whether set upstream nodes. Defaults to True.
|
||||
"""
|
||||
if not isinstance(nodes, Sequence):
|
||||
nodes = [nodes]
|
||||
if not all(isinstance(node, DAGNode) for node in nodes):
|
||||
raise ValueError(
|
||||
"all nodes to set dependency to current node must be instance of 'DAGNode'"
|
||||
"all nodes to set dependency to current node must be instance "
|
||||
"of 'DAGNode'"
|
||||
)
|
||||
nodes: Sequence[DAGNode] = nodes
|
||||
dags = set([node.dag for node in nodes if node.dag])
|
||||
nodes = cast(Sequence[DAGNode], nodes)
|
||||
dags = set([node.dag for node in nodes if node.dag]) # noqa: C403
|
||||
if self.dag:
|
||||
dags.add(self.dag)
|
||||
if not dags:
|
||||
@ -302,6 +418,7 @@ class DAGNode(DAGLifecycle, DependencyMixin, ABC):
|
||||
node._upstream.append(self)
|
||||
|
||||
def __repr__(self):
|
||||
"""Return the representation of current DAGNode."""
|
||||
cls_name = self.__class__.__name__
|
||||
if self.node_name and self.node_name:
|
||||
return f"{cls_name}(node_id={self.node_id}, node_name={self.node_name})"
|
||||
@ -313,6 +430,7 @@ class DAGNode(DAGLifecycle, DependencyMixin, ABC):
|
||||
return f"{cls_name}"
|
||||
|
||||
def __str__(self):
|
||||
"""Return the string of current DAGNode."""
|
||||
return self.__repr__()
|
||||
|
||||
|
||||
@ -321,7 +439,7 @@ def _build_task_key(task_name: str, key: str) -> str:
|
||||
|
||||
|
||||
class DAGContext:
|
||||
"""The context of current DAG, created when the DAG is running
|
||||
"""The context of current DAG, created when the DAG is running.
|
||||
|
||||
Every DAG has been triggered will create a new DAGContext.
|
||||
"""
|
||||
@ -329,22 +447,32 @@ class DAGContext:
|
||||
def __init__(
|
||||
self,
|
||||
streaming_call: bool = False,
|
||||
node_to_outputs: Dict[str, TaskContext] = None,
|
||||
node_name_to_ids: Dict[str, str] = None,
|
||||
node_to_outputs: Optional[Dict[str, TaskContext]] = None,
|
||||
node_name_to_ids: Optional[Dict[str, str]] = None,
|
||||
) -> None:
|
||||
"""Initialize a DAGContext.
|
||||
|
||||
Args:
|
||||
streaming_call (bool, optional): Whether the current DAG is streaming call.
|
||||
Defaults to False.
|
||||
node_to_outputs (Optional[Dict[str, TaskContext]], optional):
|
||||
The task outputs of current DAG. Defaults to None.
|
||||
node_name_to_ids (Optional[Dict[str, str]], optional):
|
||||
The task name to task id mapping. Defaults to None.
|
||||
"""
|
||||
if not node_to_outputs:
|
||||
node_to_outputs = {}
|
||||
if not node_name_to_ids:
|
||||
node_name_to_ids = {}
|
||||
self._streaming_call = streaming_call
|
||||
self._curr_task_ctx = None
|
||||
self._curr_task_ctx: Optional[TaskContext] = None
|
||||
self._share_data: Dict[str, Any] = {}
|
||||
self._node_to_outputs = node_to_outputs
|
||||
self._node_name_to_ids = node_name_to_ids
|
||||
self._node_to_outputs: Dict[str, TaskContext] = node_to_outputs
|
||||
self._node_name_to_ids: Dict[str, str] = node_name_to_ids
|
||||
|
||||
@property
|
||||
def _task_outputs(self) -> Dict[str, TaskContext]:
|
||||
"""The task outputs of current DAG
|
||||
"""Return the task outputs of current DAG.
|
||||
|
||||
Just use for internal for now.
|
||||
Returns:
|
||||
@ -354,18 +482,28 @@ class DAGContext:
|
||||
|
||||
@property
|
||||
def current_task_context(self) -> TaskContext:
|
||||
"""Return the current task context."""
|
||||
if not self._curr_task_ctx:
|
||||
raise RuntimeError("Current task context not set")
|
||||
return self._curr_task_ctx
|
||||
|
||||
@property
|
||||
def streaming_call(self) -> bool:
|
||||
"""Whether the current DAG is streaming call"""
|
||||
"""Whether the current DAG is streaming call."""
|
||||
return self._streaming_call
|
||||
|
||||
def set_current_task_context(self, _curr_task_ctx: TaskContext) -> None:
|
||||
"""Set the current task context.
|
||||
|
||||
When the task is running, the current task context
|
||||
will be set to the task context.
|
||||
|
||||
TODO: We should support parallel task running in the future.
|
||||
"""
|
||||
self._curr_task_ctx = _curr_task_ctx
|
||||
|
||||
def get_task_output(self, task_name: str) -> TaskOutput:
|
||||
"""Get the task output by task name
|
||||
"""Get the task output by task name.
|
||||
|
||||
Args:
|
||||
task_name (str): The task name
|
||||
@ -376,22 +514,41 @@ class DAGContext:
|
||||
if task_name is None:
|
||||
raise ValueError("task_name can't be None")
|
||||
node_id = self._node_name_to_ids.get(task_name)
|
||||
if node_id:
|
||||
if not node_id:
|
||||
raise ValueError(f"Task name {task_name} not exists in DAG")
|
||||
return self._task_outputs.get(node_id).task_output
|
||||
task_output = self._task_outputs.get(node_id)
|
||||
if not task_output:
|
||||
raise ValueError(f"Task output for task {task_name} not exists")
|
||||
return task_output.task_output
|
||||
|
||||
async def get_from_share_data(self, key: str) -> Any:
|
||||
"""Get share data by key.
|
||||
|
||||
Args:
|
||||
key (str): The share data key
|
||||
|
||||
Returns:
|
||||
Any: The share data, you can cast it to the real type
|
||||
"""
|
||||
return self._share_data.get(key)
|
||||
|
||||
async def save_to_share_data(
|
||||
self, key: str, data: Any, overwrite: bool = False
|
||||
) -> None:
|
||||
"""Save share data by key.
|
||||
|
||||
Args:
|
||||
key (str): The share data key
|
||||
data (Any): The share data
|
||||
overwrite (bool): Whether overwrite the share data if the key
|
||||
already exists. Defaults to None.
|
||||
"""
|
||||
if key in self._share_data and not overwrite:
|
||||
raise ValueError(f"Share data key {key} already exists")
|
||||
self._share_data[key] = data
|
||||
|
||||
async def get_task_share_data(self, task_name: str, key: str) -> Any:
|
||||
"""Get share data by task name and key
|
||||
"""Get share data by task name and key.
|
||||
|
||||
Args:
|
||||
task_name (str): The task name
|
||||
@ -409,14 +566,14 @@ class DAGContext:
|
||||
async def save_task_share_data(
|
||||
self, task_name: str, key: str, data: Any, overwrite: bool = False
|
||||
) -> None:
|
||||
"""Save share data by task name and key
|
||||
"""Save share data by task name and key.
|
||||
|
||||
Args:
|
||||
task_name (str): The task name
|
||||
key (str): The share data key
|
||||
data (Any): The share data
|
||||
overwrite (bool): Whether overwrite the share data if the key already exists.
|
||||
Defaults to None.
|
||||
overwrite (bool): Whether overwrite the share data if the key
|
||||
already exists. Defaults to None.
|
||||
|
||||
Raises:
|
||||
ValueError: If the share data key already exists and overwrite is not True
|
||||
@ -429,15 +586,22 @@ class DAGContext:
|
||||
|
||||
|
||||
class DAG:
|
||||
"""The DAG class.
|
||||
|
||||
Manage the DAG nodes and the relationship between them.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, dag_id: str, resource_group: Optional[ResourceGroup] = None
|
||||
) -> None:
|
||||
"""Initialize a DAG."""
|
||||
self._dag_id = dag_id
|
||||
self.node_map: Dict[str, DAGNode] = {}
|
||||
self.node_name_to_node: Dict[str, DAGNode] = {}
|
||||
self._root_nodes: List[DAGNode] = None
|
||||
self._leaf_nodes: List[DAGNode] = None
|
||||
self._trigger_nodes: List[DAGNode] = None
|
||||
self._root_nodes: List[DAGNode] = []
|
||||
self._leaf_nodes: List[DAGNode] = []
|
||||
self._trigger_nodes: List[DAGNode] = []
|
||||
self._resource_group: Optional[ResourceGroup] = resource_group
|
||||
|
||||
def _append_node(self, node: DAGNode) -> None:
|
||||
if node.node_id in self.node_map:
|
||||
@ -448,22 +612,26 @@ class DAG:
|
||||
f"Node name {node.node_name} already exists in DAG {self.dag_id}"
|
||||
)
|
||||
self.node_name_to_node[node.node_name] = node
|
||||
self.node_map[node.node_id] = node
|
||||
node_id = node.node_id
|
||||
if not node_id:
|
||||
raise ValueError("Node id can't be None")
|
||||
self.node_map[node_id] = node
|
||||
# clear cached nodes
|
||||
self._root_nodes = None
|
||||
self._leaf_nodes = None
|
||||
self._root_nodes = []
|
||||
self._leaf_nodes = []
|
||||
|
||||
def _new_node_id(self) -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
@property
|
||||
def dag_id(self) -> str:
|
||||
"""Return the dag id of current DAG."""
|
||||
return self._dag_id
|
||||
|
||||
def _build(self) -> None:
|
||||
from ..operator.common_operator import TriggerOperator
|
||||
|
||||
nodes = set()
|
||||
nodes: Set[DAGNode] = set()
|
||||
for _, node in self.node_map.items():
|
||||
nodes = nodes.union(_get_nodes(node))
|
||||
self._root_nodes = list(set(filter(lambda x: not x.upstream, nodes)))
|
||||
@ -474,7 +642,7 @@ class DAG:
|
||||
|
||||
@property
|
||||
def root_nodes(self) -> List[DAGNode]:
|
||||
"""The root nodes of current DAG
|
||||
"""Return the root nodes of current DAG.
|
||||
|
||||
Returns:
|
||||
List[DAGNode]: The root nodes of current DAG, no repeat
|
||||
@ -485,7 +653,7 @@ class DAG:
|
||||
|
||||
@property
|
||||
def leaf_nodes(self) -> List[DAGNode]:
|
||||
"""The leaf nodes of current DAG
|
||||
"""Return the leaf nodes of current DAG.
|
||||
|
||||
Returns:
|
||||
List[DAGNode]: The leaf nodes of current DAG, no repeat
|
||||
@ -496,7 +664,7 @@ class DAG:
|
||||
|
||||
@property
|
||||
def trigger_nodes(self) -> List[DAGNode]:
|
||||
"""The trigger nodes of current DAG
|
||||
"""Return the trigger nodes of current DAG.
|
||||
|
||||
Returns:
|
||||
List[DAGNode]: The trigger nodes of current DAG, no repeat
|
||||
@ -506,34 +674,42 @@ class DAG:
|
||||
return self._trigger_nodes
|
||||
|
||||
async def _after_dag_end(self) -> None:
|
||||
"""The callback after DAG end"""
|
||||
"""Execute after DAG end."""
|
||||
tasks = []
|
||||
for node in self.node_map.values():
|
||||
tasks.append(node.after_dag_end())
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
def print_tree(self) -> None:
|
||||
"""Print the DAG tree"""
|
||||
"""Print the DAG tree""" # noqa: D400
|
||||
_print_format_dag_tree(self)
|
||||
|
||||
def visualize_dag(self, view: bool = True, **kwargs) -> Optional[str]:
|
||||
"""Create the DAG graph"""
|
||||
"""Visualize the DAG.
|
||||
|
||||
Args:
|
||||
view (bool, optional): Whether view the DAG graph. Defaults to True,
|
||||
if True, it will open the graph file with your default viewer.
|
||||
"""
|
||||
self.print_tree()
|
||||
return _visualize_dag(self, view=view, **kwargs)
|
||||
|
||||
def __enter__(self):
|
||||
"""Enter a DAG context."""
|
||||
DAGVar.enter_dag(self)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Exit a DAG context."""
|
||||
DAGVar.exit_dag()
|
||||
|
||||
def __repr__(self):
|
||||
"""Return the representation of current DAG."""
|
||||
return f"DAG(dag_id={self.dag_id})"
|
||||
|
||||
|
||||
def _get_nodes(node: DAGNode, is_upstream: Optional[bool] = True) -> set[DAGNode]:
|
||||
nodes = set()
|
||||
def _get_nodes(node: DAGNode, is_upstream: Optional[bool] = True) -> Set[DAGNode]:
|
||||
nodes: Set[DAGNode] = set()
|
||||
if not node:
|
||||
return nodes
|
||||
nodes.add(node)
|
||||
@ -553,7 +729,7 @@ def _print_dag(
|
||||
level: int = 0,
|
||||
prefix: str = "",
|
||||
last: bool = True,
|
||||
level_dict: Dict[str, Any] = None,
|
||||
level_dict: Optional[Dict[int, Any]] = None,
|
||||
):
|
||||
if level_dict is None:
|
||||
level_dict = {}
|
||||
@ -606,7 +782,7 @@ def _handle_dag_nodes(
|
||||
|
||||
|
||||
def _visualize_dag(dag: DAG, view: bool = True, **kwargs) -> Optional[str]:
|
||||
"""Visualize the DAG
|
||||
"""Visualize the DAG.
|
||||
|
||||
Args:
|
||||
dag (DAG): The DAG to visualize
|
||||
@ -641,7 +817,7 @@ def _visualize_dag(dag: DAG, view: bool = True, **kwargs) -> Optional[str]:
|
||||
filename = kwargs["filename"]
|
||||
del kwargs["filename"]
|
||||
|
||||
if not "directory" in kwargs:
|
||||
if "directory" not in kwargs:
|
||||
from dbgpt.configs.model_config import LOGDIR
|
||||
|
||||
kwargs["directory"] = LOGDIR
|
||||
|
@ -1,3 +1,8 @@
|
||||
"""DAGManager is a component of AWEL, it is used to manage DAGs.
|
||||
|
||||
DAGManager will load DAGs from dag_dirs, and register the trigger nodes
|
||||
to TriggerManager.
|
||||
"""
|
||||
import logging
|
||||
from typing import Dict, List
|
||||
|
||||
@ -10,24 +15,35 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DAGManager(BaseComponent):
|
||||
"""The component of DAGManager."""
|
||||
|
||||
name = ComponentType.AWEL_DAG_MANAGER
|
||||
|
||||
def __init__(self, system_app: SystemApp, dag_dirs: List[str]):
|
||||
"""Initialize a DAGManager.
|
||||
|
||||
Args:
|
||||
system_app (SystemApp): The system app.
|
||||
dag_dirs (List[str]): The directories to load DAGs.
|
||||
"""
|
||||
super().__init__(system_app)
|
||||
self.dag_loader = LocalFileDAGLoader(dag_dirs)
|
||||
self.system_app = system_app
|
||||
self.dag_map: Dict[str, DAG] = {}
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
"""Initialize the DAGManager."""
|
||||
self.system_app = system_app
|
||||
|
||||
def load_dags(self):
|
||||
"""Load DAGs from dag_dirs."""
|
||||
dags = self.dag_loader.load_dags()
|
||||
triggers = []
|
||||
for dag in dags:
|
||||
dag_id = dag.dag_id
|
||||
if dag_id in self.dag_map:
|
||||
raise ValueError(f"Load DAG error, DAG ID {dag_id} has already exist")
|
||||
self.dag_map[dag_id] = dag
|
||||
triggers += dag.trigger_nodes
|
||||
from ..trigger.trigger_manager import DefaultTriggerManager
|
||||
|
||||
|
@ -1,3 +1,8 @@
|
||||
"""DAG loader.
|
||||
|
||||
DAGLoader will load DAGs from dag_dirs or other sources.
|
||||
Now only support load DAGs from local files.
|
||||
"""
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
@ -12,16 +17,26 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DAGLoader(ABC):
|
||||
"""Abstract base class representing a loader for loading DAGs."""
|
||||
|
||||
@abstractmethod
|
||||
def load_dags(self) -> List[DAG]:
|
||||
"""Load dags"""
|
||||
"""Load dags."""
|
||||
|
||||
|
||||
class LocalFileDAGLoader(DAGLoader):
|
||||
"""DAG loader for loading DAGs from local files."""
|
||||
|
||||
def __init__(self, dag_dirs: List[str]) -> None:
|
||||
"""Initialize a LocalFileDAGLoader.
|
||||
|
||||
Args:
|
||||
dag_dirs (List[str]): The directories to load DAGs.
|
||||
"""
|
||||
self._dag_dirs = dag_dirs
|
||||
|
||||
def load_dags(self) -> List[DAG]:
|
||||
"""Load dags from local files."""
|
||||
dags = []
|
||||
for filepath in self._dag_dirs:
|
||||
if not os.path.exists(filepath):
|
||||
@ -70,7 +85,7 @@ def _load_modules_from_file(filepath: str):
|
||||
sys.modules[spec.name] = new_module
|
||||
loader.exec_module(new_module)
|
||||
return [new_module]
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
msg = traceback.format_exc()
|
||||
logger.error(f"Failed to import: {filepath}, error message: {msg}")
|
||||
# TODO save error message
|
||||
|
@ -0,0 +1 @@
|
||||
"""The module of operator."""
|
@ -1,7 +1,7 @@
|
||||
"""Base classes for operators that can be executed within a workflow."""
|
||||
import asyncio
|
||||
import functools
|
||||
from abc import ABC, ABCMeta, abstractmethod
|
||||
from inspect import signature
|
||||
from types import FunctionType
|
||||
from typing import (
|
||||
Any,
|
||||
@ -9,7 +9,6 @@ from typing import (
|
||||
Dict,
|
||||
Generic,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
TypeVar,
|
||||
Union,
|
||||
@ -21,7 +20,6 @@ from dbgpt.util.executor_utils import (
|
||||
AsyncToSyncIterator,
|
||||
BlockingFunction,
|
||||
DefaultExecutorFactory,
|
||||
ExecutorFactory,
|
||||
blocking_func_to_async,
|
||||
)
|
||||
|
||||
@ -54,13 +52,15 @@ class WorkflowRunner(ABC, Generic[T]):
|
||||
node (RunnableDAGNode): The starting node of the workflow to be executed.
|
||||
call_data (CALL_DATA): The data pass to root operator node.
|
||||
streaming_call (bool): Whether the call is a streaming call.
|
||||
exist_dag_ctx (DAGContext): The context of the DAG when this node is run, Defaults to None.
|
||||
exist_dag_ctx (DAGContext): The context of the DAG when this node is run,
|
||||
Defaults to None.
|
||||
Returns:
|
||||
DAGContext: The context after executing the workflow, containing the final state and data.
|
||||
DAGContext: The context after executing the workflow, containing the final
|
||||
state and data.
|
||||
"""
|
||||
|
||||
|
||||
default_runner: WorkflowRunner = None
|
||||
default_runner: Optional[WorkflowRunner] = None
|
||||
|
||||
|
||||
class BaseOperatorMeta(ABCMeta):
|
||||
@ -68,8 +68,7 @@ class BaseOperatorMeta(ABCMeta):
|
||||
|
||||
@classmethod
|
||||
def _apply_defaults(cls, func: F) -> F:
|
||||
sig_cache = signature(func)
|
||||
|
||||
# sig_cache = signature(func)
|
||||
@functools.wraps(func)
|
||||
def apply_defaults(self: "BaseOperator", *args: Any, **kwargs: Any) -> Any:
|
||||
dag: Optional[DAG] = kwargs.get("dag") or DAGVar.get_current_dag()
|
||||
@ -81,7 +80,7 @@ class BaseOperatorMeta(ABCMeta):
|
||||
if not executor:
|
||||
if system_app:
|
||||
executor = system_app.get_component(
|
||||
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
|
||||
ComponentType.EXECUTOR_DEFAULT, DefaultExecutorFactory
|
||||
).create()
|
||||
else:
|
||||
executor = DefaultExecutorFactory().create()
|
||||
@ -107,9 +106,10 @@ class BaseOperatorMeta(ABCMeta):
|
||||
real_obj = func(self, *args, **kwargs)
|
||||
return real_obj
|
||||
|
||||
return cast(T, apply_defaults)
|
||||
return cast(F, apply_defaults)
|
||||
|
||||
def __new__(cls, name, bases, namespace, **kwargs):
|
||||
"""Create a new BaseOperator class with default arguments."""
|
||||
new_cls = super().__new__(cls, name, bases, namespace, **kwargs)
|
||||
new_cls.__init__ = cls._apply_defaults(new_cls.__init__)
|
||||
return new_cls
|
||||
@ -126,13 +126,14 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
task_id: Optional[str] = None,
|
||||
task_name: Optional[str] = None,
|
||||
dag: Optional[DAG] = None,
|
||||
runner: WorkflowRunner = None,
|
||||
runner: Optional[WorkflowRunner] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Initializes a BaseOperator with an optional workflow runner.
|
||||
"""Create a BaseOperator with an optional workflow runner.
|
||||
|
||||
Args:
|
||||
runner (WorkflowRunner, optional): The runner used to execute the workflow. Defaults to None.
|
||||
runner (WorkflowRunner, optional): The runner used to execute the workflow.
|
||||
Defaults to None.
|
||||
"""
|
||||
super().__init__(node_id=task_id, node_name=task_name, dag=dag, **kwargs)
|
||||
if not runner:
|
||||
@ -141,19 +142,24 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
runner = DefaultWorkflowRunner()
|
||||
|
||||
self._runner: WorkflowRunner = runner
|
||||
self._dag_ctx: DAGContext = None
|
||||
self._dag_ctx: Optional[DAGContext] = None
|
||||
|
||||
@property
|
||||
def current_dag_context(self) -> DAGContext:
|
||||
"""Return the current DAG context."""
|
||||
if not self._dag_ctx:
|
||||
raise ValueError("DAGContext is not set")
|
||||
return self._dag_ctx
|
||||
|
||||
@property
|
||||
def dev_mode(self) -> bool:
|
||||
"""Whether the operator is in dev mode.
|
||||
|
||||
In production mode, the default runner is not None.
|
||||
|
||||
Returns:
|
||||
bool: Whether the operator is in dev mode. True if the default runner is None.
|
||||
bool: Whether the operator is in dev mode. True if the
|
||||
default runner is None.
|
||||
"""
|
||||
return default_runner is None
|
||||
|
||||
@ -186,7 +192,8 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
|
||||
Args:
|
||||
call_data (CALL_DATA): The data pass to root operator node.
|
||||
dag_ctx (DAGContext): The context of the DAG when this node is run, Defaults to None.
|
||||
dag_ctx (DAGContext): The context of the DAG when this node is run,
|
||||
Defaults to None.
|
||||
Returns:
|
||||
OUT: The output of the node after execution.
|
||||
"""
|
||||
@ -196,7 +203,9 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
return out_ctx.current_task_context.task_output.output
|
||||
|
||||
def _blocking_call(
|
||||
self, call_data: Optional[CALL_DATA] = None, loop: asyncio.BaseEventLoop = None
|
||||
self,
|
||||
call_data: Optional[CALL_DATA] = None,
|
||||
loop: Optional[asyncio.BaseEventLoop] = None,
|
||||
) -> OUT:
|
||||
"""Execute the node and return the output.
|
||||
|
||||
@ -213,6 +222,7 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
|
||||
if not loop:
|
||||
loop = get_or_create_event_loop()
|
||||
loop = cast(asyncio.BaseEventLoop, loop)
|
||||
return loop.run_until_complete(self.call(call_data))
|
||||
|
||||
async def call_stream(
|
||||
@ -226,7 +236,8 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
|
||||
Args:
|
||||
call_data (CALL_DATA): The data pass to root operator node.
|
||||
dag_ctx (DAGContext): The context of the DAG when this node is run, Defaults to None.
|
||||
dag_ctx (DAGContext): The context of the DAG when this node is run,
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
AsyncIterator[OUT]: An asynchronous iterator over the output stream.
|
||||
@ -237,7 +248,9 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
return out_ctx.current_task_context.task_output.output_stream
|
||||
|
||||
def _blocking_call_stream(
|
||||
self, call_data: Optional[CALL_DATA] = None, loop: asyncio.BaseEventLoop = None
|
||||
self,
|
||||
call_data: Optional[CALL_DATA] = None,
|
||||
loop: Optional[asyncio.BaseEventLoop] = None,
|
||||
) -> Iterator[OUT]:
|
||||
"""Execute the node and return the output as a stream.
|
||||
|
||||
@ -259,9 +272,22 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
async def blocking_func_to_async(
|
||||
self, func: BlockingFunction, *args, **kwargs
|
||||
) -> Any:
|
||||
"""Execute a blocking function asynchronously.
|
||||
|
||||
In AWEL, the operators are executed asynchronously. However,
|
||||
some functions are blocking, we run them in a separate thread.
|
||||
|
||||
Args:
|
||||
func (BlockingFunction): The blocking function to be executed.
|
||||
*args: Positional arguments for the function.
|
||||
**kwargs: Keyword arguments for the function.
|
||||
"""
|
||||
if not self._executor:
|
||||
raise ValueError("Executor is not set")
|
||||
return await blocking_func_to_async(self._executor, func, *args, **kwargs)
|
||||
|
||||
|
||||
def initialize_runner(runner: WorkflowRunner):
|
||||
"""Initialize the default runner."""
|
||||
global default_runner
|
||||
default_runner = runner
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""Common operators of AWEL."""
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
@ -13,7 +13,17 @@ from typing import (
|
||||
)
|
||||
|
||||
from ..dag.base import DAGContext
|
||||
from ..task.base import IN, OUT, InputContext, InputSource, TaskContext, TaskOutput
|
||||
from ..task.base import (
|
||||
IN,
|
||||
OUT,
|
||||
InputContext,
|
||||
InputSource,
|
||||
JoinFunc,
|
||||
MapFunc,
|
||||
ReduceFunc,
|
||||
TaskContext,
|
||||
TaskOutput,
|
||||
)
|
||||
from .base import BaseOperator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -25,7 +35,12 @@ class JoinOperator(BaseOperator, Generic[OUT]):
|
||||
This node type is useful for combining the outputs of upstream nodes.
|
||||
"""
|
||||
|
||||
def __init__(self, combine_function, **kwargs):
|
||||
def __init__(self, combine_function: JoinFunc, **kwargs):
|
||||
"""Create a JoinDAGNode with a combine function.
|
||||
|
||||
Args:
|
||||
combine_function: A function that defines how to combine inputs.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
if not callable(combine_function):
|
||||
raise ValueError("combine_function must be callable")
|
||||
@ -33,6 +48,7 @@ class JoinOperator(BaseOperator, Generic[OUT]):
|
||||
|
||||
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||
"""Run the join operation on the DAG context's inputs.
|
||||
|
||||
Args:
|
||||
dag_ctx (DAGContext): The current context of the DAG.
|
||||
|
||||
@ -50,8 +66,10 @@ class JoinOperator(BaseOperator, Generic[OUT]):
|
||||
|
||||
|
||||
class ReduceStreamOperator(BaseOperator, Generic[IN, OUT]):
|
||||
def __init__(self, reduce_function=None, **kwargs):
|
||||
"""Initializes a ReduceStreamOperator with a combine function.
|
||||
"""Operator that reduces inputs using a custom reduce function."""
|
||||
|
||||
def __init__(self, reduce_function: Optional[ReduceFunc] = None, **kwargs):
|
||||
"""Create a ReduceStreamOperator with a combine function.
|
||||
|
||||
Args:
|
||||
combine_function: A function that defines how to combine inputs.
|
||||
@ -89,6 +107,7 @@ class ReduceStreamOperator(BaseOperator, Generic[IN, OUT]):
|
||||
return reduce_output
|
||||
|
||||
async def reduce(self, input_value: AsyncIterator[IN]) -> OUT:
|
||||
"""Reduce the input stream to a single value."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ -99,8 +118,8 @@ class MapOperator(BaseOperator, Generic[IN, OUT]):
|
||||
passes the transformed data downstream.
|
||||
"""
|
||||
|
||||
def __init__(self, map_function=None, **kwargs):
|
||||
"""Initializes a MapDAGNode with a mapping function.
|
||||
def __init__(self, map_function: Optional[MapFunc] = None, **kwargs):
|
||||
"""Create a MapDAGNode with a mapping function.
|
||||
|
||||
Args:
|
||||
map_function: A function that defines how to map the input data.
|
||||
@ -133,13 +152,18 @@ class MapOperator(BaseOperator, Generic[IN, OUT]):
|
||||
if not call_data and not curr_task_ctx.task_input.check_single_parent():
|
||||
num_parents = len(curr_task_ctx.task_input.parent_outputs)
|
||||
raise ValueError(
|
||||
f"task {curr_task_ctx.task_id} MapDAGNode expects single parent, now number of parents: {num_parents}"
|
||||
f"task {curr_task_ctx.task_id} MapDAGNode expects single parent,"
|
||||
f"now number of parents: {num_parents}"
|
||||
)
|
||||
map_function = self.map_function or self.map
|
||||
|
||||
if call_data:
|
||||
call_data = await curr_task_ctx._call_data_to_output()
|
||||
output = await call_data.map(map_function)
|
||||
wrapped_call_data = await curr_task_ctx._call_data_to_output()
|
||||
if not wrapped_call_data:
|
||||
raise ValueError(
|
||||
f"task {curr_task_ctx.task_id} MapDAGNode expects wrapped_call_data"
|
||||
)
|
||||
output: TaskOutput[OUT] = await wrapped_call_data.map(map_function)
|
||||
curr_task_ctx.set_task_output(output)
|
||||
return output
|
||||
|
||||
@ -150,6 +174,7 @@ class MapOperator(BaseOperator, Generic[IN, OUT]):
|
||||
return output
|
||||
|
||||
async def map(self, input_value: IN) -> OUT:
|
||||
"""Map the input data to a new value."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ -161,6 +186,11 @@ class BranchOperator(BaseOperator, Generic[IN, OUT]):
|
||||
|
||||
This node filters its input data using a branching function and
|
||||
allows for conditional paths in the workflow.
|
||||
|
||||
If a branch function returns True, the corresponding task will be executed.
|
||||
otherwise, the corresponding task will be skipped, and the output of
|
||||
this skip node will be set to `SKIP_DATA`
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -168,11 +198,11 @@ class BranchOperator(BaseOperator, Generic[IN, OUT]):
|
||||
branches: Optional[Dict[BranchFunc[IN], Union[BaseOperator, str]]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initializes a BranchDAGNode with a branching function.
|
||||
"""Create a BranchDAGNode with a branching function.
|
||||
|
||||
Args:
|
||||
branches (Dict[BranchFunc[IN], Union[BaseOperator, str]]): Dict of function that defines the branching condition.
|
||||
branches (Dict[BranchFunc[IN], Union[BaseOperator, str]]):
|
||||
Dict of function that defines the branching condition.
|
||||
|
||||
Raises:
|
||||
ValueError: If the branch_function is not callable.
|
||||
@ -183,7 +213,9 @@ class BranchOperator(BaseOperator, Generic[IN, OUT]):
|
||||
if not callable(branch_function):
|
||||
raise ValueError("branch_function must be callable")
|
||||
if isinstance(value, BaseOperator):
|
||||
branches[branch_function] = value.node_name or value.node_name
|
||||
if not value.node_name:
|
||||
raise ValueError("branch node name must be set")
|
||||
branches[branch_function] = value.node_name
|
||||
self._branches = branches
|
||||
|
||||
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||
@ -210,7 +242,7 @@ class BranchOperator(BaseOperator, Generic[IN, OUT]):
|
||||
branches = await self.branches()
|
||||
|
||||
branch_func_tasks = []
|
||||
branch_nodes: List[str] = []
|
||||
branch_nodes: List[Union[BaseOperator, str]] = []
|
||||
for func, node_name in branches.items():
|
||||
branch_nodes.append(node_name)
|
||||
branch_func_tasks.append(
|
||||
@ -225,20 +257,25 @@ class BranchOperator(BaseOperator, Generic[IN, OUT]):
|
||||
node_name = branch_nodes[i]
|
||||
branch_out = ctx.parent_outputs[0].task_output
|
||||
logger.info(
|
||||
f"branch_input_ctxs {i} result {branch_out.output}, is_empty: {branch_out.is_empty}"
|
||||
f"branch_input_ctxs {i} result {branch_out.output}, "
|
||||
f"is_empty: {branch_out.is_empty}"
|
||||
)
|
||||
if ctx.parent_outputs[0].task_output.is_empty:
|
||||
if ctx.parent_outputs[0].task_output.is_none:
|
||||
logger.info(f"Skip node name {node_name}")
|
||||
skip_node_names.append(node_name)
|
||||
curr_task_ctx.update_metadata("skip_node_names", skip_node_names)
|
||||
return parent_output
|
||||
|
||||
async def branches(self) -> Dict[BranchFunc[IN], Union[BaseOperator, str]]:
|
||||
"""Return branch logic based on input data."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class InputOperator(BaseOperator, Generic[OUT]):
|
||||
"""Operator node that reads data from an input source."""
|
||||
|
||||
def __init__(self, input_source: InputSource[OUT], **kwargs) -> None:
|
||||
"""Create an InputDAGNode with an input source."""
|
||||
super().__init__(**kwargs)
|
||||
self._input_source = input_source
|
||||
|
||||
@ -250,7 +287,10 @@ class InputOperator(BaseOperator, Generic[OUT]):
|
||||
|
||||
|
||||
class TriggerOperator(InputOperator, Generic[OUT]):
|
||||
"""Operator node that triggers the DAG to run."""
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
"""Create a TriggerDAGNode."""
|
||||
from ..task.task_impl import SimpleCallDataInputSource
|
||||
|
||||
super().__init__(input_source=SimpleCallDataInputSource(), **kwargs)
|
||||
|
@ -1,3 +1,4 @@
|
||||
"""The module of stream operator."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import AsyncIterator, Generic
|
||||
|
||||
@ -7,12 +8,18 @@ from .base import BaseOperator
|
||||
|
||||
|
||||
class StreamifyAbsOperator(BaseOperator[OUT], ABC, Generic[IN, OUT]):
|
||||
"""An abstract operator that converts a value of IN to an AsyncIterator[OUT]."""
|
||||
|
||||
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
|
||||
call_data = curr_task_ctx.call_data
|
||||
if call_data:
|
||||
call_data = await curr_task_ctx._call_data_to_output()
|
||||
output = await call_data.streamify(self.streamify)
|
||||
wrapped_call_data = await curr_task_ctx._call_data_to_output()
|
||||
if not wrapped_call_data:
|
||||
raise ValueError(
|
||||
f"task {curr_task_ctx.task_id} MapDAGNode expects wrapped_call_data"
|
||||
)
|
||||
output = await wrapped_call_data.streamify(self.streamify)
|
||||
curr_task_ctx.set_task_output(output)
|
||||
return output
|
||||
output = await curr_task_ctx.task_input.parent_outputs[0].task_output.streamify(
|
||||
@ -23,26 +30,28 @@ class StreamifyAbsOperator(BaseOperator[OUT], ABC, Generic[IN, OUT]):
|
||||
|
||||
@abstractmethod
|
||||
async def streamify(self, input_value: IN) -> AsyncIterator[OUT]:
|
||||
"""Convert a value of IN to an AsyncIterator[OUT]
|
||||
"""Convert a value of IN to an AsyncIterator[OUT].
|
||||
|
||||
Args:
|
||||
input_value (IN): The data of parent operator's output
|
||||
|
||||
Example:
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
.. code-block:: python
|
||||
class MyStreamOperator(StreamifyAbsOperator[int, int]):
|
||||
async def streamify(self, input_value: int) -> AsyncIterator[int]:
|
||||
for i in range(input_value):
|
||||
yield i
|
||||
|
||||
class MyStreamOperator(StreamifyAbsOperator[int, int]):
|
||||
async def streamify(self, input_value: int) -> AsyncIterator[int]:
|
||||
for i in range(input_value):
|
||||
yield i
|
||||
"""
|
||||
|
||||
|
||||
class UnstreamifyAbsOperator(BaseOperator[OUT], Generic[IN, OUT]):
|
||||
"""An abstract operator that converts a value of AsyncIterator[IN] to an OUT."""
|
||||
|
||||
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
|
||||
output = await curr_task_ctx.task_input.parent_outputs[
|
||||
output: TaskOutput[OUT] = await curr_task_ctx.task_input.parent_outputs[
|
||||
0
|
||||
].task_output.unstreamify(self.unstreamify)
|
||||
curr_task_ctx.set_task_output(output)
|
||||
@ -56,24 +65,30 @@ class UnstreamifyAbsOperator(BaseOperator[OUT], Generic[IN, OUT]):
|
||||
input_value (AsyncIterator[IN])): The data of parent operator's output
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class MyUnstreamOperator(UnstreamifyAbsOperator[int, int]):
|
||||
async def unstreamify(self, input_value: AsyncIterator[int]) -> int:
|
||||
value_cnt = 0
|
||||
async for v in input_value:
|
||||
value_cnt += 1
|
||||
return value_cnt
|
||||
class MyUnstreamOperator(UnstreamifyAbsOperator[int, int]):
|
||||
async def unstreamify(self, input_value: AsyncIterator[int]) -> int:
|
||||
value_cnt = 0
|
||||
async for v in input_value:
|
||||
value_cnt += 1
|
||||
return value_cnt
|
||||
"""
|
||||
|
||||
|
||||
class TransformStreamAbsOperator(BaseOperator[OUT], Generic[IN, OUT]):
|
||||
"""Streaming to other streaming data.
|
||||
|
||||
An abstract operator that transforms a value of
|
||||
AsyncIterator[IN] to another AsyncIterator[OUT].
|
||||
"""
|
||||
|
||||
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
|
||||
output = await curr_task_ctx.task_input.parent_outputs[
|
||||
output: TaskOutput[OUT] = await curr_task_ctx.task_input.parent_outputs[
|
||||
0
|
||||
].task_output.transform_stream(self.transform_stream)
|
||||
|
||||
curr_task_ctx.set_task_output(output)
|
||||
return output
|
||||
|
||||
@ -81,19 +96,18 @@ class TransformStreamAbsOperator(BaseOperator[OUT], Generic[IN, OUT]):
|
||||
async def transform_stream(
|
||||
self, input_value: AsyncIterator[IN]
|
||||
) -> AsyncIterator[OUT]:
|
||||
"""Transform an AsyncIterator[IN] to another AsyncIterator[OUT] using a given function.
|
||||
"""Transform an AsyncIterator[IN] to another AsyncIterator[OUT].
|
||||
|
||||
Args:
|
||||
input_value (AsyncIterator[IN])): The data of parent operator's output
|
||||
|
||||
Example:
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class MyTransformStreamOperator(TransformStreamAbsOperator[int, int]):
|
||||
async def unstreamify(
|
||||
self, input_value: AsyncIterator[int]
|
||||
) -> AsyncIterator[int]:
|
||||
async for v in input_value:
|
||||
yield v + 1
|
||||
class MyTransformStreamOperator(TransformStreamAbsOperator[int, int]):
|
||||
async def unstreamify(
|
||||
self, input_value: AsyncIterator[int]
|
||||
) -> AsyncIterator[int]:
|
||||
async for v in input_value:
|
||||
yield v + 1
|
||||
"""
|
||||
|
@ -0,0 +1,4 @@
|
||||
"""The module of AWEL resource.
|
||||
|
||||
Not implemented yet.
|
||||
"""
|
@ -1,8 +1,15 @@
|
||||
"""Base class for resource group."""
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class ResourceGroup(ABC):
|
||||
"""Base class for resource group.
|
||||
|
||||
A resource group is a group of resources that are related to each other.
|
||||
It contains the all resources that are needed to run a workflow.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""The name of current resource group"""
|
||||
"""Return the name of current resource group."""
|
||||
|
@ -0,0 +1,4 @@
|
||||
"""The module to run AWEL operators.
|
||||
|
||||
You can implement your own runner by inheriting the `WorkflowRunner` class.
|
||||
"""
|
@ -1,33 +1,38 @@
|
||||
"""Job manager for DAG."""
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Dict, List, Optional, Set
|
||||
from typing import Dict, List, Optional, cast
|
||||
|
||||
from ..dag.base import DAG, DAGLifecycle
|
||||
from ..dag.base import DAGLifecycle
|
||||
from ..operator.base import CALL_DATA, BaseOperator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DAGNodeInstance:
|
||||
def __init__(self, node_instance: DAG) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class DAGInstance:
|
||||
def __init__(self, dag: DAG) -> None:
|
||||
self._dag = dag
|
||||
|
||||
|
||||
class JobManager(DAGLifecycle):
|
||||
"""Job manager for DAG.
|
||||
|
||||
This class is used to manage the DAG lifecycle.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_nodes: List[BaseOperator],
|
||||
all_nodes: List[BaseOperator],
|
||||
end_node: BaseOperator,
|
||||
id2call_data: Dict[str, Dict],
|
||||
id2call_data: Dict[str, Optional[Dict]],
|
||||
node_name_to_ids: Dict[str, str],
|
||||
) -> None:
|
||||
"""Create a job manager.
|
||||
|
||||
Args:
|
||||
root_nodes (List[BaseOperator]): The root nodes of the DAG.
|
||||
all_nodes (List[BaseOperator]): All nodes of the DAG.
|
||||
end_node (BaseOperator): The end node of the DAG.
|
||||
id2call_data (Dict[str, Optional[Dict]]): The call data of each node.
|
||||
node_name_to_ids (Dict[str, str]): The node name to node id mapping.
|
||||
"""
|
||||
self._root_nodes = root_nodes
|
||||
self._all_nodes = all_nodes
|
||||
self._end_node = end_node
|
||||
@ -38,6 +43,15 @@ class JobManager(DAGLifecycle):
|
||||
def build_from_end_node(
|
||||
end_node: BaseOperator, call_data: Optional[CALL_DATA] = None
|
||||
) -> "JobManager":
|
||||
"""Build a job manager from the end node.
|
||||
|
||||
This will get all upstream nodes from the end node, and build a job manager.
|
||||
|
||||
Args:
|
||||
end_node (BaseOperator): The end node of the DAG.
|
||||
call_data (Optional[CALL_DATA], optional): The call data of the end node.
|
||||
Defaults to None.
|
||||
"""
|
||||
nodes = _build_from_end_node(end_node)
|
||||
root_nodes = _get_root_nodes(nodes)
|
||||
id2call_data = _save_call_data(root_nodes, call_data)
|
||||
@ -50,17 +64,22 @@ class JobManager(DAGLifecycle):
|
||||
return JobManager(root_nodes, nodes, end_node, id2call_data, node_name_to_ids)
|
||||
|
||||
def get_call_data_by_id(self, node_id: str) -> Optional[Dict]:
|
||||
"""Get the call data by node id.
|
||||
|
||||
Args:
|
||||
node_id (str): The node id.
|
||||
"""
|
||||
return self._id2node_data.get(node_id)
|
||||
|
||||
async def before_dag_run(self):
|
||||
"""The callback before DAG run"""
|
||||
"""Execute the callback before DAG run."""
|
||||
tasks = []
|
||||
for node in self._all_nodes:
|
||||
tasks.append(node.before_dag_run())
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def after_dag_end(self):
|
||||
"""The callback after DAG end"""
|
||||
"""Execute the callback after DAG end."""
|
||||
tasks = []
|
||||
for node in self._all_nodes:
|
||||
tasks.append(node.after_dag_end())
|
||||
@ -68,9 +87,9 @@ class JobManager(DAGLifecycle):
|
||||
|
||||
|
||||
def _save_call_data(
|
||||
root_nodes: List[BaseOperator], call_data: CALL_DATA
|
||||
) -> Dict[str, Dict]:
|
||||
id2call_data = {}
|
||||
root_nodes: List[BaseOperator], call_data: Optional[CALL_DATA]
|
||||
) -> Dict[str, Optional[Dict]]:
|
||||
id2call_data: Dict[str, Optional[Dict]] = {}
|
||||
logger.debug(f"_save_call_data: {call_data}, root_nodes: {root_nodes}")
|
||||
if not call_data:
|
||||
return id2call_data
|
||||
@ -82,7 +101,8 @@ def _save_call_data(
|
||||
for node in root_nodes:
|
||||
node_id = node.node_id
|
||||
logger.debug(
|
||||
f"Save call data to node {node.node_id}, call_data: {call_data.get(node_id)}"
|
||||
f"Save call data to node {node.node_id}, call_data: "
|
||||
f"{call_data.get(node_id)}"
|
||||
)
|
||||
id2call_data[node_id] = call_data.get(node_id)
|
||||
return id2call_data
|
||||
@ -91,13 +111,11 @@ def _save_call_data(
|
||||
def _build_from_end_node(end_node: BaseOperator) -> List[BaseOperator]:
|
||||
"""Build all nodes from the end node."""
|
||||
nodes = []
|
||||
if isinstance(end_node, BaseOperator):
|
||||
task_id = end_node.node_id
|
||||
if not task_id:
|
||||
task_id = str(uuid.uuid4())
|
||||
end_node.set_node_id(task_id)
|
||||
if isinstance(end_node, BaseOperator) and not end_node._node_id:
|
||||
end_node.set_node_id(str(uuid.uuid4()))
|
||||
nodes.append(end_node)
|
||||
for node in end_node.upstream:
|
||||
node = cast(BaseOperator, node)
|
||||
nodes += _build_from_end_node(node)
|
||||
return nodes
|
||||
|
||||
|
@ -1,12 +1,16 @@
|
||||
"""Local runner for workflow.
|
||||
|
||||
This runner will run the workflow in the current process.
|
||||
"""
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Set
|
||||
from typing import Dict, List, Optional, Set, cast
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
|
||||
from ..dag.base import DAGContext, DAGVar
|
||||
from ..operator.base import CALL_DATA, BaseOperator, WorkflowRunner
|
||||
from ..operator.common_operator import BranchOperator, JoinOperator, TriggerOperator
|
||||
from ..task.base import TaskContext, TaskState
|
||||
from ..operator.common_operator import BranchOperator, JoinOperator
|
||||
from ..task.base import SKIP_DATA, TaskContext, TaskState
|
||||
from ..task.task_impl import DefaultInputContext, DefaultTaskContext, SimpleTaskOutput
|
||||
from .job_manager import JobManager
|
||||
|
||||
@ -14,6 +18,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DefaultWorkflowRunner(WorkflowRunner):
|
||||
"""The default workflow runner."""
|
||||
|
||||
async def execute_workflow(
|
||||
self,
|
||||
node: BaseOperator,
|
||||
@ -21,6 +27,17 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
||||
streaming_call: bool = False,
|
||||
exist_dag_ctx: Optional[DAGContext] = None,
|
||||
) -> DAGContext:
|
||||
"""Execute the workflow.
|
||||
|
||||
Args:
|
||||
node (BaseOperator): The end node of the workflow.
|
||||
call_data (Optional[CALL_DATA], optional): The call data of the end node.
|
||||
Defaults to None.
|
||||
streaming_call (bool, optional): Whether the call is streaming call.
|
||||
Defaults to False.
|
||||
exist_dag_ctx (Optional[DAGContext], optional): The exist DAG context.
|
||||
Defaults to None.
|
||||
"""
|
||||
# Save node output
|
||||
# dag = node.dag
|
||||
job_manager = JobManager.build_from_end_node(node, call_data)
|
||||
@ -37,8 +54,8 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
||||
)
|
||||
logger.info(f"Begin run workflow from end operator, id: {node.node_id}")
|
||||
logger.debug(f"Node id {node.node_id}, call_data: {call_data}")
|
||||
skip_node_ids = set()
|
||||
system_app: SystemApp = DAGVar.get_current_system_app()
|
||||
skip_node_ids: Set[str] = set()
|
||||
system_app: Optional[SystemApp] = DAGVar.get_current_system_app()
|
||||
|
||||
await job_manager.before_dag_run()
|
||||
await self._execute_node(
|
||||
@ -57,7 +74,7 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
||||
dag_ctx: DAGContext,
|
||||
node_outputs: Dict[str, TaskContext],
|
||||
skip_node_ids: Set[str],
|
||||
system_app: SystemApp,
|
||||
system_app: Optional[SystemApp],
|
||||
):
|
||||
# Skip run node
|
||||
if node.node_id in node_outputs:
|
||||
@ -79,8 +96,12 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
||||
node_outputs[upstream_node.node_id] for upstream_node in node.upstream
|
||||
]
|
||||
input_ctx = DefaultInputContext(inputs)
|
||||
task_ctx = DefaultTaskContext(node.node_id, TaskState.INIT, task_output=None)
|
||||
task_ctx.set_call_data(job_manager.get_call_data_by_id(node.node_id))
|
||||
task_ctx: DefaultTaskContext = DefaultTaskContext(
|
||||
node.node_id, TaskState.INIT, task_output=None
|
||||
)
|
||||
current_call_data = job_manager.get_call_data_by_id(node.node_id)
|
||||
if current_call_data:
|
||||
task_ctx.set_call_data(current_call_data)
|
||||
|
||||
task_ctx.set_task_input(input_ctx)
|
||||
dag_ctx.set_current_task_context(task_ctx)
|
||||
@ -88,12 +109,13 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
||||
|
||||
if node.node_id in skip_node_ids:
|
||||
task_ctx.set_current_state(TaskState.SKIP)
|
||||
task_ctx.set_task_output(SimpleTaskOutput(None))
|
||||
task_ctx.set_task_output(SimpleTaskOutput(SKIP_DATA))
|
||||
node_outputs[node.node_id] = task_ctx
|
||||
return
|
||||
try:
|
||||
logger.debug(
|
||||
f"Begin run operator, node id: {node.node_id}, node name: {node.node_name}, cls: {node}"
|
||||
f"Begin run operator, node id: {node.node_id}, node name: "
|
||||
f"{node.node_name}, cls: {node}"
|
||||
)
|
||||
if system_app is not None and node.system_app is None:
|
||||
node.set_system_app(system_app)
|
||||
@ -120,6 +142,7 @@ def _skip_current_downstream_by_node_name(
|
||||
if not skip_nodes:
|
||||
return
|
||||
for child in branch_node.downstream:
|
||||
child = cast(BaseOperator, child)
|
||||
if child.node_name in skip_nodes:
|
||||
logger.info(f"Skip node name {child.node_name}, node id {child.node_id}")
|
||||
_skip_downstream_by_id(child, skip_node_ids)
|
||||
@ -131,4 +154,5 @@ def _skip_downstream_by_id(node: BaseOperator, skip_node_ids: Set[str]):
|
||||
return
|
||||
skip_node_ids.add(node.node_id)
|
||||
for child in node.downstream:
|
||||
child = cast(BaseOperator, child)
|
||||
_skip_downstream_by_id(child, skip_node_ids)
|
||||
|
@ -0,0 +1 @@
|
||||
"""The module of Task."""
|
@ -1,8 +1,10 @@
|
||||
"""Base classes for task-related objects."""
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
Generic,
|
||||
@ -17,6 +19,24 @@ OUT = TypeVar("OUT")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class _EMPTY_DATA_TYPE:
|
||||
def __bool__(self):
|
||||
return False
|
||||
|
||||
|
||||
EMPTY_DATA = _EMPTY_DATA_TYPE()
|
||||
SKIP_DATA = _EMPTY_DATA_TYPE()
|
||||
PLACEHOLDER_DATA = _EMPTY_DATA_TYPE()
|
||||
|
||||
MapFunc = Union[Callable[[IN], OUT], Callable[[IN], Awaitable[OUT]]]
|
||||
ReduceFunc = Union[Callable[[IN], OUT], Callable[[IN], Awaitable[OUT]]]
|
||||
StreamFunc = Callable[[IN], Awaitable[AsyncIterator[OUT]]]
|
||||
UnStreamFunc = Callable[[AsyncIterator[IN]], OUT]
|
||||
TransformFunc = Callable[[AsyncIterator[IN]], Awaitable[AsyncIterator[OUT]]]
|
||||
PredicateFunc = Union[Callable[[IN], bool], Callable[[IN], Awaitable[bool]]]
|
||||
JoinFunc = Union[Callable[..., OUT], Callable[..., Awaitable[OUT]]]
|
||||
|
||||
|
||||
class TaskState(str, Enum):
|
||||
"""Enumeration representing the state of a task in the workflow.
|
||||
|
||||
@ -33,8 +53,8 @@ class TaskState(str, Enum):
|
||||
class TaskOutput(ABC, Generic[T]):
|
||||
"""Abstract base class representing the output of a task.
|
||||
|
||||
This class encapsulates the output of a task and provides methods to access the output data.
|
||||
It can be subclassed to implement specific output behaviors.
|
||||
This class encapsulates the output of a task and provides methods to access the
|
||||
output data.It can be subclassed to implement specific output behaviors.
|
||||
"""
|
||||
|
||||
@property
|
||||
@ -56,20 +76,30 @@ class TaskOutput(ABC, Generic[T]):
|
||||
return False
|
||||
|
||||
@property
|
||||
def output(self) -> Optional[T]:
|
||||
def is_none(self) -> bool:
|
||||
"""Check if the output is None.
|
||||
|
||||
Returns:
|
||||
bool: True if the output is None, False otherwise.
|
||||
"""
|
||||
return False
|
||||
|
||||
@property
|
||||
def output(self) -> T:
|
||||
"""Return the output of the task.
|
||||
|
||||
Returns:
|
||||
T: The output of the task. None if the output is empty.
|
||||
T: The output of the task.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def output_stream(self) -> Optional[AsyncIterator[T]]:
|
||||
def output_stream(self) -> AsyncIterator[T]:
|
||||
"""Return the output of the task as an asynchronous stream.
|
||||
|
||||
Returns:
|
||||
AsyncIterator[T]: An asynchronous iterator over the output. None if the output is empty.
|
||||
AsyncIterator[T]: An asynchronous iterator over the output. None if the
|
||||
output is empty.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -83,39 +113,38 @@ class TaskOutput(ABC, Generic[T]):
|
||||
|
||||
@abstractmethod
|
||||
def new_output(self) -> "TaskOutput[T]":
|
||||
"""Create new output object"""
|
||||
"""Create new output object."""
|
||||
|
||||
async def map(self, map_func) -> "TaskOutput[T]":
|
||||
async def map(self, map_func: MapFunc) -> "TaskOutput[OUT]":
|
||||
"""Apply a mapping function to the task's output.
|
||||
|
||||
Args:
|
||||
map_func: A function to apply to the task's output.
|
||||
map_func (MapFunc): A function to apply to the task's output.
|
||||
|
||||
Returns:
|
||||
TaskOutput[T]: The result of applying the mapping function.
|
||||
TaskOutput[OUT]: The result of applying the mapping function.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def reduce(self, reduce_func) -> "TaskOutput[T]":
|
||||
async def reduce(self, reduce_func: ReduceFunc) -> "TaskOutput[OUT]":
|
||||
"""Apply a reducing function to the task's output.
|
||||
|
||||
Stream TaskOutput to Nonstream TaskOutput.
|
||||
Stream TaskOutput to no stream TaskOutput.
|
||||
|
||||
Args:
|
||||
reduce_func: A reducing function to apply to the task's output.
|
||||
|
||||
Returns:
|
||||
TaskOutput[T]: The result of applying the reducing function.
|
||||
TaskOutput[OUT]: The result of applying the reducing function.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def streamify(
|
||||
self, transform_func: Callable[[T], AsyncIterator[T]]
|
||||
) -> "TaskOutput[T]":
|
||||
async def streamify(self, transform_func: StreamFunc) -> "TaskOutput[T]":
|
||||
"""Convert a value of type T to an AsyncIterator[T] using a transform function.
|
||||
|
||||
Args:
|
||||
transform_func (Callable[[T], AsyncIterator[T]]): Function to transform a T value into an AsyncIterator[T].
|
||||
transform_func (StreamFunc): Function to transform a T value into an
|
||||
AsyncIterator[OUT].
|
||||
|
||||
Returns:
|
||||
TaskOutput[T]: The result of applying the reducing function.
|
||||
@ -123,38 +152,39 @@ class TaskOutput(ABC, Generic[T]):
|
||||
raise NotImplementedError
|
||||
|
||||
async def transform_stream(
|
||||
self, transform_func: Callable[[AsyncIterator[T]], AsyncIterator[T]]
|
||||
) -> "TaskOutput[T]":
|
||||
"""Transform an AsyncIterator[T] to another AsyncIterator[T] using a given function.
|
||||
self, transform_func: TransformFunc
|
||||
) -> "TaskOutput[OUT]":
|
||||
"""Transform an AsyncIterator[T] to another AsyncIterator[T].
|
||||
|
||||
Args:
|
||||
transform_func (Callable[[AsyncIterator[T]], AsyncIterator[T]]): Function to apply to the AsyncIterator[T].
|
||||
transform_func (Callable[[AsyncIterator[T]], AsyncIterator[T]]): Function to
|
||||
apply to the AsyncIterator[T].
|
||||
|
||||
Returns:
|
||||
TaskOutput[T]: The result of applying the reducing function.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def unstreamify(
|
||||
self, transform_func: Callable[[AsyncIterator[T]], T]
|
||||
) -> "TaskOutput[T]":
|
||||
async def unstreamify(self, transform_func: UnStreamFunc) -> "TaskOutput[OUT]":
|
||||
"""Convert an AsyncIterator[T] to a value of type T using a transform function.
|
||||
|
||||
Args:
|
||||
transform_func (Callable[[AsyncIterator[T]], T]): Function to transform an AsyncIterator[T] into a T value.
|
||||
transform_func (UnStreamFunc): Function to transform an AsyncIterator[T]
|
||||
into a T value.
|
||||
|
||||
Returns:
|
||||
TaskOutput[T]: The result of applying the reducing function.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def check_condition(self, condition_func) -> bool:
|
||||
async def check_condition(self, condition_func) -> "TaskOutput[OUT]":
|
||||
"""Check if current output meets a given condition.
|
||||
|
||||
Args:
|
||||
condition_func: A function to determine if the condition is met.
|
||||
Returns:
|
||||
bool: True if current output meet the condition, False otherwise.
|
||||
TaskOutput[T]: The result of applying the reducing function.
|
||||
If the condition is not met, return empty output.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -182,6 +212,9 @@ class TaskContext(ABC, Generic[T]):
|
||||
|
||||
Returns:
|
||||
InputContext: The InputContext of current task.
|
||||
|
||||
Raises:
|
||||
Exception: If the InputContext is not set.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
@ -216,7 +249,7 @@ class TaskContext(ABC, Generic[T]):
|
||||
|
||||
@abstractmethod
|
||||
def set_current_state(self, task_state: TaskState) -> None:
|
||||
"""Set current task state
|
||||
"""Set current task state.
|
||||
|
||||
Args:
|
||||
task_state (TaskState): The task state to be set.
|
||||
@ -224,7 +257,7 @@ class TaskContext(ABC, Generic[T]):
|
||||
|
||||
@abstractmethod
|
||||
def new_ctx(self) -> "TaskContext":
|
||||
"""Create new task context
|
||||
"""Create new task context.
|
||||
|
||||
Returns:
|
||||
TaskContext: A new instance of a TaskContext.
|
||||
@ -233,14 +266,14 @@ class TaskContext(ABC, Generic[T]):
|
||||
@property
|
||||
@abstractmethod
|
||||
def metadata(self) -> Dict[str, Any]:
|
||||
"""Get the metadata of current task
|
||||
"""Return the metadata of current task.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The metadata
|
||||
"""
|
||||
|
||||
def update_metadata(self, key: str, value: Any) -> None:
|
||||
"""Update metadata with key and value
|
||||
"""Update metadata with key and value.
|
||||
|
||||
Args:
|
||||
key (str): The key of metadata
|
||||
@ -250,15 +283,15 @@ class TaskContext(ABC, Generic[T]):
|
||||
|
||||
@property
|
||||
def call_data(self) -> Optional[Dict]:
|
||||
"""Get the call data for current data"""
|
||||
"""Return the call data for current data."""
|
||||
return self.metadata.get("call_data")
|
||||
|
||||
@abstractmethod
|
||||
async def _call_data_to_output(self) -> Optional[TaskOutput[T]]:
|
||||
"""Get the call data for current data"""
|
||||
"""Get the call data for current data."""
|
||||
|
||||
def set_call_data(self, call_data: Dict) -> None:
|
||||
"""Set call data for current task"""
|
||||
"""Save the call data for current task."""
|
||||
self.update_metadata("call_data", call_data)
|
||||
|
||||
|
||||
@ -315,7 +348,8 @@ class InputContext(ABC):
|
||||
"""Filter the inputs based on a provided function.
|
||||
|
||||
Args:
|
||||
filter_func (Callable[[Any], bool]): A function that returns True for inputs to keep.
|
||||
filter_func (Callable[[Any], bool]): A function that returns True for
|
||||
inputs to keep.
|
||||
|
||||
Returns:
|
||||
InputContext: A new InputContext instance with the filtered inputs.
|
||||
@ -323,13 +357,15 @@ class InputContext(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def predicate_map(
|
||||
self, predicate_func: Callable[[Any], bool], failed_value: Any = None
|
||||
self, predicate_func: PredicateFunc, failed_value: Any = None
|
||||
) -> "InputContext":
|
||||
"""Predicate the inputs based on a provided function.
|
||||
|
||||
Args:
|
||||
predicate_func (Callable[[Any], bool]): A function that returns True for inputs is predicate True.
|
||||
failed_value (Any): The value to be set if the return value of predicate function is False
|
||||
predicate_func (Callable[[Any], bool]): A function that returns True for
|
||||
inputs is predicate True.
|
||||
failed_value (Any): The value to be set if the return value of predicate
|
||||
function is False
|
||||
Returns:
|
||||
InputContext: A new InputContext instance with the predicate inputs.
|
||||
"""
|
||||
|
@ -1,3 +1,7 @@
|
||||
"""The default implementation of Task.
|
||||
|
||||
This implementation can run workflow in local machine.
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
@ -8,15 +12,32 @@ from typing import (
|
||||
Coroutine,
|
||||
Dict,
|
||||
Generic,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from .base import InputContext, InputSource, T, TaskContext, TaskOutput, TaskState
|
||||
from .base import (
|
||||
_EMPTY_DATA_TYPE,
|
||||
EMPTY_DATA,
|
||||
OUT,
|
||||
PLACEHOLDER_DATA,
|
||||
SKIP_DATA,
|
||||
InputContext,
|
||||
InputSource,
|
||||
MapFunc,
|
||||
PredicateFunc,
|
||||
ReduceFunc,
|
||||
StreamFunc,
|
||||
T,
|
||||
TaskContext,
|
||||
TaskOutput,
|
||||
TaskState,
|
||||
TransformFunc,
|
||||
UnStreamFunc,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -37,101 +58,197 @@ async def _reduce_stream(stream: AsyncIterator, reduce_function) -> Any:
|
||||
|
||||
|
||||
class SimpleTaskOutput(TaskOutput[T], Generic[T]):
|
||||
def __init__(self, data: T) -> None:
|
||||
"""The default implementation of TaskOutput.
|
||||
|
||||
It wraps the no stream data and provide some basic data operations.
|
||||
"""
|
||||
|
||||
def __init__(self, data: Union[T, _EMPTY_DATA_TYPE] = EMPTY_DATA) -> None:
|
||||
"""Create a SimpleTaskOutput.
|
||||
|
||||
Args:
|
||||
data (Union[T, _EMPTY_DATA_TYPE], optional): The output data. Defaults to
|
||||
EMPTY_DATA.
|
||||
"""
|
||||
super().__init__()
|
||||
self._data = data
|
||||
|
||||
@property
|
||||
def output(self) -> T:
|
||||
return self._data
|
||||
"""Return the output data."""
|
||||
if self._data == EMPTY_DATA:
|
||||
raise ValueError("No output data for current task output")
|
||||
return cast(T, self._data)
|
||||
|
||||
def set_output(self, output_data: T | AsyncIterator[T]) -> None:
|
||||
self._data = output_data
|
||||
"""Save the output data to current object.
|
||||
|
||||
Args:
|
||||
output_data (T | AsyncIterator[T]): The output data.
|
||||
"""
|
||||
if _is_async_iterator(output_data):
|
||||
raise ValueError(
|
||||
f"Can not set stream data {output_data} to SimpleTaskOutput"
|
||||
)
|
||||
self._data = cast(T, output_data)
|
||||
|
||||
def new_output(self) -> TaskOutput[T]:
|
||||
return SimpleTaskOutput(None)
|
||||
"""Create new output object with empty data."""
|
||||
return SimpleTaskOutput()
|
||||
|
||||
@property
|
||||
def is_empty(self) -> bool:
|
||||
"""Return True if the output data is empty."""
|
||||
return self._data == EMPTY_DATA or self._data == SKIP_DATA
|
||||
|
||||
@property
|
||||
def is_none(self) -> bool:
|
||||
"""Return True if the output data is None."""
|
||||
return self._data is None
|
||||
|
||||
async def _apply_func(self, func) -> Any:
|
||||
"""Apply the function to current output data."""
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
out = await func(self._data)
|
||||
else:
|
||||
out = func(self._data)
|
||||
return out
|
||||
|
||||
async def map(self, map_func) -> TaskOutput[T]:
|
||||
async def map(self, map_func: MapFunc) -> TaskOutput[OUT]:
|
||||
"""Apply a mapping function to the task's output.
|
||||
|
||||
Args:
|
||||
map_func (MapFunc): A function to apply to the task's output.
|
||||
|
||||
Returns:
|
||||
TaskOutput[OUT]: The result of applying the mapping function.
|
||||
"""
|
||||
out = await self._apply_func(map_func)
|
||||
return SimpleTaskOutput(out)
|
||||
|
||||
async def check_condition(self, condition_func) -> bool:
|
||||
return await self._apply_func(condition_func)
|
||||
async def check_condition(self, condition_func) -> TaskOutput[OUT]:
|
||||
"""Check the condition function."""
|
||||
out = await self._apply_func(condition_func)
|
||||
if out:
|
||||
return SimpleTaskOutput(PLACEHOLDER_DATA)
|
||||
return SimpleTaskOutput(EMPTY_DATA)
|
||||
|
||||
async def streamify(
|
||||
self, transform_func: Callable[[T], AsyncIterator[T]]
|
||||
) -> TaskOutput[T]:
|
||||
async def streamify(self, transform_func: StreamFunc) -> TaskOutput[OUT]:
|
||||
"""Transform the task's output to a stream output.
|
||||
|
||||
Args:
|
||||
transform_func (StreamFunc): A function to transform the task's output to a
|
||||
stream output.
|
||||
|
||||
Returns:
|
||||
TaskOutput[OUT]: The result of transforming the task's output to a stream
|
||||
output.
|
||||
"""
|
||||
out = await self._apply_func(transform_func)
|
||||
return SimpleStreamTaskOutput(out)
|
||||
|
||||
|
||||
class SimpleStreamTaskOutput(TaskOutput[T], Generic[T]):
|
||||
def __init__(self, data: AsyncIterator[T]) -> None:
|
||||
"""The default stream implementation of TaskOutput."""
|
||||
|
||||
def __init__(
|
||||
self, data: Union[AsyncIterator[T], _EMPTY_DATA_TYPE] = EMPTY_DATA
|
||||
) -> None:
|
||||
"""Create a SimpleStreamTaskOutput.
|
||||
|
||||
Args:
|
||||
data (Union[AsyncIterator[T], _EMPTY_DATA_TYPE], optional): The output data.
|
||||
Defaults to EMPTY_DATA.
|
||||
"""
|
||||
super().__init__()
|
||||
self._data = data
|
||||
|
||||
@property
|
||||
def is_stream(self) -> bool:
|
||||
"""Return True if the output data is a stream."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_empty(self) -> bool:
|
||||
return not self._data
|
||||
"""Return True if the output data is empty."""
|
||||
return self._data == EMPTY_DATA or self._data == SKIP_DATA
|
||||
|
||||
@property
|
||||
def is_none(self) -> bool:
|
||||
"""Return True if the output data is None."""
|
||||
return self._data is None
|
||||
|
||||
@property
|
||||
def output_stream(self) -> AsyncIterator[T]:
|
||||
return self._data
|
||||
"""Return the output data.
|
||||
|
||||
Returns:
|
||||
AsyncIterator[T]: The output data.
|
||||
|
||||
Raises:
|
||||
ValueError: If the output data is empty.
|
||||
"""
|
||||
if self._data == EMPTY_DATA:
|
||||
raise ValueError("No output data for current task output")
|
||||
return cast(AsyncIterator[T], self._data)
|
||||
|
||||
def set_output(self, output_data: T | AsyncIterator[T]) -> None:
|
||||
self._data = output_data
|
||||
"""Save the output data to current object.
|
||||
|
||||
Raises:
|
||||
ValueError: If the output data is not a stream.
|
||||
"""
|
||||
if not _is_async_iterator(output_data):
|
||||
raise ValueError(
|
||||
f"Can not set non-stream data {output_data} to SimpleStreamTaskOutput"
|
||||
)
|
||||
self._data = cast(AsyncIterator[T], output_data)
|
||||
|
||||
def new_output(self) -> TaskOutput[T]:
|
||||
return SimpleStreamTaskOutput(None)
|
||||
"""Create new output object with empty data."""
|
||||
return SimpleStreamTaskOutput()
|
||||
|
||||
async def map(self, map_func) -> TaskOutput[T]:
|
||||
async def map(self, map_func: MapFunc) -> TaskOutput[OUT]:
|
||||
"""Apply a mapping function to the task's output."""
|
||||
is_async = asyncio.iscoroutinefunction(map_func)
|
||||
|
||||
async def new_iter() -> AsyncIterator[T]:
|
||||
async for out in self._data:
|
||||
async def new_iter() -> AsyncIterator[OUT]:
|
||||
async for out in self.output_stream:
|
||||
if is_async:
|
||||
out = await map_func(out)
|
||||
new_out: OUT = await map_func(out)
|
||||
else:
|
||||
out = map_func(out)
|
||||
yield out
|
||||
new_out = cast(OUT, map_func(out))
|
||||
yield new_out
|
||||
|
||||
return SimpleStreamTaskOutput(new_iter())
|
||||
|
||||
async def reduce(self, reduce_func) -> TaskOutput[T]:
|
||||
out = await _reduce_stream(self._data, reduce_func)
|
||||
async def reduce(self, reduce_func: ReduceFunc) -> TaskOutput[OUT]:
|
||||
"""Apply a reduce function to the task's output."""
|
||||
out = await _reduce_stream(self.output_stream, reduce_func)
|
||||
return SimpleTaskOutput(out)
|
||||
|
||||
async def unstreamify(
|
||||
self, transform_func: Callable[[AsyncIterator[T]], T]
|
||||
) -> TaskOutput[T]:
|
||||
async def unstreamify(self, transform_func: UnStreamFunc) -> TaskOutput[OUT]:
|
||||
"""Transform the task's output to a non-stream output."""
|
||||
if asyncio.iscoroutinefunction(transform_func):
|
||||
out = await transform_func(self._data)
|
||||
out = await transform_func(self.output_stream)
|
||||
else:
|
||||
out = transform_func(self._data)
|
||||
out = transform_func(self.output_stream)
|
||||
return SimpleTaskOutput(out)
|
||||
|
||||
async def transform_stream(
|
||||
self, transform_func: Callable[[AsyncIterator[T]], AsyncIterator[T]]
|
||||
) -> TaskOutput[T]:
|
||||
async def transform_stream(self, transform_func: TransformFunc) -> TaskOutput[OUT]:
|
||||
"""Transform an AsyncIterator[T] to another AsyncIterator[T].
|
||||
|
||||
Args:
|
||||
transform_func (Callable[[AsyncIterator[T]], AsyncIterator[T]]): Function to
|
||||
apply to the AsyncIterator[T].
|
||||
|
||||
Returns:
|
||||
TaskOutput[T]: The result of applying the reducing function.
|
||||
"""
|
||||
if asyncio.iscoroutinefunction(transform_func):
|
||||
out = await transform_func(self._data)
|
||||
out: AsyncIterator[OUT] = await transform_func(self.output_stream)
|
||||
else:
|
||||
out = transform_func(self._data)
|
||||
out = cast(AsyncIterator[OUT], transform_func(self.output_stream))
|
||||
return SimpleStreamTaskOutput(out)
|
||||
|
||||
|
||||
@ -145,20 +262,34 @@ def _is_async_iterator(obj):
|
||||
|
||||
|
||||
class BaseInputSource(InputSource, ABC):
|
||||
"""The base class of InputSource."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Create a BaseInputSource."""
|
||||
super().__init__()
|
||||
self._is_read = False
|
||||
|
||||
@abstractmethod
|
||||
def _read_data(self, task_ctx: TaskContext) -> Any:
|
||||
"""Read data with task context"""
|
||||
"""Return data with task context."""
|
||||
|
||||
async def read(self, task_ctx: TaskContext) -> TaskOutput:
|
||||
"""Read data with task context.
|
||||
|
||||
Args:
|
||||
task_ctx (TaskContext): The task context.
|
||||
|
||||
Returns:
|
||||
TaskOutput: The task output.
|
||||
|
||||
Raises:
|
||||
ValueError: If the input source is a stream and has been read.
|
||||
"""
|
||||
data = self._read_data(task_ctx)
|
||||
if _is_async_iterator(data):
|
||||
if self._is_read:
|
||||
raise ValueError(f"Input iterator {data} has been read!")
|
||||
output = SimpleStreamTaskOutput(data)
|
||||
output: TaskOutput = SimpleStreamTaskOutput(data)
|
||||
else:
|
||||
output = SimpleTaskOutput(data)
|
||||
self._is_read = True
|
||||
@ -166,7 +297,14 @@ class BaseInputSource(InputSource, ABC):
|
||||
|
||||
|
||||
class SimpleInputSource(BaseInputSource):
|
||||
"""The default implementation of InputSource."""
|
||||
|
||||
def __init__(self, data: Any) -> None:
|
||||
"""Create a SimpleInputSource.
|
||||
|
||||
Args:
|
||||
data (Any): The input data.
|
||||
"""
|
||||
super().__init__()
|
||||
self._data = data
|
||||
|
||||
@ -175,63 +313,121 @@ class SimpleInputSource(BaseInputSource):
|
||||
|
||||
|
||||
class SimpleCallDataInputSource(BaseInputSource):
|
||||
"""The implementation of InputSource for call data."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Create a SimpleCallDataInputSource."""
|
||||
super().__init__()
|
||||
|
||||
def _read_data(self, task_ctx: TaskContext) -> Any:
|
||||
"""Read data from task context.
|
||||
|
||||
Returns:
|
||||
Any: The data.
|
||||
|
||||
Raises:
|
||||
ValueError: If the call data is empty.
|
||||
"""
|
||||
call_data = task_ctx.call_data
|
||||
data = call_data.get("data") if call_data else None
|
||||
if not (call_data and data):
|
||||
data = call_data.get("data", EMPTY_DATA) if call_data else EMPTY_DATA
|
||||
if data == EMPTY_DATA:
|
||||
raise ValueError("No call data for current SimpleCallDataInputSource")
|
||||
return data
|
||||
|
||||
|
||||
class DefaultTaskContext(TaskContext, Generic[T]):
|
||||
"""The default implementation of TaskContext."""
|
||||
|
||||
def __init__(
|
||||
self, task_id: str, task_state: TaskState, task_output: TaskOutput[T]
|
||||
self,
|
||||
task_id: str,
|
||||
task_state: TaskState,
|
||||
task_output: Optional[TaskOutput[T]] = None,
|
||||
) -> None:
|
||||
"""Create a DefaultTaskContext.
|
||||
|
||||
Args:
|
||||
task_id (str): The task id.
|
||||
task_state (TaskState): The task state.
|
||||
task_output (Optional[TaskOutput[T]], optional): The task output. Defaults
|
||||
to None.
|
||||
"""
|
||||
super().__init__()
|
||||
self._task_id = task_id
|
||||
self._task_state = task_state
|
||||
self._output = task_output
|
||||
self._task_input = None
|
||||
self._metadata = {}
|
||||
self._output: Optional[TaskOutput[T]] = task_output
|
||||
self._task_input: Optional[InputContext] = None
|
||||
self._metadata: Dict[str, Any] = {}
|
||||
|
||||
@property
|
||||
def task_id(self) -> str:
|
||||
"""Return the task id."""
|
||||
return self._task_id
|
||||
|
||||
@property
|
||||
def task_input(self) -> InputContext:
|
||||
"""Return the task input."""
|
||||
if not self._task_input:
|
||||
raise ValueError("No input for current task context")
|
||||
return self._task_input
|
||||
|
||||
def set_task_input(self, input_ctx: "InputContext") -> None:
|
||||
def set_task_input(self, input_ctx: InputContext) -> None:
|
||||
"""Save the task input to current task."""
|
||||
self._task_input = input_ctx
|
||||
|
||||
@property
|
||||
def task_output(self) -> TaskOutput:
|
||||
"""Return the task output.
|
||||
|
||||
Returns:
|
||||
TaskOutput: The task output.
|
||||
|
||||
Raises:
|
||||
ValueError: If the task output is empty.
|
||||
"""
|
||||
if not self._output:
|
||||
raise ValueError("No output for current task context")
|
||||
return self._output
|
||||
|
||||
def set_task_output(self, task_output: TaskOutput) -> None:
|
||||
"""Save the task output to current task.
|
||||
|
||||
Args:
|
||||
task_output (TaskOutput): The task output.
|
||||
"""
|
||||
self._output = task_output
|
||||
|
||||
@property
|
||||
def current_state(self) -> TaskState:
|
||||
"""Return the current task state."""
|
||||
return self._task_state
|
||||
|
||||
def set_current_state(self, task_state: TaskState) -> None:
|
||||
"""Save the current task state to current task."""
|
||||
self._task_state = task_state
|
||||
|
||||
def new_ctx(self) -> TaskContext:
|
||||
"""Create new task context with empty output."""
|
||||
if not self._output:
|
||||
raise ValueError("No output for current task context")
|
||||
new_output = self._output.new_output()
|
||||
return DefaultTaskContext(self._task_id, self._task_state, new_output)
|
||||
|
||||
@property
|
||||
def metadata(self) -> Dict[str, Any]:
|
||||
"""Return the metadata of current task.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The metadata.
|
||||
"""
|
||||
return self._metadata
|
||||
|
||||
async def _call_data_to_output(self) -> Optional[TaskOutput[T]]:
|
||||
"""Get the call data for current data"""
|
||||
"""Return the call data of current task.
|
||||
|
||||
Returns:
|
||||
Optional[TaskOutput[T]]: The call data.
|
||||
"""
|
||||
call_data = self.call_data
|
||||
if not call_data:
|
||||
return None
|
||||
@ -240,24 +436,48 @@ class DefaultTaskContext(TaskContext, Generic[T]):
|
||||
|
||||
|
||||
class DefaultInputContext(InputContext):
|
||||
"""The default implementation of InputContext.
|
||||
|
||||
It wraps the all inputs from parent tasks and provide some basic data operations.
|
||||
"""
|
||||
|
||||
def __init__(self, outputs: List[TaskContext]) -> None:
|
||||
"""Create a DefaultInputContext.
|
||||
|
||||
Args:
|
||||
outputs (List[TaskContext]): The outputs from parent tasks.
|
||||
"""
|
||||
super().__init__()
|
||||
self._outputs = outputs
|
||||
|
||||
@property
|
||||
def parent_outputs(self) -> List[TaskContext]:
|
||||
"""Return the outputs from parent tasks.
|
||||
|
||||
Returns:
|
||||
List[TaskContext]: The outputs from parent tasks.
|
||||
"""
|
||||
return self._outputs
|
||||
|
||||
async def _apply_func(
|
||||
self, func: Callable[[Any], Any], apply_type: str = "map"
|
||||
) -> Tuple[List[TaskContext], List[TaskOutput]]:
|
||||
"""Apply the function to all parent outputs.
|
||||
|
||||
Args:
|
||||
func (Callable[[Any], Any]): The function to apply.
|
||||
apply_type (str, optional): The apply type. Defaults to "map".
|
||||
|
||||
Returns:
|
||||
Tuple[List[TaskContext], List[TaskOutput]]: The new parent outputs and the
|
||||
results of applying the function.
|
||||
"""
|
||||
new_outputs: List[TaskContext] = []
|
||||
map_tasks = []
|
||||
for out in self._outputs:
|
||||
new_outputs.append(out.new_ctx())
|
||||
result = None
|
||||
if apply_type == "map":
|
||||
result = out.task_output.map(func)
|
||||
result: Coroutine[Any, Any, TaskOutput[Any]] = out.task_output.map(func)
|
||||
elif apply_type == "reduce":
|
||||
result = out.task_output.reduce(func)
|
||||
elif apply_type == "check_condition":
|
||||
@ -269,29 +489,40 @@ class DefaultInputContext(InputContext):
|
||||
return new_outputs, results
|
||||
|
||||
async def map(self, map_func: Callable[[Any], Any]) -> InputContext:
|
||||
"""Apply a mapping function to all parent outputs."""
|
||||
new_outputs, results = await self._apply_func(map_func)
|
||||
for i, task_ctx in enumerate(new_outputs):
|
||||
task_ctx: TaskContext = task_ctx
|
||||
task_ctx = cast(TaskContext, task_ctx)
|
||||
task_ctx.set_task_output(results[i])
|
||||
return DefaultInputContext(new_outputs)
|
||||
|
||||
async def map_all(self, map_func: Callable[..., Any]) -> InputContext:
|
||||
"""Apply a mapping function to all parent outputs.
|
||||
|
||||
The parent outputs will be unpacked and passed to the mapping function.
|
||||
|
||||
Args:
|
||||
map_func (Callable[..., Any]): The mapping function.
|
||||
|
||||
Returns:
|
||||
InputContext: The new input context.
|
||||
"""
|
||||
if not self._outputs:
|
||||
return DefaultInputContext([])
|
||||
# Some parent may be empty
|
||||
not_empty_idx = 0
|
||||
for i, p in enumerate(self._outputs):
|
||||
if p.task_output.is_empty:
|
||||
# Skip empty parent
|
||||
continue
|
||||
not_empty_idx = i
|
||||
break
|
||||
# All output is empty?
|
||||
is_steam = self._outputs[not_empty_idx].task_output.is_stream
|
||||
if is_steam:
|
||||
if not self.check_stream(skip_empty=True):
|
||||
raise ValueError(
|
||||
"The output in all tasks must has same output format to map_all"
|
||||
)
|
||||
if is_steam and not self.check_stream(skip_empty=True):
|
||||
raise ValueError(
|
||||
"The output in all tasks must has same output format to map_all"
|
||||
)
|
||||
outputs = []
|
||||
for out in self._outputs:
|
||||
if out.task_output.is_stream:
|
||||
@ -305,22 +536,26 @@ class DefaultInputContext(InputContext):
|
||||
single_output: TaskContext = self._outputs[not_empty_idx].new_ctx()
|
||||
single_output.task_output.set_output(map_res)
|
||||
logger.debug(
|
||||
f"Current map_all map_res: {map_res}, is steam: {single_output.task_output.is_stream}"
|
||||
f"Current map_all map_res: {map_res}, is steam: "
|
||||
f"{single_output.task_output.is_stream}"
|
||||
)
|
||||
return DefaultInputContext([single_output])
|
||||
|
||||
async def reduce(self, reduce_func: Callable[[Any], Any]) -> InputContext:
|
||||
"""Apply a reduce function to all parent outputs."""
|
||||
if not self.check_stream():
|
||||
raise ValueError(
|
||||
"The output in all tasks must has same output format of stream to apply reduce function"
|
||||
"The output in all tasks must has same output format of stream to apply"
|
||||
" reduce function"
|
||||
)
|
||||
new_outputs, results = await self._apply_func(reduce_func, apply_type="reduce")
|
||||
for i, task_ctx in enumerate(new_outputs):
|
||||
task_ctx: TaskContext = task_ctx
|
||||
task_ctx = cast(TaskContext, task_ctx)
|
||||
task_ctx.set_task_output(results[i])
|
||||
return DefaultInputContext(new_outputs)
|
||||
|
||||
async def filter(self, filter_func: Callable[[Any], bool]) -> InputContext:
|
||||
"""Filter all parent outputs."""
|
||||
new_outputs, results = await self._apply_func(
|
||||
filter_func, apply_type="check_condition"
|
||||
)
|
||||
@ -331,15 +566,16 @@ class DefaultInputContext(InputContext):
|
||||
return DefaultInputContext(result_outputs)
|
||||
|
||||
async def predicate_map(
|
||||
self, predicate_func: Callable[[Any], bool], failed_value: Any = None
|
||||
self, predicate_func: PredicateFunc, failed_value: Any = None
|
||||
) -> "InputContext":
|
||||
"""Apply a predicate function to all parent outputs."""
|
||||
new_outputs, results = await self._apply_func(
|
||||
predicate_func, apply_type="check_condition"
|
||||
)
|
||||
result_outputs = []
|
||||
for i, task_ctx in enumerate(new_outputs):
|
||||
task_ctx: TaskContext = task_ctx
|
||||
if results[i]:
|
||||
task_ctx = cast(TaskContext, task_ctx)
|
||||
if not results[i].is_empty:
|
||||
task_ctx.task_output.set_output(True)
|
||||
result_outputs.append(task_ctx)
|
||||
else:
|
||||
|
@ -66,10 +66,10 @@ async def _create_input_node(**kwargs):
|
||||
else:
|
||||
outputs = kwargs.get("outputs", ["Hello."])
|
||||
nodes = []
|
||||
for output in outputs:
|
||||
for i, output in enumerate(outputs):
|
||||
print(f"output: {output}")
|
||||
input_source = SimpleInputSource(output)
|
||||
input_node = InputOperator(input_source)
|
||||
input_node = InputOperator(input_source, task_id="input_node_" + str(i))
|
||||
nodes.append(input_node)
|
||||
yield nodes
|
||||
|
||||
|
@ -26,7 +26,7 @@ from .conftest import (
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_input_node(runner: WorkflowRunner):
|
||||
input_node = InputOperator(SimpleInputSource("hello"))
|
||||
input_node = InputOperator(SimpleInputSource("hello"), task_id="112232")
|
||||
res: DAGContext[str] = await runner.execute_workflow(input_node)
|
||||
assert res.current_task_context.current_state == TaskState.SUCCESS
|
||||
assert res.current_task_context.task_output.output == "hello"
|
||||
@ -36,7 +36,9 @@ async def test_input_node(runner: WorkflowRunner):
|
||||
yield i
|
||||
|
||||
num_iter = 10
|
||||
steam_input_node = InputOperator(SimpleInputSource(new_steam_iter(num_iter)))
|
||||
steam_input_node = InputOperator(
|
||||
SimpleInputSource(new_steam_iter(num_iter)), task_id="112232"
|
||||
)
|
||||
res: DAGContext[str] = await runner.execute_workflow(steam_input_node)
|
||||
assert res.current_task_context.current_state == TaskState.SUCCESS
|
||||
output_steam = res.current_task_context.task_output.output_stream
|
||||
|
@ -0,0 +1 @@
|
||||
"""The trigger module of AWEL."""
|
@ -1,3 +1,4 @@
|
||||
"""Base class for all trigger classes."""
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
@ -6,6 +7,11 @@ from ..operator.common_operator import TriggerOperator
|
||||
|
||||
|
||||
class Trigger(TriggerOperator, ABC):
|
||||
"""Base class for all trigger classes.
|
||||
|
||||
Now only support http trigger.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def trigger(self) -> None:
|
||||
"""Trigger the workflow or a specific operation in the workflow."""
|
||||
|
@ -1,10 +1,11 @@
|
||||
"""Http trigger for AWEL."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union, cast
|
||||
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
|
||||
@ -13,29 +14,35 @@ from ..operator.base import BaseOperator
|
||||
from .base import Trigger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import APIRouter, FastAPI
|
||||
from fastapi import APIRouter
|
||||
|
||||
RequestBody = Union[Type[Request], Type[BaseModel], str]
|
||||
RequestBody = Union[Type[Request], Type[BaseModel], Type[str]]
|
||||
StreamingPredictFunc = Callable[[Union[Request, BaseModel]], bool]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HttpTrigger(Trigger):
|
||||
"""Http trigger for AWEL.
|
||||
|
||||
Http trigger is used to trigger a DAG by http request.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: str,
|
||||
methods: Optional[Union[str, List[str]]] = "GET",
|
||||
request_body: Optional[RequestBody] = None,
|
||||
streaming_response: Optional[bool] = False,
|
||||
streaming_response: bool = False,
|
||||
streaming_predict_func: Optional[StreamingPredictFunc] = None,
|
||||
response_model: Optional[Type] = None,
|
||||
response_headers: Optional[Dict[str, str]] = None,
|
||||
response_media_type: Optional[str] = None,
|
||||
status_code: Optional[int] = 200,
|
||||
router_tags: Optional[List[str]] = None,
|
||||
router_tags: Optional[List[str | Enum]] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Initialize a HttpTrigger."""
|
||||
super().__init__(**kwargs)
|
||||
if not endpoint.startswith("/"):
|
||||
endpoint = "/" + endpoint
|
||||
@ -49,15 +56,21 @@ class HttpTrigger(Trigger):
|
||||
self._router_tags = router_tags
|
||||
self._response_headers = response_headers
|
||||
self._response_media_type = response_media_type
|
||||
self._end_node: BaseOperator = None
|
||||
self._end_node: Optional[BaseOperator] = None
|
||||
|
||||
async def trigger(self) -> None:
|
||||
"""Trigger the DAG. Not used in HttpTrigger."""
|
||||
pass
|
||||
|
||||
def mount_to_router(self, router: "APIRouter") -> None:
|
||||
"""Mount the trigger to a router.
|
||||
|
||||
Args:
|
||||
router (APIRouter): The router to mount the trigger.
|
||||
"""
|
||||
from fastapi import Depends
|
||||
|
||||
methods = self._methods if isinstance(self._methods, list) else [self._methods]
|
||||
methods = [self._methods] if isinstance(self._methods, str) else self._methods
|
||||
|
||||
def create_route_function(name, req_body_cls: Optional[Type[BaseModel]]):
|
||||
async def _request_body_dependency(request: Request):
|
||||
@ -87,7 +100,8 @@ class HttpTrigger(Trigger):
|
||||
)
|
||||
dynamic_route_function = create_route_function(function_name, request_model)
|
||||
logger.info(
|
||||
f"mount router function {dynamic_route_function}({function_name}), endpoint: {self._endpoint}, methods: {methods}"
|
||||
f"mount router function {dynamic_route_function}({function_name}), "
|
||||
f"endpoint: {self._endpoint}, methods: {methods}"
|
||||
)
|
||||
|
||||
router.api_route(
|
||||
@ -100,17 +114,27 @@ class HttpTrigger(Trigger):
|
||||
|
||||
|
||||
async def _parse_request_body(
|
||||
request: Request, request_body_cls: Optional[Type[BaseModel]]
|
||||
request: Request, request_body_cls: Optional[RequestBody]
|
||||
):
|
||||
if not request_body_cls:
|
||||
return None
|
||||
if request.method == "POST":
|
||||
json_data = await request.json()
|
||||
return request_body_cls(**json_data)
|
||||
elif request.method == "GET":
|
||||
return request_body_cls(**request.query_params)
|
||||
else:
|
||||
if request_body_cls == Request:
|
||||
return request
|
||||
if request.method == "POST":
|
||||
if request_body_cls == str:
|
||||
bytes_body = await request.body()
|
||||
str_body = bytes_body.decode("utf-8")
|
||||
return str_body
|
||||
elif issubclass(request_body_cls, BaseModel):
|
||||
json_data = await request.json()
|
||||
return request_body_cls(**json_data)
|
||||
else:
|
||||
raise ValueError(f"Invalid request body cls: {request_body_cls}")
|
||||
elif request.method == "GET":
|
||||
if issubclass(request_body_cls, BaseModel):
|
||||
return request_body_cls(**request.query_params)
|
||||
else:
|
||||
raise ValueError(f"Invalid request body cls: {request_body_cls}")
|
||||
|
||||
|
||||
async def _trigger_dag(
|
||||
@ -123,10 +147,10 @@ async def _trigger_dag(
|
||||
from fastapi import BackgroundTasks
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
end_node = dag.leaf_nodes
|
||||
if len(end_node) != 1:
|
||||
leaf_nodes = dag.leaf_nodes
|
||||
if len(leaf_nodes) != 1:
|
||||
raise ValueError("HttpTrigger just support one leaf node in dag")
|
||||
end_node = end_node[0]
|
||||
end_node = cast(BaseOperator, leaf_nodes[0])
|
||||
if not streaming_response:
|
||||
return await end_node.call(call_data={"data": body})
|
||||
else:
|
||||
@ -141,7 +165,7 @@ async def _trigger_dag(
|
||||
}
|
||||
generator = await end_node.call_stream(call_data={"data": body})
|
||||
background_tasks = BackgroundTasks()
|
||||
background_tasks.add_task(end_node.dag._after_dag_end)
|
||||
background_tasks.add_task(dag._after_dag_end)
|
||||
return StreamingResponse(
|
||||
generator,
|
||||
headers=headers,
|
||||
|
@ -1,41 +1,63 @@
|
||||
"""Trigger manager for AWEL.
|
||||
|
||||
After DB-GPT started, the trigger manager will be initialized and register all triggers
|
||||
"""
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
||||
|
||||
from .base import Trigger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import APIRouter
|
||||
|
||||
from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TriggerManager(ABC):
|
||||
"""Base class for trigger manager."""
|
||||
|
||||
@abstractmethod
|
||||
def register_trigger(self, trigger: Any) -> None:
|
||||
""" "Register a trigger to current manager"""
|
||||
"""Register a trigger to current manager."""
|
||||
|
||||
|
||||
class HttpTriggerManager(TriggerManager):
|
||||
"""Http trigger manager.
|
||||
|
||||
Register all http triggers to a router.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
router: Optional["APIRouter"] = None,
|
||||
router_prefix: Optional[str] = "/api/v1/awel/trigger",
|
||||
router_prefix: str = "/api/v1/awel/trigger",
|
||||
) -> None:
|
||||
"""Initialize a HttpTriggerManager.
|
||||
|
||||
Args:
|
||||
router (Optional["APIRouter"], optional): The router. Defaults to None.
|
||||
If None, will create a new FastAPI router.
|
||||
router_prefix (str, optional): The router prefix. Defaults
|
||||
to "/api/v1/awel/trigger".
|
||||
"""
|
||||
if not router:
|
||||
from fastapi import APIRouter
|
||||
|
||||
router = APIRouter()
|
||||
self._router_prefix = router_prefix
|
||||
self._router = router
|
||||
self._trigger_map = {}
|
||||
self._trigger_map: Dict[str, Trigger] = {}
|
||||
|
||||
def register_trigger(self, trigger: Any) -> None:
|
||||
"""Register a trigger to current manager."""
|
||||
from .http_trigger import HttpTrigger
|
||||
|
||||
if not isinstance(trigger, HttpTrigger):
|
||||
raise ValueError(f"Current trigger {trigger} not an object of HttpTrigger")
|
||||
trigger: HttpTrigger = trigger
|
||||
trigger_id = trigger.node_id
|
||||
if trigger_id not in self._trigger_map:
|
||||
trigger.mount_to_router(self._router)
|
||||
@ -45,23 +67,32 @@ class HttpTriggerManager(TriggerManager):
|
||||
logger.info(
|
||||
f"Include router {self._router} to prefix path {self._router_prefix}"
|
||||
)
|
||||
system_app.app.include_router(
|
||||
self._router, prefix=self._router_prefix, tags=["AWEL"]
|
||||
)
|
||||
app = system_app.app
|
||||
if not app:
|
||||
raise RuntimeError("System app not initialized")
|
||||
app.include_router(self._router, prefix=self._router_prefix, tags=["AWEL"])
|
||||
|
||||
|
||||
class DefaultTriggerManager(TriggerManager, BaseComponent):
|
||||
"""Default trigger manager for AWEL.
|
||||
|
||||
Manage all trigger managers. Just support http trigger now.
|
||||
"""
|
||||
|
||||
name = ComponentType.AWEL_TRIGGER_MANAGER
|
||||
|
||||
def __init__(self, system_app: SystemApp | None = None):
|
||||
"""Initialize a DefaultTriggerManager."""
|
||||
self.system_app = system_app
|
||||
self.http_trigger = HttpTriggerManager()
|
||||
super().__init__(None)
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
"""Initialize the trigger manager."""
|
||||
self.system_app = system_app
|
||||
|
||||
def register_trigger(self, trigger: Any) -> None:
|
||||
"""Register a trigger to current manager."""
|
||||
from .http_trigger import HttpTrigger
|
||||
|
||||
if isinstance(trigger, HttpTrigger):
|
||||
@ -71,4 +102,6 @@ class DefaultTriggerManager(TriggerManager, BaseComponent):
|
||||
raise ValueError(f"Unsupport trigger: {trigger}")
|
||||
|
||||
def after_register(self) -> None:
|
||||
self.http_trigger._init_app(self.system_app)
|
||||
"""After register, init the trigger manager."""
|
||||
if self.system_app:
|
||||
self.http_trigger._init_app(self.system_app)
|
||||
|
@ -0,0 +1,4 @@
|
||||
"""The core interface of DB-GPT.
|
||||
|
||||
Just include the core interface to keep our dependencies clean.
|
||||
"""
|
@ -1,3 +1,10 @@
|
||||
"""The cache interface.
|
||||
|
||||
The cache interface is used to cache LLM results and embedding results.
|
||||
|
||||
Maybe we can cache more server results in the future.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
@ -10,17 +17,23 @@ V = TypeVar("V")
|
||||
|
||||
|
||||
class RetrievalPolicy(str, Enum):
|
||||
"""The retrieval policy of the cache."""
|
||||
|
||||
EXACT_MATCH = "exact_match"
|
||||
SIMILARITY_MATCH = "similarity_match"
|
||||
|
||||
|
||||
class CachePolicy(str, Enum):
|
||||
"""The cache policy of the cache."""
|
||||
|
||||
LRU = "lru"
|
||||
FIFO = "fifo"
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheConfig:
|
||||
"""The cache config."""
|
||||
|
||||
retrieval_policy: Optional[RetrievalPolicy] = RetrievalPolicy.EXACT_MATCH
|
||||
cache_policy: Optional[CachePolicy] = CachePolicy.LRU
|
||||
|
||||
@ -30,7 +43,8 @@ class CacheKey(Serializable, ABC, Generic[K]):
|
||||
|
||||
Supported cache keys:
|
||||
- The LLM cache key: Include user prompt and the parameters to LLM.
|
||||
- The embedding model cache key: Include the texts to embedding and the parameters to embedding model.
|
||||
- The embedding model cache key: Include the texts to embedding and the parameters
|
||||
to embedding model.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
@ -76,7 +90,8 @@ class CacheClient(ABC, Generic[K, V]):
|
||||
cache_config (Optional[CacheConfig]): Cache config
|
||||
|
||||
Returns:
|
||||
Optional[CacheValue[V]]: The value retrieved according to key. If cache key not exist, return None.
|
||||
Optional[CacheValue[V]]: The value retrieved according to key. If cache key
|
||||
not exist, return None.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
@ -110,8 +125,8 @@ class CacheClient(ABC, Generic[K, V]):
|
||||
|
||||
@abstractmethod
|
||||
def new_key(self, **kwargs) -> CacheKey[K]:
|
||||
"""Create a cache key with params"""
|
||||
"""Create a cache key with params."""
|
||||
|
||||
@abstractmethod
|
||||
def new_value(self, **kwargs) -> CacheValue[K]:
|
||||
"""Create a cache key with params"""
|
||||
"""Create a cache key with params."""
|
||||
|
@ -1,3 +1,5 @@
|
||||
"""The interface for LLM."""
|
||||
|
||||
import collections
|
||||
import copy
|
||||
import logging
|
||||
@ -31,7 +33,8 @@ class ModelInferenceMetrics:
|
||||
"""The timestamp (in milliseconds) when the model inference ends."""
|
||||
|
||||
current_time_ms: Optional[int] = None
|
||||
"""The current timestamp (in milliseconds) when the model inference return partially output(stream)."""
|
||||
"""The current timestamp (in milliseconds) when the model inference return
|
||||
partially output(stream)."""
|
||||
|
||||
first_token_time_ms: Optional[int] = None
|
||||
"""The timestamp (in milliseconds) when the first token is generated."""
|
||||
@ -64,6 +67,14 @@ class ModelInferenceMetrics:
|
||||
def create_metrics(
|
||||
last_metrics: Optional["ModelInferenceMetrics"] = None,
|
||||
) -> "ModelInferenceMetrics":
|
||||
"""Create metrics for model inference.
|
||||
|
||||
Args:
|
||||
last_metrics(ModelInferenceMetrics): The last metrics.
|
||||
|
||||
Returns:
|
||||
ModelInferenceMetrics: The metrics for model inference.
|
||||
"""
|
||||
start_time_ms = last_metrics.start_time_ms if last_metrics else None
|
||||
first_token_time_ms = last_metrics.first_token_time_ms if last_metrics else None
|
||||
first_completion_time_ms = (
|
||||
@ -100,15 +111,21 @@ class ModelInferenceMetrics:
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert the model inference metrics to dict."""
|
||||
return asdict(self)
|
||||
|
||||
|
||||
@dataclass
|
||||
@PublicAPI(stability="beta")
|
||||
class ModelRequestContext:
|
||||
stream: Optional[bool] = False
|
||||
"""A class to represent the context of a LLM model request."""
|
||||
|
||||
stream: bool = False
|
||||
"""Whether to return a stream of responses."""
|
||||
|
||||
cache_enable: bool = False
|
||||
"""Whether to enable the cache for the model inference"""
|
||||
|
||||
user_name: Optional[str] = None
|
||||
"""The user name of the model request."""
|
||||
|
||||
@ -129,8 +146,6 @@ class ModelRequestContext:
|
||||
|
||||
request_id: Optional[str] = None
|
||||
"""The request id of the model inference."""
|
||||
cache_enable: Optional[bool] = False
|
||||
"""Whether to enable the cache for the model inference"""
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -141,27 +156,31 @@ class ModelOutput:
|
||||
text: str
|
||||
"""The generated text."""
|
||||
error_code: int
|
||||
"""The error code of the model inference. If the model inference is successful, the error code is 0."""
|
||||
model_context: Dict = None
|
||||
finish_reason: str = None
|
||||
usage: Dict[str, Any] = None
|
||||
"""The error code of the model inference. If the model inference is successful,
|
||||
the error code is 0."""
|
||||
model_context: Optional[Dict] = None
|
||||
finish_reason: Optional[str] = None
|
||||
usage: Optional[Dict[str, Any]] = None
|
||||
metrics: Optional[ModelInferenceMetrics] = None
|
||||
"""Some metrics for model inference"""
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert the model output to dict."""
|
||||
return asdict(self)
|
||||
|
||||
|
||||
_ModelMessageType = Union[ModelMessage, Dict[str, Any]]
|
||||
_ModelMessageType = Union[List[ModelMessage], List[Dict[str, Any]]]
|
||||
|
||||
|
||||
@dataclass
|
||||
@PublicAPI(stability="beta")
|
||||
class ModelRequest:
|
||||
"""The model request."""
|
||||
|
||||
model: str
|
||||
"""The name of the model."""
|
||||
|
||||
messages: List[_ModelMessageType]
|
||||
messages: _ModelMessageType
|
||||
"""The input messages."""
|
||||
|
||||
temperature: Optional[float] = None
|
||||
@ -189,28 +208,42 @@ class ModelRequest:
|
||||
@property
|
||||
def stream(self) -> bool:
|
||||
"""Whether to return a stream of responses."""
|
||||
return self.context and self.context.stream
|
||||
return bool(self.context and self.context.stream)
|
||||
|
||||
def copy(self):
|
||||
def copy(self) -> "ModelRequest":
|
||||
"""Copy the model request.
|
||||
|
||||
Returns:
|
||||
ModelRequest: The copied model request.
|
||||
"""
|
||||
new_request = copy.deepcopy(self)
|
||||
# Transform messages to List[ModelMessage]
|
||||
new_request.messages = list(
|
||||
map(
|
||||
lambda m: m if isinstance(m, ModelMessage) else ModelMessage(**m),
|
||||
new_request.messages,
|
||||
)
|
||||
)
|
||||
new_request.messages = new_request.get_messages()
|
||||
return new_request
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert the model request to dict.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The model request in dict.
|
||||
"""
|
||||
new_reqeust = copy.deepcopy(self)
|
||||
new_reqeust.messages = list(
|
||||
map(lambda m: m if isinstance(m, dict) else m.dict(), new_reqeust.messages)
|
||||
)
|
||||
new_messages = []
|
||||
for message in new_reqeust.messages:
|
||||
if isinstance(message, dict):
|
||||
new_messages.append(message)
|
||||
else:
|
||||
new_messages.append(message.dict())
|
||||
new_reqeust.messages = new_messages
|
||||
# Skip None fields
|
||||
return {k: v for k, v in asdict(new_reqeust).items() if v is not None}
|
||||
|
||||
def to_trace_metadata(self):
|
||||
def to_trace_metadata(self) -> Dict[str, Any]:
|
||||
"""Convert the model request to trace metadata.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The trace metadata.
|
||||
"""
|
||||
metadata = self.to_dict()
|
||||
metadata["prompt"] = self.messages_to_string()
|
||||
return metadata
|
||||
@ -218,16 +251,19 @@ class ModelRequest:
|
||||
def get_messages(self) -> List[ModelMessage]:
|
||||
"""Get the messages.
|
||||
|
||||
If the messages is not a list of ModelMessage, it will be converted to a list of ModelMessage.
|
||||
If the messages is not a list of ModelMessage, it will be converted to a list
|
||||
of ModelMessage.
|
||||
|
||||
Returns:
|
||||
List[ModelMessage]: The messages.
|
||||
"""
|
||||
return list(
|
||||
map(
|
||||
lambda m: m if isinstance(m, ModelMessage) else ModelMessage(**m),
|
||||
self.messages,
|
||||
)
|
||||
)
|
||||
messages = []
|
||||
for message in self.messages:
|
||||
if isinstance(message, dict):
|
||||
messages.append(ModelMessage(**message))
|
||||
else:
|
||||
messages.append(message)
|
||||
return messages
|
||||
|
||||
def get_single_user_message(self) -> Optional[ModelMessage]:
|
||||
"""Get the single user message.
|
||||
@ -245,20 +281,35 @@ class ModelRequest:
|
||||
model: str,
|
||||
messages: List[ModelMessage],
|
||||
context: Optional[Union[ModelRequestContext, Dict[str, Any], BaseModel]] = None,
|
||||
stream: Optional[bool] = False,
|
||||
echo: Optional[bool] = False,
|
||||
stream: bool = False,
|
||||
echo: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""Build a model request.
|
||||
|
||||
Args:
|
||||
model(str): The model name.
|
||||
messages(List[ModelMessage]): The messages.
|
||||
context(Optional[Union[ModelRequestContext, Dict[str, Any], BaseModel]]):
|
||||
The context.
|
||||
stream(bool): Whether to return a stream of responses. Defaults to False.
|
||||
echo(bool): Whether to echo the input messages. Defaults to False.
|
||||
**kwargs: Other arguments.
|
||||
"""
|
||||
if not context:
|
||||
context = ModelRequestContext(stream=stream)
|
||||
context_dict = None
|
||||
if isinstance(context, dict):
|
||||
context_dict = context
|
||||
elif isinstance(context, BaseModel):
|
||||
context_dict = context.dict()
|
||||
if context_dict and "stream" not in context_dict:
|
||||
context_dict["stream"] = stream
|
||||
context = ModelRequestContext(**context_dict)
|
||||
elif not isinstance(context, ModelRequestContext):
|
||||
context_dict = None
|
||||
if isinstance(context, dict):
|
||||
context_dict = context
|
||||
elif isinstance(context, BaseModel):
|
||||
context_dict = context.dict()
|
||||
if context_dict and "stream" not in context_dict:
|
||||
context_dict["stream"] = stream
|
||||
if context_dict:
|
||||
context = ModelRequestContext(**context_dict)
|
||||
else:
|
||||
context = ModelRequestContext(stream=stream)
|
||||
return ModelRequest(
|
||||
model=model,
|
||||
messages=messages,
|
||||
@ -292,7 +343,6 @@ class ModelRequest:
|
||||
ValueError: If the message role is not supported
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from dbgpt.core.interface.message import (
|
||||
@ -337,7 +387,7 @@ class ModelRequest:
|
||||
class ModelExtraMedata(BaseParameters):
|
||||
"""A class to represent the extra metadata of a LLM."""
|
||||
|
||||
prompt_roles: Optional[List[str]] = field(
|
||||
prompt_roles: List[str] = field(
|
||||
default_factory=lambda: [
|
||||
ModelMessageRoleType.SYSTEM,
|
||||
ModelMessageRoleType.HUMAN,
|
||||
@ -356,7 +406,8 @@ class ModelExtraMedata(BaseParameters):
|
||||
prompt_chat_template: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The chat template, see: https://huggingface.co/docs/transformers/main/en/chat_templating"
|
||||
"help": "The chat template, see: "
|
||||
"https://huggingface.co/docs/transformers/main/en/chat_templating"
|
||||
},
|
||||
)
|
||||
|
||||
@ -403,19 +454,19 @@ class ModelMetadata(BaseParameters):
|
||||
def from_dict(
|
||||
cls, data: dict, ignore_extra_fields: bool = False
|
||||
) -> "ModelMetadata":
|
||||
"""Create a new model metadata from a dict."""
|
||||
if "ext_metadata" in data:
|
||||
data["ext_metadata"] = ModelExtraMedata(**data["ext_metadata"])
|
||||
return cls(**data)
|
||||
|
||||
|
||||
class MessageConverter(ABC):
|
||||
"""An abstract class for message converter.
|
||||
r"""An abstract class for message converter.
|
||||
|
||||
Different LLMs may have different message formats, this class is used to convert the messages
|
||||
to the format of the LLM.
|
||||
Different LLMs may have different message formats, this class is used to convert
|
||||
the messages to the format of the LLM.
|
||||
|
||||
Examples:
|
||||
|
||||
>>> from typing import List
|
||||
>>> from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
|
||||
>>> from dbgpt.core.interface.llm import MessageConverter, ModelMetadata
|
||||
@ -425,7 +476,8 @@ class MessageConverter(ABC):
|
||||
... messages: List[ModelMessage],
|
||||
... model_metadata: Optional[ModelMetadata] = None,
|
||||
... ) -> List[ModelMessage]:
|
||||
... # Convert the messages, merge system messages to the last user message.
|
||||
... # Convert the messages, merge system messages to the last user
|
||||
... # message.
|
||||
... system_message = None
|
||||
... other_messages = []
|
||||
... sep = "\\n"
|
||||
@ -478,6 +530,7 @@ class DefaultMessageConverter(MessageConverter):
|
||||
"""The default message converter."""
|
||||
|
||||
def __init__(self, prompt_sep: Optional[str] = None):
|
||||
"""Create a new default message converter."""
|
||||
self._prompt_sep = prompt_sep
|
||||
|
||||
def convert(
|
||||
@ -493,7 +546,8 @@ class DefaultMessageConverter(MessageConverter):
|
||||
|
||||
2. Move the last user's message to the end of the list
|
||||
|
||||
3. Convert the messages to no system message if the model does not support system message
|
||||
3. Convert the messages to no system message if the model does not support
|
||||
system message
|
||||
|
||||
Args:
|
||||
messages(List[ModelMessage]): The messages.
|
||||
@ -520,10 +574,11 @@ class DefaultMessageConverter(MessageConverter):
|
||||
messages: List[ModelMessage],
|
||||
model_metadata: Optional[ModelMetadata] = None,
|
||||
) -> List[ModelMessage]:
|
||||
"""Convert the messages to no system message.
|
||||
r"""Convert the messages to no system message.
|
||||
|
||||
Examples:
|
||||
>>> # Convert the messages to no system message, just merge system messages to the last user message
|
||||
>>> # Convert the messages to no system message, just merge system messages
|
||||
>>> # to the last user message
|
||||
>>> from typing import List
|
||||
>>> from dbgpt.core.interface.message import (
|
||||
... ModelMessage,
|
||||
@ -550,7 +605,7 @@ class DefaultMessageConverter(MessageConverter):
|
||||
>>> assert converted_messages == [
|
||||
... ModelMessage(
|
||||
... role=ModelMessageRoleType.HUMAN,
|
||||
... content="You are a helpful assistant\\nWho are you",
|
||||
... content="You are a helpful assistant\nWho are you",
|
||||
... ),
|
||||
... ]
|
||||
"""
|
||||
@ -562,7 +617,8 @@ class DefaultMessageConverter(MessageConverter):
|
||||
result_messages = []
|
||||
for message in messages:
|
||||
if message.role == ModelMessageRoleType.SYSTEM:
|
||||
# Not support system message, append system message to the last user message
|
||||
# Not support system message, append system message to the last user
|
||||
# message
|
||||
system_messages.append(message)
|
||||
elif message.role in [
|
||||
ModelMessageRoleType.HUMAN,
|
||||
@ -578,7 +634,8 @@ class DefaultMessageConverter(MessageConverter):
|
||||
system_message_str = system_messages[0].content
|
||||
|
||||
if system_message_str and result_messages:
|
||||
# Not support system messages, merge system messages to the last user message
|
||||
# Not support system messages, merge system messages to the last user
|
||||
# message
|
||||
result_messages[-1].content = (
|
||||
system_message_str + prompt_sep + result_messages[-1].content
|
||||
)
|
||||
@ -587,10 +644,9 @@ class DefaultMessageConverter(MessageConverter):
|
||||
def move_last_user_message_to_end(
|
||||
self, messages: List[ModelMessage]
|
||||
) -> List[ModelMessage]:
|
||||
"""Move the last user message to the end of the list.
|
||||
"""Try to move the last user message to the end of the list.
|
||||
|
||||
Examples:
|
||||
|
||||
>>> from typing import List
|
||||
>>> from dbgpt.core.interface.message import (
|
||||
... ModelMessage,
|
||||
@ -660,7 +716,7 @@ class LLMClient(ABC):
|
||||
|
||||
@property
|
||||
def cache(self) -> collections.abc.MutableMapping:
|
||||
"""The cache object to cache the model metadata.
|
||||
"""Return the cache object to cache the model metadata.
|
||||
|
||||
You can override this property to use your own cache object.
|
||||
Returns:
|
||||
@ -677,7 +733,8 @@ class LLMClient(ABC):
|
||||
"""Generate a response for a given model request.
|
||||
|
||||
Sometimes, different LLMs may have different message formats,
|
||||
you can use the message converter to convert the messages to the format of the LLM.
|
||||
you can use the message converter to convert the messages to the format of the
|
||||
LLM.
|
||||
|
||||
Args:
|
||||
request(ModelRequest): The model request.
|
||||
@ -697,7 +754,8 @@ class LLMClient(ABC):
|
||||
"""Generate a stream of responses for a given model request.
|
||||
|
||||
Sometimes, different LLMs may have different message formats,
|
||||
you can use the message converter to convert the messages to the format of the LLM.
|
||||
you can use the message converter to convert the messages to the format of the
|
||||
LLM.
|
||||
|
||||
Args:
|
||||
request(ModelRequest): The model request.
|
||||
@ -733,6 +791,7 @@ class LLMClient(ABC):
|
||||
message_converter: Optional[MessageConverter] = None,
|
||||
) -> ModelRequest:
|
||||
"""Covert the message.
|
||||
|
||||
If no message converter is provided, the original request will be returned.
|
||||
|
||||
Args:
|
||||
@ -746,14 +805,15 @@ class LLMClient(ABC):
|
||||
return request
|
||||
new_request = request.copy()
|
||||
model_metadata = await self.get_model_metadata(request.model)
|
||||
new_messages = message_converter.convert(request.messages, model_metadata)
|
||||
new_messages = message_converter.convert(request.get_messages(), model_metadata)
|
||||
new_request.messages = new_messages
|
||||
return new_request
|
||||
|
||||
async def cached_models(self) -> List[ModelMetadata]:
|
||||
"""Get all the models from the cache or the llm server.
|
||||
|
||||
If the model metadata is not in the cache, it will be fetched from the llm server.
|
||||
If the model metadata is not in the cache, it will be fetched from the
|
||||
llm server.
|
||||
|
||||
Returns:
|
||||
List[ModelMetadata]: A list of model metadata.
|
||||
|
@ -1,8 +1,10 @@
|
||||
"""The conversation and message module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt.core.interface.storage import (
|
||||
@ -29,11 +31,11 @@ class BaseMessage(BaseModel, ABC):
|
||||
|
||||
@property
|
||||
def pass_to_model(self) -> bool:
|
||||
"""Whether the message will be passed to the model"""
|
||||
"""Whether the message will be passed to the model."""
|
||||
return True
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert to dict
|
||||
"""Convert to dict.
|
||||
|
||||
Returns:
|
||||
Dict: The dict object
|
||||
@ -47,7 +49,7 @@ class BaseMessage(BaseModel, ABC):
|
||||
|
||||
@staticmethod
|
||||
def messages_to_string(messages: List["BaseMessage"]) -> str:
|
||||
"""Convert messages to str
|
||||
"""Convert messages to str.
|
||||
|
||||
Args:
|
||||
messages (List[BaseMessage]): The messages
|
||||
@ -92,7 +94,7 @@ class ViewMessage(BaseMessage):
|
||||
|
||||
@property
|
||||
def pass_to_model(self) -> bool:
|
||||
"""Whether the message will be passed to the model
|
||||
"""Whether the message will be passed to the model.
|
||||
|
||||
The view message will not be passed to the model
|
||||
"""
|
||||
@ -109,7 +111,7 @@ class SystemMessage(BaseMessage):
|
||||
|
||||
|
||||
class ModelMessageRoleType:
|
||||
""" "Type of ModelMessage role"""
|
||||
"""Type of ModelMessage role."""
|
||||
|
||||
SYSTEM = "system"
|
||||
HUMAN = "human"
|
||||
@ -118,7 +120,7 @@ class ModelMessageRoleType:
|
||||
|
||||
|
||||
class ModelMessage(BaseModel):
|
||||
"""Type of message that interaction between dbgpt-server and llm-server"""
|
||||
"""Type of message that interaction between dbgpt-server and llm-server."""
|
||||
|
||||
"""Similar to openai's message format"""
|
||||
role: str
|
||||
@ -127,7 +129,7 @@ class ModelMessage(BaseModel):
|
||||
|
||||
@property
|
||||
def pass_to_model(self) -> bool:
|
||||
"""Whether the message will be passed to the model
|
||||
"""Whether the message will be passed to the model.
|
||||
|
||||
The view message will not be passed to the model
|
||||
|
||||
@ -142,6 +144,14 @@ class ModelMessage(BaseModel):
|
||||
|
||||
@staticmethod
|
||||
def from_base_messages(messages: List[BaseMessage]) -> List["ModelMessage"]:
|
||||
"""Covert BaseMessage format to current ModelMessage format.
|
||||
|
||||
Args:
|
||||
messages (List[BaseMessage]): The base messages
|
||||
|
||||
Returns:
|
||||
List[ModelMessage]: The model messages
|
||||
"""
|
||||
result = []
|
||||
for message in messages:
|
||||
content, round_index = message.content, message.round_index
|
||||
@ -173,7 +183,7 @@ class ModelMessage(BaseModel):
|
||||
def from_openai_messages(
|
||||
messages: Union[str, List[Dict[str, str]]]
|
||||
) -> List["ModelMessage"]:
|
||||
"""Openai message format to current ModelMessage format"""
|
||||
"""Openai message format to current ModelMessage format."""
|
||||
if isinstance(messages, str):
|
||||
return [ModelMessage(role=ModelMessageRoleType.HUMAN, content=messages)]
|
||||
result = []
|
||||
@ -202,8 +212,11 @@ class ModelMessage(BaseModel):
|
||||
convert_to_compatible_format: bool = False,
|
||||
support_system_role: bool = True,
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Convert to common message format(e.g. OpenAI message format) and
|
||||
huggingface [Templates of Chat Models](https://huggingface.co/docs/transformers/v4.34.1/en/chat_templating)
|
||||
"""Cover to common message format.
|
||||
|
||||
Convert to common message format(e.g. OpenAI message format) and
|
||||
huggingface [Templates of Chat Models]
|
||||
(https://huggingface.co/docs/transformers/v4.34.1/en/chat_templating)
|
||||
|
||||
Args:
|
||||
messages (List["ModelMessage"]): The model messages
|
||||
@ -243,15 +256,38 @@ class ModelMessage(BaseModel):
|
||||
|
||||
@staticmethod
|
||||
def to_dict_list(messages: List["ModelMessage"]) -> List[Dict[str, str]]:
|
||||
"""Convert to dict list.
|
||||
|
||||
Args:
|
||||
messages (List["ModelMessage"]): The model messages
|
||||
|
||||
Returns:
|
||||
List[Dict[str, str]]: The dict list
|
||||
"""
|
||||
return list(map(lambda m: m.dict(), messages))
|
||||
|
||||
@staticmethod
|
||||
def build_human_message(content: str) -> "ModelMessage":
|
||||
"""Build human message.
|
||||
|
||||
Args:
|
||||
content (str): The content
|
||||
|
||||
Returns:
|
||||
ModelMessage: The model message
|
||||
"""
|
||||
return ModelMessage(role=ModelMessageRoleType.HUMAN, content=content)
|
||||
|
||||
@staticmethod
|
||||
def get_printable_message(messages: List["ModelMessage"]) -> str:
|
||||
"""Get the printable message"""
|
||||
"""Get the printable message.
|
||||
|
||||
Args:
|
||||
messages (List["ModelMessage"]): The model messages
|
||||
|
||||
Returns:
|
||||
str: The printable message
|
||||
"""
|
||||
str_msg = ""
|
||||
for message in messages:
|
||||
curr_message = (
|
||||
@ -263,7 +299,7 @@ class ModelMessage(BaseModel):
|
||||
|
||||
@staticmethod
|
||||
def messages_to_string(messages: List["ModelMessage"]) -> str:
|
||||
"""Convert messages to str
|
||||
"""Convert messages to str.
|
||||
|
||||
Args:
|
||||
messages (List[ModelMessage]): The messages
|
||||
@ -287,12 +323,12 @@ def _messages_to_dict(messages: List[BaseMessage]) -> List[Dict]:
|
||||
|
||||
|
||||
def _messages_to_str(
|
||||
messages: List[Union[BaseMessage, ModelMessage]],
|
||||
messages: Union[List[BaseMessage], List[ModelMessage]],
|
||||
human_prefix: str = "Human",
|
||||
ai_prefix: str = "AI",
|
||||
system_prefix: str = "System",
|
||||
) -> str:
|
||||
"""Convert messages to str
|
||||
"""Convert messages to str.
|
||||
|
||||
Args:
|
||||
messages (List[Union[BaseMessage, ModelMessage]]): The messages
|
||||
@ -343,21 +379,27 @@ def _messages_from_dict(messages: List[Dict]) -> List[BaseMessage]:
|
||||
|
||||
def parse_model_messages(
|
||||
messages: List[ModelMessage],
|
||||
) -> Tuple[str, List[str], List[List[str, str]]]:
|
||||
"""
|
||||
Parse model messages to extract the user prompt, system messages, and a history of conversation.
|
||||
) -> Tuple[str, List[str], List[List[str]]]:
|
||||
"""Parse model messages.
|
||||
|
||||
This function analyzes a list of ModelMessage objects, identifying the role of each message (e.g., human, system, ai)
|
||||
and categorizes them accordingly. The last message is expected to be from the user (human), and it's treated as
|
||||
the current user prompt. System messages are extracted separately, and the conversation history is compiled into
|
||||
pairs of human and AI messages.
|
||||
Parse model messages to extract the user prompt, system messages, and a history of
|
||||
conversation.
|
||||
|
||||
This function analyzes a list of ModelMessage objects, identifying the role of each
|
||||
message (e.g., human, system, ai)
|
||||
and categorizes them accordingly. The last message is expected to be from the user
|
||||
(human), and it's treated as
|
||||
the current user prompt. System messages are extracted separately, and the
|
||||
conversation history is compiled into pairs of human and AI messages.
|
||||
|
||||
Args:
|
||||
messages (List[ModelMessage]): List of messages from a chat conversation.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing the user prompt, list of system messages, and the conversation history.
|
||||
The conversation history is a list of message pairs, each containing a user message and the corresponding AI response.
|
||||
tuple: A tuple containing the user prompt, list of system messages, and the
|
||||
conversation history.
|
||||
The conversation history is a list of message pairs, each containing a
|
||||
user message and the corresponding AI response.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
@ -399,7 +441,6 @@ def parse_model_messages(
|
||||
# system_messages: ["Error 404"]
|
||||
# history: [["Hi", "Hello!"], ["What's the error?", "Just a joke."]]
|
||||
"""
|
||||
|
||||
system_messages: List[str] = []
|
||||
history_messages: List[List[str]] = [[]]
|
||||
|
||||
@ -420,27 +461,30 @@ def parse_model_messages(
|
||||
|
||||
|
||||
class OnceConversation:
|
||||
"""All the information of a conversation, the current single service in memory,
|
||||
can expand cache and database support distributed services.
|
||||
"""Once conversation.
|
||||
|
||||
All the information of a conversation, the current single service in memory,
|
||||
can expand cache and database support distributed services.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chat_mode: str,
|
||||
user_name: str = None,
|
||||
sys_code: str = None,
|
||||
summary: str = None,
|
||||
user_name: Optional[str] = None,
|
||||
sys_code: Optional[str] = None,
|
||||
summary: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create a new conversation."""
|
||||
self.chat_mode: str = chat_mode
|
||||
self.user_name: str = user_name
|
||||
self.sys_code: str = sys_code
|
||||
self.summary: str = summary
|
||||
self.user_name: Optional[str] = user_name
|
||||
self.sys_code: Optional[str] = sys_code
|
||||
self.summary: Optional[str] = summary
|
||||
|
||||
self.messages: List[BaseMessage] = kwargs.get("messages", [])
|
||||
self.start_date: str = kwargs.get("start_date", "")
|
||||
# After each complete round of dialogue, the current value will be increased by 1
|
||||
# After each complete round of dialogue, the current value will be
|
||||
# increased by 1
|
||||
self.chat_order: int = int(kwargs.get("chat_order", 0))
|
||||
self.model_name: str = kwargs.get("model_name", "")
|
||||
self.param_type: str = kwargs.get("param_type", "")
|
||||
@ -460,10 +504,9 @@ class OnceConversation:
|
||||
self.messages.append(message)
|
||||
|
||||
def start_new_round(self) -> None:
|
||||
"""Start a new round of conversation
|
||||
"""Start a new round of conversation.
|
||||
|
||||
Example:
|
||||
|
||||
>>> conversation = OnceConversation("chat_normal")
|
||||
>>> # The chat order will be 0, then we start a new round of conversation
|
||||
>>> assert conversation.chat_order == 0
|
||||
@ -473,7 +516,8 @@ class OnceConversation:
|
||||
>>> conversation.add_user_message("hello")
|
||||
>>> conversation.add_ai_message("hi")
|
||||
>>> conversation.end_current_round()
|
||||
>>> # Now the chat order will be 1, then we start a new round of conversation
|
||||
>>> # Now the chat order will be 1, then we start a new round of
|
||||
>>> # conversation
|
||||
>>> conversation.start_new_round()
|
||||
>>> # Now the chat order will be 2
|
||||
>>> assert conversation.chat_order == 2
|
||||
@ -485,7 +529,7 @@ class OnceConversation:
|
||||
self.chat_order += 1
|
||||
|
||||
def end_current_round(self) -> None:
|
||||
"""End the current round of conversation
|
||||
"""Execute the end of the current round of conversation.
|
||||
|
||||
We do noting here, just for the interface
|
||||
"""
|
||||
@ -494,7 +538,7 @@ class OnceConversation:
|
||||
def add_user_message(
|
||||
self, message: str, check_duplicate_type: Optional[bool] = False
|
||||
) -> None:
|
||||
"""Add a user message to the conversation
|
||||
"""Save a user message to the conversation.
|
||||
|
||||
Args:
|
||||
message (str): The message content
|
||||
@ -514,11 +558,12 @@ class OnceConversation:
|
||||
def add_ai_message(
|
||||
self, message: str, update_if_exist: Optional[bool] = False
|
||||
) -> None:
|
||||
"""Add an AI message to the conversation
|
||||
"""Save an AI message to current conversation.
|
||||
|
||||
Args:
|
||||
message (str): The message content
|
||||
update_if_exist (bool): Whether to update the message if the message type is duplicate
|
||||
update_if_exist (bool): Whether to update the message if the message type
|
||||
is duplicate
|
||||
"""
|
||||
if not update_if_exist:
|
||||
self._append_message(AIMessage(content=message))
|
||||
@ -530,51 +575,57 @@ class OnceConversation:
|
||||
self._append_message(AIMessage(content=message))
|
||||
|
||||
def _update_ai_message(self, new_message: str) -> None:
|
||||
"""
|
||||
"""Update the all AI message to new message.
|
||||
|
||||
stream out message update
|
||||
|
||||
Args:
|
||||
new_message:
|
||||
|
||||
Returns:
|
||||
|
||||
new_message (str): The new message
|
||||
"""
|
||||
|
||||
for item in self.messages:
|
||||
if item.type == "ai":
|
||||
item.content = new_message
|
||||
|
||||
def add_view_message(self, message: str) -> None:
|
||||
"""Add an AI message to the store"""
|
||||
"""Save a view message to current conversation."""
|
||||
self._append_message(ViewMessage(content=message))
|
||||
|
||||
def add_system_message(self, message: str) -> None:
|
||||
"""Add a system message to the store"""
|
||||
"""Save a system message to current conversation."""
|
||||
self._append_message(SystemMessage(content=message))
|
||||
|
||||
def set_start_time(self, datatime: datetime):
|
||||
"""Set the start time of the conversation."""
|
||||
dt_str = datatime.strftime("%Y-%m-%d %H:%M:%S")
|
||||
self.start_date = dt_str
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Remove all messages from the store"""
|
||||
"""Remove all messages from the store."""
|
||||
self.messages.clear()
|
||||
|
||||
def get_latest_user_message(self) -> Optional[HumanMessage]:
|
||||
"""Get the latest user message"""
|
||||
"""Get the latest user message."""
|
||||
for message in self.messages[::-1]:
|
||||
if isinstance(message, HumanMessage):
|
||||
return message
|
||||
return None
|
||||
|
||||
def get_system_messages(self) -> List[SystemMessage]:
|
||||
"""Get the latest user message"""
|
||||
return list(filter(lambda x: isinstance(x, SystemMessage), self.messages))
|
||||
"""Get the latest user message.
|
||||
|
||||
Returns:
|
||||
List[SystemMessage]: The system messages
|
||||
"""
|
||||
return cast(
|
||||
List[SystemMessage],
|
||||
list(filter(lambda x: isinstance(x, SystemMessage), self.messages)),
|
||||
)
|
||||
|
||||
def _to_dict(self) -> Dict:
|
||||
return _conversation_to_dict(self)
|
||||
|
||||
def from_conversation(self, conversation: OnceConversation) -> None:
|
||||
"""Load the conversation from the storage"""
|
||||
"""Load the conversation from the storage."""
|
||||
self.chat_mode = conversation.chat_mode
|
||||
self.messages = conversation.messages
|
||||
self.start_date = conversation.start_date
|
||||
@ -592,7 +643,7 @@ class OnceConversation:
|
||||
self._message_index = conversation._message_index
|
||||
|
||||
def get_messages_by_round(self, round_index: int) -> List[BaseMessage]:
|
||||
"""Get the messages by round index
|
||||
"""Get the messages by round index.
|
||||
|
||||
Args:
|
||||
round_index (int): The round index
|
||||
@ -603,7 +654,7 @@ class OnceConversation:
|
||||
return list(filter(lambda x: x.round_index == round_index, self.messages))
|
||||
|
||||
def get_latest_round(self) -> List[BaseMessage]:
|
||||
"""Get the latest round messages
|
||||
"""Get the latest round messages.
|
||||
|
||||
Returns:
|
||||
List[BaseMessage]: The messages
|
||||
@ -611,7 +662,7 @@ class OnceConversation:
|
||||
return self.get_messages_by_round(self.chat_order)
|
||||
|
||||
def get_messages_with_round(self, round_count: int) -> List[BaseMessage]:
|
||||
"""Get the messages with round count
|
||||
"""Get the messages with round count.
|
||||
|
||||
If the round count is 1, the history messages will not be included.
|
||||
|
||||
@ -660,16 +711,19 @@ class OnceConversation:
|
||||
return messages
|
||||
|
||||
def get_model_messages(self) -> List[ModelMessage]:
|
||||
"""Get the model messages
|
||||
"""Get the model messages.
|
||||
|
||||
Model messages just include human, ai and system messages.
|
||||
Model messages maybe include the history messages, The order of the messages is the same as the order of
|
||||
Model messages maybe include the history messages, The order of the messages is
|
||||
the same as the order of
|
||||
the messages in the conversation, the last message is the latest message.
|
||||
|
||||
If you want to hand the message with your own logic, you can override this method.
|
||||
If you want to hand the message with your own logic, you can override this
|
||||
method.
|
||||
|
||||
Examples:
|
||||
If you not need the history messages, you can override this method like this:
|
||||
If you not need the history messages, you can override this method
|
||||
like this:
|
||||
.. code-block:: python
|
||||
|
||||
def get_model_messages(self) -> List[ModelMessage]:
|
||||
@ -681,7 +735,8 @@ class OnceConversation:
|
||||
)
|
||||
return messages
|
||||
|
||||
If you want to add the one round history messages, you can override this method like this:
|
||||
If you want to add the one round history messages, you can override this
|
||||
method like this:
|
||||
.. code-block:: python
|
||||
|
||||
def get_model_messages(self) -> List[ModelMessage]:
|
||||
@ -717,7 +772,7 @@ class OnceConversation:
|
||||
def get_history_message(
|
||||
self, include_system_message: bool = False
|
||||
) -> List[BaseMessage]:
|
||||
"""Get the history message
|
||||
"""Get the history message.
|
||||
|
||||
Not include the system messages.
|
||||
|
||||
@ -729,46 +784,60 @@ class OnceConversation:
|
||||
"""
|
||||
messages = []
|
||||
for message in self.messages:
|
||||
if message.pass_to_model:
|
||||
if include_system_message:
|
||||
messages.append(message)
|
||||
elif message.type != "system":
|
||||
messages.append(message)
|
||||
if (
|
||||
message.pass_to_model
|
||||
and include_system_message
|
||||
or message.type != "system"
|
||||
):
|
||||
messages.append(message)
|
||||
return messages
|
||||
|
||||
|
||||
class ConversationIdentifier(ResourceIdentifier):
|
||||
"""Conversation identifier"""
|
||||
"""Conversation identifier."""
|
||||
|
||||
def __init__(self, conv_uid: str, identifier_type: str = "conversation"):
|
||||
"""Create a conversation identifier.
|
||||
|
||||
Args:
|
||||
conv_uid (str): The conversation uid
|
||||
identifier_type (str): The identifier type
|
||||
"""
|
||||
self.conv_uid = conv_uid
|
||||
self.identifier_type = identifier_type
|
||||
|
||||
@property
|
||||
def str_identifier(self) -> str:
|
||||
"""Return the str identifier."""
|
||||
return f"{self.identifier_type}:{self.conv_uid}"
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert to dict."""
|
||||
return {"conv_uid": self.conv_uid, "identifier_type": self.identifier_type}
|
||||
|
||||
|
||||
class MessageIdentifier(ResourceIdentifier):
|
||||
"""Message identifier"""
|
||||
"""Message identifier."""
|
||||
|
||||
identifier_split = "___"
|
||||
|
||||
def __init__(self, conv_uid: str, index: int, identifier_type: str = "message"):
|
||||
"""Create a message identifier."""
|
||||
self.conv_uid = conv_uid
|
||||
self.index = index
|
||||
self.identifier_type = identifier_type
|
||||
|
||||
@property
|
||||
def str_identifier(self) -> str:
|
||||
return f"{self.identifier_type}{self.identifier_split}{self.conv_uid}{self.identifier_split}{self.index}"
|
||||
"""Return the str identifier."""
|
||||
return (
|
||||
f"{self.identifier_type}{self.identifier_split}{self.conv_uid}"
|
||||
f"{self.identifier_split}{self.index}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_str_identifier(str_identifier: str) -> MessageIdentifier:
|
||||
"""Convert from str identifier
|
||||
"""Convert from str identifier.
|
||||
|
||||
Args:
|
||||
str_identifier (str): The str identifier
|
||||
@ -782,6 +851,7 @@ class MessageIdentifier(ResourceIdentifier):
|
||||
return MessageIdentifier(parts[1], int(parts[2]))
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert to dict."""
|
||||
return {
|
||||
"conv_uid": self.conv_uid,
|
||||
"index": self.index,
|
||||
@ -790,17 +860,31 @@ class MessageIdentifier(ResourceIdentifier):
|
||||
|
||||
|
||||
class MessageStorageItem(StorageItem):
|
||||
"""The message storage item.
|
||||
|
||||
Keep the message detail and the message index.
|
||||
"""
|
||||
|
||||
@property
|
||||
def identifier(self) -> MessageIdentifier:
|
||||
"""Return the identifier."""
|
||||
return self._id
|
||||
|
||||
def __init__(self, conv_uid: str, index: int, message_detail: Dict):
|
||||
"""Create a message storage item.
|
||||
|
||||
Args:
|
||||
conv_uid (str): The conversation uid
|
||||
index (int): The message index
|
||||
message_detail (Dict): The message detail
|
||||
"""
|
||||
self.conv_uid = conv_uid
|
||||
self.index = index
|
||||
self.message_detail = message_detail
|
||||
self._id = MessageIdentifier(conv_uid, index)
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert to dict."""
|
||||
return {
|
||||
"conv_uid": self.conv_uid,
|
||||
"index": self.index,
|
||||
@ -808,7 +892,8 @@ class MessageStorageItem(StorageItem):
|
||||
}
|
||||
|
||||
def to_message(self) -> BaseMessage:
|
||||
"""Convert to message object
|
||||
"""Convert to message object.
|
||||
|
||||
Returns:
|
||||
BaseMessage: The message object
|
||||
|
||||
@ -818,7 +903,7 @@ class MessageStorageItem(StorageItem):
|
||||
return _message_from_dict(self.message_detail)
|
||||
|
||||
def merge(self, other: "StorageItem") -> None:
|
||||
"""Merge the other message to self
|
||||
"""Merge the other message to self.
|
||||
|
||||
Args:
|
||||
other (StorageItem): The other message
|
||||
@ -829,16 +914,20 @@ class MessageStorageItem(StorageItem):
|
||||
|
||||
|
||||
class StorageConversation(OnceConversation, StorageItem):
|
||||
"""All the information of a conversation, the current single service in memory,
|
||||
"""The storage conversation.
|
||||
|
||||
All the information of a conversation, the current single service in memory,
|
||||
can expand cache and database support distributed services.
|
||||
|
||||
"""
|
||||
|
||||
@property
|
||||
def identifier(self) -> ConversationIdentifier:
|
||||
"""Return the identifier."""
|
||||
return self._id
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert to dict."""
|
||||
dict_data = self._to_dict()
|
||||
messages: Dict = dict_data.pop("messages")
|
||||
message_ids = []
|
||||
@ -859,7 +948,7 @@ class StorageConversation(OnceConversation, StorageItem):
|
||||
return dict_data
|
||||
|
||||
def merge(self, other: "StorageItem") -> None:
|
||||
"""Merge the other conversation to self
|
||||
"""Merge the other conversation to self.
|
||||
|
||||
Args:
|
||||
other (StorageItem): The other conversation
|
||||
@ -871,17 +960,18 @@ class StorageConversation(OnceConversation, StorageItem):
|
||||
def __init__(
|
||||
self,
|
||||
conv_uid: str,
|
||||
chat_mode: str = None,
|
||||
user_name: str = None,
|
||||
sys_code: str = None,
|
||||
message_ids: List[str] = None,
|
||||
summary: str = None,
|
||||
save_message_independent: Optional[bool] = True,
|
||||
conv_storage: StorageInterface = None,
|
||||
message_storage: StorageInterface = None,
|
||||
chat_mode: str = "chat_normal",
|
||||
user_name: Optional[str] = None,
|
||||
sys_code: Optional[str] = None,
|
||||
message_ids: Optional[List[str]] = None,
|
||||
summary: Optional[str] = None,
|
||||
save_message_independent: bool = True,
|
||||
conv_storage: Optional[StorageInterface] = None,
|
||||
message_storage: Optional[StorageInterface] = None,
|
||||
load_message: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create a conversation."""
|
||||
super().__init__(chat_mode, user_name, sys_code, summary, **kwargs)
|
||||
self.conv_uid = conv_uid
|
||||
self._message_ids = message_ids
|
||||
@ -905,7 +995,7 @@ class StorageConversation(OnceConversation, StorageItem):
|
||||
|
||||
@property
|
||||
def message_ids(self) -> List[str]:
|
||||
"""Get the message ids
|
||||
"""Return the message ids.
|
||||
|
||||
Returns:
|
||||
List[str]: The message ids
|
||||
@ -913,7 +1003,7 @@ class StorageConversation(OnceConversation, StorageItem):
|
||||
return self._message_ids if self._message_ids else []
|
||||
|
||||
def end_current_round(self) -> None:
|
||||
"""End the current round of conversation
|
||||
"""End the current round of conversation.
|
||||
|
||||
Save the conversation to the storage after a round of conversation
|
||||
"""
|
||||
@ -926,7 +1016,7 @@ class StorageConversation(OnceConversation, StorageItem):
|
||||
]
|
||||
|
||||
def save_to_storage(self) -> None:
|
||||
"""Save the conversation to the storage"""
|
||||
"""Save the conversation to the storage."""
|
||||
# Save messages first
|
||||
message_list = self._get_message_items()
|
||||
self._message_ids = [
|
||||
@ -943,7 +1033,7 @@ class StorageConversation(OnceConversation, StorageItem):
|
||||
def load_from_storage(
|
||||
self, conv_storage: StorageInterface, message_storage: StorageInterface
|
||||
) -> None:
|
||||
"""Load the conversation from the storage
|
||||
"""Load the conversation from the storage.
|
||||
|
||||
Warning: This will overwrite the current conversation.
|
||||
|
||||
@ -952,7 +1042,7 @@ class StorageConversation(OnceConversation, StorageItem):
|
||||
message_storage (StorageInterface): The storage interface
|
||||
"""
|
||||
# Load conversation first
|
||||
conversation: StorageConversation = conv_storage.load(
|
||||
conversation: Optional[StorageConversation] = conv_storage.load(
|
||||
self._id, StorageConversation
|
||||
)
|
||||
if conversation is None:
|
||||
@ -988,18 +1078,18 @@ class StorageConversation(OnceConversation, StorageItem):
|
||||
def _append_additional_kwargs(
|
||||
self, conversation: StorageConversation, messages: List[BaseMessage]
|
||||
) -> None:
|
||||
"""Parse the additional kwargs and append to the conversation
|
||||
"""Parse the additional kwargs and append to the conversation.
|
||||
|
||||
Args:
|
||||
conversation (StorageConversation): The conversation
|
||||
messages (List[BaseMessage]): The messages
|
||||
"""
|
||||
param_type = None
|
||||
param_value = None
|
||||
param_type = ""
|
||||
param_value = ""
|
||||
for message in messages[::-1]:
|
||||
if message.additional_kwargs:
|
||||
param_type = message.additional_kwargs.get("param_type")
|
||||
param_value = message.additional_kwargs.get("param_value")
|
||||
param_type = message.additional_kwargs.get("param_type", "")
|
||||
param_value = message.additional_kwargs.get("param_value", "")
|
||||
break
|
||||
if not conversation.param_type:
|
||||
conversation.param_type = param_type
|
||||
@ -1007,7 +1097,7 @@ class StorageConversation(OnceConversation, StorageItem):
|
||||
conversation.param_value = param_value
|
||||
|
||||
def delete(self) -> None:
|
||||
"""Delete all the messages and conversation from the storage"""
|
||||
"""Delete all the messages and conversation."""
|
||||
# Delete messages first
|
||||
message_list = self._get_message_items()
|
||||
message_ids = [message.identifier for message in message_list]
|
||||
@ -1055,13 +1145,13 @@ def _conversations_to_dict(conversations: List[OnceConversation]) -> List[dict]:
|
||||
|
||||
def _conversation_from_dict(once: dict) -> OnceConversation:
|
||||
conversation = OnceConversation(
|
||||
once.get("chat_mode"), once.get("user_name"), once.get("sys_code")
|
||||
once.get("chat_mode", ""), once.get("user_name"), once.get("sys_code")
|
||||
)
|
||||
conversation.cost = once.get("cost", 0)
|
||||
conversation.chat_mode = once.get("chat_mode", "chat_normal")
|
||||
conversation.tokens = once.get("tokens", 0)
|
||||
conversation.start_date = once.get("start_date", "")
|
||||
conversation.chat_order = int(once.get("chat_order"))
|
||||
conversation.chat_order = int(once.get("chat_order", 0))
|
||||
conversation.param_type = once.get("param_type", "")
|
||||
conversation.param_value = once.get("param_value", "")
|
||||
conversation.model_name = once.get("model_name", "proxyllm")
|
||||
@ -1093,7 +1183,8 @@ def _split_messages_by_round(messages: List[BaseMessage]) -> List[List[BaseMessa
|
||||
|
||||
|
||||
def _append_view_messages(messages: List[BaseMessage]) -> List[BaseMessage]:
|
||||
"""Append the view message to the messages
|
||||
"""Append the view message to the messages.
|
||||
|
||||
Just for show in DB-GPT-Web.
|
||||
If already have view message, do nothing.
|
||||
|
||||
|
@ -0,0 +1 @@
|
||||
"""The module include all core operators of DB-GPT."""
|
@ -1,5 +1,9 @@
|
||||
"""The chat history prompt composer operator.
|
||||
|
||||
We can wrap some atomic operators to a complex operator.
|
||||
"""
|
||||
import dataclasses
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
from dbgpt.core import (
|
||||
ChatPromptTemplate,
|
||||
@ -51,6 +55,7 @@ class ChatHistoryPromptComposerOperator(MapOperator[ChatComposerInput, ModelRequ
|
||||
message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create a new chat history prompt composer operator."""
|
||||
super().__init__(**kwargs)
|
||||
self._prompt_template = prompt_template
|
||||
self._history_key = history_key
|
||||
@ -61,7 +66,8 @@ class ChatHistoryPromptComposerOperator(MapOperator[ChatComposerInput, ModelRequ
|
||||
self._sub_compose_dag = self._build_composer_dag()
|
||||
|
||||
async def map(self, input_value: ChatComposerInput) -> ModelRequest:
|
||||
end_node: BaseOperator = self._sub_compose_dag.leaf_nodes[0]
|
||||
"""Compose the chat history prompt."""
|
||||
end_node: BaseOperator = cast(BaseOperator, self._sub_compose_dag.leaf_nodes[0])
|
||||
# Sub dag, use the same dag context in the parent dag
|
||||
return await end_node.call(
|
||||
call_data={"data": input_value}, dag_ctx=self.current_dag_context
|
||||
@ -82,7 +88,9 @@ class ChatHistoryPromptComposerOperator(MapOperator[ChatComposerInput, ModelRequ
|
||||
history_prompt_build_task = HistoryPromptBuilderOperator(
|
||||
prompt=self._prompt_template, history_key=self._history_key
|
||||
)
|
||||
model_request_build_task = JoinOperator(self._build_model_request)
|
||||
model_request_build_task: JoinOperator[ModelRequest] = JoinOperator(
|
||||
combine_function=self._build_model_request
|
||||
)
|
||||
|
||||
# Build composer dag
|
||||
(
|
||||
@ -113,5 +121,6 @@ class ChatHistoryPromptComposerOperator(MapOperator[ChatComposerInput, ModelRequ
|
||||
return ModelRequest.build_request(messages=messages, **model_dict)
|
||||
|
||||
async def after_dag_end(self):
|
||||
"""Execute after dag end."""
|
||||
# Should call after_dag_end() of sub dag
|
||||
await self._sub_compose_dag._after_dag_end()
|
||||
|
@ -1,9 +1,12 @@
|
||||
"""The LLM operator."""
|
||||
|
||||
import dataclasses
|
||||
from abc import ABC
|
||||
from typing import Any, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
from dbgpt.core.awel import (
|
||||
BaseOperator,
|
||||
BranchFunc,
|
||||
BranchOperator,
|
||||
DAGContext,
|
||||
@ -32,11 +35,13 @@ class RequestBuilderOperator(MapOperator[RequestInput, ModelRequest], ABC):
|
||||
"""Build the model request from the input value."""
|
||||
|
||||
def __init__(self, model: Optional[str] = None, **kwargs):
|
||||
"""Create a new request builder operator."""
|
||||
self._model = model
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def map(self, input_value: RequestInput) -> ModelRequest:
|
||||
req_dict = {}
|
||||
"""Transform the input value to a model request."""
|
||||
req_dict: Dict[str, Any] = {}
|
||||
if not input_value:
|
||||
raise ValueError("input_value is not set")
|
||||
if isinstance(input_value, str):
|
||||
@ -47,7 +52,9 @@ class RequestBuilderOperator(MapOperator[RequestInput, ModelRequest], ABC):
|
||||
req_dict = {"messages": [input_value]}
|
||||
elif isinstance(input_value, list) and isinstance(input_value[0], ModelMessage):
|
||||
req_dict = {"messages": input_value}
|
||||
elif dataclasses.is_dataclass(input_value):
|
||||
elif dataclasses.is_dataclass(input_value) and not isinstance(
|
||||
input_value, type
|
||||
):
|
||||
req_dict = dataclasses.asdict(input_value)
|
||||
elif isinstance(input_value, BaseModel):
|
||||
req_dict = input_value.dict()
|
||||
@ -90,6 +97,7 @@ class BaseLLM:
|
||||
SHARE_DATA_KEY_MODEL_OUTPUT = "share_data_key_model_output"
|
||||
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None):
|
||||
"""Create a new LLM operator."""
|
||||
self._llm_client = llm_client
|
||||
|
||||
@property
|
||||
@ -118,10 +126,19 @@ class BaseLLMOperator(BaseLLM, MapOperator[ModelRequest, ModelOutput], ABC):
|
||||
"""
|
||||
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
|
||||
"""Create a new LLM operator."""
|
||||
super().__init__(llm_client=llm_client)
|
||||
MapOperator.__init__(self, **kwargs)
|
||||
|
||||
async def map(self, request: ModelRequest) -> ModelOutput:
|
||||
"""Generate the model output.
|
||||
|
||||
Args:
|
||||
request (ModelRequest): The model request.
|
||||
|
||||
Returns:
|
||||
ModelOutput: The model output.
|
||||
"""
|
||||
await self.current_dag_context.save_to_share_data(
|
||||
self.SHARE_DATA_KEY_MODEL_NAME, request.model
|
||||
)
|
||||
@ -142,15 +159,23 @@ class BaseStreamingLLMOperator(
|
||||
"""
|
||||
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
|
||||
super().__init__(llm_client=llm_client)
|
||||
StreamifyAbsOperator.__init__(self, **kwargs)
|
||||
"""Create a streaming operator for a LLM.
|
||||
|
||||
async def streamify(self, request: ModelRequest) -> AsyncIterator[ModelOutput]:
|
||||
Args:
|
||||
llm_client (LLMClient, optional): The LLM client. Defaults to None.
|
||||
"""
|
||||
super().__init__(llm_client=llm_client)
|
||||
BaseOperator.__init__(self, **kwargs)
|
||||
|
||||
async def streamify( # type: ignore
|
||||
self, request: ModelRequest # type: ignore
|
||||
) -> AsyncIterator[ModelOutput]: # type: ignore
|
||||
"""Streamify the request."""
|
||||
await self.current_dag_context.save_to_share_data(
|
||||
self.SHARE_DATA_KEY_MODEL_NAME, request.model
|
||||
)
|
||||
model_output = None
|
||||
async for output in self.llm_client.generate_stream(request):
|
||||
async for output in self.llm_client.generate_stream(request): # type: ignore
|
||||
model_output = output
|
||||
yield output
|
||||
if model_output:
|
||||
@ -160,10 +185,17 @@ class BaseStreamingLLMOperator(
|
||||
class LLMBranchOperator(BranchOperator[ModelRequest, ModelRequest]):
|
||||
"""Branch operator for LLM.
|
||||
|
||||
This operator will branch the workflow based on the stream flag of the request.
|
||||
This operator will branch the workflow based on
|
||||
the stream flag of the request.
|
||||
"""
|
||||
|
||||
def __init__(self, stream_task_name: str, no_stream_task_name: str, **kwargs):
|
||||
"""Create a new LLM branch operator.
|
||||
|
||||
Args:
|
||||
stream_task_name (str): The name of the streaming task.
|
||||
no_stream_task_name (str): The name of the non-streaming task.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
if not stream_task_name:
|
||||
raise ValueError("stream_task_name is not set")
|
||||
@ -172,18 +204,22 @@ class LLMBranchOperator(BranchOperator[ModelRequest, ModelRequest]):
|
||||
self._stream_task_name = stream_task_name
|
||||
self._no_stream_task_name = no_stream_task_name
|
||||
|
||||
async def branches(self) -> Dict[BranchFunc[ModelRequest], str]:
|
||||
async def branches(
|
||||
self,
|
||||
) -> Dict[BranchFunc[ModelRequest], Union[BaseOperator, str]]:
|
||||
"""
|
||||
Return a dict of branch function and task name.
|
||||
|
||||
Returns:
|
||||
Dict[BranchFunc[ModelRequest], str]: A dict of branch function and task name.
|
||||
the key is a predicate function, the value is the task name. If the predicate function returns True,
|
||||
we will run the corresponding task.
|
||||
Dict[BranchFunc[ModelRequest], str]: A dict of branch function and task
|
||||
name. the key is a predicate function, the value is the task name.
|
||||
If the predicate function returns True, we will run the corresponding
|
||||
task.
|
||||
"""
|
||||
|
||||
async def check_stream_true(r: ModelRequest) -> bool:
|
||||
# If stream is true, we will run the streaming task. otherwise, we will run the non-streaming task.
|
||||
# If stream is true, we will run the streaming task. otherwise, we will run
|
||||
# the non-streaming task.
|
||||
return r.stream
|
||||
|
||||
return {
|
||||
|
@ -1,3 +1,4 @@
|
||||
"""The message operator."""
|
||||
import uuid
|
||||
from abc import ABC
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
@ -36,6 +37,7 @@ class BaseConversationOperator(BaseOperator, ABC):
|
||||
check_storage: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create a new BaseConversationOperator."""
|
||||
self._check_storage = check_storage
|
||||
self._storage = storage
|
||||
self._message_storage = message_storage
|
||||
@ -102,7 +104,8 @@ class PreChatHistoryLoadOperator(
|
||||
):
|
||||
"""The operator to prepare the storage conversation.
|
||||
|
||||
In DB-GPT, conversation record and the messages in the conversation are stored in the storage,
|
||||
In DB-GPT, conversation record and the messages in the conversation are stored in
|
||||
the storage,
|
||||
and they can store in different storage(for high performance).
|
||||
|
||||
This operator just load the conversation and messages from storage.
|
||||
@ -115,6 +118,7 @@ class PreChatHistoryLoadOperator(
|
||||
include_system_message: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create a new PreChatHistoryLoadOperator."""
|
||||
super().__init__(storage=storage, message_storage=message_storage)
|
||||
MapOperator.__init__(self, **kwargs)
|
||||
self._include_system_message = include_system_message
|
||||
@ -139,7 +143,8 @@ class PreChatHistoryLoadOperator(
|
||||
|
||||
chat_mode = input_value.chat_mode
|
||||
|
||||
# Create a new storage conversation, this will load the conversation from storage, so we must do this async
|
||||
# Create a new storage conversation, this will load the conversation from
|
||||
# storage, so we must do this async
|
||||
storage_conv: StorageConversation = await self.blocking_func_to_async(
|
||||
StorageConversation,
|
||||
conv_uid=input_value.conv_uid,
|
||||
@ -167,14 +172,21 @@ class PreChatHistoryLoadOperator(
|
||||
class ConversationMapperOperator(
|
||||
BaseConversationOperator, MapOperator[List[BaseMessage], List[BaseMessage]]
|
||||
):
|
||||
def __init__(self, message_mapper: _MultiRoundMessageMapper = None, **kwargs):
|
||||
"""The base conversation mapper operator."""
|
||||
|
||||
def __init__(
|
||||
self, message_mapper: Optional[_MultiRoundMessageMapper] = None, **kwargs
|
||||
):
|
||||
"""Create a new ConversationMapperOperator."""
|
||||
MapOperator.__init__(self, **kwargs)
|
||||
self._message_mapper = message_mapper
|
||||
|
||||
async def map(self, input_value: List[BaseMessage]) -> List[BaseMessage]:
|
||||
"""Map the input value to a ModelRequest."""
|
||||
return await self.map_messages(input_value)
|
||||
|
||||
async def map_messages(self, messages: List[BaseMessage]) -> List[BaseMessage]:
|
||||
"""Map multi round messages to a list of BaseMessage."""
|
||||
messages_by_round: List[List[BaseMessage]] = _split_messages_by_round(messages)
|
||||
message_mapper = self._message_mapper or self.map_multi_round_messages
|
||||
return message_mapper(messages_by_round)
|
||||
@ -184,11 +196,11 @@ class ConversationMapperOperator(
|
||||
) -> List[BaseMessage]:
|
||||
"""Map multi round messages to a list of BaseMessage.
|
||||
|
||||
By default, just merge all multi round messages to a list of BaseMessage according origin order.
|
||||
By default, just merge all multi round messages to a list of BaseMessage
|
||||
according origin order.
|
||||
And you can overwrite this method to implement your own logic.
|
||||
|
||||
Examples:
|
||||
|
||||
Merge multi round messages to a list of BaseMessage according origin order.
|
||||
|
||||
>>> from dbgpt.core.interface.message import (
|
||||
@ -215,7 +227,8 @@ class ConversationMapperOperator(
|
||||
... AIMessage(content="Just a joke.", round_index=2),
|
||||
... ]
|
||||
|
||||
Map multi round messages to a list of BaseMessage just keep the last one round.
|
||||
Map multi round messages to a list of BaseMessage just keep the last one
|
||||
round.
|
||||
|
||||
>>> class MyMapper(ConversationMapperOperator):
|
||||
... def __init__(self, **kwargs):
|
||||
@ -234,13 +247,16 @@ class ConversationMapperOperator(
|
||||
... ]
|
||||
|
||||
Args:
|
||||
messages_by_round (List[List[BaseMessage]]):
|
||||
The messages grouped by round.
|
||||
"""
|
||||
# Just merge and return
|
||||
return _merge_multi_round_messages(messages_by_round)
|
||||
|
||||
|
||||
class BufferedConversationMapperOperator(ConversationMapperOperator):
|
||||
"""
|
||||
"""Buffered conversation mapper operator.
|
||||
|
||||
The buffered conversation mapper operator which can be configured to keep
|
||||
a certain number of starting and/or ending rounds of a conversation.
|
||||
|
||||
@ -249,39 +265,44 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
|
||||
keep_end_rounds (Optional[int]): Number of final rounds to keep.
|
||||
|
||||
Examples:
|
||||
# Keeping the first 2 and the last 1 rounds of a conversation
|
||||
import asyncio
|
||||
from dbgpt.core.interface.message import AIMessage, HumanMessage
|
||||
from dbgpt.core.operator import BufferedConversationMapperOperator
|
||||
.. code-block:: python
|
||||
|
||||
operator = BufferedConversationMapperOperator(keep_start_rounds=2, keep_end_rounds=1)
|
||||
messages = [
|
||||
# Assume each HumanMessage and AIMessage belongs to separate rounds
|
||||
HumanMessage(content="Hi", round_index=1),
|
||||
AIMessage(content="Hello!", round_index=1),
|
||||
HumanMessage(content="How are you?", round_index=2),
|
||||
AIMessage(content="I'm good, thanks!", round_index=2),
|
||||
HumanMessage(content="What's new today?", round_index=3),
|
||||
AIMessage(content="Lots of things!", round_index=3),
|
||||
]
|
||||
# This will keep rounds 1, 2, and 3
|
||||
assert asyncio.run(operator.map_messages(messages)) == [
|
||||
HumanMessage(content="Hi", round_index=1),
|
||||
AIMessage(content="Hello!", round_index=1),
|
||||
HumanMessage(content="How are you?", round_index=2),
|
||||
AIMessage(content="I'm good, thanks!", round_index=2),
|
||||
HumanMessage(content="What's new today?", round_index=3),
|
||||
AIMessage(content="Lots of things!", round_index=3),
|
||||
]
|
||||
# Keeping the first 2 and the last 1 rounds of a conversation
|
||||
import asyncio
|
||||
from dbgpt.core.interface.message import AIMessage, HumanMessage
|
||||
from dbgpt.core.operator import BufferedConversationMapperOperator
|
||||
|
||||
operator = BufferedConversationMapperOperator(
|
||||
keep_start_rounds=2, keep_end_rounds=1
|
||||
)
|
||||
messages = [
|
||||
# Assume each HumanMessage and AIMessage belongs to separate rounds
|
||||
HumanMessage(content="Hi", round_index=1),
|
||||
AIMessage(content="Hello!", round_index=1),
|
||||
HumanMessage(content="How are you?", round_index=2),
|
||||
AIMessage(content="I'm good, thanks!", round_index=2),
|
||||
HumanMessage(content="What's new today?", round_index=3),
|
||||
AIMessage(content="Lots of things!", round_index=3),
|
||||
]
|
||||
# This will keep rounds 1, 2, and 3
|
||||
assert asyncio.run(operator.map_messages(messages)) == [
|
||||
HumanMessage(content="Hi", round_index=1),
|
||||
AIMessage(content="Hello!", round_index=1),
|
||||
HumanMessage(content="How are you?", round_index=2),
|
||||
AIMessage(content="I'm good, thanks!", round_index=2),
|
||||
HumanMessage(content="What's new today?", round_index=3),
|
||||
AIMessage(content="Lots of things!", round_index=3),
|
||||
]
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
keep_start_rounds: Optional[int] = None,
|
||||
keep_end_rounds: Optional[int] = None,
|
||||
message_mapper: _MultiRoundMessageMapper = None,
|
||||
message_mapper: Optional[_MultiRoundMessageMapper] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create a new BufferedConversationMapperOperator."""
|
||||
# Validate the input parameters
|
||||
if keep_start_rounds is not None and keep_start_rounds < 0:
|
||||
raise ValueError("keep_start_rounds must be non-negative")
|
||||
@ -311,10 +332,11 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
|
||||
def _filter_round_messages(
|
||||
self, messages_by_round: List[List[BaseMessage]]
|
||||
) -> List[List[BaseMessage]]:
|
||||
"""Filters the messages to keep only the specified starting and/or ending rounds.
|
||||
"""Return a filtered list of messages.
|
||||
|
||||
Filters the messages to keep only the specified starting and/or ending rounds.
|
||||
|
||||
Examples:
|
||||
|
||||
>>> from dbgpt.core import AIMessage, HumanMessage
|
||||
>>> from dbgpt.core.operator import BufferedConversationMapperOperator
|
||||
>>> messages = [
|
||||
@ -395,15 +417,18 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
|
||||
... ]
|
||||
|
||||
Args:
|
||||
messages_by_round (List[List[BaseMessage]]): The messages grouped by round.
|
||||
messages_by_round (List[List[BaseMessage]]):
|
||||
The messages grouped by round.
|
||||
|
||||
Returns:
|
||||
List[List[BaseMessage]]: Filtered list of messages.
|
||||
|
||||
Returns:
|
||||
List[List[BaseMessage]]: Filtered list of messages.
|
||||
"""
|
||||
total_rounds = len(messages_by_round)
|
||||
if self._keep_start_rounds is not None and self._keep_end_rounds is not None:
|
||||
if self._keep_start_rounds + self._keep_end_rounds > total_rounds:
|
||||
# Avoid overlapping when the sum of start and end rounds exceeds total rounds
|
||||
# Avoid overlapping when the sum of start and end rounds exceeds total
|
||||
# rounds
|
||||
return messages_by_round
|
||||
return (
|
||||
messages_by_round[: self._keep_start_rounds]
|
||||
@ -423,14 +448,16 @@ EvictionPolicyType = Callable[[List[List[BaseMessage]]], List[List[BaseMessage]]
|
||||
class TokenBufferedConversationMapperOperator(ConversationMapperOperator):
|
||||
"""The token buffered conversation mapper operator.
|
||||
|
||||
If the token count of the messages is greater than the max token limit, we will evict the messages by round.
|
||||
If the token count of the messages is greater than the max token limit, we will
|
||||
evict the messages by round.
|
||||
|
||||
Args:
|
||||
model (str): The model name.
|
||||
llm_client (LLMClient): The LLM client.
|
||||
max_token_limit (int): The max token limit.
|
||||
eviction_policy (EvictionPolicyType): The eviction policy.
|
||||
message_mapper (_MultiRoundMessageMapper): The message mapper, it applies after all messages are handled.
|
||||
message_mapper (_MultiRoundMessageMapper): The message mapper, it applies after
|
||||
all messages are handled.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -438,10 +465,11 @@ class TokenBufferedConversationMapperOperator(ConversationMapperOperator):
|
||||
model: str,
|
||||
llm_client: LLMClient,
|
||||
max_token_limit: int = 2000,
|
||||
eviction_policy: EvictionPolicyType = None,
|
||||
message_mapper: _MultiRoundMessageMapper = None,
|
||||
eviction_policy: Optional[EvictionPolicyType] = None,
|
||||
message_mapper: Optional[_MultiRoundMessageMapper] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create a new TokenBufferedConversationMapperOperator."""
|
||||
if max_token_limit < 0:
|
||||
raise ValueError("Max token limit can't be negative")
|
||||
self._model = model
|
||||
@ -452,6 +480,7 @@ class TokenBufferedConversationMapperOperator(ConversationMapperOperator):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def map_messages(self, messages: List[BaseMessage]) -> List[BaseMessage]:
|
||||
"""Map multi round messages to a list of BaseMessage."""
|
||||
eviction_policy = self._eviction_policy or self.eviction_policy
|
||||
messages_by_round: List[List[BaseMessage]] = _split_messages_by_round(messages)
|
||||
messages_str = _messages_to_str(_merge_multi_round_messages(messages_by_round))
|
||||
@ -459,7 +488,8 @@ class TokenBufferedConversationMapperOperator(ConversationMapperOperator):
|
||||
current_tokens = await self._llm_client.count_token(self._model, messages_str)
|
||||
|
||||
while current_tokens > self._max_token_limit:
|
||||
# Evict the messages by round after all tokens are not greater than the max token limit
|
||||
# Evict the messages by round after all tokens are not greater than the max
|
||||
# token limit
|
||||
# TODO: We should find a high performance way to do this
|
||||
messages_by_round = eviction_policy(messages_by_round)
|
||||
messages_str = _messages_to_str(
|
||||
|
@ -1,8 +1,8 @@
|
||||
"""The prompt operator."""
|
||||
from abc import ABC
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from dbgpt.core import (
|
||||
BasePromptTemplate,
|
||||
ChatPromptTemplate,
|
||||
ModelMessage,
|
||||
ModelMessageRoleType,
|
||||
@ -13,7 +13,13 @@ from dbgpt.core.awel import JoinOperator, MapOperator
|
||||
from dbgpt.core.interface.message import BaseMessage
|
||||
from dbgpt.core.interface.operator.llm_operator import BaseLLM
|
||||
from dbgpt.core.interface.operator.message_operator import BaseConversationOperator
|
||||
from dbgpt.core.interface.prompt import HumanPromptTemplate, MessageType
|
||||
from dbgpt.core.interface.prompt import (
|
||||
BaseChatPromptTemplate,
|
||||
HumanPromptTemplate,
|
||||
MessagesPlaceholder,
|
||||
MessageType,
|
||||
PromptTemplate,
|
||||
)
|
||||
from dbgpt.util.function_utils import rearrange_args_by_type
|
||||
|
||||
|
||||
@ -21,6 +27,7 @@ class BasePromptBuilderOperator(BaseConversationOperator, ABC):
|
||||
"""The base prompt builder operator."""
|
||||
|
||||
def __init__(self, check_storage: bool, **kwargs):
|
||||
"""Create a new prompt builder operator."""
|
||||
super().__init__(check_storage=check_storage, **kwargs)
|
||||
|
||||
async def format_prompt(
|
||||
@ -39,10 +46,10 @@ class BasePromptBuilderOperator(BaseConversationOperator, ABC):
|
||||
kwargs.update(prompt_dict)
|
||||
pass_kwargs = {k: v for k, v in kwargs.items() if k in prompt.input_variables}
|
||||
messages = prompt.format_messages(**pass_kwargs)
|
||||
messages = ModelMessage.from_base_messages(messages)
|
||||
model_messages = ModelMessage.from_base_messages(messages)
|
||||
# Start new round conversation, and save user message to storage
|
||||
await self.start_new_round_conv(messages)
|
||||
return messages
|
||||
await self.start_new_round_conv(model_messages)
|
||||
return model_messages
|
||||
|
||||
async def start_new_round_conv(self, messages: List[ModelMessage]) -> None:
|
||||
"""Start a new round conversation.
|
||||
@ -50,7 +57,6 @@ class BasePromptBuilderOperator(BaseConversationOperator, ABC):
|
||||
Args:
|
||||
messages (List[ModelMessage]): The messages.
|
||||
"""
|
||||
|
||||
lass_user_message = None
|
||||
for message in messages[::-1]:
|
||||
if message.role == ModelMessageRoleType.HUMAN:
|
||||
@ -58,7 +64,9 @@ class BasePromptBuilderOperator(BaseConversationOperator, ABC):
|
||||
break
|
||||
if not lass_user_message:
|
||||
raise ValueError("No user message")
|
||||
storage_conv: StorageConversation = await self.get_storage_conversation()
|
||||
storage_conv: Optional[
|
||||
StorageConversation
|
||||
] = await self.get_storage_conversation()
|
||||
if not storage_conv:
|
||||
return
|
||||
# Start new round
|
||||
@ -66,13 +74,17 @@ class BasePromptBuilderOperator(BaseConversationOperator, ABC):
|
||||
storage_conv.add_user_message(lass_user_message)
|
||||
|
||||
async def after_dag_end(self):
|
||||
"""The callback after DAG end"""
|
||||
# TODO remove this to start_new_round()
|
||||
"""Execute after the DAG finished."""
|
||||
# Save the storage conversation to storage after the whole DAG finished
|
||||
storage_conv: StorageConversation = await self.get_storage_conversation()
|
||||
storage_conv: Optional[
|
||||
StorageConversation
|
||||
] = await self.get_storage_conversation()
|
||||
|
||||
if not storage_conv:
|
||||
return
|
||||
model_output: ModelOutput = await self.current_dag_context.get_from_share_data(
|
||||
model_output: Optional[
|
||||
ModelOutput
|
||||
] = await self.current_dag_context.get_from_share_data(
|
||||
BaseLLM.SHARE_DATA_KEY_MODEL_OUTPUT
|
||||
)
|
||||
if model_output:
|
||||
@ -82,7 +94,7 @@ class BasePromptBuilderOperator(BaseConversationOperator, ABC):
|
||||
storage_conv.end_current_round()
|
||||
|
||||
|
||||
PromptTemplateType = Union[ChatPromptTemplate, BasePromptTemplate, MessageType, str]
|
||||
PromptTemplateType = Union[ChatPromptTemplate, PromptTemplate, MessageType, str]
|
||||
|
||||
|
||||
class PromptBuilderOperator(
|
||||
@ -91,7 +103,6 @@ class PromptBuilderOperator(
|
||||
"""The operator to build the prompt with static prompt.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
@ -119,7 +130,8 @@ class PromptBuilderOperator(
|
||||
ChatPromptTemplate(
|
||||
messages=[
|
||||
HumanPromptTemplate.from_template(
|
||||
"Please write a {dialect} SQL count the length of a field"
|
||||
"Please write a {dialect} SQL count the length of a"
|
||||
" field"
|
||||
)
|
||||
]
|
||||
)
|
||||
@ -131,7 +143,8 @@ class PromptBuilderOperator(
|
||||
"You are a {dialect} SQL expert"
|
||||
),
|
||||
HumanPromptTemplate.from_template(
|
||||
"Please write a {dialect} SQL count the length of a field"
|
||||
"Please write a {dialect} SQL count the length of a"
|
||||
" field"
|
||||
),
|
||||
],
|
||||
)
|
||||
@ -171,17 +184,18 @@ class PromptBuilderOperator(
|
||||
"""
|
||||
|
||||
def __init__(self, prompt: PromptTemplateType, **kwargs):
|
||||
"""Create a new prompt builder operator."""
|
||||
if isinstance(prompt, str):
|
||||
prompt = ChatPromptTemplate(
|
||||
messages=[HumanPromptTemplate.from_template(prompt)]
|
||||
)
|
||||
elif isinstance(prompt, BasePromptTemplate) and not isinstance(
|
||||
prompt, ChatPromptTemplate
|
||||
):
|
||||
elif isinstance(prompt, PromptTemplate):
|
||||
prompt = ChatPromptTemplate(
|
||||
messages=[HumanPromptTemplate.from_template(prompt.template)]
|
||||
)
|
||||
elif isinstance(prompt, MessageType):
|
||||
elif isinstance(
|
||||
prompt, (BaseChatPromptTemplate, MessagesPlaceholder, BaseMessage)
|
||||
):
|
||||
prompt = ChatPromptTemplate(messages=[prompt])
|
||||
self._prompt = prompt
|
||||
|
||||
@ -190,6 +204,7 @@ class PromptBuilderOperator(
|
||||
|
||||
@rearrange_args_by_type
|
||||
async def merge_prompt(self, prompt_dict: Dict[str, Any]) -> List[ModelMessage]:
|
||||
"""Format the prompt."""
|
||||
return await self.format_prompt(self._prompt, prompt_dict)
|
||||
|
||||
|
||||
@ -202,6 +217,7 @@ class DynamicPromptBuilderOperator(
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Create a new dynamic prompt builder operator."""
|
||||
super().__init__(check_storage=False, **kwargs)
|
||||
JoinOperator.__init__(self, combine_function=self.merge_prompt, **kwargs)
|
||||
|
||||
@ -209,20 +225,37 @@ class DynamicPromptBuilderOperator(
|
||||
async def merge_prompt(
|
||||
self, prompt: ChatPromptTemplate, prompt_dict: Dict[str, Any]
|
||||
) -> List[ModelMessage]:
|
||||
"""Merge the prompt and history."""
|
||||
return await self.format_prompt(prompt, prompt_dict)
|
||||
|
||||
|
||||
class HistoryPromptBuilderOperator(
|
||||
BasePromptBuilderOperator, JoinOperator[List[ModelMessage]]
|
||||
):
|
||||
"""The operator to build the prompt with static prompt.
|
||||
|
||||
The prompt will pass to this operator.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompt: ChatPromptTemplate,
|
||||
history_key: Optional[str] = None,
|
||||
history_key: str = "chat_history",
|
||||
check_storage: bool = True,
|
||||
str_history: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create a new history prompt builder operator.
|
||||
|
||||
Args:
|
||||
prompt (ChatPromptTemplate): The prompt.
|
||||
history_key (str, optional): The key of history in prompt dict. Defaults
|
||||
to "chat_history".
|
||||
check_storage (bool, optional): Whether to check the storage.
|
||||
Defaults to True.
|
||||
str_history (bool, optional): Whether to convert the history to string.
|
||||
Defaults to False.
|
||||
"""
|
||||
self._prompt = prompt
|
||||
self._history_key = history_key
|
||||
self._str_history = str_history
|
||||
@ -233,6 +266,7 @@ class HistoryPromptBuilderOperator(
|
||||
async def merge_history(
|
||||
self, history: List[BaseMessage], prompt_dict: Dict[str, Any]
|
||||
) -> List[ModelMessage]:
|
||||
"""Merge the prompt and history."""
|
||||
if self._str_history:
|
||||
prompt_dict[self._history_key] = BaseMessage.messages_to_string(history)
|
||||
else:
|
||||
@ -250,11 +284,12 @@ class HistoryDynamicPromptBuilderOperator(
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
history_key: Optional[str] = None,
|
||||
history_key: str = "chat_history",
|
||||
check_storage: bool = True,
|
||||
str_history: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create a new history dynamic prompt builder operator."""
|
||||
self._history_key = history_key
|
||||
self._str_history = str_history
|
||||
BasePromptBuilderOperator.__init__(self, check_storage=check_storage)
|
||||
@ -267,6 +302,7 @@ class HistoryDynamicPromptBuilderOperator(
|
||||
history: List[BaseMessage],
|
||||
prompt_dict: Dict[str, Any],
|
||||
) -> List[ModelMessage]:
|
||||
"""Merge the prompt and history."""
|
||||
if self._str_history:
|
||||
prompt_dict[self._history_key] = BaseMessage.messages_to_string(history)
|
||||
else:
|
||||
|
@ -1,3 +1,4 @@
|
||||
"""The Abstract Retriever Operator."""
|
||||
from abc import abstractmethod
|
||||
|
||||
from dbgpt.core.awel import MapOperator
|
||||
@ -16,7 +17,8 @@ class RetrieverOperator(MapOperator[IN, OUT]):
|
||||
Returns:
|
||||
OUT: The output value.
|
||||
"""
|
||||
# The retrieve function is blocking, so we need to wrap it in a blocking_func_to_async.
|
||||
# The retrieve function is blocking, so we need to wrap it in a
|
||||
# blocking_func_to_async.
|
||||
return await self.blocking_func_to_async(self.retrieve, input_value)
|
||||
|
||||
@abstractmethod
|
@ -1,10 +1,15 @@
|
||||
"""The output parser is used to parse the output of an LLM call.
|
||||
|
||||
TODO: Make this more general and clear.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC
|
||||
from dataclasses import asdict
|
||||
from typing import Any, Dict, TypeVar, Union
|
||||
from typing import Any, TypeVar, Union
|
||||
|
||||
from dbgpt.core import ModelOutput
|
||||
from dbgpt.core.awel import MapOperator
|
||||
@ -22,11 +27,16 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
|
||||
"""
|
||||
|
||||
def __init__(self, is_stream_out: bool = True, **kwargs):
|
||||
"""Create a new output parser."""
|
||||
super().__init__(**kwargs)
|
||||
self.is_stream_out = is_stream_out
|
||||
self.data_schema = None
|
||||
|
||||
def update(self, data_schema):
|
||||
"""Update the data schema.
|
||||
|
||||
TODO: Remove this method.
|
||||
"""
|
||||
self.data_schema = data_schema
|
||||
|
||||
def __post_process_code(self, code):
|
||||
@ -40,9 +50,16 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
|
||||
return code
|
||||
|
||||
def parse_model_stream_resp_ex(self, chunk: ResponseTye, skip_echo_len):
|
||||
data = _parse_model_response(chunk)
|
||||
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
|
||||
"""Parse the output of an LLM call.
|
||||
|
||||
Args:
|
||||
chunk (ResponseTye): The output of an LLM call.
|
||||
skip_echo_len (int): The length of the prompt to skip.
|
||||
"""
|
||||
data = _parse_model_response(chunk)
|
||||
# TODO: Multi mode output handler, rewrite this for multi model, use adapter
|
||||
# mode.
|
||||
|
||||
model_context = data.get("model_context")
|
||||
has_echo = False
|
||||
if model_context and "prompt_echo_len_char" in model_context:
|
||||
@ -65,6 +82,7 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
|
||||
return output
|
||||
|
||||
def parse_model_nostream_resp(self, response: ResponseTye, sep: str):
|
||||
"""Parse the output of an LLM call."""
|
||||
resp_obj_ex = _parse_model_response(response)
|
||||
if isinstance(resp_obj_ex, str):
|
||||
resp_obj_ex = json.loads(resp_obj_ex)
|
||||
@ -89,7 +107,8 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
|
||||
return ai_response
|
||||
else:
|
||||
raise ValueError(
|
||||
f"""Model server error!code={resp_obj_ex["error_code"]}, errmsg is {resp_obj_ex["text"]}"""
|
||||
f"Model server error!code={resp_obj_ex['error_code']}, error msg is "
|
||||
f"{resp_obj_ex['text']}"
|
||||
)
|
||||
|
||||
def _illegal_json_ends(self, s):
|
||||
@ -117,7 +136,7 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
|
||||
|
||||
temp_json = self._illegal_json_ends(temp_json)
|
||||
return temp_json
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise ValueError("Failed to find a valid json in LLM response!" + temp_json)
|
||||
|
||||
def _json_interception(self, s, is_json_array: bool = False):
|
||||
@ -150,17 +169,17 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
|
||||
break
|
||||
assert count == 0
|
||||
return s[i : j + 1]
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
def parse_prompt_response(self, model_out_text) -> T:
|
||||
"""
|
||||
parse model out text to prompt define response
|
||||
def parse_prompt_response(self, model_out_text) -> Any:
|
||||
"""Parse model out text to prompt define response.
|
||||
|
||||
Args:
|
||||
model_out_text:
|
||||
model_out_text: The output of an LLM call.
|
||||
|
||||
Returns:
|
||||
|
||||
Any: The parsed output of an LLM call.
|
||||
"""
|
||||
cleaned_output = model_out_text.rstrip()
|
||||
if "```json" in cleaned_output:
|
||||
@ -194,12 +213,15 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
|
||||
def parse_view_response(
|
||||
self, ai_text, data, parse_prompt_response: Any = None
|
||||
) -> str:
|
||||
"""
|
||||
parse the ai response info to user view
|
||||
"""Parse the AI response info to user view.
|
||||
|
||||
Args:
|
||||
text:
|
||||
ai_text (str): The output of an LLM call.
|
||||
data (dict): The data has been handled by some scene.
|
||||
parse_prompt_response (Any): The prompt response has been parsed.
|
||||
|
||||
Returns:
|
||||
str: The parsed output of an LLM call.
|
||||
|
||||
"""
|
||||
return ai_text
|
||||
@ -240,10 +262,14 @@ def _parse_model_response(response: ResponseTye):
|
||||
|
||||
|
||||
class SQLOutputParser(BaseOutputParser):
|
||||
"""Parse the SQL output of an LLM call."""
|
||||
|
||||
def __init__(self, is_stream_out: bool = False, **kwargs):
|
||||
"""Create a new SQL output parser."""
|
||||
super().__init__(is_stream_out=is_stream_out, **kwargs)
|
||||
|
||||
def parse_model_nostream_resp(self, response: ResponseTye, sep: str):
|
||||
"""Parse the output of an LLM call."""
|
||||
model_out_text = super().parse_model_nostream_resp(response, sep)
|
||||
clean_str = super().parse_prompt_response(model_out_text)
|
||||
return json.loads(clean_str, strict=True)
|
||||
|
@ -1,3 +1,5 @@
|
||||
"""The prompt template interface."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
@ -7,9 +9,7 @@ from string import Formatter
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Union
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, root_validator
|
||||
from dbgpt.core._private.example_base import ExampleSelector
|
||||
from dbgpt.core.interface.message import BaseMessage, HumanMessage, SystemMessage
|
||||
from dbgpt.core.interface.output_parser import BaseOutputParser
|
||||
from dbgpt.core.interface.storage import (
|
||||
InMemoryStorage,
|
||||
QuerySpec,
|
||||
@ -42,64 +42,32 @@ _DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = {
|
||||
|
||||
|
||||
class BasePromptTemplate(BaseModel):
|
||||
"""Base class for all prompt templates, returning a prompt."""
|
||||
|
||||
input_variables: List[str]
|
||||
"""A list of the names of the variables the prompt template expects."""
|
||||
|
||||
template: Optional[str]
|
||||
|
||||
class PromptTemplate(BasePromptTemplate):
|
||||
"""Prompt template."""
|
||||
|
||||
template: str
|
||||
"""The prompt template."""
|
||||
|
||||
template_format: Optional[str] = "f-string"
|
||||
template_format: str = "f-string"
|
||||
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
|
||||
|
||||
response_key: str = "response"
|
||||
|
||||
template_is_strict: bool = True
|
||||
"""strict template will check template args"""
|
||||
|
||||
response_format: Optional[str] = None
|
||||
|
||||
response_key: Optional[str] = "response"
|
||||
template_scene: Optional[str] = None
|
||||
|
||||
template_is_strict: Optional[bool] = True
|
||||
"""strict template will check template args"""
|
||||
|
||||
def format(self, **kwargs: Any) -> str:
|
||||
"""Format the prompt with the inputs."""
|
||||
if self.template:
|
||||
if self.response_format:
|
||||
kwargs[self.response_key] = json.dumps(
|
||||
self.response_format, ensure_ascii=False, indent=4
|
||||
)
|
||||
return _DEFAULT_FORMATTER_MAPPING[self.template_format](
|
||||
self.template_is_strict
|
||||
)(self.template, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_template(
|
||||
cls, template: str, template_format: Optional[str] = "f-string", **kwargs: Any
|
||||
) -> BasePromptTemplate:
|
||||
"""Create a prompt template from a template string."""
|
||||
input_variables = get_template_vars(template, template_format)
|
||||
return cls(
|
||||
template=template,
|
||||
input_variables=input_variables,
|
||||
template_format=template_format,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class PromptTemplate(BasePromptTemplate):
|
||||
template_scene: Optional[str]
|
||||
template_define: Optional[str]
|
||||
template_define: Optional[str] = None
|
||||
"""this template define"""
|
||||
"""default use stream out"""
|
||||
stream_out: bool = True
|
||||
""""""
|
||||
output_parser: BaseOutputParser = None
|
||||
""""""
|
||||
sep: str = "###"
|
||||
|
||||
example_selector: ExampleSelector = None
|
||||
|
||||
need_historical_messages: bool = False
|
||||
|
||||
temperature: float = 0.6
|
||||
max_new_tokens: int = 1024
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@ -111,13 +79,38 @@ class PromptTemplate(BasePromptTemplate):
|
||||
"""Return the prompt type key."""
|
||||
return "prompt"
|
||||
|
||||
def format(self, **kwargs: Any) -> str:
|
||||
"""Format the prompt with the inputs."""
|
||||
if self.response_format:
|
||||
kwargs[self.response_key] = json.dumps(
|
||||
self.response_format, ensure_ascii=False, indent=4
|
||||
)
|
||||
return _DEFAULT_FORMATTER_MAPPING[self.template_format](
|
||||
self.template_is_strict
|
||||
)(self.template, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_template(
|
||||
cls, template: str, template_format: str = "f-string", **kwargs: Any
|
||||
) -> BasePromptTemplate:
|
||||
"""Create a prompt template from a template string."""
|
||||
input_variables = get_template_vars(template, template_format)
|
||||
return cls(
|
||||
template=template,
|
||||
input_variables=input_variables,
|
||||
template_format=template_format,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class BaseChatPromptTemplate(BaseModel, ABC):
|
||||
"""The base chat prompt template."""
|
||||
|
||||
prompt: BasePromptTemplate
|
||||
|
||||
@property
|
||||
def input_variables(self) -> List[str]:
|
||||
"""A list of the names of the variables the prompt template expects."""
|
||||
"""Return a list of the names of the variables the prompt template expects."""
|
||||
return self.prompt.input_variables
|
||||
|
||||
@abstractmethod
|
||||
@ -128,14 +121,14 @@ class BaseChatPromptTemplate(BaseModel, ABC):
|
||||
def from_template(
|
||||
cls,
|
||||
template: str,
|
||||
template_format: Optional[str] = "f-string",
|
||||
template_format: str = "f-string",
|
||||
response_format: Optional[str] = None,
|
||||
response_key: Optional[str] = "response",
|
||||
response_key: str = "response",
|
||||
template_is_strict: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> BaseChatPromptTemplate:
|
||||
"""Create a prompt template from a template string."""
|
||||
prompt = BasePromptTemplate.from_template(
|
||||
prompt = PromptTemplate.from_template(
|
||||
template,
|
||||
template_format,
|
||||
response_format=response_format,
|
||||
@ -149,6 +142,11 @@ class SystemPromptTemplate(BaseChatPromptTemplate):
|
||||
"""The system prompt template."""
|
||||
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
"""Format the prompt with the inputs.
|
||||
|
||||
Returns:
|
||||
List[BaseMessage]: The formatted messages.
|
||||
"""
|
||||
content = self.prompt.format(**kwargs)
|
||||
return [SystemMessage(content=content)]
|
||||
|
||||
@ -157,20 +155,31 @@ class HumanPromptTemplate(BaseChatPromptTemplate):
|
||||
"""The human prompt template."""
|
||||
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
"""Format the prompt with the inputs.
|
||||
|
||||
Returns:
|
||||
List[BaseMessage]: The formatted messages.
|
||||
"""
|
||||
content = self.prompt.format(**kwargs)
|
||||
return [HumanMessage(content=content)]
|
||||
|
||||
|
||||
class MessagesPlaceholder(BaseChatPromptTemplate):
|
||||
class MessagesPlaceholder(BaseModel):
|
||||
"""The messages placeholder template.
|
||||
|
||||
Mostly used for the chat history.
|
||||
"""
|
||||
|
||||
variable_name: str
|
||||
prompt: BasePromptTemplate = None
|
||||
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
"""Format the prompt with the inputs.
|
||||
|
||||
Just return the messages from the kwargs with the variable name.
|
||||
|
||||
Returns:
|
||||
List[BaseMessage]: The messages.
|
||||
"""
|
||||
messages = kwargs.get(self.variable_name, [])
|
||||
if not isinstance(messages, list):
|
||||
raise ValueError(
|
||||
@ -185,7 +194,7 @@ class MessagesPlaceholder(BaseChatPromptTemplate):
|
||||
|
||||
@property
|
||||
def input_variables(self) -> List[str]:
|
||||
"""A list of the names of the variables the prompt template expects.
|
||||
"""Return a list of the names of the variables the prompt template expects.
|
||||
|
||||
Returns:
|
||||
List[str]: The input variables.
|
||||
@ -193,10 +202,26 @@ class MessagesPlaceholder(BaseChatPromptTemplate):
|
||||
return [self.variable_name]
|
||||
|
||||
|
||||
MessageType = Union[BaseChatPromptTemplate, BaseMessage]
|
||||
MessageType = Union[BaseChatPromptTemplate, MessagesPlaceholder, BaseMessage]
|
||||
|
||||
|
||||
class ChatPromptTemplate(BasePromptTemplate):
|
||||
"""The chat prompt template.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
prompt_template = ChatPromptTemplate(
|
||||
messages=[
|
||||
SystemPromptTemplate.from_template(
|
||||
"You are a helpful AI assistant."
|
||||
),
|
||||
MessagesPlaceholder(variable_name="chat_history"),
|
||||
HumanPromptTemplate.from_template("{question}"),
|
||||
]
|
||||
)
|
||||
"""
|
||||
|
||||
messages: List[MessageType]
|
||||
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
@ -205,12 +230,7 @@ class ChatPromptTemplate(BasePromptTemplate):
|
||||
for message in self.messages:
|
||||
if isinstance(message, BaseMessage):
|
||||
result_messages.append(message)
|
||||
elif isinstance(message, BaseChatPromptTemplate):
|
||||
pass_kwargs = {
|
||||
k: v for k, v in kwargs.items() if k in message.input_variables
|
||||
}
|
||||
result_messages.extend(message.format_messages(**pass_kwargs))
|
||||
elif isinstance(message, MessagesPlaceholder):
|
||||
elif isinstance(message, (BaseChatPromptTemplate, MessagesPlaceholder)):
|
||||
pass_kwargs = {
|
||||
k: v for k, v in kwargs.items() if k in message.input_variables
|
||||
}
|
||||
@ -227,7 +247,7 @@ class ChatPromptTemplate(BasePromptTemplate):
|
||||
if not input_variables:
|
||||
input_variables = set()
|
||||
for message in messages:
|
||||
if isinstance(message, BaseChatPromptTemplate):
|
||||
if isinstance(message, (BaseChatPromptTemplate, MessagesPlaceholder)):
|
||||
input_variables.update(message.input_variables)
|
||||
values["input_variables"] = sorted(input_variables)
|
||||
return values
|
||||
@ -235,6 +255,8 @@ class ChatPromptTemplate(BasePromptTemplate):
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PromptTemplateIdentifier(ResourceIdentifier):
|
||||
"""The identifier of a prompt template."""
|
||||
|
||||
identifier_split: str = dataclasses.field(default="___$$$$___", init=False)
|
||||
prompt_name: str
|
||||
prompt_language: Optional[str] = None
|
||||
@ -242,6 +264,7 @@ class PromptTemplateIdentifier(ResourceIdentifier):
|
||||
model: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
"""Post init method."""
|
||||
if self.prompt_name is None:
|
||||
raise ValueError("prompt_name cannot be None")
|
||||
|
||||
@ -256,11 +279,13 @@ class PromptTemplateIdentifier(ResourceIdentifier):
|
||||
if key is not None
|
||||
):
|
||||
raise ValueError(
|
||||
f"identifier_split {self.identifier_split} is not allowed in prompt_name, prompt_language, sys_code, model"
|
||||
f"identifier_split {self.identifier_split} is not allowed in "
|
||||
f"prompt_name, prompt_language, sys_code, model"
|
||||
)
|
||||
|
||||
@property
|
||||
def str_identifier(self) -> str:
|
||||
"""Return the string identifier of the identifier."""
|
||||
return self.identifier_split.join(
|
||||
key
|
||||
for key in [
|
||||
@ -273,6 +298,11 @@ class PromptTemplateIdentifier(ResourceIdentifier):
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert the identifier to a dict.
|
||||
|
||||
Returns:
|
||||
Dict: The dict of the identifier.
|
||||
"""
|
||||
return {
|
||||
"prompt_name": self.prompt_name,
|
||||
"prompt_language": self.prompt_language,
|
||||
@ -283,6 +313,8 @@ class PromptTemplateIdentifier(ResourceIdentifier):
|
||||
|
||||
@dataclasses.dataclass
|
||||
class StoragePromptTemplate(StorageItem):
|
||||
"""The storage prompt template."""
|
||||
|
||||
prompt_name: str
|
||||
content: Optional[str] = None
|
||||
prompt_language: Optional[str] = None
|
||||
@ -297,25 +329,28 @@ class StoragePromptTemplate(StorageItem):
|
||||
_identifier: PromptTemplateIdentifier = dataclasses.field(init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Post init method."""
|
||||
self._identifier = PromptTemplateIdentifier(
|
||||
prompt_name=self.prompt_name,
|
||||
prompt_language=self.prompt_language,
|
||||
sys_code=self.sys_code,
|
||||
model=self.model,
|
||||
)
|
||||
self._check() # Assuming _check() is a method you need to call after initialization
|
||||
# Assuming _check() is a method you need to call after initialization
|
||||
self._check()
|
||||
|
||||
def to_prompt_template(self) -> PromptTemplate:
|
||||
"""Convert the storage prompt template to a prompt template."""
|
||||
input_variables = (
|
||||
[] if not self.input_variables else self.input_variables.strip().split(",")
|
||||
)
|
||||
template_format = self.prompt_format or "f-string"
|
||||
return PromptTemplate(
|
||||
input_variables=input_variables,
|
||||
template=self.content,
|
||||
template_scene=self.chat_scene,
|
||||
prompt_name=self.prompt_name,
|
||||
template_format=self.prompt_format,
|
||||
# prompt_name=self.prompt_name,
|
||||
template_format=template_format,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -335,12 +370,18 @@ class StoragePromptTemplate(StorageItem):
|
||||
Args:
|
||||
prompt_template (PromptTemplate): The prompt template to convert from.
|
||||
prompt_name (str): The name of the prompt.
|
||||
prompt_language (Optional[str], optional): The language of the prompt. Defaults to None. e.g. zh-cn, en.
|
||||
prompt_type (Optional[str], optional): The type of the prompt. Defaults to None. e.g. common, private.
|
||||
sys_code (Optional[str], optional): The system code of the prompt. Defaults to None.
|
||||
user_name (Optional[str], optional): The username of the prompt. Defaults to None.
|
||||
sub_chat_scene (Optional[str], optional): The sub chat scene of the prompt. Defaults to None.
|
||||
model (Optional[str], optional): The model name of the prompt. Defaults to None.
|
||||
prompt_language (Optional[str], optional): The language of the prompt.
|
||||
Defaults to None. e.g. zh-cn, en.
|
||||
prompt_type (Optional[str], optional): The type of the prompt.
|
||||
Defaults to None. e.g. common, private.
|
||||
sys_code (Optional[str], optional): The system code of the prompt.
|
||||
Defaults to None.
|
||||
user_name (Optional[str], optional): The username of the prompt.
|
||||
Defaults to None.
|
||||
sub_chat_scene (Optional[str], optional): The sub chat scene of the prompt.
|
||||
Defaults to None.
|
||||
model (Optional[str], optional): The model name of the prompt.
|
||||
Defaults to None.
|
||||
kwargs (Dict): Other params to build the storage prompt template.
|
||||
"""
|
||||
input_variables = prompt_template.input_variables or kwargs.get(
|
||||
@ -365,6 +406,7 @@ class StoragePromptTemplate(StorageItem):
|
||||
|
||||
@property
|
||||
def identifier(self) -> PromptTemplateIdentifier:
|
||||
"""Return the identifier of the storage prompt template."""
|
||||
return self._identifier
|
||||
|
||||
def merge(self, other: "StorageItem") -> None:
|
||||
@ -375,11 +417,17 @@ class StoragePromptTemplate(StorageItem):
|
||||
"""
|
||||
if not isinstance(other, StoragePromptTemplate):
|
||||
raise ValueError(
|
||||
f"Cannot merge {type(other)} into {type(self)} because they are not the same type."
|
||||
f"Cannot merge {type(other)} into {type(self)} because they are not "
|
||||
f"the same type."
|
||||
)
|
||||
self.from_object(other)
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert the storage prompt template to a dict.
|
||||
|
||||
Returns:
|
||||
Dict: The dict of the storage prompt template.
|
||||
"""
|
||||
return {
|
||||
"prompt_name": self.prompt_name,
|
||||
"content": self.content,
|
||||
@ -422,7 +470,6 @@ class PromptManager:
|
||||
Simple wrapper for the storage interface.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Default use InMemoryStorage
|
||||
@ -458,13 +505,14 @@ class PromptManager:
|
||||
def __init__(
|
||||
self, storage: Optional[StorageInterface[StoragePromptTemplate, Any]] = None
|
||||
):
|
||||
"""Create a new prompt manager."""
|
||||
if storage is None:
|
||||
storage = InMemoryStorage()
|
||||
self._storage = storage
|
||||
|
||||
@property
|
||||
def storage(self) -> StorageInterface[StoragePromptTemplate, Any]:
|
||||
"""The storage interface for prompt templates."""
|
||||
"""Return the storage interface for prompt templates."""
|
||||
return self._storage
|
||||
|
||||
def prefer_query(
|
||||
@ -477,11 +525,12 @@ class PromptManager:
|
||||
) -> List[StoragePromptTemplate]:
|
||||
"""Query prompt templates from storage with prefer params.
|
||||
|
||||
Sometimes, we want to query prompt templates with prefer params(e.g. some language or some model).
|
||||
This method will query prompt templates with prefer params first, if not found, will query all prompt templates.
|
||||
Sometimes, we want to query prompt templates with prefer params(e.g. some
|
||||
language or some model).
|
||||
This method will query prompt templates with prefer params first, if not found,
|
||||
will query all prompt templates.
|
||||
|
||||
Examples:
|
||||
|
||||
Query a prompt template.
|
||||
.. code-block:: python
|
||||
|
||||
@ -500,7 +549,8 @@ class PromptManager:
|
||||
.. code-block:: python
|
||||
|
||||
# First query with prompt name "hello" exactly.
|
||||
# Second filter with prompt language "zh-cn", if not found, will return all prompt templates.
|
||||
# Second filter with prompt language "zh-cn", if not found, will return
|
||||
# all prompt templates.
|
||||
prompt_template_list = prompt_manager.prefer_query(
|
||||
"hello", prefer_prompt_language="zh-cn"
|
||||
)
|
||||
@ -510,17 +560,22 @@ class PromptManager:
|
||||
.. code-block:: python
|
||||
|
||||
# First query with prompt name "hello" exactly.
|
||||
# Second filter with model "vicuna-13b-v1.5", if not found, will return all prompt templates.
|
||||
# Second filter with model "vicuna-13b-v1.5", if not found, will return
|
||||
# all prompt templates.
|
||||
prompt_template_list = prompt_manager.prefer_query(
|
||||
"hello", prefer_model="vicuna-13b-v1.5"
|
||||
)
|
||||
|
||||
Args:
|
||||
prompt_name (str): The name of the prompt template.
|
||||
sys_code (Optional[str], optional): The system code of the prompt template. Defaults to None.
|
||||
prefer_prompt_language (Optional[str], optional): The language of the prompt template. Defaults to None.
|
||||
prefer_model (Optional[str], optional): The model of the prompt template. Defaults to None.
|
||||
kwargs (Dict): Other query params(If some key and value not None, wo we query it exactly).
|
||||
sys_code (Optional[str], optional): The system code of the prompt template.
|
||||
Defaults to None.
|
||||
prefer_prompt_language (Optional[str], optional): The language of the
|
||||
prompt template. Defaults to None.
|
||||
prefer_model (Optional[str], optional): The model of the prompt template.
|
||||
Defaults to None.
|
||||
kwargs (Dict): Other query params(If some key and value not None, wo we
|
||||
query it exactly).
|
||||
"""
|
||||
query_spec = QuerySpec(
|
||||
conditions={
|
||||
@ -559,7 +614,6 @@ class PromptManager:
|
||||
"""Save a prompt template to storage.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
prompt_template = PromptTemplate(
|
||||
@ -618,15 +672,17 @@ class PromptManager:
|
||||
if exist_prompt_template:
|
||||
return exist_prompt_template
|
||||
self.save(prompt_template, prompt_name, **kwargs)
|
||||
return self.storage.load(
|
||||
prompt = self.storage.load(
|
||||
storage_prompt_template.identifier, StoragePromptTemplate
|
||||
)
|
||||
if not prompt:
|
||||
raise ValueError("Can't read prompt from storage")
|
||||
return prompt
|
||||
|
||||
def list(self, **kwargs) -> List[StoragePromptTemplate]:
|
||||
"""List prompt templates from storage.
|
||||
|
||||
Examples:
|
||||
|
||||
List all prompt templates.
|
||||
.. code-block:: python
|
||||
|
||||
@ -656,7 +712,6 @@ class PromptManager:
|
||||
"""Delete a prompt template from storage.
|
||||
|
||||
Examples:
|
||||
|
||||
Delete a prompt template.
|
||||
|
||||
.. code-block:: python
|
||||
@ -673,9 +728,12 @@ class PromptManager:
|
||||
|
||||
Args:
|
||||
prompt_name (str): The name of the prompt template.
|
||||
prompt_language (Optional[str], optional): The language of the prompt template. Defaults to None.
|
||||
sys_code (Optional[str], optional): The system code of the prompt template. Defaults to None.
|
||||
model (Optional[str], optional): The model of the prompt template. Defaults to None.
|
||||
prompt_language (Optional[str], optional): The language of the prompt
|
||||
template. Defaults to None.
|
||||
sys_code (Optional[str], optional): The system code of the prompt template.
|
||||
Defaults to None.
|
||||
model (Optional[str], optional): The model of the prompt template.
|
||||
Defaults to None.
|
||||
"""
|
||||
identifier = PromptTemplateIdentifier(
|
||||
prompt_name=prompt_name,
|
||||
|
@ -1,11 +1,15 @@
|
||||
"""The interface for serializing."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Type
|
||||
from typing import Dict, Optional, Type
|
||||
|
||||
|
||||
class Serializable(ABC):
|
||||
serializer: "Serializer" = None
|
||||
"""The serializable abstract class."""
|
||||
|
||||
serializer: Optional["Serializer"] = None
|
||||
|
||||
@abstractmethod
|
||||
def to_dict(self) -> Dict:
|
||||
|
@ -1,5 +1,7 @@
|
||||
"""The storage interface for storing and loading data."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar
|
||||
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, cast
|
||||
|
||||
from dbgpt.core.interface.serialization import Serializable, Serializer
|
||||
from dbgpt.util.annotations import PublicAPI
|
||||
@ -55,17 +57,22 @@ class StorageItem(Serializable, ABC):
|
||||
"""
|
||||
|
||||
|
||||
ID = TypeVar("ID", bound=ResourceIdentifier)
|
||||
T = TypeVar("T", bound=StorageItem)
|
||||
TDataRepresentation = TypeVar("TDataRepresentation")
|
||||
|
||||
|
||||
class StorageItemAdapter(Generic[T, TDataRepresentation]):
|
||||
"""The storage item adapter for converting storage items to and from the storage format.
|
||||
"""Storage item adapter.
|
||||
|
||||
The storage item adapter for converting storage items to and from the storage
|
||||
format.
|
||||
|
||||
Sometimes, the storage item is not the same as the storage format,
|
||||
so we need to convert the storage item to the storage format and vice versa.
|
||||
|
||||
In database storage, the storage format is database model, but the StorageItem is the user-defined object.
|
||||
In database storage, the storage format is database model, but the StorageItem is
|
||||
the user-defined object.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
@ -110,20 +117,44 @@ class StorageItemAdapter(Generic[T, TDataRepresentation]):
|
||||
|
||||
|
||||
class DefaultStorageItemAdapter(StorageItemAdapter[T, T]):
|
||||
"""The default storage item adapter for converting storage items to and from the storage format.
|
||||
"""Default storage item adapter.
|
||||
|
||||
The default storage item adapter for converting storage items to and from the
|
||||
storage format.
|
||||
|
||||
The storage item is the same as the storage format, so no conversion is required.
|
||||
"""
|
||||
|
||||
def to_storage_format(self, item: T) -> T:
|
||||
"""Convert the storage item to the storage format.
|
||||
|
||||
Returns the storage item itself.
|
||||
|
||||
Args:
|
||||
item (T): The storage item
|
||||
|
||||
Returns:
|
||||
T: The data in the storage format
|
||||
"""
|
||||
return item
|
||||
|
||||
def from_storage_format(self, data: T) -> T:
|
||||
"""Convert the storage format to the storage item.
|
||||
|
||||
Returns the storage format itself.
|
||||
|
||||
Args:
|
||||
data (T): The data in the storage format
|
||||
|
||||
Returns:
|
||||
T: The storage item
|
||||
"""
|
||||
return data
|
||||
|
||||
def get_query_for_identifier(
|
||||
self, storage_format: Type[T], resource_id: ResourceIdentifier, **kwargs
|
||||
self, storage_format: Type[T], resource_id: ID, **kwargs
|
||||
) -> bool:
|
||||
"""Return the query for the resource identifier."""
|
||||
return True
|
||||
|
||||
|
||||
@ -132,6 +163,7 @@ class StorageError(Exception):
|
||||
"""The base exception class for storage errors."""
|
||||
|
||||
def __init__(self, message: str):
|
||||
"""Create a new StorageError."""
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
@ -146,8 +178,9 @@ class QuerySpec:
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, conditions: Dict[str, Any], limit: int = None, offset: int = 0
|
||||
self, conditions: Dict[str, Any], limit: Optional[int] = None, offset: int = 0
|
||||
) -> None:
|
||||
"""Create a new QuerySpec."""
|
||||
self.conditions = conditions
|
||||
self.limit = limit
|
||||
self.offset = offset
|
||||
@ -162,6 +195,7 @@ class StorageInterface(Generic[T, TDataRepresentation], ABC):
|
||||
serializer: Optional[Serializer] = None,
|
||||
adapter: Optional[StorageItemAdapter[T, TDataRepresentation]] = None,
|
||||
):
|
||||
"""Create a new StorageInterface."""
|
||||
self._serializer = serializer or JsonSerializer()
|
||||
self._storage_item_adapter = adapter or DefaultStorageItemAdapter()
|
||||
|
||||
@ -238,7 +272,7 @@ class StorageInterface(Generic[T, TDataRepresentation], ABC):
|
||||
self.save_or_update(d)
|
||||
|
||||
@abstractmethod
|
||||
def load(self, resource_id: ResourceIdentifier, cls: Type[T]) -> Optional[T]:
|
||||
def load(self, resource_id: ID, cls: Type[T]) -> Optional[T]:
|
||||
"""Load the data from the storage.
|
||||
|
||||
None will be returned if the data does not exist in the storage.
|
||||
@ -247,14 +281,14 @@ class StorageInterface(Generic[T, TDataRepresentation], ABC):
|
||||
so we suggest to use load if possible.
|
||||
|
||||
Args:
|
||||
resource_id (ResourceIdentifier): The resource identifier of the data
|
||||
resource_id (ID): The resource identifier of the data
|
||||
cls (Type[T]): The type of the data
|
||||
|
||||
Returns:
|
||||
Optional[T]: The loaded data
|
||||
"""
|
||||
|
||||
def load_list(self, resource_id: List[ResourceIdentifier], cls: Type[T]) -> List[T]:
|
||||
def load_list(self, resource_id: List[ID], cls: Type[T]) -> List[T]:
|
||||
"""Load the data from the storage.
|
||||
|
||||
None will be returned if the data does not exist in the storage.
|
||||
@ -263,7 +297,7 @@ class StorageInterface(Generic[T, TDataRepresentation], ABC):
|
||||
so we suggest to use load if possible.
|
||||
|
||||
Args:
|
||||
resource_id (ResourceIdentifier): The resource identifier of the data
|
||||
resource_id (ID): The resource identifier of the data
|
||||
cls (Type[T]): The type of the data
|
||||
|
||||
Returns:
|
||||
@ -277,18 +311,18 @@ class StorageInterface(Generic[T, TDataRepresentation], ABC):
|
||||
return result
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, resource_id: ResourceIdentifier) -> None:
|
||||
def delete(self, resource_id: ID) -> None:
|
||||
"""Delete the data from the storage.
|
||||
|
||||
Args:
|
||||
resource_id (ResourceIdentifier): The resource identifier of the data
|
||||
resource_id (ID): The resource identifier of the data
|
||||
"""
|
||||
|
||||
def delete_list(self, resource_id: List[ResourceIdentifier]) -> None:
|
||||
def delete_list(self, resource_id: List[ID]) -> None:
|
||||
"""Delete the data from the storage.
|
||||
|
||||
Args:
|
||||
resource_id (ResourceIdentifier): The resource identifier of the data
|
||||
resource_id (ID): The resource identifier of the data
|
||||
"""
|
||||
for r in resource_id:
|
||||
self.delete(r)
|
||||
@ -297,7 +331,8 @@ class StorageInterface(Generic[T, TDataRepresentation], ABC):
|
||||
def query(self, spec: QuerySpec, cls: Type[T]) -> List[T]:
|
||||
"""Query data from the storage.
|
||||
|
||||
Query data with resource_id will be faster than query data with conditions, so please use load if possible.
|
||||
Query data with resource_id will be faster than query data with conditions,
|
||||
so please use load if possible.
|
||||
|
||||
Args:
|
||||
spec (QuerySpec): The query specification
|
||||
@ -328,7 +363,8 @@ class StorageInterface(Generic[T, TDataRepresentation], ABC):
|
||||
page (int): The page number
|
||||
page_size (int): The number of items per page
|
||||
cls (Type[T]): The type of the data
|
||||
spec (Optional[QuerySpec], optional): The query specification. Defaults to None.
|
||||
spec (Optional[QuerySpec], optional): The query specification.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
PaginationResult[T]: The pagination result
|
||||
@ -356,10 +392,17 @@ class InMemoryStorage(StorageInterface[T, T]):
|
||||
self,
|
||||
serializer: Optional[Serializer] = None,
|
||||
):
|
||||
"""Create a new InMemoryStorage."""
|
||||
super().__init__(serializer)
|
||||
self._data = {} # Key: ResourceIdentifier, Value: Serialized data
|
||||
# Key: ResourceIdentifier, Value: Serialized data
|
||||
self._data: Dict[str, bytes] = {}
|
||||
|
||||
def save(self, data: T) -> None:
|
||||
"""Save the data to the storage.
|
||||
|
||||
Args:
|
||||
data (T): The data to save
|
||||
"""
|
||||
if not data:
|
||||
raise StorageError("Data cannot be None")
|
||||
if not data.serializer:
|
||||
@ -372,6 +415,7 @@ class InMemoryStorage(StorageInterface[T, T]):
|
||||
self._data[data.identifier.str_identifier] = data.serialize()
|
||||
|
||||
def update(self, data: T) -> None:
|
||||
"""Update the data to the storage."""
|
||||
if not data:
|
||||
raise StorageError("Data cannot be None")
|
||||
if not data.serializer:
|
||||
@ -379,22 +423,34 @@ class InMemoryStorage(StorageInterface[T, T]):
|
||||
self._data[data.identifier.str_identifier] = data.serialize()
|
||||
|
||||
def save_or_update(self, data: T) -> None:
|
||||
"""Save or update the data to the storage."""
|
||||
self.update(data)
|
||||
|
||||
def load(self, resource_id: ResourceIdentifier, cls: Type[T]) -> Optional[T]:
|
||||
def load(self, resource_id: ID, cls: Type[T]) -> Optional[T]:
|
||||
"""Load the data from the storage."""
|
||||
serialized_data = self._data.get(resource_id.str_identifier)
|
||||
if serialized_data is None:
|
||||
return None
|
||||
return self.serializer.deserialize(serialized_data, cls)
|
||||
return cast(T, self.serializer.deserialize(serialized_data, cls))
|
||||
|
||||
def delete(self, resource_id: ResourceIdentifier) -> None:
|
||||
def delete(self, resource_id: ID) -> None:
|
||||
"""Delete the data from the storage."""
|
||||
if resource_id.str_identifier in self._data:
|
||||
del self._data[resource_id.str_identifier]
|
||||
|
||||
def query(self, spec: QuerySpec, cls: Type[T]) -> List[T]:
|
||||
"""Query data from the storage.
|
||||
|
||||
Args:
|
||||
spec (QuerySpec): The query specification
|
||||
cls (Type[T]): The type of the data
|
||||
|
||||
Returns:
|
||||
List[T]: The queried data
|
||||
"""
|
||||
result = []
|
||||
for serialized_data in self._data.values():
|
||||
data = self._serializer.deserialize(serialized_data, cls)
|
||||
data = cast(T, self._serializer.deserialize(serialized_data, cls))
|
||||
if all(
|
||||
getattr(data, key) == value for key, value in spec.conditions.items()
|
||||
):
|
||||
@ -408,6 +464,15 @@ class InMemoryStorage(StorageInterface[T, T]):
|
||||
return result
|
||||
|
||||
def count(self, spec: QuerySpec, cls: Type[T]) -> int:
|
||||
"""Count the number of data from the storage.
|
||||
|
||||
Args:
|
||||
spec (QuerySpec): The query specification
|
||||
cls (Type[T]): The type of the data
|
||||
|
||||
Returns:
|
||||
int: The number of data
|
||||
"""
|
||||
count = 0
|
||||
for serialized_data in self._data.values():
|
||||
data = self._serializer.deserialize(serialized_data, cls)
|
||||
|
@ -1,22 +1,24 @@
|
||||
from dbgpt.core.interface.operator.composer_operator import (
|
||||
"""All core operators."""
|
||||
|
||||
from dbgpt.core.interface.operator.composer_operator import ( # noqa: F401
|
||||
ChatComposerInput,
|
||||
ChatHistoryPromptComposerOperator,
|
||||
)
|
||||
from dbgpt.core.interface.operator.llm_operator import (
|
||||
from dbgpt.core.interface.operator.llm_operator import ( # noqa: F401
|
||||
BaseLLM,
|
||||
BaseLLMOperator,
|
||||
BaseStreamingLLMOperator,
|
||||
LLMBranchOperator,
|
||||
RequestBuilderOperator,
|
||||
)
|
||||
from dbgpt.core.interface.operator.message_operator import (
|
||||
from dbgpt.core.interface.operator.message_operator import ( # noqa: F401
|
||||
BaseConversationOperator,
|
||||
BufferedConversationMapperOperator,
|
||||
ConversationMapperOperator,
|
||||
PreChatHistoryLoadOperator,
|
||||
TokenBufferedConversationMapperOperator,
|
||||
)
|
||||
from dbgpt.core.interface.operator.prompt_operator import (
|
||||
from dbgpt.core.interface.operator.prompt_operator import ( # noqa: F401
|
||||
DynamicPromptBuilderOperator,
|
||||
HistoryDynamicPromptBuilderOperator,
|
||||
HistoryPromptBuilderOperator,
|
||||
|
@ -8,6 +8,9 @@ import threading
|
||||
from functools import cache
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional, Tuple
|
||||
|
||||
from dbgpt.model.adapter.base import LLMModelAdapter
|
||||
from dbgpt.model.adapter.template import ConversationAdapter, PromptType
|
||||
|
||||
try:
|
||||
from fastchat.conversation import (
|
||||
Conversation,
|
||||
@ -20,8 +23,6 @@ except ImportError as exc:
|
||||
"Please install fastchat by command `pip install fschat` "
|
||||
) from exc
|
||||
|
||||
from dbgpt.model.adapter.base import LLMModelAdapter
|
||||
from dbgpt.model.adapter.template import ConversationAdapter, PromptType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastchat.model.model_adapter import BaseModelAdapter
|
||||
|
@ -196,7 +196,14 @@ class DefaultModelWorker(ModelWorker):
|
||||
return _try_to_count_token(prompt, self.tokenizer, self.model)
|
||||
|
||||
async def async_count_token(self, prompt: str) -> int:
|
||||
# TODO if we deploy the model by vllm, it can't work, we should run transformer _try_to_count_token to async
|
||||
# TODO if we deploy the model by vllm, it can't work, we should run
|
||||
# transformer _try_to_count_token to async
|
||||
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
|
||||
|
||||
if isinstance(self.model, ProxyModel) and self.model.proxy_llm_client:
|
||||
return await self.model.proxy_llm_client.count_token(
|
||||
self.model.proxy_llm_client.default_model, prompt
|
||||
)
|
||||
raise NotImplementedError
|
||||
|
||||
def get_model_metadata(self, params: Dict) -> ModelMetadata:
|
||||
|
@ -118,9 +118,6 @@ def replace_llama_attn_with_non_inplace_operations():
|
||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
|
||||
|
||||
|
||||
import transformers
|
||||
|
||||
|
||||
def replace_llama_attn_with_non_inplace_operations():
|
||||
"""Avoid bugs in mps backend by not using in-place operations."""
|
||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
|
||||
|
@ -196,6 +196,15 @@ class ProxyLLMClient(LLMClient):
|
||||
"""
|
||||
return self._models()
|
||||
|
||||
@property
|
||||
def default_model(self) -> str:
|
||||
"""Get default model name
|
||||
|
||||
Returns:
|
||||
str: default model name
|
||||
"""
|
||||
return self.model_names[0]
|
||||
|
||||
@cache
|
||||
def _models(self) -> List[ModelMetadata]:
|
||||
results = []
|
||||
@ -237,6 +246,7 @@ class ProxyLLMClient(LLMClient):
|
||||
Returns:
|
||||
int: token count, -1 if failed
|
||||
"""
|
||||
return await blocking_func_to_async(
|
||||
counts = await blocking_func_to_async(
|
||||
self.executor, self.proxy_tokenizer.count_token, model, [prompt]
|
||||
)[0]
|
||||
)
|
||||
return counts[0]
|
||||
|
@ -86,6 +86,11 @@ class OpenAILLMClient(ProxyLLMClient):
|
||||
self._openai_kwargs = openai_kwargs or {}
|
||||
super().__init__(model_names=[model_alias], context_length=context_length)
|
||||
|
||||
if self._openai_less_then_v1:
|
||||
from dbgpt.model.utils.chatgpt_utils import _initialize_openai
|
||||
|
||||
_initialize_openai(self._init_params)
|
||||
|
||||
@classmethod
|
||||
def new_client(
|
||||
cls,
|
||||
|
@ -114,7 +114,6 @@ class GeminiLLMClient(ProxyLLMClient):
|
||||
self._api_key = api_key if api_key else os.getenv("GEMINI_PROXY_API_KEY")
|
||||
self._api_base = api_base if api_base else os.getenv("GEMINI_PROXY_API_BASE")
|
||||
self._model = model
|
||||
self.default_model = self._model
|
||||
if not self._api_key:
|
||||
raise RuntimeError("api_key can't be empty")
|
||||
|
||||
@ -148,6 +147,10 @@ class GeminiLLMClient(ProxyLLMClient):
|
||||
executor=default_executor,
|
||||
)
|
||||
|
||||
@property
|
||||
def default_model(self) -> str:
|
||||
return self._model
|
||||
|
||||
def sync_generate_stream(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
|
@ -8,9 +8,6 @@ from datetime import datetime
|
||||
from time import mktime
|
||||
from typing import Iterator, Optional
|
||||
from urllib.parse import urlencode, urlparse
|
||||
from wsgiref.handlers import format_date_time
|
||||
|
||||
from websockets.sync.client import connect
|
||||
|
||||
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext
|
||||
from dbgpt.model.parameter import ProxyModelParameters
|
||||
@ -56,6 +53,8 @@ def spark_generate_stream(
|
||||
|
||||
|
||||
def get_response(request_url, data):
|
||||
from websockets.sync.client import connect
|
||||
|
||||
with connect(request_url) as ws:
|
||||
ws.send(json.dumps(data, ensure_ascii=False))
|
||||
result = ""
|
||||
@ -87,6 +86,8 @@ class SparkAPI:
|
||||
self.spark_url = spark_url
|
||||
|
||||
def gen_url(self):
|
||||
from wsgiref.handlers import format_date_time
|
||||
|
||||
# 生成RFC1123格式的时间戳
|
||||
now = datetime.now()
|
||||
date = format_date_time(mktime(now.timetuple()))
|
||||
@ -145,7 +146,6 @@ class SparkLLMClient(ProxyLLMClient):
|
||||
if not api_domain:
|
||||
api_domain = domain
|
||||
self._model = model
|
||||
self.default_model = self._model
|
||||
self._model_version = model_version
|
||||
self._api_base = api_base
|
||||
self._domain = api_domain
|
||||
@ -183,6 +183,10 @@ class SparkLLMClient(ProxyLLMClient):
|
||||
executor=default_executor,
|
||||
)
|
||||
|
||||
@property
|
||||
def default_model(self) -> str:
|
||||
return self._model
|
||||
|
||||
def sync_generate_stream(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
|
@ -51,7 +51,6 @@ class TongyiLLMClient(ProxyLLMClient):
|
||||
if api_region:
|
||||
dashscope.api_region = api_region
|
||||
self._model = model
|
||||
self.default_model = self._model
|
||||
|
||||
super().__init__(
|
||||
model_names=[model, model_alias],
|
||||
@ -73,6 +72,10 @@ class TongyiLLMClient(ProxyLLMClient):
|
||||
executor=default_executor,
|
||||
)
|
||||
|
||||
@property
|
||||
def default_model(self) -> str:
|
||||
return self._model
|
||||
|
||||
def sync_generate_stream(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
|
@ -121,7 +121,6 @@ class WenxinLLMClient(ProxyLLMClient):
|
||||
self._api_key = api_key
|
||||
self._api_secret = api_secret
|
||||
self._model_version = model_version
|
||||
self.default_model = self._model
|
||||
|
||||
super().__init__(
|
||||
model_names=[model, model_alias],
|
||||
@ -145,6 +144,10 @@ class WenxinLLMClient(ProxyLLMClient):
|
||||
executor=default_executor,
|
||||
)
|
||||
|
||||
@property
|
||||
def default_model(self) -> str:
|
||||
return self._model
|
||||
|
||||
def sync_generate_stream(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
|
@ -54,7 +54,6 @@ class ZhipuLLMClient(ProxyLLMClient):
|
||||
if api_key:
|
||||
zhipuai.api_key = api_key
|
||||
self._model = model
|
||||
self.default_model = self._model
|
||||
|
||||
super().__init__(
|
||||
model_names=[model, model_alias],
|
||||
@ -76,6 +75,10 @@ class ZhipuLLMClient(ProxyLLMClient):
|
||||
executor=default_executor,
|
||||
)
|
||||
|
||||
@property
|
||||
def default_model(self) -> str:
|
||||
return self._model
|
||||
|
||||
def sync_generate_stream(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
|
@ -88,6 +88,42 @@ def _initialize_openai_v1(init_params: OpenAIParameters):
|
||||
return openai_params, api_type, api_version
|
||||
|
||||
|
||||
def _initialize_openai(params: OpenAIParameters):
|
||||
try:
|
||||
import openai
|
||||
except ImportError as exc:
|
||||
raise ValueError(
|
||||
"Could not import python package: openai "
|
||||
"Please install openai by command `pip install openai` "
|
||||
) from exc
|
||||
|
||||
api_type = params.api_type or os.getenv("OPENAI_API_TYPE", "open_ai")
|
||||
|
||||
api_base = params.api_base or os.getenv(
|
||||
"OPENAI_API_TYPE",
|
||||
os.getenv("AZURE_OPENAI_ENDPOINT") if api_type == "azure" else None,
|
||||
)
|
||||
api_key = params.api_key or os.getenv(
|
||||
"OPENAI_API_KEY",
|
||||
os.getenv("AZURE_OPENAI_KEY") if api_type == "azure" else None,
|
||||
)
|
||||
api_version = params.api_version or os.getenv("OPENAI_API_VERSION")
|
||||
|
||||
if not api_base and params.full_url:
|
||||
# Adapt previous proxy_server_url configuration
|
||||
api_base = params.full_url.split("/chat/completions")[0]
|
||||
if api_type:
|
||||
openai.api_type = api_type
|
||||
if api_base:
|
||||
openai.api_base = api_base
|
||||
if api_key:
|
||||
openai.api_key = api_key
|
||||
if api_version:
|
||||
openai.api_version = api_version
|
||||
if params.proxies:
|
||||
openai.proxy = params.proxies
|
||||
|
||||
|
||||
def _build_openai_client(init_params: OpenAIParameters) -> Tuple[str, ClientType]:
|
||||
import httpx
|
||||
|
||||
@ -112,9 +148,7 @@ def _build_openai_client(init_params: OpenAIParameters) -> Tuple[str, ClientType
|
||||
class OpenAIStreamingOutputOperator(TransformStreamAbsOperator[ModelOutput, str]):
|
||||
"""Transform ModelOutput to openai stream format."""
|
||||
|
||||
async def transform_stream(
|
||||
self, input_value: AsyncIterator[ModelOutput]
|
||||
) -> AsyncIterator[str]:
|
||||
async def transform_stream(self, input_value: AsyncIterator[ModelOutput]):
|
||||
async def model_caller() -> str:
|
||||
"""Read model name from share data.
|
||||
In streaming mode, this transform_stream function will be executed
|
||||
|
@ -1,6 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from dbgpt.core.interface.retriever import RetrieverOperator
|
||||
from dbgpt.core.interface.operator.retriever import RetrieverOperator
|
||||
from dbgpt.datasource.rdbms.base import RDBMSDatabase
|
||||
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from dbgpt.core.awel.task.base import IN
|
||||
from dbgpt.core.interface.retriever import RetrieverOperator
|
||||
from dbgpt.core.interface.operator.retriever import RetrieverOperator
|
||||
from dbgpt.datasource.rdbms.base import RDBMSDatabase
|
||||
from dbgpt.rag.retriever.db_schema import DBSchemaRetriever
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
@ -2,7 +2,7 @@ from functools import reduce
|
||||
from typing import Any, Optional
|
||||
|
||||
from dbgpt.core.awel.task.base import IN
|
||||
from dbgpt.core.interface.retriever import RetrieverOperator
|
||||
from dbgpt.core.interface.operator.retriever import RetrieverOperator
|
||||
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
|
||||
from dbgpt.rag.retriever.rerank import Ranker
|
||||
from dbgpt.rag.retriever.rewrite import QueryRewrite
|
||||
|
@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
@ -30,7 +31,6 @@ from .dbgpts import DbGptsInstance
|
||||
|
||||
CFG = Config()
|
||||
|
||||
import asyncio
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -6,13 +6,13 @@ from dbgpt.storage.chat_history.base import BaseChatHistoryMemory
|
||||
|
||||
from .base import MemoryStoreType
|
||||
|
||||
# Import first for auto create table
|
||||
from .store_type.meta_db_history import DbHistoryMemory
|
||||
|
||||
# TODO remove global variable
|
||||
CFG = Config()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Import first for auto create table
|
||||
from .store_type.meta_db_history import DbHistoryMemory
|
||||
|
||||
|
||||
class ChatHistory:
|
||||
def __init__(self):
|
||||
|
@ -5,6 +5,8 @@ from sqlalchemy.orm.session import Session
|
||||
|
||||
from dbgpt.util.pagination_utils import PaginationResult
|
||||
|
||||
from .db_manager import BaseQuery, DatabaseManager, db
|
||||
|
||||
# The entity type
|
||||
T = TypeVar("T")
|
||||
# The request schema type
|
||||
@ -12,7 +14,6 @@ REQ = TypeVar("REQ")
|
||||
# The response schema type
|
||||
RES = TypeVar("RES")
|
||||
|
||||
from .db_manager import BaseQuery, DatabaseManager, db
|
||||
|
||||
QUERY_SPEC = Union[REQ, Dict[str, Any]]
|
||||
|
||||
|
@ -1,3 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def PublicAPI(*args, **kwargs):
|
||||
"""Decorator to mark a function or class as a public API.
|
||||
|
||||
@ -64,7 +67,7 @@ def DeveloperAPI(*args, **kwargs):
|
||||
return decorator
|
||||
|
||||
|
||||
def _modify_docstring(obj, message: str = None):
|
||||
def _modify_docstring(obj, message: Optional[str] = None):
|
||||
if not message:
|
||||
return
|
||||
if not obj.__doc__:
|
||||
@ -81,6 +84,7 @@ def _modify_docstring(obj, message: str = None):
|
||||
|
||||
if min_indent == float("inf"):
|
||||
min_indent = 0
|
||||
min_indent = int(min_indent)
|
||||
indented_message = message.rstrip() + "\n" + (" " * min_indent)
|
||||
obj.__doc__ = indented_message + original_doc
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
import os
|
||||
from functools import cache
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Optional, cast
|
||||
|
||||
|
||||
class AppConfig:
|
||||
@ -46,7 +46,7 @@ class AppConfig:
|
||||
"""
|
||||
env_lang = (
|
||||
"zh"
|
||||
if os.getenv("LANG") and os.getenv("LANG").startswith("zh")
|
||||
if os.getenv("LANG") and cast(str, os.getenv("LANG")).startswith("zh")
|
||||
else default
|
||||
)
|
||||
return self.get("dbgpt.app.global.language", env_lang)
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""Utilities for formatting strings."""
|
||||
import json
|
||||
from string import Formatter
|
||||
from typing import Any, List, Mapping, Sequence, Union
|
||||
from typing import Any, List, Mapping, Sequence, Set, Union
|
||||
|
||||
|
||||
class StrictFormatter(Formatter):
|
||||
@ -9,7 +9,7 @@ class StrictFormatter(Formatter):
|
||||
|
||||
def check_unused_args(
|
||||
self,
|
||||
used_args: Sequence[Union[int, str]],
|
||||
used_args: Set[Union[int, str]],
|
||||
args: Sequence,
|
||||
kwargs: Mapping[str, Any],
|
||||
) -> None:
|
||||
@ -39,7 +39,7 @@ class StrictFormatter(Formatter):
|
||||
class NoStrictFormatter(StrictFormatter):
|
||||
def check_unused_args(
|
||||
self,
|
||||
used_args: Sequence[Union[int, str]],
|
||||
used_args: Set[Union[int, str]],
|
||||
args: Sequence,
|
||||
kwargs: Mapping[str, Any],
|
||||
) -> None:
|
||||
|
@ -12,14 +12,14 @@ MISSING_DEFAULT_VALUE = "__MISSING_DEFAULT_VALUE__"
|
||||
|
||||
@dataclass
|
||||
class ParameterDescription:
|
||||
param_class: str
|
||||
param_name: str
|
||||
param_type: str
|
||||
default_value: Optional[Any]
|
||||
description: str
|
||||
required: Optional[bool]
|
||||
valid_values: Optional[List[Any]]
|
||||
ext_metadata: Dict
|
||||
required: bool = False
|
||||
param_class: Optional[str] = None
|
||||
param_name: Optional[str] = None
|
||||
param_type: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
default_value: Optional[Any] = None
|
||||
valid_values: Optional[List[Any]] = None
|
||||
ext_metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -186,7 +186,9 @@ def _get_simple_privacy_field_value(obj, field_info):
|
||||
return "******"
|
||||
|
||||
|
||||
def _genenv_ignoring_key_case(env_key: str, env_prefix: str = None, default_value=None):
|
||||
def _genenv_ignoring_key_case(
|
||||
env_key: str, env_prefix: Optional[str] = None, default_value: Optional[str] = None
|
||||
):
|
||||
"""Get the value from the environment variable, ignoring the case of the key"""
|
||||
if env_prefix:
|
||||
env_key = env_prefix + env_key
|
||||
@ -196,7 +198,9 @@ def _genenv_ignoring_key_case(env_key: str, env_prefix: str = None, default_valu
|
||||
|
||||
|
||||
def _genenv_ignoring_key_case_with_prefixes(
|
||||
env_key: str, env_prefixes: List[str] = None, default_value=None
|
||||
env_key: str,
|
||||
env_prefixes: Optional[List[str]] = None,
|
||||
default_value: Optional[str] = None,
|
||||
) -> str:
|
||||
if env_prefixes:
|
||||
for env_prefix in env_prefixes:
|
||||
@ -208,7 +212,7 @@ def _genenv_ignoring_key_case_with_prefixes(
|
||||
|
||||
class EnvArgumentParser:
|
||||
@staticmethod
|
||||
def get_env_prefix(env_key: str) -> str:
|
||||
def get_env_prefix(env_key: str) -> Optional[str]:
|
||||
if not env_key:
|
||||
return None
|
||||
env_key = env_key.replace("-", "_")
|
||||
@ -217,14 +221,14 @@ class EnvArgumentParser:
|
||||
def parse_args_into_dataclass(
|
||||
self,
|
||||
dataclass_type: Type,
|
||||
env_prefixes: List[str] = None,
|
||||
command_args: List[str] = None,
|
||||
env_prefixes: Optional[List[str]] = None,
|
||||
command_args: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Parse parameters from environment variables and command lines and populate them into data class"""
|
||||
parser = argparse.ArgumentParser()
|
||||
for field in fields(dataclass_type):
|
||||
env_var_value = _genenv_ignoring_key_case_with_prefixes(
|
||||
env_var_value: Any = _genenv_ignoring_key_case_with_prefixes(
|
||||
field.name, env_prefixes
|
||||
)
|
||||
if env_var_value:
|
||||
@ -313,7 +317,8 @@ class EnvArgumentParser:
|
||||
|
||||
@staticmethod
|
||||
def create_click_option(
|
||||
*dataclass_types: Type, _dynamic_factory: Callable[[None], List[Type]] = None
|
||||
*dataclass_types: Type,
|
||||
_dynamic_factory: Optional[Callable[[], List[Type]]] = None,
|
||||
):
|
||||
import functools
|
||||
from collections import OrderedDict
|
||||
@ -322,8 +327,9 @@ class EnvArgumentParser:
|
||||
if _dynamic_factory:
|
||||
_types = _dynamic_factory()
|
||||
if _types:
|
||||
dataclass_types = list(_types)
|
||||
dataclass_types = list(_types) # type: ignore
|
||||
for dataclass_type in dataclass_types:
|
||||
# type: ignore
|
||||
for field in fields(dataclass_type):
|
||||
if field.name not in combined_fields:
|
||||
combined_fields[field.name] = field
|
||||
@ -345,7 +351,8 @@ class EnvArgumentParser:
|
||||
|
||||
@staticmethod
|
||||
def _create_raw_click_option(
|
||||
*dataclass_types: Type, _dynamic_factory: Callable[[None], List[Type]] = None
|
||||
*dataclass_types: Type,
|
||||
_dynamic_factory: Optional[Callable[[], List[Type]]] = None,
|
||||
):
|
||||
combined_fields = _merge_dataclass_types(
|
||||
*dataclass_types, _dynamic_factory=_dynamic_factory
|
||||
@ -362,7 +369,8 @@ class EnvArgumentParser:
|
||||
|
||||
@staticmethod
|
||||
def create_argparse_option(
|
||||
*dataclass_types: Type, _dynamic_factory: Callable[[None], List[Type]] = None
|
||||
*dataclass_types: Type,
|
||||
_dynamic_factory: Optional[Callable[[], List[Type]]] = None,
|
||||
) -> argparse.ArgumentParser:
|
||||
combined_fields = _merge_dataclass_types(
|
||||
*dataclass_types, _dynamic_factory=_dynamic_factory
|
||||
@ -429,7 +437,7 @@ class EnvArgumentParser:
|
||||
return "str"
|
||||
|
||||
@staticmethod
|
||||
def _is_require_type(field_type: Type) -> str:
|
||||
def _is_require_type(field_type: Type) -> bool:
|
||||
return field_type not in [Optional[int], Optional[float], Optional[bool]]
|
||||
|
||||
@staticmethod
|
||||
@ -455,13 +463,13 @@ class EnvArgumentParser:
|
||||
|
||||
|
||||
def _merge_dataclass_types(
|
||||
*dataclass_types: Type, _dynamic_factory: Callable[[None], List[Type]] = None
|
||||
*dataclass_types: Type, _dynamic_factory: Optional[Callable[[], List[Type]]] = None
|
||||
) -> OrderedDict:
|
||||
combined_fields = OrderedDict()
|
||||
if _dynamic_factory:
|
||||
_types = _dynamic_factory()
|
||||
if _types:
|
||||
dataclass_types = list(_types)
|
||||
dataclass_types = list(_types) # type: ignore
|
||||
for dataclass_type in dataclass_types:
|
||||
for field in fields(dataclass_type):
|
||||
if field.name not in combined_fields:
|
||||
@ -511,11 +519,12 @@ def _build_parameter_class(desc: List[ParameterDescription]) -> Type:
|
||||
if not desc:
|
||||
raise ValueError("Parameter descriptions cant be empty")
|
||||
param_class_str = desc[0].param_class
|
||||
class_name = None
|
||||
if param_class_str:
|
||||
param_class = import_from_string(param_class_str, ignore_import_error=True)
|
||||
if param_class:
|
||||
return param_class
|
||||
module_name, _, class_name = param_class_str.rpartition(".")
|
||||
module_name, _, class_name = param_class_str.rpartition(".")
|
||||
|
||||
fields_dict = {} # This will store field names and their default values or field()
|
||||
annotations = {} # This will store the type annotations for the fields
|
||||
@ -526,25 +535,30 @@ def _build_parameter_class(desc: List[ParameterDescription]) -> Type:
|
||||
metadata["valid_values"] = d.valid_values
|
||||
|
||||
annotations[d.param_name] = _type_str_to_python_type(
|
||||
d.param_type
|
||||
d.param_type # type: ignore
|
||||
) # Set type annotation
|
||||
fields_dict[d.param_name] = field(default=d.default_value, metadata=metadata)
|
||||
|
||||
# Create the new class. Note the setting of __annotations__ for type hints
|
||||
new_class = type(
|
||||
class_name, (object,), {**fields_dict, "__annotations__": annotations}
|
||||
class_name, # type: ignore
|
||||
(object,),
|
||||
{**fields_dict, "__annotations__": annotations}, # type: ignore
|
||||
)
|
||||
result_class = dataclass(new_class) # Make it a dataclass
|
||||
# Make it a dataclass
|
||||
result_class = dataclass(new_class) # type: ignore
|
||||
|
||||
return result_class
|
||||
|
||||
|
||||
def _extract_parameter_details(
|
||||
parser: argparse.ArgumentParser,
|
||||
param_class: str = None,
|
||||
skip_names: List[str] = None,
|
||||
overwrite_default_values: Dict = {},
|
||||
param_class: Optional[str] = None,
|
||||
skip_names: Optional[List[str]] = None,
|
||||
overwrite_default_values: Optional[Dict[str, Any]] = None,
|
||||
) -> List[ParameterDescription]:
|
||||
if overwrite_default_values is None:
|
||||
overwrite_default_values = {}
|
||||
descriptions = []
|
||||
|
||||
for action in parser._actions:
|
||||
@ -575,7 +589,9 @@ def _extract_parameter_details(
|
||||
if param_name in overwrite_default_values:
|
||||
default_value = overwrite_default_values[param_name]
|
||||
arg_type = (
|
||||
action.type if not callable(action.type) else str(action.type.__name__)
|
||||
action.type
|
||||
if not callable(action.type)
|
||||
else str(action.type.__name__) # type: ignore
|
||||
)
|
||||
description = action.help
|
||||
|
||||
@ -583,10 +599,10 @@ def _extract_parameter_details(
|
||||
required = action.required
|
||||
|
||||
# extract valid values for choices, if provided
|
||||
valid_values = action.choices if action.choices is not None else None
|
||||
valid_values = list(action.choices) if action.choices is not None else None
|
||||
|
||||
# set ext_metadata as an empty dict for now, can be updated later if needed
|
||||
ext_metadata = {}
|
||||
ext_metadata: Dict[str, Any] = {}
|
||||
|
||||
descriptions.append(
|
||||
ParameterDescription(
|
||||
@ -621,7 +637,7 @@ def _get_dict_from_obj(obj, default_value=None) -> Optional[Dict]:
|
||||
def _get_base_model_descriptions(model_cls: "BaseModel") -> List[ParameterDescription]:
|
||||
from dbgpt._private import pydantic
|
||||
|
||||
version = int(pydantic.VERSION.split(".")[0])
|
||||
version = int(pydantic.VERSION.split(".")[0]) # type: ignore
|
||||
schema = model_cls.model_json_schema() if version >= 2 else model_cls.schema()
|
||||
required_fields = set(schema.get("required", []))
|
||||
param_descs = []
|
||||
@ -661,7 +677,7 @@ def _get_base_model_descriptions(model_cls: "BaseModel") -> List[ParameterDescri
|
||||
ext_metadata = (
|
||||
field.field_info.extra if hasattr(field.field_info, "extra") else None
|
||||
)
|
||||
param_class = (f"{model_cls.__module__}.{model_cls.__name__}",)
|
||||
param_class = f"{model_cls.__module__}.{model_cls.__name__}"
|
||||
param_desc = ParameterDescription(
|
||||
param_class=param_class,
|
||||
param_name=field_name,
|
||||
|
@ -5,7 +5,7 @@ import asyncio
|
||||
import logging
|
||||
import logging.handlers
|
||||
import os
|
||||
from typing import Any, List
|
||||
from typing import Any, List, Optional, cast
|
||||
|
||||
from dbgpt.configs.model_config import LOGDIR
|
||||
|
||||
@ -28,19 +28,25 @@ def _get_logging_level() -> str:
|
||||
return os.getenv("DBGPT_LOG_LEVEL", "INFO")
|
||||
|
||||
|
||||
def setup_logging_level(logging_level=None, logger_name: str = None):
|
||||
def setup_logging_level(
|
||||
logging_level: Optional[str] = None, logger_name: Optional[str] = None
|
||||
):
|
||||
if not logging_level:
|
||||
logging_level = _get_logging_level()
|
||||
if type(logging_level) is str:
|
||||
logging_level = logging.getLevelName(logging_level.upper())
|
||||
if logger_name:
|
||||
logger = logging.getLogger(logger_name)
|
||||
logger.setLevel(logging_level)
|
||||
logger.setLevel(cast(str, logging_level))
|
||||
else:
|
||||
logging.basicConfig(level=logging_level, encoding="utf-8")
|
||||
|
||||
|
||||
def setup_logging(logger_name: str, logging_level=None, logger_filename: str = None):
|
||||
def setup_logging(
|
||||
logger_name: str,
|
||||
logging_level: Optional[str] = None,
|
||||
logger_filename: Optional[str] = None,
|
||||
):
|
||||
if not logging_level:
|
||||
logging_level = _get_logging_level()
|
||||
logger = _build_logger(logger_name, logging_level, logger_filename)
|
||||
@ -74,7 +80,11 @@ def get_gpu_memory(max_gpus=None):
|
||||
return gpu_memory
|
||||
|
||||
|
||||
def _build_logger(logger_name, logging_level=None, logger_filename: str = None):
|
||||
def _build_logger(
|
||||
logger_name,
|
||||
logging_level: Optional[str] = None,
|
||||
logger_filename: Optional[str] = None,
|
||||
):
|
||||
global handler
|
||||
|
||||
formatter = logging.Formatter(
|
||||
@ -111,14 +121,14 @@ def get_or_create_event_loop() -> asyncio.BaseEventLoop:
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
assert loop is not None
|
||||
return loop
|
||||
return cast(asyncio.BaseEventLoop, loop)
|
||||
except RuntimeError as e:
|
||||
if not "no running event loop" in str(e) and not "no current event loop" in str(
|
||||
e
|
||||
):
|
||||
raise e
|
||||
logging.warning("Cant not get running event loop, create new event loop now")
|
||||
return asyncio.get_event_loop_policy().new_event_loop()
|
||||
return cast(asyncio.BaseEventLoop, asyncio.get_event_loop_policy().new_event_loop())
|
||||
|
||||
|
||||
def logging_str_to_uvicorn_level(log_level_str):
|
||||
@ -152,7 +162,7 @@ class EndpointFilter(logging.Filter):
|
||||
return record.getMessage().find(self._path) == -1
|
||||
|
||||
|
||||
def setup_http_service_logging(exclude_paths: List[str] = None):
|
||||
def setup_http_service_logging(exclude_paths: Optional[List[str]] = None):
|
||||
"""Setup http service logging
|
||||
|
||||
Now just disable some logs
|
||||
|
@ -53,7 +53,7 @@ class RequestHandleOperator(MapOperator[TriggerReqBody, Dict]):
|
||||
return params
|
||||
|
||||
|
||||
with DAG("dbgpt_awel_simple_rag_rewrite_example") as dag:
|
||||
with DAG("dbgpt_awel_simple_rag_summary_example") as dag:
|
||||
trigger = HttpTrigger(
|
||||
"/examples/rag/summary", methods="POST", request_body=TriggerReqBody
|
||||
)
|
||||
|
@ -13,4 +13,4 @@ aioresponses
|
||||
# for git hooks
|
||||
pre-commit
|
||||
# Type checking
|
||||
mypy==0.991
|
||||
mypy==1.7.0
|
Loading…
Reference in New Issue
Block a user