diff --git a/pilot/agent/agent_manager.py b/pilot/agent/agent_manager.py index 1ceecec6f..89754bd1c 100644 --- a/pilot/agent/agent_manager.py +++ b/pilot/agent/agent_manager.py @@ -5,6 +5,9 @@ from __future__ import annotations from pilot.configs.config import Config from pilot.singleton import Singleton +from pilot.configs.config import Config +from typing import List +from pilot.model.base import Message class AgentManager(metaclass=Singleton): @@ -14,6 +17,19 @@ class AgentManager(metaclass=Singleton): self.next_key = 0 self.agents = {} # key, (task, full_message_history, model) self.cfg = Config() + """Agent manager for managing DB-GPT agents + In order to compatible auto gpt plugins, + we use the same template with it. + + Args: next_keys + agents + cfg + """ + + def __init__(self) -> None: + self.next_key = 0 + self.agents = {} #TODO need to define + self.cfg = Config() # Create new GPT agent # TODO: Centralise use of create_chat_completion() to globally enforce token limit diff --git a/pilot/configs/config.py b/pilot/configs/config.py index 7d5ee7eea..f9dded64d 100644 --- a/pilot/configs/config.py +++ b/pilot/configs/config.py @@ -3,6 +3,7 @@ import os from typing import List + from auto_gpt_plugin_template import AutoGPTPluginTemplate from pilot.singleton import Singleton @@ -18,7 +19,10 @@ class Config(metaclass=Singleton): self.temperature = float(os.getenv("TEMPERATURE", 0.7)) - + # TODO change model_config there + self.execute_local_commands = ( + os.getenv("EXECUTE_LOCAL_COMMANDS", "False") == "True" + ) # User agent header to use when making HTTP requests # Some websites might just completely deny request with an error code if # no user agent was found. diff --git a/pilot/model/base.py b/pilot/model/base.py new file mode 100644 index 000000000..8199198eb --- /dev/null +++ b/pilot/model/base.py @@ -0,0 +1,11 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +from typing import List, TypedDict + +class Message(TypedDict): + """LLM Message object containing usually like (role: content) """ + + role: str + content: str + diff --git a/pilot/model/chat.py b/pilot/model/chat.py new file mode 100644 index 000000000..97206f2d5 --- /dev/null +++ b/pilot/model/chat.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- + diff --git a/pilot/model/llm_utils.py b/pilot/model/llm_utils.py new file mode 100644 index 000000000..ee4e3b6e9 --- /dev/null +++ b/pilot/model/llm_utils.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- + +from typing import List, Optional +from pilot.model.base import Message +from pilot.configs.config import Config +from pilot.server.vicuna_server import generate_output + +def create_chat_completion( + messages: List[Message], # type: ignore + model: Optional[str] = None, + temperature: float = None, + max_tokens: Optional[int] = None, +) -> str: + """Create a chat completion using the vicuna local model + + Args: + messages(List[Message]): The messages to send to the chat completion + model (str, optional): The model to use. Defaults to None. + temperature (float, optional): The temperature to use. Defaults to 0.7. + max_tokens (int, optional): The max tokens to use. Defaults to None + + Returns: + str: The response from chat completion + """ + cfg = Config() + if temperature is None: + temperature = cfg.temperature + + for plugin in cfg.plugins: + if plugin.can_handle_chat_completion( + messages=messages, + model=model, + temperature=temperature, + max_tokens=max_tokens, + ): + message = plugin.handle_chat_completion( + messages=messages, + model=model, + temperature=temperature, + max_tokens=max_tokens, + ) + if message is not None: + return message + + response = None + # TODO impl this use vicuna server api