diff --git a/.gitignore b/.gitignore index 0b232d95a..78f55da35 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ __pycache__/ # C extensions *.so +.idea .vscode # Distribution / packaging .Python diff --git a/.idea/misc.xml b/.idea/misc.xml index e965926fe..4ab5f746d 100644 --- a/.idea/misc.xml +++ b/.idea/misc.xml @@ -1,4 +1,7 @@ + + \ No newline at end of file diff --git a/pilot/agent/json_fix_llm.py b/pilot/agent/json_fix_llm.py new file mode 100644 index 000000000..3ca8f85b0 --- /dev/null +++ b/pilot/agent/json_fix_llm.py @@ -0,0 +1,166 @@ + +import json +from typing import Any, Dict +import contextlib +from colorama import Fore +from regex import regex + +from pilot.configs.config import Config +from pilot.logs import logger +from pilot.speech import say_text + +from pilot.json_utils.json_fix_general import fix_invalid_escape,add_quotes_to_property_names,balance_braces + +CFG = Config() + +def fix_and_parse_json( + json_to_load: str, try_to_fix_with_gpt: bool = True +) -> Dict[Any, Any]: + """Fix and parse JSON string + + Args: + json_to_load (str): The JSON string. + try_to_fix_with_gpt (bool, optional): Try to fix the JSON with GPT. + Defaults to True. + + Returns: + str or dict[Any, Any]: The parsed JSON. + """ + + with contextlib.suppress(json.JSONDecodeError): + json_to_load = json_to_load.replace("\t", "") + return json.loads(json_to_load) + + with contextlib.suppress(json.JSONDecodeError): + json_to_load = correct_json(json_to_load) + return json.loads(json_to_load) + # Let's do something manually: + # sometimes GPT responds with something BEFORE the braces: + # "I'm sorry, I don't understand. Please try again." + # {"text": "I'm sorry, I don't understand. Please try again.", + # "confidence": 0.0} + # So let's try to find the first brace and then parse the rest + # of the string + try: + brace_index = json_to_load.index("{") + maybe_fixed_json = json_to_load[brace_index:] + last_brace_index = maybe_fixed_json.rindex("}") + maybe_fixed_json = maybe_fixed_json[: last_brace_index + 1] + return json.loads(maybe_fixed_json) + except (json.JSONDecodeError, ValueError) as e: + logger.error("参数解析错误", e) + + +def fix_json_using_multiple_techniques(assistant_reply: str) -> Dict[Any, Any]: + """Fix the given JSON string to make it parseable and fully compliant with two techniques. + + Args: + json_string (str): The JSON string to fix. + + Returns: + str: The fixed JSON string. + """ + assistant_reply = assistant_reply.strip() + if assistant_reply.startswith("```json"): + assistant_reply = assistant_reply[7:] + if assistant_reply.endswith("```"): + assistant_reply = assistant_reply[:-3] + try: + return json.loads(assistant_reply) # just check the validity + except json.JSONDecodeError as e: # noqa: E722 + print(f"JSONDecodeError: {e}") + pass + + if assistant_reply.startswith("json "): + assistant_reply = assistant_reply[5:] + assistant_reply = assistant_reply.strip() + try: + return json.loads(assistant_reply) # just check the validity + except json.JSONDecodeError: # noqa: E722 + pass + + # Parse and print Assistant response + assistant_reply_json = fix_and_parse_json(assistant_reply) + logger.debug("Assistant reply JSON: %s", str(assistant_reply_json)) + if assistant_reply_json == {}: + assistant_reply_json = attempt_to_fix_json_by_finding_outermost_brackets( + assistant_reply + ) + + logger.debug("Assistant reply JSON 2: %s", str(assistant_reply_json)) + if assistant_reply_json != {}: + return assistant_reply_json + + logger.error( + "Error: The following AI output couldn't be converted to a JSON:\n", + assistant_reply, + ) + if CFG.speak_mode: + say_text("I have received an invalid JSON response from the OpenAI API.") + + return {} + + +def correct_json(json_to_load: str) -> str: + """ + Correct common JSON errors. + Args: + json_to_load (str): The JSON string. + """ + + try: + logger.debug("json", json_to_load) + json.loads(json_to_load) + return json_to_load + except json.JSONDecodeError as e: + logger.debug("json loads error", e) + error_message = str(e) + if error_message.startswith("Invalid \\escape"): + json_to_load = fix_invalid_escape(json_to_load, error_message) + if error_message.startswith( + "Expecting property name enclosed in double quotes" + ): + json_to_load = add_quotes_to_property_names(json_to_load) + try: + json.loads(json_to_load) + return json_to_load + except json.JSONDecodeError as e: + logger.debug("json loads error - add quotes", e) + error_message = str(e) + if balanced_str := balance_braces(json_to_load): + return balanced_str + return json_to_load + + +def attempt_to_fix_json_by_finding_outermost_brackets(json_string: str): + if CFG.speak_mode and CFG.debug_mode: + say_text( + "I have received an invalid JSON response from the OpenAI API. " + "Trying to fix it now." + ) + logger.error("Attempting to fix JSON by finding outermost brackets\n") + + try: + json_pattern = regex.compile(r"\{(?:[^{}]|(?R))*\}") + json_match = json_pattern.search(json_string) + + if json_match: + # Extract the valid JSON object from the string + json_string = json_match.group(0) + logger.typewriter_log( + title="Apparently json was fixed.", title_color=Fore.GREEN + ) + if CFG.speak_mode and CFG.debug_mode: + say_text("Apparently json was fixed.") + else: + return {} + + except (json.JSONDecodeError, ValueError): + if CFG.debug_mode: + logger.error(f"Error: Invalid JSON: {json_string}\n") + if CFG.speak_mode: + say_text("Didn't work. I will have to ignore this response then.") + logger.error("Error: Invalid JSON, setting it to empty JSON now.\n") + json_string = {} + + return fix_and_parse_json(json_string) diff --git a/pilot/commands/command.py b/pilot/commands/command.py index 6a987d723..36d3e324b 100644 --- a/pilot/commands/command.py +++ b/pilot/commands/command.py @@ -1,36 +1,153 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -import functools -import importlib -import inspect -from typing import Any, Callable, Optional +from pilot.prompts.generator import PromptGenerator +from typing import Dict, List, NoReturn, Union +from pilot.configs.config import Config -class Command: - """A class representing a command. +from pilot.speech import say_text - Attributes: - name (str): The name of the command. - description (str): A brief description of what the command does. - signature (str): The signature of the function that the command executes. Default to None. +from pilot.agent.json_fix_llm import fix_json_using_multiple_techniques +from pilot.commands.exception_not_commands import NotCommands +import json + + +def _resolve_pathlike_command_args(command_args): + if "directory" in command_args and command_args["directory"] in {"", "/"}: + # todo + command_args["directory"] = "" + else: + for pathlike in ["filename", "directory", "clone_path"]: + if pathlike in command_args: + # todo + command_args[pathlike] = "" + return command_args + + +def execute_ai_response_json( + prompt: PromptGenerator, + ai_response: str, + user_input: str = None, +) -> str: """ - def __init__(self, - name: str, - description: str, - method: Callable[..., Any], - signature: str = "", - enabled: bool = True, - disabled_reason: Optional[str] = None, - ) -> None: - self.name = name - self.description = description - self.method = method - self.signature = signature if signature else str(inspect.signature(self.method)) - self.enabled = enabled - self.disabled_reason = disabled_reason + Args: + command_registry: + ai_response: + prompt: - def __call__(self, *args: Any, **kwds: Any) -> Any: - if not self.enabled: - return f"Command '{self.name}' is disabled: {self.disabled_reason}" - return self.method(*args, **kwds) \ No newline at end of file + Returns: + + """ + cfg = Config() + try: + assistant_reply_json = fix_json_using_multiple_techniques(ai_response) + except (json.JSONDecodeError, ValueError) as e: + raise NotCommands("非可执行命令结构") + command_name, arguments = get_command(assistant_reply_json) + if cfg.speak_mode: + say_text(f"I want to execute {command_name}") + + arguments = _resolve_pathlike_command_args(arguments) + # Execute command + if command_name is not None and command_name.lower().startswith("error"): + result = ( + f"Command {command_name} threw the following error: {arguments}" + ) + elif command_name == "human_feedback": + result = f"Human feedback: {user_input}" + else: + for plugin in cfg.plugins: + if not plugin.can_handle_pre_command(): + continue + command_name, arguments = plugin.pre_command( + command_name, arguments + ) + command_result = execute_command( + command_name, + arguments, + prompt, + ) + result = f"Command {command_name} returned: " f"{command_result}" + return result + + +def execute_command( + command_name: str, + arguments, + prompt: PromptGenerator, +): + """Execute the command and return the result + + Args: + command_name (str): The name of the command to execute + arguments (dict): The arguments for the command + + Returns: + str: The result of the command + """ + + cmd = prompt.command_registry.commands.get(command_name) + + # If the command is found, call it with the provided arguments + if cmd: + try: + return cmd(**arguments) + except Exception as e: + return f"Error: {str(e)}" + # TODO: Change these to take in a file rather than pasted code, if + # non-file is given, return instructions "Input should be a python + # filepath, write your code to file and try again + else: + for command in prompt.commands: + if ( + command_name == command["label"].lower() + or command_name == command["name"].lower() + ): + try: + + return command["function"](**arguments) + except Exception as e: + return f"Error: {str(e)}" + raise NotCommands("非可用命令" + command) + + +def get_command(response_json: Dict): + """Parse the response and return the command name and arguments + + Args: + response_json (json): The response from the AI + + Returns: + tuple: The command name and arguments + + Raises: + json.decoder.JSONDecodeError: If the response is not valid JSON + + Exception: If any other error occurs + """ + try: + if "command" not in response_json: + return "Error:", "Missing 'command' object in JSON" + + if not isinstance(response_json, dict): + return "Error:", f"'response_json' object is not dictionary {response_json}" + + command = response_json["command"] + if not isinstance(command, dict): + return "Error:", "'command' object is not a dictionary" + + if "name" not in command: + return "Error:", "Missing 'name' field in 'command' object" + + command_name = command["name"] + + # Use an empty dictionary if 'args' field is not present in 'command' object + arguments = command.get("args", {}) + + return command_name, arguments + except json.decoder.JSONDecodeError: + return "Error:", "Invalid JSON" + # All other errors, return "Error: + error message" + except Exception as e: + return "Error:", str(e) diff --git a/pilot/commands/exception_not_commands.py b/pilot/commands/exception_not_commands.py new file mode 100644 index 000000000..283e618e1 --- /dev/null +++ b/pilot/commands/exception_not_commands.py @@ -0,0 +1,4 @@ +class NotCommands(Exception): + def __init__(self, message, error_code): + super().__init__(message) + self.error_code = error_code \ No newline at end of file diff --git a/pilot/configs/config.py b/pilot/configs/config.py index 7d5ee7eea..57650fee2 100644 --- a/pilot/configs/config.py +++ b/pilot/configs/config.py @@ -42,7 +42,12 @@ class Config(metaclass=Singleton): self.milvus_collection = os.getenv("MILVUS_COLLECTION", "dbgpt") self.milvus_secure = os.getenv("MILVUS_SECURE") == "True" - self.plugins_dir = os.getenv("PLUGINS_DIR", "plugins") + self.authorise_key = os.getenv("AUTHORISE_COMMAND_KEY", "y") + self.exit_key = os.getenv("EXIT_KEY", "n") + self.image_provider = bool(os.getenv("IMAGE_PROVIDER", True)) + self.image_size = int(os.getenv("IMAGE_SIZE", 256)) + + self.plugins_dir = os.getenv("PLUGINS_DIR", "../../plugins") self.plugins: List[AutoGPTPluginTemplate] = [] self.plugins_openai = [] @@ -57,6 +62,7 @@ class Config(metaclass=Singleton): self.huggingface_audio_to_text_model = os.getenv( "HUGGINGFACE_AUDIO_TO_TEXT_MODEL" ) + self.speak_mode = False disabled_command_categories = os.getenv("DISABLED_COMMAND_CATEGORIES") if disabled_command_categories: @@ -92,4 +98,6 @@ class Config(metaclass=Singleton): """Set the temperature value.""" self.temperature = value - \ No newline at end of file + def set_speak_mode(self, value: bool) -> None: + """Set the speak mode value.""" + self.speak_mode = value diff --git a/pilot/json_utils/__init__.py b/pilot/json_utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/json_utils/json_fix_general.py b/pilot/json_utils/json_fix_general.py new file mode 100644 index 000000000..eecf83568 --- /dev/null +++ b/pilot/json_utils/json_fix_general.py @@ -0,0 +1,121 @@ +"""This module contains functions to fix JSON strings using general programmatic approaches, suitable for addressing +common JSON formatting issues.""" +from __future__ import annotations + +import contextlib +import json +import re +from typing import Optional + +from pilot.configs.config import Config +from pilot.logs import logger +from pilot.json_utils.utilities import extract_char_position + +CFG = Config() + + +def fix_invalid_escape(json_to_load: str, error_message: str) -> str: + """Fix invalid escape sequences in JSON strings. + + Args: + json_to_load (str): The JSON string. + error_message (str): The error message from the JSONDecodeError + exception. + + Returns: + str: The JSON string with invalid escape sequences fixed. + """ + while error_message.startswith("Invalid \\escape"): + bad_escape_location = extract_char_position(error_message) + json_to_load = ( + json_to_load[:bad_escape_location] + json_to_load[bad_escape_location + 1 :] + ) + try: + json.loads(json_to_load) + return json_to_load + except json.JSONDecodeError as e: + logger.debug("json loads error - fix invalid escape", e) + error_message = str(e) + return json_to_load + + +def balance_braces(json_string: str) -> Optional[str]: + """ + Balance the braces in a JSON string. + + Args: + json_string (str): The JSON string. + + Returns: + str: The JSON string with braces balanced. + """ + + open_braces_count = json_string.count("{") + close_braces_count = json_string.count("}") + + while open_braces_count > close_braces_count: + json_string += "}" + close_braces_count += 1 + + while close_braces_count > open_braces_count: + json_string = json_string.rstrip("}") + close_braces_count -= 1 + + with contextlib.suppress(json.JSONDecodeError): + json.loads(json_string) + return json_string + + +def add_quotes_to_property_names(json_string: str) -> str: + """ + Add quotes to property names in a JSON string. + + Args: + json_string (str): The JSON string. + + Returns: + str: The JSON string with quotes added to property names. + """ + + def replace_func(match: re.Match) -> str: + return f'"{match[1]}":' + + property_name_pattern = re.compile(r"(\w+):") + corrected_json_string = property_name_pattern.sub(replace_func, json_string) + + try: + json.loads(corrected_json_string) + return corrected_json_string + except json.JSONDecodeError as e: + raise e + + +def correct_json(json_to_load: str) -> str: + """ + Correct common JSON errors. + Args: + json_to_load (str): The JSON string. + """ + + try: + logger.debug("json", json_to_load) + json.loads(json_to_load) + return json_to_load + except json.JSONDecodeError as e: + logger.debug("json loads error", e) + error_message = str(e) + if error_message.startswith("Invalid \\escape"): + json_to_load = fix_invalid_escape(json_to_load, error_message) + if error_message.startswith( + "Expecting property name enclosed in double quotes" + ): + json_to_load = add_quotes_to_property_names(json_to_load) + try: + json.loads(json_to_load) + return json_to_load + except json.JSONDecodeError as e: + logger.debug("json loads error - add quotes", e) + error_message = str(e) + if balanced_str := balance_braces(json_to_load): + return balanced_str + return json_to_load diff --git a/pilot/json_utils/utilities.py b/pilot/json_utils/utilities.py new file mode 100644 index 000000000..c8d207807 --- /dev/null +++ b/pilot/json_utils/utilities.py @@ -0,0 +1,81 @@ +"""Utilities for the json_fixes package.""" +import json +import os.path +import re + +from jsonschema import Draft7Validator + +from pilot.configs.config import Config +from pilot.logs import logger + +CFG = Config() +LLM_DEFAULT_RESPONSE_FORMAT = "llm_response_format_1" + + +def extract_char_position(error_message: str) -> int: + """Extract the character position from the JSONDecodeError message. + + Args: + error_message (str): The error message from the JSONDecodeError + exception. + + Returns: + int: The character position. + """ + + char_pattern = re.compile(r"\(char (\d+)\)") + if match := char_pattern.search(error_message): + return int(match[1]) + else: + raise ValueError("Character position not found in the error message.") + + +def validate_json(json_object: object, schema_name: str) -> dict | None: + """ + :type schema_name: object + :param schema_name: str + :type json_object: object + """ + scheme_file = os.path.join(os.path.dirname(__file__), f"{schema_name}.json") + with open(scheme_file, "r") as f: + schema = json.load(f) + validator = Draft7Validator(schema) + + if errors := sorted(validator.iter_errors(json_object), key=lambda e: e.path): + logger.error("The JSON object is invalid.") + if CFG.debug_mode: + logger.error( + json.dumps(json_object, indent=4) + ) # Replace 'json_object' with the variable containing the JSON data + logger.error("The following issues were found:") + + for error in errors: + logger.error(f"Error: {error.message}") + else: + logger.debug("The JSON object is valid.") + + return json_object + + +def validate_json_string(json_string: str, schema_name: str) -> dict | None: + """ + :type schema_name: object + :param schema_name: str + :type json_object: object + """ + + try: + json_loaded = json.loads(json_string) + return validate_json(json_loaded, schema_name) + except: + return None + + +def is_string_valid_json(json_string: str, schema_name: str) -> bool: + """ + :type schema_name: object + :param schema_name: str + :type json_object: object + """ + + return validate_json_string(json_string, schema_name) is not None diff --git a/pilot/model/inference.py b/pilot/model/inference.py index 532be9c33..c62b0e255 100644 --- a/pilot/model/inference.py +++ b/pilot/model/inference.py @@ -72,7 +72,7 @@ def generate_stream(model, tokenizer, params, device, def generate_output(model, tokenizer, params, device, context_len=2048, stream_interval=2): prompt = params["prompt"] temperature = float(params.get("temperature", 1.0)) - max_new_tokens = int(params.get("max_new_tokens", 1024)) + max_new_tokens = int(params.get("max_new_tokens", 2048)) stop_parameter = params.get("stop", None) if stop_parameter == tokenizer.eos_token: stop_parameter = None diff --git a/pilot/plugins.py b/pilot/plugins.py index 72b0a13a8..5f99e34ff 100644 --- a/pilot/plugins.py +++ b/pilot/plugins.py @@ -207,6 +207,8 @@ def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate List[Tuple[str, Path]]: List of plugins. """ loaded_plugins = [] + current_dir = os.getcwd() + print(current_dir) # Generic plugins plugins_path_path = Path(cfg.plugins_dir) diff --git a/pilot/prompts/first_conversation_prompt.py b/pilot/prompts/first_conversation_prompt.py index 8f26bcc54..9b9afc025 100644 --- a/pilot/prompts/first_conversation_prompt.py +++ b/pilot/prompts/first_conversation_prompt.py @@ -38,7 +38,6 @@ class FirstPrompt: def construct_first_prompt( self, - command_registry: [] = None, fisrt_message: [str]=[], prompt_generator: Optional[PromptGenerator] = None ) -> str: @@ -64,7 +63,7 @@ class FirstPrompt: if prompt_generator is None: prompt_generator = build_default_prompt_generator() prompt_generator.goals = fisrt_message - prompt_generator.command_registry = command_registry + prompt_generator.command_registry = self.command_registry # 加载插件中可用命令 cfg = Config() for plugin in cfg.plugins: diff --git a/pilot/server/__init__.py b/pilot/server/__init__.py index e69de29bb..909f8bf4b 100644 --- a/pilot/server/__init__.py +++ b/pilot/server/__init__.py @@ -0,0 +1,14 @@ +import os +import random +import sys + +from dotenv import load_dotenv + +if "pytest" in sys.argv or "pytest" in sys.modules or os.getenv("CI"): + print("Setting random seed to 42") + random.seed(42) + +# Load the users .env file into environment variables +load_dotenv(verbose=True, override=True) + +del load_dotenv diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index 6882cd2fc..db538da24 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -20,8 +20,6 @@ from pilot.configs.model_config import LOGDIR, VICUNA_MODEL_SERVER, LLM_MODEL, D from pilot.plugins import scan_plugins from pilot.configs.config import Config from pilot.commands.command_mange import CommandRegistry -from pilot.prompts.prompt import build_default_prompt_generator - from pilot.prompts.first_conversation_prompt import FirstPrompt from pilot.conversation import ( @@ -39,6 +37,8 @@ from pilot.utils import ( from pilot.server.gradio_css import code_highlight_css from pilot.server.gradio_patch import Chatbot as grChatbot +from pilot.commands.command import execute_ai_response_json + logger = build_logger("webserver", LOGDIR + "webserver.log") headers = {"User-Agent": "dbgpt Client"} @@ -172,12 +172,11 @@ def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr. if len(state.messages) == state.offset + 2: query = state.messages[-2][1] # 第一轮对话需要加入提示Prompt + cfg = Config() + first_prompt = FirstPrompt() + first_prompt.command_registry = cfg.command_registry if(autogpt): # autogpt模式的第一轮对话需要 构建专属prompt - cfg = Config() - first_prompt = FirstPrompt() - first_prompt.command_registry = cfg.command_registry - system_prompt = first_prompt.construct_first_prompt(fisrt_message=[query]) logger.info("[TEST]:" + system_prompt) template_name = "auto_dbgpt_one_shot" @@ -456,7 +455,7 @@ if __name__ == "__main__": args = parser.parse_args() logger.info(f"args: {args}") - dbs = get_database_list() + # dbs = get_database_list() # 加载插件 cfg = Config() @@ -464,7 +463,6 @@ if __name__ == "__main__": cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode)) # 加载插件可执行命令 - command_registry = CommandRegistry() command_categories = [ "pilot.commands.audio_text", "pilot.commands.image_gen", @@ -473,11 +471,11 @@ if __name__ == "__main__": command_categories = [ x for x in command_categories if x not in cfg.disabled_command_categories ] + command_registry = CommandRegistry() for command_category in command_categories: command_registry.import_commands(command_category) - cfg.command_registry =command_category - + cfg.command_registry =command_registry logger.info(args) demo = build_webdemo() diff --git a/plugins/Db-GPT-SimpleChart-Plugin.zip b/plugins/Db-GPT-SimpleChart-Plugin.zip new file mode 100644 index 000000000..03d995339 Binary files /dev/null and b/plugins/Db-GPT-SimpleChart-Plugin.zip differ