mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-26 21:37:40 +00:00
插件启动接入
This commit is contained in:
parent
192db2236a
commit
416205ae7a
1
.gitignore
vendored
1
.gitignore
vendored
@ -6,6 +6,7 @@ __pycache__/
|
|||||||
# C extensions
|
# C extensions
|
||||||
*.so
|
*.so
|
||||||
|
|
||||||
|
.idea
|
||||||
.vscode
|
.vscode
|
||||||
# Distribution / packaging
|
# Distribution / packaging
|
||||||
.Python
|
.Python
|
||||||
|
@ -1,4 +1,7 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
<project version="4">
|
<project version="4">
|
||||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10 (Auto-GPT)" project-jdk-type="Python SDK" />
|
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10 (Auto-GPT)" project-jdk-type="Python SDK" />
|
||||||
|
<component name="PythonCompatibilityInspectionAdvertiser">
|
||||||
|
<option name="version" value="3" />
|
||||||
|
</component>
|
||||||
</project>
|
</project>
|
166
pilot/agent/json_fix_llm.py
Normal file
166
pilot/agent/json_fix_llm.py
Normal file
@ -0,0 +1,166 @@
|
|||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any, Dict
|
||||||
|
import contextlib
|
||||||
|
from colorama import Fore
|
||||||
|
from regex import regex
|
||||||
|
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
from pilot.logs import logger
|
||||||
|
from pilot.speech import say_text
|
||||||
|
|
||||||
|
from pilot.json_utils.json_fix_general import fix_invalid_escape,add_quotes_to_property_names,balance_braces
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
|
|
||||||
|
def fix_and_parse_json(
|
||||||
|
json_to_load: str, try_to_fix_with_gpt: bool = True
|
||||||
|
) -> Dict[Any, Any]:
|
||||||
|
"""Fix and parse JSON string
|
||||||
|
|
||||||
|
Args:
|
||||||
|
json_to_load (str): The JSON string.
|
||||||
|
try_to_fix_with_gpt (bool, optional): Try to fix the JSON with GPT.
|
||||||
|
Defaults to True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str or dict[Any, Any]: The parsed JSON.
|
||||||
|
"""
|
||||||
|
|
||||||
|
with contextlib.suppress(json.JSONDecodeError):
|
||||||
|
json_to_load = json_to_load.replace("\t", "")
|
||||||
|
return json.loads(json_to_load)
|
||||||
|
|
||||||
|
with contextlib.suppress(json.JSONDecodeError):
|
||||||
|
json_to_load = correct_json(json_to_load)
|
||||||
|
return json.loads(json_to_load)
|
||||||
|
# Let's do something manually:
|
||||||
|
# sometimes GPT responds with something BEFORE the braces:
|
||||||
|
# "I'm sorry, I don't understand. Please try again."
|
||||||
|
# {"text": "I'm sorry, I don't understand. Please try again.",
|
||||||
|
# "confidence": 0.0}
|
||||||
|
# So let's try to find the first brace and then parse the rest
|
||||||
|
# of the string
|
||||||
|
try:
|
||||||
|
brace_index = json_to_load.index("{")
|
||||||
|
maybe_fixed_json = json_to_load[brace_index:]
|
||||||
|
last_brace_index = maybe_fixed_json.rindex("}")
|
||||||
|
maybe_fixed_json = maybe_fixed_json[: last_brace_index + 1]
|
||||||
|
return json.loads(maybe_fixed_json)
|
||||||
|
except (json.JSONDecodeError, ValueError) as e:
|
||||||
|
logger.error("参数解析错误", e)
|
||||||
|
|
||||||
|
|
||||||
|
def fix_json_using_multiple_techniques(assistant_reply: str) -> Dict[Any, Any]:
|
||||||
|
"""Fix the given JSON string to make it parseable and fully compliant with two techniques.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
json_string (str): The JSON string to fix.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The fixed JSON string.
|
||||||
|
"""
|
||||||
|
assistant_reply = assistant_reply.strip()
|
||||||
|
if assistant_reply.startswith("```json"):
|
||||||
|
assistant_reply = assistant_reply[7:]
|
||||||
|
if assistant_reply.endswith("```"):
|
||||||
|
assistant_reply = assistant_reply[:-3]
|
||||||
|
try:
|
||||||
|
return json.loads(assistant_reply) # just check the validity
|
||||||
|
except json.JSONDecodeError as e: # noqa: E722
|
||||||
|
print(f"JSONDecodeError: {e}")
|
||||||
|
pass
|
||||||
|
|
||||||
|
if assistant_reply.startswith("json "):
|
||||||
|
assistant_reply = assistant_reply[5:]
|
||||||
|
assistant_reply = assistant_reply.strip()
|
||||||
|
try:
|
||||||
|
return json.loads(assistant_reply) # just check the validity
|
||||||
|
except json.JSONDecodeError: # noqa: E722
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Parse and print Assistant response
|
||||||
|
assistant_reply_json = fix_and_parse_json(assistant_reply)
|
||||||
|
logger.debug("Assistant reply JSON: %s", str(assistant_reply_json))
|
||||||
|
if assistant_reply_json == {}:
|
||||||
|
assistant_reply_json = attempt_to_fix_json_by_finding_outermost_brackets(
|
||||||
|
assistant_reply
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug("Assistant reply JSON 2: %s", str(assistant_reply_json))
|
||||||
|
if assistant_reply_json != {}:
|
||||||
|
return assistant_reply_json
|
||||||
|
|
||||||
|
logger.error(
|
||||||
|
"Error: The following AI output couldn't be converted to a JSON:\n",
|
||||||
|
assistant_reply,
|
||||||
|
)
|
||||||
|
if CFG.speak_mode:
|
||||||
|
say_text("I have received an invalid JSON response from the OpenAI API.")
|
||||||
|
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def correct_json(json_to_load: str) -> str:
|
||||||
|
"""
|
||||||
|
Correct common JSON errors.
|
||||||
|
Args:
|
||||||
|
json_to_load (str): The JSON string.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.debug("json", json_to_load)
|
||||||
|
json.loads(json_to_load)
|
||||||
|
return json_to_load
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.debug("json loads error", e)
|
||||||
|
error_message = str(e)
|
||||||
|
if error_message.startswith("Invalid \\escape"):
|
||||||
|
json_to_load = fix_invalid_escape(json_to_load, error_message)
|
||||||
|
if error_message.startswith(
|
||||||
|
"Expecting property name enclosed in double quotes"
|
||||||
|
):
|
||||||
|
json_to_load = add_quotes_to_property_names(json_to_load)
|
||||||
|
try:
|
||||||
|
json.loads(json_to_load)
|
||||||
|
return json_to_load
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.debug("json loads error - add quotes", e)
|
||||||
|
error_message = str(e)
|
||||||
|
if balanced_str := balance_braces(json_to_load):
|
||||||
|
return balanced_str
|
||||||
|
return json_to_load
|
||||||
|
|
||||||
|
|
||||||
|
def attempt_to_fix_json_by_finding_outermost_brackets(json_string: str):
|
||||||
|
if CFG.speak_mode and CFG.debug_mode:
|
||||||
|
say_text(
|
||||||
|
"I have received an invalid JSON response from the OpenAI API. "
|
||||||
|
"Trying to fix it now."
|
||||||
|
)
|
||||||
|
logger.error("Attempting to fix JSON by finding outermost brackets\n")
|
||||||
|
|
||||||
|
try:
|
||||||
|
json_pattern = regex.compile(r"\{(?:[^{}]|(?R))*\}")
|
||||||
|
json_match = json_pattern.search(json_string)
|
||||||
|
|
||||||
|
if json_match:
|
||||||
|
# Extract the valid JSON object from the string
|
||||||
|
json_string = json_match.group(0)
|
||||||
|
logger.typewriter_log(
|
||||||
|
title="Apparently json was fixed.", title_color=Fore.GREEN
|
||||||
|
)
|
||||||
|
if CFG.speak_mode and CFG.debug_mode:
|
||||||
|
say_text("Apparently json was fixed.")
|
||||||
|
else:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
if CFG.debug_mode:
|
||||||
|
logger.error(f"Error: Invalid JSON: {json_string}\n")
|
||||||
|
if CFG.speak_mode:
|
||||||
|
say_text("Didn't work. I will have to ignore this response then.")
|
||||||
|
logger.error("Error: Invalid JSON, setting it to empty JSON now.\n")
|
||||||
|
json_string = {}
|
||||||
|
|
||||||
|
return fix_and_parse_json(json_string)
|
@ -1,36 +1,153 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
import functools
|
from pilot.prompts.generator import PromptGenerator
|
||||||
import importlib
|
from typing import Dict, List, NoReturn, Union
|
||||||
import inspect
|
from pilot.configs.config import Config
|
||||||
from typing import Any, Callable, Optional
|
|
||||||
|
|
||||||
class Command:
|
from pilot.speech import say_text
|
||||||
"""A class representing a command.
|
|
||||||
|
|
||||||
Attributes:
|
from pilot.agent.json_fix_llm import fix_json_using_multiple_techniques
|
||||||
name (str): The name of the command.
|
from pilot.commands.exception_not_commands import NotCommands
|
||||||
description (str): A brief description of what the command does.
|
import json
|
||||||
signature (str): The signature of the function that the command executes. Default to None.
|
|
||||||
|
|
||||||
|
def _resolve_pathlike_command_args(command_args):
|
||||||
|
if "directory" in command_args and command_args["directory"] in {"", "/"}:
|
||||||
|
# todo
|
||||||
|
command_args["directory"] = ""
|
||||||
|
else:
|
||||||
|
for pathlike in ["filename", "directory", "clone_path"]:
|
||||||
|
if pathlike in command_args:
|
||||||
|
# todo
|
||||||
|
command_args[pathlike] = ""
|
||||||
|
return command_args
|
||||||
|
|
||||||
|
|
||||||
|
def execute_ai_response_json(
|
||||||
|
prompt: PromptGenerator,
|
||||||
|
ai_response: str,
|
||||||
|
user_input: str = None,
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
Args:
|
||||||
name: str,
|
command_registry:
|
||||||
description: str,
|
ai_response:
|
||||||
method: Callable[..., Any],
|
prompt:
|
||||||
signature: str = "",
|
|
||||||
enabled: bool = True,
|
|
||||||
disabled_reason: Optional[str] = None,
|
|
||||||
) -> None:
|
|
||||||
self.name = name
|
|
||||||
self.description = description
|
|
||||||
self.method = method
|
|
||||||
self.signature = signature if signature else str(inspect.signature(self.method))
|
|
||||||
self.enabled = enabled
|
|
||||||
self.disabled_reason = disabled_reason
|
|
||||||
|
|
||||||
def __call__(self, *args: Any, **kwds: Any) -> Any:
|
Returns:
|
||||||
if not self.enabled:
|
|
||||||
return f"Command '{self.name}' is disabled: {self.disabled_reason}"
|
"""
|
||||||
return self.method(*args, **kwds)
|
cfg = Config()
|
||||||
|
try:
|
||||||
|
assistant_reply_json = fix_json_using_multiple_techniques(ai_response)
|
||||||
|
except (json.JSONDecodeError, ValueError) as e:
|
||||||
|
raise NotCommands("非可执行命令结构")
|
||||||
|
command_name, arguments = get_command(assistant_reply_json)
|
||||||
|
if cfg.speak_mode:
|
||||||
|
say_text(f"I want to execute {command_name}")
|
||||||
|
|
||||||
|
arguments = _resolve_pathlike_command_args(arguments)
|
||||||
|
# Execute command
|
||||||
|
if command_name is not None and command_name.lower().startswith("error"):
|
||||||
|
result = (
|
||||||
|
f"Command {command_name} threw the following error: {arguments}"
|
||||||
|
)
|
||||||
|
elif command_name == "human_feedback":
|
||||||
|
result = f"Human feedback: {user_input}"
|
||||||
|
else:
|
||||||
|
for plugin in cfg.plugins:
|
||||||
|
if not plugin.can_handle_pre_command():
|
||||||
|
continue
|
||||||
|
command_name, arguments = plugin.pre_command(
|
||||||
|
command_name, arguments
|
||||||
|
)
|
||||||
|
command_result = execute_command(
|
||||||
|
command_name,
|
||||||
|
arguments,
|
||||||
|
prompt,
|
||||||
|
)
|
||||||
|
result = f"Command {command_name} returned: " f"{command_result}"
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def execute_command(
|
||||||
|
command_name: str,
|
||||||
|
arguments,
|
||||||
|
prompt: PromptGenerator,
|
||||||
|
):
|
||||||
|
"""Execute the command and return the result
|
||||||
|
|
||||||
|
Args:
|
||||||
|
command_name (str): The name of the command to execute
|
||||||
|
arguments (dict): The arguments for the command
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The result of the command
|
||||||
|
"""
|
||||||
|
|
||||||
|
cmd = prompt.command_registry.commands.get(command_name)
|
||||||
|
|
||||||
|
# If the command is found, call it with the provided arguments
|
||||||
|
if cmd:
|
||||||
|
try:
|
||||||
|
return cmd(**arguments)
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error: {str(e)}"
|
||||||
|
# TODO: Change these to take in a file rather than pasted code, if
|
||||||
|
# non-file is given, return instructions "Input should be a python
|
||||||
|
# filepath, write your code to file and try again
|
||||||
|
else:
|
||||||
|
for command in prompt.commands:
|
||||||
|
if (
|
||||||
|
command_name == command["label"].lower()
|
||||||
|
or command_name == command["name"].lower()
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
|
||||||
|
return command["function"](**arguments)
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error: {str(e)}"
|
||||||
|
raise NotCommands("非可用命令" + command)
|
||||||
|
|
||||||
|
|
||||||
|
def get_command(response_json: Dict):
|
||||||
|
"""Parse the response and return the command name and arguments
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response_json (json): The response from the AI
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: The command name and arguments
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
json.decoder.JSONDecodeError: If the response is not valid JSON
|
||||||
|
|
||||||
|
Exception: If any other error occurs
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if "command" not in response_json:
|
||||||
|
return "Error:", "Missing 'command' object in JSON"
|
||||||
|
|
||||||
|
if not isinstance(response_json, dict):
|
||||||
|
return "Error:", f"'response_json' object is not dictionary {response_json}"
|
||||||
|
|
||||||
|
command = response_json["command"]
|
||||||
|
if not isinstance(command, dict):
|
||||||
|
return "Error:", "'command' object is not a dictionary"
|
||||||
|
|
||||||
|
if "name" not in command:
|
||||||
|
return "Error:", "Missing 'name' field in 'command' object"
|
||||||
|
|
||||||
|
command_name = command["name"]
|
||||||
|
|
||||||
|
# Use an empty dictionary if 'args' field is not present in 'command' object
|
||||||
|
arguments = command.get("args", {})
|
||||||
|
|
||||||
|
return command_name, arguments
|
||||||
|
except json.decoder.JSONDecodeError:
|
||||||
|
return "Error:", "Invalid JSON"
|
||||||
|
# All other errors, return "Error: + error message"
|
||||||
|
except Exception as e:
|
||||||
|
return "Error:", str(e)
|
||||||
|
4
pilot/commands/exception_not_commands.py
Normal file
4
pilot/commands/exception_not_commands.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
class NotCommands(Exception):
|
||||||
|
def __init__(self, message, error_code):
|
||||||
|
super().__init__(message)
|
||||||
|
self.error_code = error_code
|
@ -42,7 +42,12 @@ class Config(metaclass=Singleton):
|
|||||||
self.milvus_collection = os.getenv("MILVUS_COLLECTION", "dbgpt")
|
self.milvus_collection = os.getenv("MILVUS_COLLECTION", "dbgpt")
|
||||||
self.milvus_secure = os.getenv("MILVUS_SECURE") == "True"
|
self.milvus_secure = os.getenv("MILVUS_SECURE") == "True"
|
||||||
|
|
||||||
self.plugins_dir = os.getenv("PLUGINS_DIR", "plugins")
|
self.authorise_key = os.getenv("AUTHORISE_COMMAND_KEY", "y")
|
||||||
|
self.exit_key = os.getenv("EXIT_KEY", "n")
|
||||||
|
self.image_provider = bool(os.getenv("IMAGE_PROVIDER", True))
|
||||||
|
self.image_size = int(os.getenv("IMAGE_SIZE", 256))
|
||||||
|
|
||||||
|
self.plugins_dir = os.getenv("PLUGINS_DIR", "../../plugins")
|
||||||
self.plugins: List[AutoGPTPluginTemplate] = []
|
self.plugins: List[AutoGPTPluginTemplate] = []
|
||||||
self.plugins_openai = []
|
self.plugins_openai = []
|
||||||
|
|
||||||
@ -57,6 +62,7 @@ class Config(metaclass=Singleton):
|
|||||||
self.huggingface_audio_to_text_model = os.getenv(
|
self.huggingface_audio_to_text_model = os.getenv(
|
||||||
"HUGGINGFACE_AUDIO_TO_TEXT_MODEL"
|
"HUGGINGFACE_AUDIO_TO_TEXT_MODEL"
|
||||||
)
|
)
|
||||||
|
self.speak_mode = False
|
||||||
|
|
||||||
disabled_command_categories = os.getenv("DISABLED_COMMAND_CATEGORIES")
|
disabled_command_categories = os.getenv("DISABLED_COMMAND_CATEGORIES")
|
||||||
if disabled_command_categories:
|
if disabled_command_categories:
|
||||||
@ -92,4 +98,6 @@ class Config(metaclass=Singleton):
|
|||||||
"""Set the temperature value."""
|
"""Set the temperature value."""
|
||||||
self.temperature = value
|
self.temperature = value
|
||||||
|
|
||||||
|
def set_speak_mode(self, value: bool) -> None:
|
||||||
|
"""Set the speak mode value."""
|
||||||
|
self.speak_mode = value
|
||||||
|
0
pilot/json_utils/__init__.py
Normal file
0
pilot/json_utils/__init__.py
Normal file
121
pilot/json_utils/json_fix_general.py
Normal file
121
pilot/json_utils/json_fix_general.py
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
"""This module contains functions to fix JSON strings using general programmatic approaches, suitable for addressing
|
||||||
|
common JSON formatting issues."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
from pilot.logs import logger
|
||||||
|
from pilot.json_utils.utilities import extract_char_position
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
|
def fix_invalid_escape(json_to_load: str, error_message: str) -> str:
|
||||||
|
"""Fix invalid escape sequences in JSON strings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
json_to_load (str): The JSON string.
|
||||||
|
error_message (str): The error message from the JSONDecodeError
|
||||||
|
exception.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The JSON string with invalid escape sequences fixed.
|
||||||
|
"""
|
||||||
|
while error_message.startswith("Invalid \\escape"):
|
||||||
|
bad_escape_location = extract_char_position(error_message)
|
||||||
|
json_to_load = (
|
||||||
|
json_to_load[:bad_escape_location] + json_to_load[bad_escape_location + 1 :]
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
json.loads(json_to_load)
|
||||||
|
return json_to_load
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.debug("json loads error - fix invalid escape", e)
|
||||||
|
error_message = str(e)
|
||||||
|
return json_to_load
|
||||||
|
|
||||||
|
|
||||||
|
def balance_braces(json_string: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Balance the braces in a JSON string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
json_string (str): The JSON string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The JSON string with braces balanced.
|
||||||
|
"""
|
||||||
|
|
||||||
|
open_braces_count = json_string.count("{")
|
||||||
|
close_braces_count = json_string.count("}")
|
||||||
|
|
||||||
|
while open_braces_count > close_braces_count:
|
||||||
|
json_string += "}"
|
||||||
|
close_braces_count += 1
|
||||||
|
|
||||||
|
while close_braces_count > open_braces_count:
|
||||||
|
json_string = json_string.rstrip("}")
|
||||||
|
close_braces_count -= 1
|
||||||
|
|
||||||
|
with contextlib.suppress(json.JSONDecodeError):
|
||||||
|
json.loads(json_string)
|
||||||
|
return json_string
|
||||||
|
|
||||||
|
|
||||||
|
def add_quotes_to_property_names(json_string: str) -> str:
|
||||||
|
"""
|
||||||
|
Add quotes to property names in a JSON string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
json_string (str): The JSON string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The JSON string with quotes added to property names.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def replace_func(match: re.Match) -> str:
|
||||||
|
return f'"{match[1]}":'
|
||||||
|
|
||||||
|
property_name_pattern = re.compile(r"(\w+):")
|
||||||
|
corrected_json_string = property_name_pattern.sub(replace_func, json_string)
|
||||||
|
|
||||||
|
try:
|
||||||
|
json.loads(corrected_json_string)
|
||||||
|
return corrected_json_string
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
def correct_json(json_to_load: str) -> str:
|
||||||
|
"""
|
||||||
|
Correct common JSON errors.
|
||||||
|
Args:
|
||||||
|
json_to_load (str): The JSON string.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.debug("json", json_to_load)
|
||||||
|
json.loads(json_to_load)
|
||||||
|
return json_to_load
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.debug("json loads error", e)
|
||||||
|
error_message = str(e)
|
||||||
|
if error_message.startswith("Invalid \\escape"):
|
||||||
|
json_to_load = fix_invalid_escape(json_to_load, error_message)
|
||||||
|
if error_message.startswith(
|
||||||
|
"Expecting property name enclosed in double quotes"
|
||||||
|
):
|
||||||
|
json_to_load = add_quotes_to_property_names(json_to_load)
|
||||||
|
try:
|
||||||
|
json.loads(json_to_load)
|
||||||
|
return json_to_load
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.debug("json loads error - add quotes", e)
|
||||||
|
error_message = str(e)
|
||||||
|
if balanced_str := balance_braces(json_to_load):
|
||||||
|
return balanced_str
|
||||||
|
return json_to_load
|
81
pilot/json_utils/utilities.py
Normal file
81
pilot/json_utils/utilities.py
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
"""Utilities for the json_fixes package."""
|
||||||
|
import json
|
||||||
|
import os.path
|
||||||
|
import re
|
||||||
|
|
||||||
|
from jsonschema import Draft7Validator
|
||||||
|
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
from pilot.logs import logger
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
|
LLM_DEFAULT_RESPONSE_FORMAT = "llm_response_format_1"
|
||||||
|
|
||||||
|
|
||||||
|
def extract_char_position(error_message: str) -> int:
|
||||||
|
"""Extract the character position from the JSONDecodeError message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error_message (str): The error message from the JSONDecodeError
|
||||||
|
exception.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: The character position.
|
||||||
|
"""
|
||||||
|
|
||||||
|
char_pattern = re.compile(r"\(char (\d+)\)")
|
||||||
|
if match := char_pattern.search(error_message):
|
||||||
|
return int(match[1])
|
||||||
|
else:
|
||||||
|
raise ValueError("Character position not found in the error message.")
|
||||||
|
|
||||||
|
|
||||||
|
def validate_json(json_object: object, schema_name: str) -> dict | None:
|
||||||
|
"""
|
||||||
|
:type schema_name: object
|
||||||
|
:param schema_name: str
|
||||||
|
:type json_object: object
|
||||||
|
"""
|
||||||
|
scheme_file = os.path.join(os.path.dirname(__file__), f"{schema_name}.json")
|
||||||
|
with open(scheme_file, "r") as f:
|
||||||
|
schema = json.load(f)
|
||||||
|
validator = Draft7Validator(schema)
|
||||||
|
|
||||||
|
if errors := sorted(validator.iter_errors(json_object), key=lambda e: e.path):
|
||||||
|
logger.error("The JSON object is invalid.")
|
||||||
|
if CFG.debug_mode:
|
||||||
|
logger.error(
|
||||||
|
json.dumps(json_object, indent=4)
|
||||||
|
) # Replace 'json_object' with the variable containing the JSON data
|
||||||
|
logger.error("The following issues were found:")
|
||||||
|
|
||||||
|
for error in errors:
|
||||||
|
logger.error(f"Error: {error.message}")
|
||||||
|
else:
|
||||||
|
logger.debug("The JSON object is valid.")
|
||||||
|
|
||||||
|
return json_object
|
||||||
|
|
||||||
|
|
||||||
|
def validate_json_string(json_string: str, schema_name: str) -> dict | None:
|
||||||
|
"""
|
||||||
|
:type schema_name: object
|
||||||
|
:param schema_name: str
|
||||||
|
:type json_object: object
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
json_loaded = json.loads(json_string)
|
||||||
|
return validate_json(json_loaded, schema_name)
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def is_string_valid_json(json_string: str, schema_name: str) -> bool:
|
||||||
|
"""
|
||||||
|
:type schema_name: object
|
||||||
|
:param schema_name: str
|
||||||
|
:type json_object: object
|
||||||
|
"""
|
||||||
|
|
||||||
|
return validate_json_string(json_string, schema_name) is not None
|
@ -72,7 +72,7 @@ def generate_stream(model, tokenizer, params, device,
|
|||||||
def generate_output(model, tokenizer, params, device, context_len=2048, stream_interval=2):
|
def generate_output(model, tokenizer, params, device, context_len=2048, stream_interval=2):
|
||||||
prompt = params["prompt"]
|
prompt = params["prompt"]
|
||||||
temperature = float(params.get("temperature", 1.0))
|
temperature = float(params.get("temperature", 1.0))
|
||||||
max_new_tokens = int(params.get("max_new_tokens", 1024))
|
max_new_tokens = int(params.get("max_new_tokens", 2048))
|
||||||
stop_parameter = params.get("stop", None)
|
stop_parameter = params.get("stop", None)
|
||||||
if stop_parameter == tokenizer.eos_token:
|
if stop_parameter == tokenizer.eos_token:
|
||||||
stop_parameter = None
|
stop_parameter = None
|
||||||
|
@ -207,6 +207,8 @@ def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate
|
|||||||
List[Tuple[str, Path]]: List of plugins.
|
List[Tuple[str, Path]]: List of plugins.
|
||||||
"""
|
"""
|
||||||
loaded_plugins = []
|
loaded_plugins = []
|
||||||
|
current_dir = os.getcwd()
|
||||||
|
print(current_dir)
|
||||||
# Generic plugins
|
# Generic plugins
|
||||||
plugins_path_path = Path(cfg.plugins_dir)
|
plugins_path_path = Path(cfg.plugins_dir)
|
||||||
|
|
||||||
|
@ -38,7 +38,6 @@ class FirstPrompt:
|
|||||||
|
|
||||||
def construct_first_prompt(
|
def construct_first_prompt(
|
||||||
self,
|
self,
|
||||||
command_registry: [] = None,
|
|
||||||
fisrt_message: [str]=[],
|
fisrt_message: [str]=[],
|
||||||
prompt_generator: Optional[PromptGenerator] = None
|
prompt_generator: Optional[PromptGenerator] = None
|
||||||
) -> str:
|
) -> str:
|
||||||
@ -64,7 +63,7 @@ class FirstPrompt:
|
|||||||
if prompt_generator is None:
|
if prompt_generator is None:
|
||||||
prompt_generator = build_default_prompt_generator()
|
prompt_generator = build_default_prompt_generator()
|
||||||
prompt_generator.goals = fisrt_message
|
prompt_generator.goals = fisrt_message
|
||||||
prompt_generator.command_registry = command_registry
|
prompt_generator.command_registry = self.command_registry
|
||||||
# 加载插件中可用命令
|
# 加载插件中可用命令
|
||||||
cfg = Config()
|
cfg = Config()
|
||||||
for plugin in cfg.plugins:
|
for plugin in cfg.plugins:
|
||||||
|
@ -0,0 +1,14 @@
|
|||||||
|
import os
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
if "pytest" in sys.argv or "pytest" in sys.modules or os.getenv("CI"):
|
||||||
|
print("Setting random seed to 42")
|
||||||
|
random.seed(42)
|
||||||
|
|
||||||
|
# Load the users .env file into environment variables
|
||||||
|
load_dotenv(verbose=True, override=True)
|
||||||
|
|
||||||
|
del load_dotenv
|
@ -20,8 +20,6 @@ from pilot.configs.model_config import LOGDIR, VICUNA_MODEL_SERVER, LLM_MODEL, D
|
|||||||
from pilot.plugins import scan_plugins
|
from pilot.plugins import scan_plugins
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.commands.command_mange import CommandRegistry
|
from pilot.commands.command_mange import CommandRegistry
|
||||||
from pilot.prompts.prompt import build_default_prompt_generator
|
|
||||||
|
|
||||||
from pilot.prompts.first_conversation_prompt import FirstPrompt
|
from pilot.prompts.first_conversation_prompt import FirstPrompt
|
||||||
|
|
||||||
from pilot.conversation import (
|
from pilot.conversation import (
|
||||||
@ -39,6 +37,8 @@ from pilot.utils import (
|
|||||||
from pilot.server.gradio_css import code_highlight_css
|
from pilot.server.gradio_css import code_highlight_css
|
||||||
from pilot.server.gradio_patch import Chatbot as grChatbot
|
from pilot.server.gradio_patch import Chatbot as grChatbot
|
||||||
|
|
||||||
|
from pilot.commands.command import execute_ai_response_json
|
||||||
|
|
||||||
logger = build_logger("webserver", LOGDIR + "webserver.log")
|
logger = build_logger("webserver", LOGDIR + "webserver.log")
|
||||||
headers = {"User-Agent": "dbgpt Client"}
|
headers = {"User-Agent": "dbgpt Client"}
|
||||||
|
|
||||||
@ -172,12 +172,11 @@ def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.
|
|||||||
if len(state.messages) == state.offset + 2:
|
if len(state.messages) == state.offset + 2:
|
||||||
query = state.messages[-2][1]
|
query = state.messages[-2][1]
|
||||||
# 第一轮对话需要加入提示Prompt
|
# 第一轮对话需要加入提示Prompt
|
||||||
if(autogpt):
|
|
||||||
# autogpt模式的第一轮对话需要 构建专属prompt
|
|
||||||
cfg = Config()
|
cfg = Config()
|
||||||
first_prompt = FirstPrompt()
|
first_prompt = FirstPrompt()
|
||||||
first_prompt.command_registry = cfg.command_registry
|
first_prompt.command_registry = cfg.command_registry
|
||||||
|
if(autogpt):
|
||||||
|
# autogpt模式的第一轮对话需要 构建专属prompt
|
||||||
system_prompt = first_prompt.construct_first_prompt(fisrt_message=[query])
|
system_prompt = first_prompt.construct_first_prompt(fisrt_message=[query])
|
||||||
logger.info("[TEST]:" + system_prompt)
|
logger.info("[TEST]:" + system_prompt)
|
||||||
template_name = "auto_dbgpt_one_shot"
|
template_name = "auto_dbgpt_one_shot"
|
||||||
@ -456,7 +455,7 @@ if __name__ == "__main__":
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
logger.info(f"args: {args}")
|
logger.info(f"args: {args}")
|
||||||
|
|
||||||
dbs = get_database_list()
|
# dbs = get_database_list()
|
||||||
|
|
||||||
# 加载插件
|
# 加载插件
|
||||||
cfg = Config()
|
cfg = Config()
|
||||||
@ -464,7 +463,6 @@ if __name__ == "__main__":
|
|||||||
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
|
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
|
||||||
|
|
||||||
# 加载插件可执行命令
|
# 加载插件可执行命令
|
||||||
command_registry = CommandRegistry()
|
|
||||||
command_categories = [
|
command_categories = [
|
||||||
"pilot.commands.audio_text",
|
"pilot.commands.audio_text",
|
||||||
"pilot.commands.image_gen",
|
"pilot.commands.image_gen",
|
||||||
@ -473,11 +471,11 @@ if __name__ == "__main__":
|
|||||||
command_categories = [
|
command_categories = [
|
||||||
x for x in command_categories if x not in cfg.disabled_command_categories
|
x for x in command_categories if x not in cfg.disabled_command_categories
|
||||||
]
|
]
|
||||||
|
command_registry = CommandRegistry()
|
||||||
for command_category in command_categories:
|
for command_category in command_categories:
|
||||||
command_registry.import_commands(command_category)
|
command_registry.import_commands(command_category)
|
||||||
|
|
||||||
cfg.command_registry =command_category
|
cfg.command_registry =command_registry
|
||||||
|
|
||||||
|
|
||||||
logger.info(args)
|
logger.info(args)
|
||||||
demo = build_webdemo()
|
demo = build_webdemo()
|
||||||
|
BIN
plugins/Db-GPT-SimpleChart-Plugin.zip
Normal file
BIN
plugins/Db-GPT-SimpleChart-Plugin.zip
Normal file
Binary file not shown.
Loading…
Reference in New Issue
Block a user