插件启动接入

This commit is contained in:
tuyang.yhj 2023-05-12 23:40:37 +08:00
parent 192db2236a
commit 416205ae7a
15 changed files with 557 additions and 43 deletions

1
.gitignore vendored
View File

@ -6,6 +6,7 @@ __pycache__/
# C extensions
*.so
.idea
.vscode
# Distribution / packaging
.Python

View File

@ -1,4 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10 (Auto-GPT)" project-jdk-type="Python SDK" />
<component name="PythonCompatibilityInspectionAdvertiser">
<option name="version" value="3" />
</component>
</project>

166
pilot/agent/json_fix_llm.py Normal file
View File

@ -0,0 +1,166 @@
import json
from typing import Any, Dict
import contextlib
from colorama import Fore
from regex import regex
from pilot.configs.config import Config
from pilot.logs import logger
from pilot.speech import say_text
from pilot.json_utils.json_fix_general import fix_invalid_escape,add_quotes_to_property_names,balance_braces
CFG = Config()
def fix_and_parse_json(
json_to_load: str, try_to_fix_with_gpt: bool = True
) -> Dict[Any, Any]:
"""Fix and parse JSON string
Args:
json_to_load (str): The JSON string.
try_to_fix_with_gpt (bool, optional): Try to fix the JSON with GPT.
Defaults to True.
Returns:
str or dict[Any, Any]: The parsed JSON.
"""
with contextlib.suppress(json.JSONDecodeError):
json_to_load = json_to_load.replace("\t", "")
return json.loads(json_to_load)
with contextlib.suppress(json.JSONDecodeError):
json_to_load = correct_json(json_to_load)
return json.loads(json_to_load)
# Let's do something manually:
# sometimes GPT responds with something BEFORE the braces:
# "I'm sorry, I don't understand. Please try again."
# {"text": "I'm sorry, I don't understand. Please try again.",
# "confidence": 0.0}
# So let's try to find the first brace and then parse the rest
# of the string
try:
brace_index = json_to_load.index("{")
maybe_fixed_json = json_to_load[brace_index:]
last_brace_index = maybe_fixed_json.rindex("}")
maybe_fixed_json = maybe_fixed_json[: last_brace_index + 1]
return json.loads(maybe_fixed_json)
except (json.JSONDecodeError, ValueError) as e:
logger.error("参数解析错误", e)
def fix_json_using_multiple_techniques(assistant_reply: str) -> Dict[Any, Any]:
"""Fix the given JSON string to make it parseable and fully compliant with two techniques.
Args:
json_string (str): The JSON string to fix.
Returns:
str: The fixed JSON string.
"""
assistant_reply = assistant_reply.strip()
if assistant_reply.startswith("```json"):
assistant_reply = assistant_reply[7:]
if assistant_reply.endswith("```"):
assistant_reply = assistant_reply[:-3]
try:
return json.loads(assistant_reply) # just check the validity
except json.JSONDecodeError as e: # noqa: E722
print(f"JSONDecodeError: {e}")
pass
if assistant_reply.startswith("json "):
assistant_reply = assistant_reply[5:]
assistant_reply = assistant_reply.strip()
try:
return json.loads(assistant_reply) # just check the validity
except json.JSONDecodeError: # noqa: E722
pass
# Parse and print Assistant response
assistant_reply_json = fix_and_parse_json(assistant_reply)
logger.debug("Assistant reply JSON: %s", str(assistant_reply_json))
if assistant_reply_json == {}:
assistant_reply_json = attempt_to_fix_json_by_finding_outermost_brackets(
assistant_reply
)
logger.debug("Assistant reply JSON 2: %s", str(assistant_reply_json))
if assistant_reply_json != {}:
return assistant_reply_json
logger.error(
"Error: The following AI output couldn't be converted to a JSON:\n",
assistant_reply,
)
if CFG.speak_mode:
say_text("I have received an invalid JSON response from the OpenAI API.")
return {}
def correct_json(json_to_load: str) -> str:
"""
Correct common JSON errors.
Args:
json_to_load (str): The JSON string.
"""
try:
logger.debug("json", json_to_load)
json.loads(json_to_load)
return json_to_load
except json.JSONDecodeError as e:
logger.debug("json loads error", e)
error_message = str(e)
if error_message.startswith("Invalid \\escape"):
json_to_load = fix_invalid_escape(json_to_load, error_message)
if error_message.startswith(
"Expecting property name enclosed in double quotes"
):
json_to_load = add_quotes_to_property_names(json_to_load)
try:
json.loads(json_to_load)
return json_to_load
except json.JSONDecodeError as e:
logger.debug("json loads error - add quotes", e)
error_message = str(e)
if balanced_str := balance_braces(json_to_load):
return balanced_str
return json_to_load
def attempt_to_fix_json_by_finding_outermost_brackets(json_string: str):
if CFG.speak_mode and CFG.debug_mode:
say_text(
"I have received an invalid JSON response from the OpenAI API. "
"Trying to fix it now."
)
logger.error("Attempting to fix JSON by finding outermost brackets\n")
try:
json_pattern = regex.compile(r"\{(?:[^{}]|(?R))*\}")
json_match = json_pattern.search(json_string)
if json_match:
# Extract the valid JSON object from the string
json_string = json_match.group(0)
logger.typewriter_log(
title="Apparently json was fixed.", title_color=Fore.GREEN
)
if CFG.speak_mode and CFG.debug_mode:
say_text("Apparently json was fixed.")
else:
return {}
except (json.JSONDecodeError, ValueError):
if CFG.debug_mode:
logger.error(f"Error: Invalid JSON: {json_string}\n")
if CFG.speak_mode:
say_text("Didn't work. I will have to ignore this response then.")
logger.error("Error: Invalid JSON, setting it to empty JSON now.\n")
json_string = {}
return fix_and_parse_json(json_string)

