mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-05 02:51:07 +00:00
refactor(agent): Agent modular refactoring (#1487)
This commit is contained in:
1
dbgpt/agent/util/llm/__init__.py
Normal file
1
dbgpt/agent/util/llm/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""LLM for agents."""
|
113
dbgpt/agent/util/llm/llm.py
Normal file
113
dbgpt/agent/util/llm/llm.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""LLM module."""
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field
|
||||
from dbgpt.core import LLMClient, ModelMetadata, ModelRequest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _build_model_request(input_value: Dict) -> ModelRequest:
|
||||
"""Build model request from input value.
|
||||
|
||||
Args:
|
||||
input_value(str or dict): input value
|
||||
|
||||
Returns:
|
||||
ModelRequest: model request, pass to llm client
|
||||
"""
|
||||
parm = {
|
||||
"model": input_value.get("model"),
|
||||
"messages": input_value.get("messages"),
|
||||
"temperature": input_value.get("temperature", None),
|
||||
"max_new_tokens": input_value.get("max_new_tokens", None),
|
||||
"stop": input_value.get("stop", None),
|
||||
"stop_token_ids": input_value.get("stop_token_ids", None),
|
||||
"context_len": input_value.get("context_len", None),
|
||||
"echo": input_value.get("echo", None),
|
||||
"span_id": input_value.get("span_id", None),
|
||||
}
|
||||
|
||||
return ModelRequest(**parm)
|
||||
|
||||
|
||||
class LLMStrategyType(Enum):
|
||||
"""LLM strategy type."""
|
||||
|
||||
Priority = "priority"
|
||||
Auto = "auto"
|
||||
Default = "default"
|
||||
|
||||
|
||||
class LLMStrategy:
|
||||
"""LLM strategy base class."""
|
||||
|
||||
def __init__(self, llm_client: LLMClient, context: Optional[str] = None):
|
||||
"""Create an LLMStrategy instance."""
|
||||
self._llm_client = llm_client
|
||||
self._context = context
|
||||
|
||||
@property
|
||||
def type(self) -> LLMStrategyType:
|
||||
"""Return the strategy type."""
|
||||
return LLMStrategyType.Default
|
||||
|
||||
def _excluded_models(
|
||||
self,
|
||||
all_models: List[ModelMetadata],
|
||||
excluded_models: List[str],
|
||||
need_uses: Optional[List[str]] = None,
|
||||
):
|
||||
if not need_uses:
|
||||
need_uses = []
|
||||
can_uses = []
|
||||
for item in all_models:
|
||||
if item.model in need_uses and item.model not in excluded_models:
|
||||
can_uses.append(item)
|
||||
return can_uses
|
||||
|
||||
async def next_llm(self, excluded_models: Optional[List[str]] = None):
|
||||
"""Return next available llm model name.
|
||||
|
||||
Args:
|
||||
excluded_models(List[str]): excluded models
|
||||
|
||||
Returns:
|
||||
str: Next available llm model name
|
||||
"""
|
||||
if not excluded_models:
|
||||
excluded_models = []
|
||||
try:
|
||||
all_models = await self._llm_client.models()
|
||||
available_llms = self._excluded_models(all_models, excluded_models, None)
|
||||
if available_llms and len(available_llms) > 0:
|
||||
return available_llms[0].model
|
||||
else:
|
||||
raise ValueError("No model service available!")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.type} get next llm failed!{str(e)}")
|
||||
raise ValueError(f"Failed to allocate model service,{str(e)}!")
|
||||
|
||||
|
||||
llm_strategies: Dict[LLMStrategyType, List[Type[LLMStrategy]]] = defaultdict(list)
|
||||
|
||||
|
||||
def register_llm_strategy(
|
||||
llm_strategy_type: LLMStrategyType, strategy: Type[LLMStrategy]
|
||||
):
|
||||
"""Register llm strategy."""
|
||||
llm_strategies[llm_strategy_type].append(strategy)
|
||||
|
||||
|
||||
class LLMConfig(BaseModel):
|
||||
"""LLM configuration."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
llm_client: Optional[LLMClient] = Field(default_factory=LLMClient)
|
||||
llm_strategy: LLMStrategyType = Field(default=LLMStrategyType.Default)
|
||||
strategy_context: Optional[Any] = None
|
183
dbgpt/agent/util/llm/llm_client.py
Normal file
183
dbgpt/agent/util/llm/llm_client.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""AIWrapper for LLM."""
|
||||
import json
|
||||
import logging
|
||||
import traceback
|
||||
from typing import Callable, Dict, Optional, Union
|
||||
|
||||
from dbgpt.core import LLMClient
|
||||
from dbgpt.core.interface.output_parser import BaseOutputParser
|
||||
from dbgpt.util.error_types import LLMChatError
|
||||
from dbgpt.util.tracer import root_tracer
|
||||
|
||||
from .llm import _build_model_request
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AIWrapper:
|
||||
"""AIWrapper for LLM."""
|
||||
|
||||
cache_path_root: str = ".cache"
|
||||
extra_kwargs = {
|
||||
"cache_seed",
|
||||
"filter_func",
|
||||
"allow_format_str_template",
|
||||
"context",
|
||||
"llm_model",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self, llm_client: LLMClient, output_parser: Optional[BaseOutputParser] = None
|
||||
):
|
||||
"""Create an AIWrapper instance."""
|
||||
self.llm_echo = False
|
||||
self.model_cache_enable = False
|
||||
self._llm_client = llm_client
|
||||
self._output_parser = output_parser or BaseOutputParser(is_stream_out=False)
|
||||
|
||||
@classmethod
|
||||
def instantiate(
|
||||
cls,
|
||||
template: Optional[Union[str, Callable]] = None,
|
||||
context: Optional[Dict] = None,
|
||||
allow_format_str_template: Optional[bool] = False,
|
||||
):
|
||||
"""Instantiate the template with the context."""
|
||||
if not context or template is None:
|
||||
return template
|
||||
if isinstance(template, str):
|
||||
return template.format(**context) if allow_format_str_template else template
|
||||
return template(context)
|
||||
|
||||
def _construct_create_params(self, create_config: Dict, extra_kwargs: Dict) -> Dict:
|
||||
"""Prime the create_config with additional_kwargs."""
|
||||
# Validate the config
|
||||
prompt = create_config.get("prompt")
|
||||
messages = create_config.get("messages")
|
||||
if prompt is None and messages is None:
|
||||
raise ValueError(
|
||||
"Either prompt or messages should be in create config but not both."
|
||||
)
|
||||
|
||||
context = extra_kwargs.get("context")
|
||||
if context is None:
|
||||
# No need to instantiate if no context is provided.
|
||||
return create_config
|
||||
# Instantiate the prompt or messages
|
||||
allow_format_str_template = extra_kwargs.get("allow_format_str_template", False)
|
||||
# Make a copy of the config
|
||||
params = create_config.copy()
|
||||
if prompt is not None:
|
||||
# Instantiate the prompt
|
||||
params["prompt"] = self.instantiate(
|
||||
prompt, context, allow_format_str_template
|
||||
)
|
||||
elif context and messages and isinstance(messages, list):
|
||||
# Instantiate the messages
|
||||
params["messages"] = [
|
||||
{
|
||||
**m,
|
||||
"content": self.instantiate(
|
||||
m["content"], context, allow_format_str_template
|
||||
),
|
||||
}
|
||||
if m.get("content")
|
||||
else m
|
||||
for m in messages
|
||||
]
|
||||
return params
|
||||
|
||||
def _separate_create_config(self, config):
|
||||
"""Separate the config into create_config and extra_kwargs."""
|
||||
create_config = {k: v for k, v in config.items() if k not in self.extra_kwargs}
|
||||
extra_kwargs = {k: v for k, v in config.items() if k in self.extra_kwargs}
|
||||
return create_config, extra_kwargs
|
||||
|
||||
def _get_key(self, config):
|
||||
"""Get a unique identifier of a configuration.
|
||||
|
||||
Args:
|
||||
config (dict or list): A configuration.
|
||||
|
||||
Returns:
|
||||
tuple: A unique identifier which can be used as a key for a dict.
|
||||
"""
|
||||
non_cache_key = ["api_key", "base_url", "api_type", "api_version"]
|
||||
copied = False
|
||||
for key in non_cache_key:
|
||||
if key in config:
|
||||
config, copied = config.copy() if not copied else config, True
|
||||
config.pop(key)
|
||||
return json.dumps(config, sort_keys=True, ensure_ascii=False)
|
||||
|
||||
async def create(self, **config) -> Optional[str]:
|
||||
"""Create a response from the input config."""
|
||||
# merge the input config with the i-th config in the config list
|
||||
full_config = {**config}
|
||||
# separate the config into create_config and extra_kwargs
|
||||
create_config, extra_kwargs = self._separate_create_config(full_config)
|
||||
|
||||
# construct the create params
|
||||
params = self._construct_create_params(create_config, extra_kwargs)
|
||||
filter_func = extra_kwargs.get("filter_func")
|
||||
context = extra_kwargs.get("context")
|
||||
llm_model = extra_kwargs.get("llm_model")
|
||||
try:
|
||||
response = await self._completions_create(llm_model, params)
|
||||
except LLMChatError as e:
|
||||
logger.debug(f"{llm_model} generate failed!{str(e)}")
|
||||
raise e
|
||||
else:
|
||||
pass_filter = filter_func is None or filter_func(
|
||||
context=context, response=response
|
||||
)
|
||||
if pass_filter:
|
||||
# Return the response if it passes the filter
|
||||
return response
|
||||
else:
|
||||
return None
|
||||
|
||||
def _get_span_metadata(self, payload: Dict) -> Dict:
|
||||
metadata = {k: v for k, v in payload.items()}
|
||||
|
||||
metadata["messages"] = list(
|
||||
map(lambda m: m if isinstance(m, dict) else m.dict(), metadata["messages"])
|
||||
)
|
||||
return metadata
|
||||
|
||||
def _llm_messages_convert(self, params):
|
||||
gpts_messages = params["messages"]
|
||||
# TODO
|
||||
|
||||
return gpts_messages
|
||||
|
||||
async def _completions_create(self, llm_model, params) -> str:
|
||||
payload = {
|
||||
"model": llm_model,
|
||||
"prompt": params.get("prompt"),
|
||||
"messages": self._llm_messages_convert(params),
|
||||
"temperature": float(params.get("temperature")),
|
||||
"max_new_tokens": int(params.get("max_new_tokens")),
|
||||
"echo": self.llm_echo,
|
||||
}
|
||||
logger.info(f"Request: \n{payload}")
|
||||
span = root_tracer.start_span(
|
||||
"Agent.llm_client.no_streaming_call",
|
||||
metadata=self._get_span_metadata(payload),
|
||||
)
|
||||
payload["span_id"] = span.span_id
|
||||
payload["model_cache_enable"] = self.model_cache_enable
|
||||
try:
|
||||
model_request = _build_model_request(payload)
|
||||
model_output = await self._llm_client.generate(model_request.copy())
|
||||
parsed_output = self._output_parser.parse_model_nostream_resp(
|
||||
model_output, "#########################"
|
||||
)
|
||||
return parsed_output
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Call LLMClient error, {str(e)}, detail: {traceback.format_exc()}"
|
||||
)
|
||||
raise LLMChatError(original_exception=e) from e
|
||||
finally:
|
||||
span.end()
|
1
dbgpt/agent/util/llm/strategy/__init__.py
Normal file
1
dbgpt/agent/util/llm/strategy/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""LLM strategy module."""
|
37
dbgpt/agent/util/llm/strategy/priority.py
Normal file
37
dbgpt/agent/util/llm/strategy/priority.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""Priority strategy for LLM."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from ..llm import LLMStrategy, LLMStrategyType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMStrategyPriority(LLMStrategy):
|
||||
"""Priority strategy for llm model service."""
|
||||
|
||||
@property
|
||||
def type(self) -> LLMStrategyType:
|
||||
"""Return the strategy type."""
|
||||
return LLMStrategyType.Priority
|
||||
|
||||
async def next_llm(self, excluded_models: Optional[List[str]] = None) -> str:
|
||||
"""Return next available llm model name."""
|
||||
try:
|
||||
if not excluded_models:
|
||||
excluded_models = []
|
||||
all_models = await self._llm_client.models()
|
||||
if not self._context:
|
||||
raise ValueError("No context provided for priority strategy!")
|
||||
priority: List[str] = json.loads(self._context)
|
||||
can_uses = self._excluded_models(all_models, excluded_models, priority)
|
||||
if can_uses and len(can_uses) > 0:
|
||||
return can_uses[0].model
|
||||
else:
|
||||
raise ValueError("No model service available!")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.type} get next llm failed!{str(e)}")
|
||||
raise ValueError(f"Failed to allocate model service,{str(e)}!")
|
Reference in New Issue
Block a user