fix error

This commit is contained in:
csunny 2023-05-12 20:15:38 +08:00
commit 0a27d8ff7e
5 changed files with 82 additions and 1 deletions

View File

@ -5,6 +5,9 @@ from __future__ import annotations
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.singleton import Singleton from pilot.singleton import Singleton
from pilot.configs.config import Config
from typing import List
from pilot.model.base import Message
class AgentManager(metaclass=Singleton): class AgentManager(metaclass=Singleton):
@ -14,6 +17,19 @@ class AgentManager(metaclass=Singleton):
self.next_key = 0 self.next_key = 0
self.agents = {} # key, (task, full_message_history, model) self.agents = {} # key, (task, full_message_history, model)
self.cfg = Config() self.cfg = Config()
"""Agent manager for managing DB-GPT agents
In order to compatible auto gpt plugins,
we use the same template with it.
Args: next_keys
agents
cfg
"""
def __init__(self) -> None:
self.next_key = 0
self.agents = {} #TODO need to define
self.cfg = Config()
# Create new GPT agent # Create new GPT agent
# TODO: Centralise use of create_chat_completion() to globally enforce token limit # TODO: Centralise use of create_chat_completion() to globally enforce token limit

View File

@ -3,6 +3,7 @@
import os import os
from typing import List from typing import List
from auto_gpt_plugin_template import AutoGPTPluginTemplate from auto_gpt_plugin_template import AutoGPTPluginTemplate
from pilot.singleton import Singleton from pilot.singleton import Singleton
@ -18,7 +19,10 @@ class Config(metaclass=Singleton):
self.temperature = float(os.getenv("TEMPERATURE", 0.7)) self.temperature = float(os.getenv("TEMPERATURE", 0.7))
# TODO change model_config there
self.execute_local_commands = (
os.getenv("EXECUTE_LOCAL_COMMANDS", "False") == "True"
)
# User agent header to use when making HTTP requests # User agent header to use when making HTTP requests
# Some websites might just completely deny request with an error code if # Some websites might just completely deny request with an error code if
# no user agent was found. # no user agent was found.

11
pilot/model/base.py Normal file
View File

@ -0,0 +1,11 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import List, TypedDict
class Message(TypedDict):
"""LLM Message object containing usually like (role: content) """
role: str
content: str

3
pilot/model/chat.py Normal file
View File

@ -0,0 +1,3 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-

47
pilot/model/llm_utils.py Normal file
View File

@ -0,0 +1,47 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
from typing import List, Optional
from pilot.model.base import Message
from pilot.configs.config import Config
from pilot.server.vicuna_server import generate_output
def create_chat_completion(
messages: List[Message], # type: ignore
model: Optional[str] = None,
temperature: float = None,
max_tokens: Optional[int] = None,
) -> str:
"""Create a chat completion using the vicuna local model
Args:
messages(List[Message]): The messages to send to the chat completion
model (str, optional): The model to use. Defaults to None.
temperature (float, optional): The temperature to use. Defaults to 0.7.
max_tokens (int, optional): The max tokens to use. Defaults to None
Returns:
str: The response from chat completion
"""
cfg = Config()
if temperature is None:
temperature = cfg.temperature
for plugin in cfg.plugins:
if plugin.can_handle_chat_completion(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
):
message = plugin.handle_chat_completion(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
)
if message is not None:
return message
response = None
# TODO impl this use vicuna server api