mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-17 23:18:20 +00:00
impl llm utils
This commit is contained in:
@@ -3,6 +3,8 @@
|
||||
|
||||
from pilot.singleton import Singleton
|
||||
from pilot.configs.config import Config
|
||||
from typing import List
|
||||
from pilot.model.base import Message
|
||||
|
||||
class AgentManager(metaclass=Singleton):
|
||||
"""Agent manager for managing DB-GPT agents
|
||||
@@ -30,7 +32,16 @@ class AgentManager(metaclass=Singleton):
|
||||
Returns:
|
||||
The key of the new agent
|
||||
"""
|
||||
pass
|
||||
messages: List[Message] = [
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
for plugin in self.cfg.plugins:
|
||||
if not plugin.can_handle_pre_instruction():
|
||||
continue
|
||||
if plugin_messages := plugin.pre_instruction(messages):
|
||||
messages.extend(iter(plugin_messages))
|
||||
#
|
||||
|
||||
def message_agent(self):
|
||||
pass
|
||||
|
@@ -17,6 +17,7 @@ class Config(metaclass=Singleton):
|
||||
self.execute_local_commands = (
|
||||
os.getenv("EXECUTE_LOCAL_COMMANDS", "False") == "True"
|
||||
)
|
||||
self.temperature = float(os.getenv("TEMPERATURE", "0.7"))
|
||||
|
||||
self.plugins_dir = os.getenv("PLUGINS_DIR", 'plugins')
|
||||
self.plugins:List[AutoGPTPluginTemplate] = []
|
||||
@@ -29,4 +30,8 @@ class Config(metaclass=Singleton):
|
||||
"""Set the plugins value."""
|
||||
self.plugins = value
|
||||
|
||||
def set_temperature(self, value: int) -> None:
|
||||
""" Set the temperature value."""
|
||||
self.temperature = value
|
||||
|
||||
|
47
pilot/model/llm_utils.py
Normal file
47
pilot/model/llm_utils.py
Normal file
@@ -0,0 +1,47 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
from typing import List, Optional
|
||||
from pilot.model.base import Message
|
||||
from pilot.configs.config import Config
|
||||
from pilot.server.vicuna_server import generate_output
|
||||
|
||||
def create_chat_completion(
|
||||
messages: List[Message], # type: ignore
|
||||
model: Optional[str] = None,
|
||||
temperature: float = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> str:
|
||||
"""Create a chat completion using the vicuna local model
|
||||
|
||||
Args:
|
||||
messages(List[Message]): The messages to send to the chat completion
|
||||
model (str, optional): The model to use. Defaults to None.
|
||||
temperature (float, optional): The temperature to use. Defaults to 0.7.
|
||||
max_tokens (int, optional): The max tokens to use. Defaults to None
|
||||
|
||||
Returns:
|
||||
str: The response from chat completion
|
||||
"""
|
||||
cfg = Config()
|
||||
if temperature is None:
|
||||
temperature = cfg.temperature
|
||||
|
||||
for plugin in cfg.plugins:
|
||||
if plugin.can_handle_chat_completion(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
):
|
||||
message = plugin.handle_chat_completion(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
if message is not None:
|
||||
return message
|
||||
|
||||
response = None
|
||||
# TODO impl this use vicuna server api
|
Reference in New Issue
Block a user