fix error

This commit is contained in:
csunny 2023-05-12 20:15:38 +08:00
commit 0a27d8ff7e
5 changed files with 82 additions and 1 deletions

View File

@ -5,6 +5,9 @@ from __future__ import annotations
from pilot.configs.config import Config
from pilot.singleton import Singleton
from pilot.configs.config import Config
from typing import List
from pilot.model.base import Message
class AgentManager(metaclass=Singleton):
@ -14,6 +17,19 @@ class AgentManager(metaclass=Singleton):
self.next_key = 0
self.agents = {} # key, (task, full_message_history, model)
self.cfg = Config()
"""Agent manager for managing DB-GPT agents
In order to compatible auto gpt plugins,
we use the same template with it.
Args: next_keys
agents
cfg
"""
def __init__(self) -> None:
self.next_key = 0
self.agents = {} #TODO need to define
self.cfg = Config()
# Create new GPT agent
# TODO: Centralise use of create_chat_completion() to globally enforce token limit

View File

@ -3,6 +3,7 @@
import os
from typing import List
from auto_gpt_plugin_template import AutoGPTPluginTemplate
from pilot.singleton import Singleton
@ -18,7 +19,10 @@ class Config(metaclass=Singleton):
self.temperature = float(os.getenv("TEMPERATURE", 0.7))
# TODO change model_config there
self.execute_local_commands = (
os.getenv("EXECUTE_LOCAL_COMMANDS", "False") == "True"
)
# User agent header to use when making HTTP requests
# Some websites might just completely deny request with an error code if
# no user agent was found.

11
pilot/model/base.py Normal file
View File

@ -0,0 +1,11 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import List, TypedDict
class Message(TypedDict):
"""LLM Message object containing usually like (role: content) """
role: str
content: str

3
pilot/model/chat.py Normal file
View File

@ -0,0 +1,3 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-

47
pilot/model/llm_utils.py Normal file
View File

@ -0,0 +1,47 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
from typing import List, Optional
from pilot.model.base import Message
from pilot.configs.config import Config
from pilot.server.vicuna_server import generate_output
def create_chat_completion(
messages: List[Message], # type: ignore
model: Optional[str] = None,
temperature: float = None,
max_tokens: Optional[int] = None,
) -> str:
"""Create a chat completion using the vicuna local model
Args:
messages(List[Message]): The messages to send to the chat completion
model (str, optional): The model to use. Defaults to None.
temperature (float, optional): The temperature to use. Defaults to 0.7.
max_tokens (int, optional): The max tokens to use. Defaults to None
Returns:
str: The response from chat completion
"""
cfg = Config()
if temperature is None:
temperature = cfg.temperature
for plugin in cfg.plugins:
if plugin.can_handle_chat_completion(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
):
message = plugin.handle_chat_completion(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
)
if message is not None:
return message
response = None
# TODO impl this use vicuna server api