mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-03 10:05:13 +00:00
refactor(agent): Agent modular refactoring (#1487)
This commit is contained in:
@@ -67,6 +67,44 @@ def DeveloperAPI(*args, **kwargs):
|
||||
return decorator
|
||||
|
||||
|
||||
def mutable(func):
|
||||
"""Decorator to mark a method of an instance will change the instance state.
|
||||
|
||||
Examples:
|
||||
>>> from dbgpt.util.annotations import mutable
|
||||
>>> class Foo:
|
||||
... def __init__(self):
|
||||
... self.a = 1
|
||||
...
|
||||
... @mutable
|
||||
... def change_a(self):
|
||||
... self.a = 2
|
||||
...
|
||||
|
||||
"""
|
||||
_modify_mutability(func, mutability=True)
|
||||
return func
|
||||
|
||||
|
||||
def immutable(func):
|
||||
"""Decorator to mark a method of an instance will not change the instance state.
|
||||
|
||||
Examples:
|
||||
>>> from dbgpt.util.annotations import immutable
|
||||
>>> class Foo:
|
||||
... def __init__(self):
|
||||
... self.a = 1
|
||||
...
|
||||
... @immutable
|
||||
... def get_a(self):
|
||||
... return self.a
|
||||
...
|
||||
|
||||
"""
|
||||
_modify_mutability(func, mutability=False)
|
||||
return func
|
||||
|
||||
|
||||
def _modify_docstring(obj, message: Optional[str] = None):
|
||||
if not message:
|
||||
return
|
||||
@@ -94,3 +132,7 @@ def _modify_annotation(obj, stability) -> None:
|
||||
obj._public_stability = stability
|
||||
if hasattr(obj, "__name__"):
|
||||
obj._annotated = obj.__name__
|
||||
|
||||
|
||||
def _modify_mutability(obj, mutability) -> None:
|
||||
obj._mutability = mutability
|
||||
|
@@ -70,6 +70,7 @@ def extract_code(
|
||||
text: Union[str, List],
|
||||
pattern: str = CODE_BLOCK_PATTERN,
|
||||
detect_single_line_code: bool = False,
|
||||
default_lang: str = "python",
|
||||
) -> List[Tuple[str, str]]:
|
||||
"""Extract code from a text.
|
||||
|
||||
@@ -80,6 +81,7 @@ def extract_code(
|
||||
code block. Defaults to CODE_BLOCK_PATTERN.
|
||||
detect_single_line_code (bool, optional): Enable the new feature for
|
||||
extracting single line code. Defaults to False.
|
||||
default_lang (str, optional): The default language to use when the language
|
||||
|
||||
Returns:
|
||||
list: A list of tuples, each containing the language and the code.
|
||||
@@ -89,7 +91,7 @@ def extract_code(
|
||||
text = content_str(text)
|
||||
if not detect_single_line_code:
|
||||
match = re.findall(pattern, text, flags=re.DOTALL)
|
||||
return match if match else [(UNKNOWN, text)]
|
||||
return match if match else [(default_lang, text)]
|
||||
|
||||
# Extract both multi-line and single-line code block, separated by the | operator
|
||||
# `([^`]+)`: Matches inline code.
|
||||
|
1
dbgpt/util/configure/__init__.py
Normal file
1
dbgpt/util/configure/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .base import ConfigInfo, ConfigProvider, DynConfig
|
163
dbgpt/util/configure/base.py
Normal file
163
dbgpt/util/configure/base.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""Configuration base module."""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _MISSING_TYPE:
|
||||
pass
|
||||
|
||||
|
||||
_MISSING = _MISSING_TYPE()
|
||||
|
||||
|
||||
class ConfigCategory(str, Enum):
|
||||
"""The configuration category."""
|
||||
|
||||
AGENT = "agent"
|
||||
|
||||
|
||||
class ProviderType(str, Enum):
|
||||
"""The provider type."""
|
||||
|
||||
ENV = "env"
|
||||
PROMPT_MANAGER = "prompt_manager"
|
||||
|
||||
|
||||
class ConfigProvider(ABC):
|
||||
"""The configuration provider."""
|
||||
|
||||
name: ProviderType
|
||||
|
||||
@abstractmethod
|
||||
def query(self, key: str, **kwargs) -> Any:
|
||||
"""Query the configuration value by key."""
|
||||
|
||||
|
||||
class EnvironmentConfigProvider(ConfigProvider):
|
||||
"""Environment configuration provider.
|
||||
|
||||
Obtain the configuration value from the environment variable.
|
||||
"""
|
||||
|
||||
name: ProviderType = ProviderType.ENV
|
||||
|
||||
def query(self, key: str, **kwargs) -> Any:
|
||||
import os
|
||||
|
||||
return os.environ.get(key, None)
|
||||
|
||||
|
||||
class PromptManagerConfigProvider(ConfigProvider):
|
||||
"""Prompt manager configuration provider.
|
||||
|
||||
Obtain the configuration value from the prompt manager.
|
||||
|
||||
It is valid only when DB-GPT web server is running for now.
|
||||
"""
|
||||
|
||||
name: ProviderType = ProviderType.PROMPT_MANAGER
|
||||
|
||||
def query(self, key: str, **kwargs) -> Any:
|
||||
from dbgpt._private.config import Config
|
||||
|
||||
try:
|
||||
from dbgpt.serve.prompt.serve import Serve
|
||||
except ImportError:
|
||||
logger.warning("Prompt manager is not available.")
|
||||
return None
|
||||
|
||||
cfg = Config()
|
||||
sys_app = cfg.SYSTEM_APP
|
||||
if not sys_app:
|
||||
return None
|
||||
prompt_serve = Serve.get_instance(sys_app)
|
||||
if not prompt_serve or not prompt_serve.prompt_manager:
|
||||
return None
|
||||
prompt_manager = prompt_serve.prompt_manager
|
||||
value = prompt_manager.prefer_query(key, **kwargs)
|
||||
if not value:
|
||||
return None
|
||||
# Just return the first value
|
||||
return value[0].to_prompt_template().template
|
||||
|
||||
|
||||
class ConfigInfo:
|
||||
def __init__(
|
||||
self,
|
||||
default: Any,
|
||||
key: Optional[str] = None,
|
||||
provider: Optional[Union[str, ConfigProvider]] = None,
|
||||
is_list: bool = False,
|
||||
separator: str = "[LIST_SEP]",
|
||||
description: Optional[str] = None,
|
||||
):
|
||||
self.default = default
|
||||
self.key = key
|
||||
self.provider = provider
|
||||
self.is_list = is_list
|
||||
self.separator = separator
|
||||
self.description = description
|
||||
|
||||
def query(self, **kwargs) -> Any:
|
||||
if self.key is None:
|
||||
return self.default
|
||||
value: Any = None
|
||||
if isinstance(self.provider, ConfigProvider):
|
||||
value = self.provider.query(self.key, **kwargs)
|
||||
elif self.provider == ProviderType.ENV:
|
||||
value = EnvironmentConfigProvider().query(self.key, **kwargs)
|
||||
elif self.provider == ProviderType.PROMPT_MANAGER:
|
||||
value = PromptManagerConfigProvider().query(self.key, **kwargs)
|
||||
if value is None:
|
||||
value = self.default
|
||||
if value and self.is_list and isinstance(value, str):
|
||||
value = value.split(self.separator)
|
||||
return value
|
||||
|
||||
|
||||
def DynConfig(
|
||||
default: Any = _MISSING,
|
||||
*,
|
||||
category: str | ConfigCategory | None = None,
|
||||
key: str | None = None,
|
||||
provider: str | ProviderType | ConfigProvider | None = None,
|
||||
is_list: bool = False,
|
||||
separator: str = "[LIST_SEP]",
|
||||
description: str | None = None,
|
||||
) -> Any:
|
||||
"""Dynamic configuration.
|
||||
|
||||
It allows to query the configuration value dynamically.
|
||||
It can obtain the configuration value from the specified provider.
|
||||
|
||||
**Note**: Now just support obtaining string value or string list value.
|
||||
|
||||
Args:
|
||||
default (Any): The default value.
|
||||
category (str | ConfigCategory | None): The configuration category.
|
||||
key (str | None): The configuration key.
|
||||
provider (str | ProviderType | ConfigProvider | None): The configuration
|
||||
provider.
|
||||
is_list (bool): Whether the value is a list.
|
||||
separator (str): The separator to split the list value.
|
||||
description (str | None): The configuration description.
|
||||
"""
|
||||
if provider is None and category == ConfigCategory.AGENT:
|
||||
provider = ProviderType.PROMPT_MANAGER
|
||||
if default == _MISSING and key is None:
|
||||
raise ValueError("Default value or key is required.")
|
||||
if default != _MISSING and isinstance(default, list):
|
||||
is_list = True
|
||||
return ConfigInfo(
|
||||
default=default,
|
||||
key=key,
|
||||
provider=provider,
|
||||
is_list=is_list,
|
||||
separator=separator,
|
||||
description=description,
|
||||
)
|
@@ -252,11 +252,12 @@ import asyncio
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from dbgpt.agent import (
|
||||
AgentMessage,
|
||||
Action,
|
||||
ActionOutput,
|
||||
AgentMessage,
|
||||
AgentResource,
|
||||
ConversableAgent,
|
||||
ProfileConfig,
|
||||
)
|
||||
from dbgpt.agent.util import cmp_string_equal
|
||||
|
||||
@@ -264,21 +265,25 @@ _HELLO_WORLD = "Hello world"
|
||||
|
||||
|
||||
class HelloWorldSpeakerAgent(ConversableAgent):
|
||||
name: str = "Hodor"
|
||||
profile: str = "HelloWorldSpeaker"
|
||||
goal: str = f"answer any question from user with '{_HELLO_WORLD}'"
|
||||
desc: str = f"You can answer any question from user with '{_HELLO_WORLD}'"
|
||||
constraints: list[str] = [
|
||||
"You can only answer with '{fix_message}'",
|
||||
"You can't use any other words",
|
||||
]
|
||||
examples: str = (
|
||||
f"user: What's your name?\\nassistant: {_HELLO_WORLD}\\n\\n",
|
||||
f"user: What's the weather today?\\nassistant: {_HELLO_WORLD}\\n\\n",
|
||||
f"user: Can you help me?\\nassistant: {_HELLO_WORLD}\\n\\n",
|
||||
f"user: Please tell me a joke.\\nassistant: {_HELLO_WORLD}\\n\\n",
|
||||
f"user: Please answer me without '{_HELLO_WORLD}'.\\nassistant: {_HELLO_WORLD}"
|
||||
"\\n\\n",
|
||||
|
||||
profile: ProfileConfig = ProfileConfig(
|
||||
name="Hodor",
|
||||
role="HelloWorldSpeaker",
|
||||
goal=f"answer any question from user with '{_HELLO_WORLD}'",
|
||||
desc=f"You can answer any question from user with '{_HELLO_WORLD}'",
|
||||
constraints=[
|
||||
"You can only answer with '{{ fix_message }}'",
|
||||
f"You can't use any other words",
|
||||
],
|
||||
examples=(
|
||||
f"user: What's your name?\\nassistant: {_HELLO_WORLD}\\n\\n"
|
||||
f"user: What's the weather today?\\nassistant: {_HELLO_WORLD}\\n\\n"
|
||||
f"user: Can you help me?\\nassistant: {_HELLO_WORLD}\\n\\n"
|
||||
f"user: Please tell me a joke.\\nassistant: {_HELLO_WORLD}\\n\\n"
|
||||
f"user: Please answer me without '{_HELLO_WORLD}'.\\nassistant: "
|
||||
f"{_HELLO_WORLD}"
|
||||
"\\n\\n"
|
||||
),
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
@@ -330,28 +335,28 @@ async def _test_agent():
|
||||
It will not run in the production environment.
|
||||
\"\"\"
|
||||
from dbgpt.model.proxy import OpenAILLMClient
|
||||
from dbgpt.agent import AgentContext, GptsMemory, UserProxyAgent, LLMConfig
|
||||
from dbgpt.agent import AgentContext, AgentMemory, UserProxyAgent, LLMConfig
|
||||
|
||||
llm_client = OpenAILLMClient(model_alias="gpt-3.5-turbo")
|
||||
context: AgentContext = AgentContext(conv_id="summarize")
|
||||
|
||||
default_memory: GptsMemory = GptsMemory()
|
||||
agent_memory: AgentMemory = AgentMemory()
|
||||
|
||||
speaker = (
|
||||
await HelloWorldSpeakerAgent()
|
||||
.bind(context)
|
||||
.bind(LLMConfig(llm_client=llm_client))
|
||||
.bind(default_memory)
|
||||
.bind(agent_memory)
|
||||
.build()
|
||||
)
|
||||
|
||||
user_proxy = await UserProxyAgent().bind(default_memory).bind(context).build()
|
||||
user_proxy = await UserProxyAgent().bind(agent_memory).bind(context).build()
|
||||
await user_proxy.initiate_chat(
|
||||
recipient=speaker,
|
||||
reviewer=user_proxy,
|
||||
message="What's your name?",
|
||||
)
|
||||
print(await default_memory.one_chat_completions("summarize"))
|
||||
print(await agent_memory.gpts_memory.one_chat_completions("summarize"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
73
dbgpt/util/id_generator.py
Normal file
73
dbgpt/util/id_generator.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from typing import Optional
|
||||
|
||||
from snowflake import Snowflake, SnowflakeGenerator
|
||||
|
||||
_GLOBAL_GENERATOR = SnowflakeGenerator(42)
|
||||
|
||||
|
||||
def initialize_id_generator(
|
||||
instance: int, *, seq: int = 0, epoch: int = 0, timestamp: Optional[int] = None
|
||||
):
|
||||
"""Initialize the global ID generator.
|
||||
|
||||
Args:
|
||||
instance (int): The identifier combining both data center and machine ID in
|
||||
traditional Snowflake algorithm. This single value serves to uniquely
|
||||
identify the source of the ID generation request within distributed
|
||||
environments. In standard Snowflake, this would be split into datacenter_id
|
||||
and worker_id, but here it is combined into one for simplicity.
|
||||
|
||||
seq (int, optional): The initial sequence number for the generator. Default is
|
||||
0. The sequence number increments within the same millisecond to allow
|
||||
multiple IDs to be generated in quick succession. It resets when the
|
||||
timestamp advances.
|
||||
|
||||
epoch (int, optional): The epoch time in milliseconds that acts as an offset
|
||||
for the generator. This value helps to reduce the length of the generated
|
||||
number by setting a custom "start time" for the timestamp component.
|
||||
Default is 0.
|
||||
|
||||
timestamp (int, optional): The initial timestamp for the generator in
|
||||
milliseconds since epoch. If not provided, the generator will use the
|
||||
current system time. This can be used for testing or in scenarios where a
|
||||
fixed start time is required.
|
||||
"""
|
||||
global _GLOBAL_GENERATOR
|
||||
_GLOBAL_GENERATOR = SnowflakeGenerator(
|
||||
instance, seq=seq, epoch=epoch, timestamp=timestamp
|
||||
)
|
||||
|
||||
|
||||
def new_id() -> int:
|
||||
"""Generate a new Snowflake ID.
|
||||
|
||||
Returns:
|
||||
int: A new Snowflake ID.
|
||||
"""
|
||||
return next(_GLOBAL_GENERATOR)
|
||||
|
||||
|
||||
def parse(snowflake_id: int, epoch: int = 0) -> Snowflake:
|
||||
"""Parse a Snowflake ID into its components.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from dbgpt.util.id_generator import parse, new_id
|
||||
|
||||
snowflake_id = new_id()
|
||||
snowflake = parse(snowflake_id)
|
||||
print(snowflake.timestamp)
|
||||
print(snowflake.instance)
|
||||
print(snowflake.seq)
|
||||
print(snowflake.datetime)
|
||||
|
||||
Args:
|
||||
snowflake_id (int): The Snowflake ID to parse.
|
||||
epoch (int, optional): The epoch time in milliseconds that acts as an offset
|
||||
for the generator.
|
||||
|
||||
Returns:
|
||||
Snowflake: The parsed Snowflake object.
|
||||
"""
|
||||
return Snowflake.parse(snowflake_id, epoch=epoch)
|
@@ -1,10 +1,53 @@
|
||||
"""Utility functions for calculating similarity."""
|
||||
from typing import TYPE_CHECKING, Any, Sequence
|
||||
from typing import TYPE_CHECKING, Any, List, Sequence
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dbgpt.core.interface.embeddings import Embeddings
|
||||
|
||||
|
||||
def cosine_similarity(embedding1: List[float], embedding2: List[float]) -> float:
|
||||
"""Calculate the cosine similarity between two vectors.
|
||||
|
||||
Args:
|
||||
embedding1(List[float]): The first vector.
|
||||
embedding2(List[float]): The second vector.
|
||||
|
||||
Returns:
|
||||
float: The cosine similarity.
|
||||
"""
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
raise ImportError("numpy is required for SimilarityMetric")
|
||||
dot_product = np.dot(embedding1, embedding2)
|
||||
norm1 = np.linalg.norm(embedding1)
|
||||
norm2 = np.linalg.norm(embedding2)
|
||||
similarity = dot_product / (norm1 * norm2)
|
||||
return similarity
|
||||
|
||||
|
||||
def sigmoid_function(x: float) -> float:
|
||||
"""Calculate the sigmoid function.
|
||||
|
||||
The sigmoid function is defined as:
|
||||
.. math::
|
||||
f(x) = \\frac{1}{1 + e^{-x}}
|
||||
|
||||
It is used to map the input to a value between 0 and 1.
|
||||
|
||||
Args:
|
||||
x(float): The input to the sigmoid function.
|
||||
|
||||
Returns:
|
||||
float: The output of the sigmoid function.
|
||||
"""
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
raise ImportError("numpy is required for sigmoid_function")
|
||||
return 1 / (1 + np.exp(-x))
|
||||
|
||||
|
||||
def calculate_cosine_similarity(
|
||||
embeddings: "Embeddings", prediction: str, contexts: Sequence[str]
|
||||
) -> Any:
|
||||
|
36
dbgpt/util/time_utils.py
Normal file
36
dbgpt/util/time_utils.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import contextlib
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def mock_now(dt_value): # type: ignore
|
||||
"""Context manager for mocking out datetime.now() in unit tests.
|
||||
|
||||
Adapted from langchain.utils.mock_now.
|
||||
|
||||
Example:
|
||||
with mock_now(datetime.datetime(2011, 2, 3, 10, 11)):
|
||||
assert datetime.datetime.now() == datetime.datetime(2011, 2, 3, 10, 11)
|
||||
"""
|
||||
|
||||
class MockDateTime(datetime.datetime):
|
||||
@classmethod
|
||||
def now(cls): # type: ignore
|
||||
# Create a copy of dt_value.
|
||||
return datetime.datetime(
|
||||
dt_value.year,
|
||||
dt_value.month,
|
||||
dt_value.day,
|
||||
dt_value.hour,
|
||||
dt_value.minute,
|
||||
dt_value.second,
|
||||
dt_value.microsecond,
|
||||
dt_value.tzinfo,
|
||||
)
|
||||
|
||||
real_datetime = datetime.datetime
|
||||
datetime.datetime = MockDateTime
|
||||
try:
|
||||
yield datetime.datetime
|
||||
finally:
|
||||
datetime.datetime = real_datetime
|
Reference in New Issue
Block a user