mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-22 20:01:46 +00:00
fix error
This commit is contained in:
commit
0a27d8ff7e
@ -5,6 +5,9 @@ from __future__ import annotations
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.singleton import Singleton
|
||||
from pilot.configs.config import Config
|
||||
from typing import List
|
||||
from pilot.model.base import Message
|
||||
|
||||
|
||||
class AgentManager(metaclass=Singleton):
|
||||
@ -14,6 +17,19 @@ class AgentManager(metaclass=Singleton):
|
||||
self.next_key = 0
|
||||
self.agents = {} # key, (task, full_message_history, model)
|
||||
self.cfg = Config()
|
||||
"""Agent manager for managing DB-GPT agents
|
||||
In order to compatible auto gpt plugins,
|
||||
we use the same template with it.
|
||||
|
||||
Args: next_keys
|
||||
agents
|
||||
cfg
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.next_key = 0
|
||||
self.agents = {} #TODO need to define
|
||||
self.cfg = Config()
|
||||
|
||||
# Create new GPT agent
|
||||
# TODO: Centralise use of create_chat_completion() to globally enforce token limit
|
||||
|
@ -3,6 +3,7 @@
|
||||
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
from auto_gpt_plugin_template import AutoGPTPluginTemplate
|
||||
from pilot.singleton import Singleton
|
||||
|
||||
@ -18,7 +19,10 @@ class Config(metaclass=Singleton):
|
||||
|
||||
self.temperature = float(os.getenv("TEMPERATURE", 0.7))
|
||||
|
||||
|
||||
# TODO change model_config there
|
||||
self.execute_local_commands = (
|
||||
os.getenv("EXECUTE_LOCAL_COMMANDS", "False") == "True"
|
||||
)
|
||||
# User agent header to use when making HTTP requests
|
||||
# Some websites might just completely deny request with an error code if
|
||||
# no user agent was found.
|
||||
|
11
pilot/model/base.py
Normal file
11
pilot/model/base.py
Normal file
@ -0,0 +1,11 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from typing import List, TypedDict
|
||||
|
||||
class Message(TypedDict):
|
||||
"""LLM Message object containing usually like (role: content) """
|
||||
|
||||
role: str
|
||||
content: str
|
||||
|
3
pilot/model/chat.py
Normal file
3
pilot/model/chat.py
Normal file
@ -0,0 +1,3 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
|
47
pilot/model/llm_utils.py
Normal file
47
pilot/model/llm_utils.py
Normal file
@ -0,0 +1,47 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
from typing import List, Optional
|
||||
from pilot.model.base import Message
|
||||
from pilot.configs.config import Config
|
||||
from pilot.server.vicuna_server import generate_output
|
||||
|
||||
def create_chat_completion(
|
||||
messages: List[Message], # type: ignore
|
||||
model: Optional[str] = None,
|
||||
temperature: float = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> str:
|
||||
"""Create a chat completion using the vicuna local model
|
||||
|
||||
Args:
|
||||
messages(List[Message]): The messages to send to the chat completion
|
||||
model (str, optional): The model to use. Defaults to None.
|
||||
temperature (float, optional): The temperature to use. Defaults to 0.7.
|
||||
max_tokens (int, optional): The max tokens to use. Defaults to None
|
||||
|
||||
Returns:
|
||||
str: The response from chat completion
|
||||
"""
|
||||
cfg = Config()
|
||||
if temperature is None:
|
||||
temperature = cfg.temperature
|
||||
|
||||
for plugin in cfg.plugins:
|
||||
if plugin.can_handle_chat_completion(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
):
|
||||
message = plugin.handle_chat_completion(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
if message is not None:
|
||||
return message
|
||||
|
||||
response = None
|
||||
# TODO impl this use vicuna server api
|
Loading…
Reference in New Issue
Block a user