mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-23 12:21:08 +00:00
fix error
This commit is contained in:
commit
0a27d8ff7e
@ -5,6 +5,9 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.singleton import Singleton
|
from pilot.singleton import Singleton
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
from typing import List
|
||||||
|
from pilot.model.base import Message
|
||||||
|
|
||||||
|
|
||||||
class AgentManager(metaclass=Singleton):
|
class AgentManager(metaclass=Singleton):
|
||||||
@ -14,6 +17,19 @@ class AgentManager(metaclass=Singleton):
|
|||||||
self.next_key = 0
|
self.next_key = 0
|
||||||
self.agents = {} # key, (task, full_message_history, model)
|
self.agents = {} # key, (task, full_message_history, model)
|
||||||
self.cfg = Config()
|
self.cfg = Config()
|
||||||
|
"""Agent manager for managing DB-GPT agents
|
||||||
|
In order to compatible auto gpt plugins,
|
||||||
|
we use the same template with it.
|
||||||
|
|
||||||
|
Args: next_keys
|
||||||
|
agents
|
||||||
|
cfg
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.next_key = 0
|
||||||
|
self.agents = {} #TODO need to define
|
||||||
|
self.cfg = Config()
|
||||||
|
|
||||||
# Create new GPT agent
|
# Create new GPT agent
|
||||||
# TODO: Centralise use of create_chat_completion() to globally enforce token limit
|
# TODO: Centralise use of create_chat_completion() to globally enforce token limit
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from auto_gpt_plugin_template import AutoGPTPluginTemplate
|
from auto_gpt_plugin_template import AutoGPTPluginTemplate
|
||||||
from pilot.singleton import Singleton
|
from pilot.singleton import Singleton
|
||||||
|
|
||||||
@ -18,7 +19,10 @@ class Config(metaclass=Singleton):
|
|||||||
|
|
||||||
self.temperature = float(os.getenv("TEMPERATURE", 0.7))
|
self.temperature = float(os.getenv("TEMPERATURE", 0.7))
|
||||||
|
|
||||||
|
# TODO change model_config there
|
||||||
|
self.execute_local_commands = (
|
||||||
|
os.getenv("EXECUTE_LOCAL_COMMANDS", "False") == "True"
|
||||||
|
)
|
||||||
# User agent header to use when making HTTP requests
|
# User agent header to use when making HTTP requests
|
||||||
# Some websites might just completely deny request with an error code if
|
# Some websites might just completely deny request with an error code if
|
||||||
# no user agent was found.
|
# no user agent was found.
|
||||||
|
11
pilot/model/base.py
Normal file
11
pilot/model/base.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
from typing import List, TypedDict
|
||||||
|
|
||||||
|
class Message(TypedDict):
|
||||||
|
"""LLM Message object containing usually like (role: content) """
|
||||||
|
|
||||||
|
role: str
|
||||||
|
content: str
|
||||||
|
|
3
pilot/model/chat.py
Normal file
3
pilot/model/chat.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding:utf-8 -*-
|
||||||
|
|
47
pilot/model/llm_utils.py
Normal file
47
pilot/model/llm_utils.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding:utf-8 -*-
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
from pilot.model.base import Message
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
from pilot.server.vicuna_server import generate_output
|
||||||
|
|
||||||
|
def create_chat_completion(
|
||||||
|
messages: List[Message], # type: ignore
|
||||||
|
model: Optional[str] = None,
|
||||||
|
temperature: float = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
) -> str:
|
||||||
|
"""Create a chat completion using the vicuna local model
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages(List[Message]): The messages to send to the chat completion
|
||||||
|
model (str, optional): The model to use. Defaults to None.
|
||||||
|
temperature (float, optional): The temperature to use. Defaults to 0.7.
|
||||||
|
max_tokens (int, optional): The max tokens to use. Defaults to None
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The response from chat completion
|
||||||
|
"""
|
||||||
|
cfg = Config()
|
||||||
|
if temperature is None:
|
||||||
|
temperature = cfg.temperature
|
||||||
|
|
||||||
|
for plugin in cfg.plugins:
|
||||||
|
if plugin.can_handle_chat_completion(
|
||||||
|
messages=messages,
|
||||||
|
model=model,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
):
|
||||||
|
message = plugin.handle_chat_completion(
|
||||||
|
messages=messages,
|
||||||
|
model=model,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
)
|
||||||
|
if message is not None:
|
||||||
|
return message
|
||||||
|
|
||||||
|
response = None
|
||||||
|
# TODO impl this use vicuna server api
|
Loading…
Reference in New Issue
Block a user