View File

@ -1,36 +1,153 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import functools
import importlib
import inspect
from typing import Any, Callable, Optional
from pilot.prompts.generator import PromptGenerator
from typing import Dict, List, NoReturn, Union
from pilot.configs.config import Config
class Command:
"""A class representing a command.
from pilot.speech import say_text
Attributes:
name (str): The name of the command.
description (str): A brief description of what the command does.
signature (str): The signature of the function that the command executes. Default to None.
from pilot.agent.json_fix_llm import fix_json_using_multiple_techniques
from pilot.commands.exception_not_commands import NotCommands
import json
def _resolve_pathlike_command_args(command_args):
if "directory" in command_args and command_args["directory"] in {"", "/"}:
# todo
command_args["directory"] = ""
else:
for pathlike in ["filename", "directory", "clone_path"]:
if pathlike in command_args:
# todo
command_args[pathlike] = ""
return command_args
def execute_ai_response_json(
prompt: PromptGenerator,
ai_response: str,
user_input: str = None,
) -> str:
"""
def __init__(self,
name: str,
description: str,
method: Callable[..., Any],
signature: str = "",
enabled: bool = True,
disabled_reason: Optional[str] = None,
) -> None:
self.name = name
self.description = description
self.method = method
self.signature = signature if signature else str(inspect.signature(self.method))
self.enabled = enabled
self.disabled_reason = disabled_reason
Args:
command_registry:
ai_response:
prompt:
def __call__(self, *args: Any, **kwds: Any) -> Any:
if not self.enabled:
return f"Command '{self.name}' is disabled: {self.disabled_reason}"
return self.method(*args, **kwds)
Returns:
"""
cfg = Config()
try:
assistant_reply_json = fix_json_using_multiple_techniques(ai_response)
except (json.JSONDecodeError, ValueError) as e:
raise NotCommands("非可执行命令结构")
command_name, arguments = get_command(assistant_reply_json)
if cfg.speak_mode:
say_text(f"I want to execute {command_name}")
arguments = _resolve_pathlike_command_args(arguments)
# Execute command
if command_name is not None and command_name.lower().startswith("error"):
result = (
f"Command {command_name} threw the following error: {arguments}"
)
elif command_name == "human_feedback":
result = f"Human feedback: {user_input}"
else:
for plugin in cfg.plugins:
if not plugin.can_handle_pre_command():
continue
command_name, arguments = plugin.pre_command(
command_name, arguments
)
command_result = execute_command(
command_name,
arguments,
prompt,
)
result = f"Command {command_name} returned: " f"{command_result}"
return result
def execute_command(
command_name: str,
arguments,
prompt: PromptGenerator,
):
"""Execute the command and return the result
Args:
command_name (str): The name of the command to execute
arguments (dict): The arguments for the command
Returns:
str: The result of the command
"""
cmd = prompt.command_registry.commands.get(command_name)
# If the command is found, call it with the provided arguments
if cmd:
try:
return cmd(**arguments)
except Exception as e:
return f"Error: {str(e)}"
# TODO: Change these to take in a file rather than pasted code, if
# non-file is given, return instructions "Input should be a python
# filepath, write your code to file and try again
else:
for command in prompt.commands:
if (
command_name == command["label"].lower()
or command_name == command["name"].lower()
):
try:
return command["function"](**arguments)
except Exception as e:
return f"Error: {str(e)}"
raise NotCommands("非可用命令" + command)
def get_command(response_json: Dict):
"""Parse the response and return the command name and arguments
Args:
response_json (json): The response from the AI
Returns:
tuple: The command name and arguments
Raises:
json.decoder.JSONDecodeError: If the response is not valid JSON
Exception: If any other error occurs
"""
try:
if "command" not in response_json:
return "Error:", "Missing 'command' object in JSON"
if not isinstance(response_json, dict):
return "Error:", f"'response_json' object is not dictionary {response_json}"
command = response_json["command"]
if not isinstance(command, dict):
return "Error:", "'command' object is not a dictionary"
if "name" not in command:
return "Error:", "Missing 'name' field in 'command' object"
command_name = command["name"]
# Use an empty dictionary if 'args' field is not present in 'command' object
arguments = command.get("args", {})
return command_name, arguments
except json.decoder.JSONDecodeError:
return "Error:", "Invalid JSON"
# All other errors, return "Error: + error message"
except Exception as e:
return "Error:", str(e)

View File

@ -0,0 +1,4 @@
class NotCommands(Exception):
def __init__(self, message, error_code):
super().__init__(message)
self.error_code = error_code

View File

@ -42,7 +42,12 @@ class Config(metaclass=Singleton):
self.milvus_collection = os.getenv("MILVUS_COLLECTION", "dbgpt")
self.milvus_secure = os.getenv("MILVUS_SECURE") == "True"
self.plugins_dir = os.getenv("PLUGINS_DIR", "plugins")
self.authorise_key = os.getenv("AUTHORISE_COMMAND_KEY", "y")
self.exit_key = os.getenv("EXIT_KEY", "n")
self.image_provider = bool(os.getenv("IMAGE_PROVIDER", True))
self.image_size = int(os.getenv("IMAGE_SIZE", 256))
self.plugins_dir = os.getenv("PLUGINS_DIR", "../../plugins")
self.plugins: List[AutoGPTPluginTemplate] = []
self.plugins_openai = []
@ -57,6 +62,7 @@ class Config(metaclass=Singleton):
self.huggingface_audio_to_text_model = os.getenv(
"HUGGINGFACE_AUDIO_TO_TEXT_MODEL"
)
self.speak_mode = False
disabled_command_categories = os.getenv("DISABLED_COMMAND_CATEGORIES")
if disabled_command_categories:
@ -92,4 +98,6 @@ class Config(metaclass=Singleton):
"""Set the temperature value."""
self.temperature = value
def set_speak_mode(self, value: bool) -> None:
"""Set the speak mode value."""
self.speak_mode = value

View File

View File

@ -0,0 +1,121 @@
"""This module contains functions to fix JSON strings using general programmatic approaches, suitable for addressing
common JSON formatting issues."""
from __future__ import annotations
import contextlib
import json
import re
from typing import Optional
from pilot.configs.config import Config
from pilot.logs import logger
from pilot.json_utils.utilities import extract_char_position
CFG = Config()
def fix_invalid_escape(json_to_load: str, error_message: str) -> str:
"""Fix invalid escape sequences in JSON strings.
Args:
json_to_load (str): The JSON string.
error_message (str): The error message from the JSONDecodeError
exception.
Returns:
str: The JSON string with invalid escape sequences fixed.
"""
while error_message.startswith("Invalid \\escape"):
bad_escape_location = extract_char_position(error_message)
json_to_load = (
json_to_load[:bad_escape_location] + json_to_load[bad_escape_location + 1 :]
)
try:
json.loads(json_to_load)
return json_to_load
except json.JSONDecodeError as e:
logger.debug("json loads error - fix invalid escape", e)
error_message = str(e)
return json_to_load
def balance_braces(json_string: str) -> Optional[str]:
"""
Balance the braces in a JSON string.
Args:
json_string (str): The JSON string.
Returns:
str: The JSON string with braces balanced.
"""
open_braces_count = json_string.count("{")
close_braces_count = json_string.count("}")
while open_braces_count > close_braces_count:
json_string += "}"
close_braces_count += 1
while close_braces_count > open_braces_count:
json_string = json_string.rstrip("}")
close_braces_count -= 1
with contextlib.suppress(json.JSONDecodeError):
json.loads(json_string)
return json_string
def add_quotes_to_property_names(json_string: str) -> str:
"""
Add quotes to property names in a JSON string.
Args:
json_string (str): The JSON string.
Returns:
str: The JSON string with quotes added to property names.
"""
def replace_func(match: re.Match) -> str:
return f'"{match[1]}":'
property_name_pattern = re.compile(r"(\w+):")
corrected_json_string = property_name_pattern.sub(replace_func, json_string)
try:
json.loads(corrected_json_string)
return corrected_json_string
except json.JSONDecodeError as e:
raise e
def correct_json(json_to_load: str) -> str:
"""
Correct common JSON errors.
Args:
json_to_load (str): The JSON string.
"""
try:
logger.debug("json", json_to_load)
json.loads(json_to_load)
return json_to_load
except json.JSONDecodeError as e:
logger.debug("json loads error", e)
error_message = str(e)
if error_message.startswith("Invalid \\escape"):
json_to_load = fix_invalid_escape(json_to_load, error_message)
if error_message.startswith(
"Expecting property name enclosed in double quotes"
):
json_to_load = add_quotes_to_property_names(json_to_load)
try:
json.loads(json_to_load)
return json_to_load
except json.JSONDecodeError as e:
logger.debug("json loads error - add quotes", e)
error_message = str(e)
if balanced_str := balance_braces(json_to_load):
return balanced_str
return json_to_load

View File

@ -0,0 +1,81 @@
"""Utilities for the json_fixes package."""
import json
import os.path
import re
from jsonschema import Draft7Validator
from pilot.configs.config import Config
from pilot.logs import logger
CFG = Config()
LLM_DEFAULT_RESPONSE_FORMAT = "llm_response_format_1"
def extract_char_position(error_message: str) -> int:
"""Extract the character position from the JSONDecodeError message.
Args:
error_message (str): The error message from the JSONDecodeError
exception.
Returns:
int: The character position.
"""
char_pattern = re.compile(r"\(char (\d+)\)")
if match := char_pattern.search(error_message):
return int(match[1])
else:
raise ValueError("Character position not found in the error message.")
def validate_json(json_object: object, schema_name: str) -> dict | None:
"""
:type schema_name: object
:param schema_name: str
:type json_object: object
"""
scheme_file = os.path.join(os.path.dirname(__file__), f"{schema_name}.json")
with open(scheme_file, "r") as f:
schema = json.load(f)
validator = Draft7Validator(schema)
if errors := sorted(validator.iter_errors(json_object), key=lambda e: e.path):
logger.error("The JSON object is invalid.")
if CFG.debug_mode:
logger.error(
json.dumps(json_object, indent=4)
) # Replace 'json_object' with the variable containing the JSON data
logger.error("The following issues were found:")
for error in errors:
logger.error(f"Error: {error.message}")
else:
logger.debug("The JSON object is valid.")
return json_object
def validate_json_string(json_string: str, schema_name: str) -> dict | None:
"""
:type schema_name: object
:param schema_name: str
:type json_object: object
"""
try:
json_loaded = json.loads(json_string)
return validate_json(json_loaded, schema_name)
except:
return None
def is_string_valid_json(json_string: str, schema_name: str) -> bool:
"""
:type schema_name: object
:param schema_name: str
:type json_object: object
"""
return validate_json_string(json_string, schema_name) is not None

View File

@ -72,7 +72,7 @@ def generate_stream(model, tokenizer, params, device,
def generate_output(model, tokenizer, params, device, context_len=2048, stream_interval=2):
prompt = params["prompt"]
temperature = float(params.get("temperature", 1.0))
max_new_tokens = int(params.get("max_new_tokens", 1024))
max_new_tokens = int(params.get("max_new_tokens", 2048))
stop_parameter = params.get("stop", None)
if stop_parameter == tokenizer.eos_token:
stop_parameter = None

View File

@ -207,6 +207,8 @@ def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate
List[Tuple[str, Path]]: List of plugins.
"""
loaded_plugins = []
current_dir = os.getcwd()
print(current_dir)
# Generic plugins
plugins_path_path = Path(cfg.plugins_dir)

