refactor(agent): Agent modular refactoring (#1487)

This commit is contained in:
Fangyin Cheng
2024-05-07 09:45:26 +08:00
committed by GitHub
parent 2a418f91e8
commit 863b5404dd
86 changed files with 4513 additions and 967 deletions

View File

@@ -67,6 +67,44 @@ def DeveloperAPI(*args, **kwargs):
return decorator
def mutable(func):
"""Decorator to mark a method of an instance will change the instance state.
Examples:
>>> from dbgpt.util.annotations import mutable
>>> class Foo:
... def __init__(self):
... self.a = 1
...
... @mutable
... def change_a(self):
... self.a = 2
...
"""
_modify_mutability(func, mutability=True)
return func
def immutable(func):
"""Decorator to mark a method of an instance will not change the instance state.
Examples:
>>> from dbgpt.util.annotations import immutable
>>> class Foo:
... def __init__(self):
... self.a = 1
...
... @immutable
... def get_a(self):
... return self.a
...
"""
_modify_mutability(func, mutability=False)
return func
def _modify_docstring(obj, message: Optional[str] = None):
if not message:
return
@@ -94,3 +132,7 @@ def _modify_annotation(obj, stability) -> None:
obj._public_stability = stability
if hasattr(obj, "__name__"):
obj._annotated = obj.__name__
def _modify_mutability(obj, mutability) -> None:
obj._mutability = mutability

View File

@@ -70,6 +70,7 @@ def extract_code(
text: Union[str, List],
pattern: str = CODE_BLOCK_PATTERN,
detect_single_line_code: bool = False,
default_lang: str = "python",
) -> List[Tuple[str, str]]:
"""Extract code from a text.
@@ -80,6 +81,7 @@ def extract_code(
code block. Defaults to CODE_BLOCK_PATTERN.
detect_single_line_code (bool, optional): Enable the new feature for
extracting single line code. Defaults to False.
default_lang (str, optional): The default language to use when the language
Returns:
list: A list of tuples, each containing the language and the code.
@@ -89,7 +91,7 @@ def extract_code(
text = content_str(text)
if not detect_single_line_code:
match = re.findall(pattern, text, flags=re.DOTALL)
return match if match else [(UNKNOWN, text)]
return match if match else [(default_lang, text)]
# Extract both multi-line and single-line code block, separated by the | operator
# `([^`]+)`: Matches inline code.

View File

@@ -0,0 +1 @@
from .base import ConfigInfo, ConfigProvider, DynConfig

View File

@@ -0,0 +1,163 @@
"""Configuration base module."""
import logging
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Optional, Union
logger = logging.getLogger(__name__)
class _MISSING_TYPE:
pass
_MISSING = _MISSING_TYPE()
class ConfigCategory(str, Enum):
"""The configuration category."""
AGENT = "agent"
class ProviderType(str, Enum):
"""The provider type."""
ENV = "env"
PROMPT_MANAGER = "prompt_manager"
class ConfigProvider(ABC):
"""The configuration provider."""
name: ProviderType
@abstractmethod
def query(self, key: str, **kwargs) -> Any:
"""Query the configuration value by key."""
class EnvironmentConfigProvider(ConfigProvider):
"""Environment configuration provider.
Obtain the configuration value from the environment variable.
"""
name: ProviderType = ProviderType.ENV
def query(self, key: str, **kwargs) -> Any:
import os
return os.environ.get(key, None)
class PromptManagerConfigProvider(ConfigProvider):
"""Prompt manager configuration provider.
Obtain the configuration value from the prompt manager.
It is valid only when DB-GPT web server is running for now.
"""
name: ProviderType = ProviderType.PROMPT_MANAGER
def query(self, key: str, **kwargs) -> Any:
from dbgpt._private.config import Config
try:
from dbgpt.serve.prompt.serve import Serve
except ImportError:
logger.warning("Prompt manager is not available.")
return None
cfg = Config()
sys_app = cfg.SYSTEM_APP
if not sys_app:
return None
prompt_serve = Serve.get_instance(sys_app)
if not prompt_serve or not prompt_serve.prompt_manager:
return None
prompt_manager = prompt_serve.prompt_manager
value = prompt_manager.prefer_query(key, **kwargs)
if not value:
return None
# Just return the first value
return value[0].to_prompt_template().template
class ConfigInfo:
def __init__(
self,
default: Any,
key: Optional[str] = None,
provider: Optional[Union[str, ConfigProvider]] = None,
is_list: bool = False,
separator: str = "[LIST_SEP]",
description: Optional[str] = None,
):
self.default = default
self.key = key
self.provider = provider
self.is_list = is_list
self.separator = separator
self.description = description
def query(self, **kwargs) -> Any:
if self.key is None:
return self.default
value: Any = None
if isinstance(self.provider, ConfigProvider):
value = self.provider.query(self.key, **kwargs)
elif self.provider == ProviderType.ENV:
value = EnvironmentConfigProvider().query(self.key, **kwargs)
elif self.provider == ProviderType.PROMPT_MANAGER:
value = PromptManagerConfigProvider().query(self.key, **kwargs)
if value is None:
value = self.default
if value and self.is_list and isinstance(value, str):
value = value.split(self.separator)
return value
def DynConfig(
default: Any = _MISSING,
*,
category: str | ConfigCategory | None = None,
key: str | None = None,
provider: str | ProviderType | ConfigProvider | None = None,
is_list: bool = False,
separator: str = "[LIST_SEP]",
description: str | None = None,
) -> Any:
"""Dynamic configuration.
It allows to query the configuration value dynamically.
It can obtain the configuration value from the specified provider.
**Note**: Now just support obtaining string value or string list value.
Args:
default (Any): The default value.
category (str | ConfigCategory | None): The configuration category.
key (str | None): The configuration key.
provider (str | ProviderType | ConfigProvider | None): The configuration
provider.
is_list (bool): Whether the value is a list.
separator (str): The separator to split the list value.
description (str | None): The configuration description.
"""
if provider is None and category == ConfigCategory.AGENT:
provider = ProviderType.PROMPT_MANAGER
if default == _MISSING and key is None:
raise ValueError("Default value or key is required.")
if default != _MISSING and isinstance(default, list):
is_list = True
return ConfigInfo(
default=default,
key=key,
provider=provider,
is_list=is_list,
separator=separator,
description=description,
)

View File

@@ -252,11 +252,12 @@ import asyncio
from typing import Optional, Tuple
from dbgpt.agent import (
AgentMessage,
Action,
ActionOutput,
AgentMessage,
AgentResource,
ConversableAgent,
ProfileConfig,
)
from dbgpt.agent.util import cmp_string_equal
@@ -264,21 +265,25 @@ _HELLO_WORLD = "Hello world"
class HelloWorldSpeakerAgent(ConversableAgent):
name: str = "Hodor"
profile: str = "HelloWorldSpeaker"
goal: str = f"answer any question from user with '{_HELLO_WORLD}'"
desc: str = f"You can answer any question from user with '{_HELLO_WORLD}'"
constraints: list[str] = [
"You can only answer with '{fix_message}'",
"You can't use any other words",
]
examples: str = (
f"user: What's your name?\\nassistant: {_HELLO_WORLD}\\n\\n",
f"user: What's the weather today?\\nassistant: {_HELLO_WORLD}\\n\\n",
f"user: Can you help me?\\nassistant: {_HELLO_WORLD}\\n\\n",
f"user: Please tell me a joke.\\nassistant: {_HELLO_WORLD}\\n\\n",
f"user: Please answer me without '{_HELLO_WORLD}'.\\nassistant: {_HELLO_WORLD}"
"\\n\\n",
profile: ProfileConfig = ProfileConfig(
name="Hodor",
role="HelloWorldSpeaker",
goal=f"answer any question from user with '{_HELLO_WORLD}'",
desc=f"You can answer any question from user with '{_HELLO_WORLD}'",
constraints=[
"You can only answer with '{{ fix_message }}'",
f"You can't use any other words",
],
examples=(
f"user: What's your name?\\nassistant: {_HELLO_WORLD}\\n\\n"
f"user: What's the weather today?\\nassistant: {_HELLO_WORLD}\\n\\n"
f"user: Can you help me?\\nassistant: {_HELLO_WORLD}\\n\\n"
f"user: Please tell me a joke.\\nassistant: {_HELLO_WORLD}\\n\\n"
f"user: Please answer me without '{_HELLO_WORLD}'.\\nassistant: "
f"{_HELLO_WORLD}"
"\\n\\n"
),
)
def __init__(self, **kwargs):
@@ -330,28 +335,28 @@ async def _test_agent():
It will not run in the production environment.
\"\"\"
from dbgpt.model.proxy import OpenAILLMClient
from dbgpt.agent import AgentContext, GptsMemory, UserProxyAgent, LLMConfig
from dbgpt.agent import AgentContext, AgentMemory, UserProxyAgent, LLMConfig
llm_client = OpenAILLMClient(model_alias="gpt-3.5-turbo")
context: AgentContext = AgentContext(conv_id="summarize")
default_memory: GptsMemory = GptsMemory()
agent_memory: AgentMemory = AgentMemory()
speaker = (
await HelloWorldSpeakerAgent()
.bind(context)
.bind(LLMConfig(llm_client=llm_client))
.bind(default_memory)
.bind(agent_memory)
.build()
)
user_proxy = await UserProxyAgent().bind(default_memory).bind(context).build()
user_proxy = await UserProxyAgent().bind(agent_memory).bind(context).build()
await user_proxy.initiate_chat(
recipient=speaker,
reviewer=user_proxy,
message="What's your name?",
)
print(await default_memory.one_chat_completions("summarize"))
print(await agent_memory.gpts_memory.one_chat_completions("summarize"))
if __name__ == "__main__":

View File

@@ -0,0 +1,73 @@
from typing import Optional
from snowflake import Snowflake, SnowflakeGenerator
_GLOBAL_GENERATOR = SnowflakeGenerator(42)
def initialize_id_generator(
instance: int, *, seq: int = 0, epoch: int = 0, timestamp: Optional[int] = None
):
"""Initialize the global ID generator.
Args:
instance (int): The identifier combining both data center and machine ID in
traditional Snowflake algorithm. This single value serves to uniquely
identify the source of the ID generation request within distributed
environments. In standard Snowflake, this would be split into datacenter_id
and worker_id, but here it is combined into one for simplicity.
seq (int, optional): The initial sequence number for the generator. Default is
0. The sequence number increments within the same millisecond to allow
multiple IDs to be generated in quick succession. It resets when the
timestamp advances.
epoch (int, optional): The epoch time in milliseconds that acts as an offset
for the generator. This value helps to reduce the length of the generated
number by setting a custom "start time" for the timestamp component.
Default is 0.
timestamp (int, optional): The initial timestamp for the generator in
milliseconds since epoch. If not provided, the generator will use the
current system time. This can be used for testing or in scenarios where a
fixed start time is required.
"""
global _GLOBAL_GENERATOR
_GLOBAL_GENERATOR = SnowflakeGenerator(
instance, seq=seq, epoch=epoch, timestamp=timestamp
)
def new_id() -> int:
"""Generate a new Snowflake ID.
Returns:
int: A new Snowflake ID.
"""
return next(_GLOBAL_GENERATOR)
def parse(snowflake_id: int, epoch: int = 0) -> Snowflake:
"""Parse a Snowflake ID into its components.
Example:
.. code-block:: python
from dbgpt.util.id_generator import parse, new_id
snowflake_id = new_id()
snowflake = parse(snowflake_id)
print(snowflake.timestamp)
print(snowflake.instance)
print(snowflake.seq)
print(snowflake.datetime)
Args:
snowflake_id (int): The Snowflake ID to parse.
epoch (int, optional): The epoch time in milliseconds that acts as an offset
for the generator.
Returns:
Snowflake: The parsed Snowflake object.
"""
return Snowflake.parse(snowflake_id, epoch=epoch)

View File

@@ -1,10 +1,53 @@
"""Utility functions for calculating similarity."""
from typing import TYPE_CHECKING, Any, Sequence
from typing import TYPE_CHECKING, Any, List, Sequence
if TYPE_CHECKING:
from dbgpt.core.interface.embeddings import Embeddings
def cosine_similarity(embedding1: List[float], embedding2: List[float]) -> float:
"""Calculate the cosine similarity between two vectors.
Args:
embedding1(List[float]): The first vector.
embedding2(List[float]): The second vector.
Returns:
float: The cosine similarity.
"""
try:
import numpy as np
except ImportError:
raise ImportError("numpy is required for SimilarityMetric")
dot_product = np.dot(embedding1, embedding2)
norm1 = np.linalg.norm(embedding1)
norm2 = np.linalg.norm(embedding2)
similarity = dot_product / (norm1 * norm2)
return similarity
def sigmoid_function(x: float) -> float:
"""Calculate the sigmoid function.
The sigmoid function is defined as:
.. math::
f(x) = \\frac{1}{1 + e^{-x}}
It is used to map the input to a value between 0 and 1.
Args:
x(float): The input to the sigmoid function.
Returns:
float: The output of the sigmoid function.
"""
try:
import numpy as np
except ImportError:
raise ImportError("numpy is required for sigmoid_function")
return 1 / (1 + np.exp(-x))
def calculate_cosine_similarity(
embeddings: "Embeddings", prediction: str, contexts: Sequence[str]
) -> Any:

36
dbgpt/util/time_utils.py Normal file
View File

@@ -0,0 +1,36 @@
import contextlib
from datetime import datetime
@contextlib.contextmanager
def mock_now(dt_value): # type: ignore
"""Context manager for mocking out datetime.now() in unit tests.
Adapted from langchain.utils.mock_now.
Example:
with mock_now(datetime.datetime(2011, 2, 3, 10, 11)):
assert datetime.datetime.now() == datetime.datetime(2011, 2, 3, 10, 11)
"""
class MockDateTime(datetime.datetime):
@classmethod
def now(cls): # type: ignore
# Create a copy of dt_value.
return datetime.datetime(
dt_value.year,
dt_value.month,
dt_value.day,
dt_value.hour,
dt_value.minute,
dt_value.second,
dt_value.microsecond,
dt_value.tzinfo,
)
real_datetime = datetime.datetime
datetime.datetime = MockDateTime
try:
yield datetime.datetime
finally:
datetime.datetime = real_datetime