refactor(agent): Agent modular refactoring (#1487)

This commit is contained in:
Fangyin Cheng
2024-05-07 09:45:26 +08:00
committed by GitHub
parent 2a418f91e8
commit 863b5404dd
86 changed files with 4513 additions and 967 deletions

View File

@@ -0,0 +1 @@
"""LLM for agents."""

113
dbgpt/agent/util/llm/llm.py Normal file
View File

@@ -0,0 +1,113 @@
"""LLM module."""
import logging
from collections import defaultdict
from enum import Enum
from typing import Any, Dict, List, Optional, Type
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field
from dbgpt.core import LLMClient, ModelMetadata, ModelRequest
logger = logging.getLogger(__name__)
def _build_model_request(input_value: Dict) -> ModelRequest:
"""Build model request from input value.
Args:
input_value(str or dict): input value
Returns:
ModelRequest: model request, pass to llm client
"""
parm = {
"model": input_value.get("model"),
"messages": input_value.get("messages"),
"temperature": input_value.get("temperature", None),
"max_new_tokens": input_value.get("max_new_tokens", None),
"stop": input_value.get("stop", None),
"stop_token_ids": input_value.get("stop_token_ids", None),
"context_len": input_value.get("context_len", None),
"echo": input_value.get("echo", None),
"span_id": input_value.get("span_id", None),
}
return ModelRequest(**parm)
class LLMStrategyType(Enum):
"""LLM strategy type."""
Priority = "priority"
Auto = "auto"
Default = "default"
class LLMStrategy:
"""LLM strategy base class."""
def __init__(self, llm_client: LLMClient, context: Optional[str] = None):
"""Create an LLMStrategy instance."""
self._llm_client = llm_client
self._context = context
@property
def type(self) -> LLMStrategyType:
"""Return the strategy type."""
return LLMStrategyType.Default
def _excluded_models(
self,
all_models: List[ModelMetadata],
excluded_models: List[str],
need_uses: Optional[List[str]] = None,
):
if not need_uses:
need_uses = []
can_uses = []
for item in all_models:
if item.model in need_uses and item.model not in excluded_models:
can_uses.append(item)
return can_uses
async def next_llm(self, excluded_models: Optional[List[str]] = None):
"""Return next available llm model name.
Args:
excluded_models(List[str]): excluded models
Returns:
str: Next available llm model name
"""
if not excluded_models:
excluded_models = []
try:
all_models = await self._llm_client.models()
available_llms = self._excluded_models(all_models, excluded_models, None)
if available_llms and len(available_llms) > 0:
return available_llms[0].model
else:
raise ValueError("No model service available!")
except Exception as e:
logger.error(f"{self.type} get next llm failed!{str(e)}")
raise ValueError(f"Failed to allocate model service,{str(e)}!")
llm_strategies: Dict[LLMStrategyType, List[Type[LLMStrategy]]] = defaultdict(list)
def register_llm_strategy(
llm_strategy_type: LLMStrategyType, strategy: Type[LLMStrategy]
):
"""Register llm strategy."""
llm_strategies[llm_strategy_type].append(strategy)
class LLMConfig(BaseModel):
"""LLM configuration."""
model_config = ConfigDict(arbitrary_types_allowed=True)
llm_client: Optional[LLMClient] = Field(default_factory=LLMClient)
llm_strategy: LLMStrategyType = Field(default=LLMStrategyType.Default)
strategy_context: Optional[Any] = None

View File

@@ -0,0 +1,183 @@
"""AIWrapper for LLM."""
import json
import logging
import traceback
from typing import Callable, Dict, Optional, Union
from dbgpt.core import LLMClient
from dbgpt.core.interface.output_parser import BaseOutputParser
from dbgpt.util.error_types import LLMChatError
from dbgpt.util.tracer import root_tracer
from .llm import _build_model_request
logger = logging.getLogger(__name__)
class AIWrapper:
"""AIWrapper for LLM."""
cache_path_root: str = ".cache"
extra_kwargs = {
"cache_seed",
"filter_func",
"allow_format_str_template",
"context",
"llm_model",
}
def __init__(
self, llm_client: LLMClient, output_parser: Optional[BaseOutputParser] = None
):
"""Create an AIWrapper instance."""
self.llm_echo = False
self.model_cache_enable = False
self._llm_client = llm_client
self._output_parser = output_parser or BaseOutputParser(is_stream_out=False)
@classmethod
def instantiate(
cls,
template: Optional[Union[str, Callable]] = None,
context: Optional[Dict] = None,
allow_format_str_template: Optional[bool] = False,
):
"""Instantiate the template with the context."""
if not context or template is None:
return template
if isinstance(template, str):
return template.format(**context) if allow_format_str_template else template
return template(context)
def _construct_create_params(self, create_config: Dict, extra_kwargs: Dict) -> Dict:
"""Prime the create_config with additional_kwargs."""
# Validate the config
prompt = create_config.get("prompt")
messages = create_config.get("messages")
if prompt is None and messages is None:
raise ValueError(
"Either prompt or messages should be in create config but not both."
)
context = extra_kwargs.get("context")
if context is None:
# No need to instantiate if no context is provided.
return create_config
# Instantiate the prompt or messages
allow_format_str_template = extra_kwargs.get("allow_format_str_template", False)
# Make a copy of the config
params = create_config.copy()
if prompt is not None:
# Instantiate the prompt
params["prompt"] = self.instantiate(
prompt, context, allow_format_str_template
)
elif context and messages and isinstance(messages, list):
# Instantiate the messages
params["messages"] = [
{
**m,
"content": self.instantiate(
m["content"], context, allow_format_str_template
),
}
if m.get("content")
else m
for m in messages
]
return params
def _separate_create_config(self, config):
"""Separate the config into create_config and extra_kwargs."""
create_config = {k: v for k, v in config.items() if k not in self.extra_kwargs}
extra_kwargs = {k: v for k, v in config.items() if k in self.extra_kwargs}
return create_config, extra_kwargs
def _get_key(self, config):
"""Get a unique identifier of a configuration.
Args:
config (dict or list): A configuration.
Returns:
tuple: A unique identifier which can be used as a key for a dict.
"""
non_cache_key = ["api_key", "base_url", "api_type", "api_version"]
copied = False
for key in non_cache_key:
if key in config:
config, copied = config.copy() if not copied else config, True
config.pop(key)
return json.dumps(config, sort_keys=True, ensure_ascii=False)
async def create(self, **config) -> Optional[str]:
"""Create a response from the input config."""
# merge the input config with the i-th config in the config list
full_config = {**config}
# separate the config into create_config and extra_kwargs
create_config, extra_kwargs = self._separate_create_config(full_config)
# construct the create params
params = self._construct_create_params(create_config, extra_kwargs)
filter_func = extra_kwargs.get("filter_func")
context = extra_kwargs.get("context")
llm_model = extra_kwargs.get("llm_model")
try:
response = await self._completions_create(llm_model, params)
except LLMChatError as e:
logger.debug(f"{llm_model} generate failed!{str(e)}")
raise e
else:
pass_filter = filter_func is None or filter_func(
context=context, response=response
)
if pass_filter:
# Return the response if it passes the filter
return response
else:
return None
def _get_span_metadata(self, payload: Dict) -> Dict:
metadata = {k: v for k, v in payload.items()}
metadata["messages"] = list(
map(lambda m: m if isinstance(m, dict) else m.dict(), metadata["messages"])
)
return metadata
def _llm_messages_convert(self, params):
gpts_messages = params["messages"]
# TODO
return gpts_messages
async def _completions_create(self, llm_model, params) -> str:
payload = {
"model": llm_model,
"prompt": params.get("prompt"),
"messages": self._llm_messages_convert(params),
"temperature": float(params.get("temperature")),
"max_new_tokens": int(params.get("max_new_tokens")),
"echo": self.llm_echo,
}
logger.info(f"Request: \n{payload}")
span = root_tracer.start_span(
"Agent.llm_client.no_streaming_call",
metadata=self._get_span_metadata(payload),
)
payload["span_id"] = span.span_id
payload["model_cache_enable"] = self.model_cache_enable
try:
model_request = _build_model_request(payload)
model_output = await self._llm_client.generate(model_request.copy())
parsed_output = self._output_parser.parse_model_nostream_resp(
model_output, "#########################"
)
return parsed_output
except Exception as e:
logger.error(
f"Call LLMClient error, {str(e)}, detail: {traceback.format_exc()}"
)
raise LLMChatError(original_exception=e) from e
finally:
span.end()

View File

@@ -0,0 +1 @@
"""LLM strategy module."""

View File

@@ -0,0 +1,37 @@
"""Priority strategy for LLM."""
import json
import logging
from typing import List, Optional
from ..llm import LLMStrategy, LLMStrategyType
logger = logging.getLogger(__name__)
class LLMStrategyPriority(LLMStrategy):
"""Priority strategy for llm model service."""
@property
def type(self) -> LLMStrategyType:
"""Return the strategy type."""
return LLMStrategyType.Priority
async def next_llm(self, excluded_models: Optional[List[str]] = None) -> str:
"""Return next available llm model name."""
try:
if not excluded_models:
excluded_models = []
all_models = await self._llm_client.models()
if not self._context:
raise ValueError("No context provided for priority strategy!")
priority: List[str] = json.loads(self._context)
can_uses = self._excluded_models(all_models, excluded_models, priority)
if can_uses and len(can_uses) > 0:
return can_uses[0].model
else:
raise ValueError("No model service available!")
except Exception as e:
logger.error(f"{self.type} get next llm failed!{str(e)}")
raise ValueError(f"Failed to allocate model service,{str(e)}!")