View File

@ -38,7 +38,6 @@ class FirstPrompt:
def construct_first_prompt(
self,
command_registry: [] = None,
fisrt_message: [str]=[],
prompt_generator: Optional[PromptGenerator] = None
) -> str:
@ -64,7 +63,7 @@ class FirstPrompt:
if prompt_generator is None:
prompt_generator = build_default_prompt_generator()
prompt_generator.goals = fisrt_message
prompt_generator.command_registry = command_registry
prompt_generator.command_registry = self.command_registry
# 加载插件中可用命令
cfg = Config()
for plugin in cfg.plugins:

View File

@ -0,0 +1,14 @@
import os
import random
import sys
from dotenv import load_dotenv
if "pytest" in sys.argv or "pytest" in sys.modules or os.getenv("CI"):
print("Setting random seed to 42")
random.seed(42)
# Load the users .env file into environment variables
load_dotenv(verbose=True, override=True)
del load_dotenv

View File

@ -20,8 +20,6 @@ from pilot.configs.model_config import LOGDIR, VICUNA_MODEL_SERVER, LLM_MODEL, D
from pilot.plugins import scan_plugins
from pilot.configs.config import Config
from pilot.commands.command_mange import CommandRegistry
from pilot.prompts.prompt import build_default_prompt_generator
from pilot.prompts.first_conversation_prompt import FirstPrompt
from pilot.conversation import (
@ -39,6 +37,8 @@ from pilot.utils import (
from pilot.server.gradio_css import code_highlight_css
from pilot.server.gradio_patch import Chatbot as grChatbot
from pilot.commands.command import execute_ai_response_json
logger = build_logger("webserver", LOGDIR + "webserver.log")
headers = {"User-Agent": "dbgpt Client"}
@ -172,12 +172,11 @@ def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.
if len(state.messages) == state.offset + 2:
query = state.messages[-2][1]
# 第一轮对话需要加入提示Prompt
cfg = Config()
first_prompt = FirstPrompt()
first_prompt.command_registry = cfg.command_registry
if(autogpt):
# autogpt模式的第一轮对话需要 构建专属prompt
cfg = Config()
first_prompt = FirstPrompt()
first_prompt.command_registry = cfg.command_registry
system_prompt = first_prompt.construct_first_prompt(fisrt_message=[query])
logger.info("[TEST]:" + system_prompt)
template_name = "auto_dbgpt_one_shot"
@ -456,7 +455,7 @@ if __name__ == "__main__":
args = parser.parse_args()
logger.info(f"args: {args}")
dbs = get_database_list()
# dbs = get_database_list()
# 加载插件
cfg = Config()
@ -464,7 +463,6 @@ if __name__ == "__main__":
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
# 加载插件可执行命令
command_registry = CommandRegistry()
command_categories = [
"pilot.commands.audio_text",
"pilot.commands.image_gen",
@ -473,11 +471,11 @@ if __name__ == "__main__":
command_categories = [
x for x in command_categories if x not in cfg.disabled_command_categories
]
command_registry = CommandRegistry()
for command_category in command_categories:
command_registry.import_commands(command_category)
cfg.command_registry =command_category
cfg.command_registry =command_registry
logger.info(args)
demo = build_webdemo()

Binary file not shown.