mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-23 20:26:15 +00:00
插件启动接入
This commit is contained in:
parent
192db2236a
commit
416205ae7a
1
.gitignore
vendored
1
.gitignore
vendored
@ -6,6 +6,7 @@ __pycache__/
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
.idea
|
||||
.vscode
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
|
@ -1,4 +1,7 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10 (Auto-GPT)" project-jdk-type="Python SDK" />
|
||||
<component name="PythonCompatibilityInspectionAdvertiser">
|
||||
<option name="version" value="3" />
|
||||
</component>
|
||||
</project>
|
166
pilot/agent/json_fix_llm.py
Normal file
166
pilot/agent/json_fix_llm.py
Normal file
@ -0,0 +1,166 @@
|
||||
|
||||
import json
|
||||
from typing import Any, Dict
|
||||
import contextlib
|
||||
from colorama import Fore
|
||||
from regex import regex
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.logs import logger
|
||||
from pilot.speech import say_text
|
||||
|
||||
from pilot.json_utils.json_fix_general import fix_invalid_escape,add_quotes_to_property_names,balance_braces
|
||||
|
||||
CFG = Config()
|
||||
|
||||
def fix_and_parse_json(
|
||||
json_to_load: str, try_to_fix_with_gpt: bool = True
|
||||
) -> Dict[Any, Any]:
|
||||
"""Fix and parse JSON string
|
||||
|
||||
Args:
|
||||
json_to_load (str): The JSON string.
|
||||
try_to_fix_with_gpt (bool, optional): Try to fix the JSON with GPT.
|
||||
Defaults to True.
|
||||
|
||||
Returns:
|
||||
str or dict[Any, Any]: The parsed JSON.
|
||||
"""
|
||||
|
||||
with contextlib.suppress(json.JSONDecodeError):
|
||||
json_to_load = json_to_load.replace("\t", "")
|
||||
return json.loads(json_to_load)
|
||||
|
||||
with contextlib.suppress(json.JSONDecodeError):
|
||||
json_to_load = correct_json(json_to_load)
|
||||
return json.loads(json_to_load)
|
||||
# Let's do something manually:
|
||||
# sometimes GPT responds with something BEFORE the braces:
|
||||
# "I'm sorry, I don't understand. Please try again."
|
||||
# {"text": "I'm sorry, I don't understand. Please try again.",
|
||||
# "confidence": 0.0}
|
||||
# So let's try to find the first brace and then parse the rest
|
||||
# of the string
|
||||
try:
|
||||
brace_index = json_to_load.index("{")
|
||||
maybe_fixed_json = json_to_load[brace_index:]
|
||||
last_brace_index = maybe_fixed_json.rindex("}")
|
||||
maybe_fixed_json = maybe_fixed_json[: last_brace_index + 1]
|
||||
return json.loads(maybe_fixed_json)
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
logger.error("参数解析错误", e)
|
||||
|
||||
|
||||
def fix_json_using_multiple_techniques(assistant_reply: str) -> Dict[Any, Any]:
|
||||
"""Fix the given JSON string to make it parseable and fully compliant with two techniques.
|
||||
|
||||
Args:
|
||||
json_string (str): The JSON string to fix.
|
||||
|
||||
Returns:
|
||||
str: The fixed JSON string.
|
||||
"""
|
||||
assistant_reply = assistant_reply.strip()
|
||||
if assistant_reply.startswith("```json"):
|
||||
assistant_reply = assistant_reply[7:]
|
||||
if assistant_reply.endswith("```"):
|
||||
assistant_reply = assistant_reply[:-3]
|
||||
try:
|
||||
return json.loads(assistant_reply) # just check the validity
|
||||
except json.JSONDecodeError as e: # noqa: E722
|
||||
print(f"JSONDecodeError: {e}")
|
||||
pass
|
||||
|
||||
if assistant_reply.startswith("json "):
|
||||
assistant_reply = assistant_reply[5:]
|
||||
assistant_reply = assistant_reply.strip()
|
||||
try:
|
||||
return json.loads(assistant_reply) # just check the validity
|
||||
except json.JSONDecodeError: # noqa: E722
|
||||
pass
|
||||
|
||||
# Parse and print Assistant response
|
||||
assistant_reply_json = fix_and_parse_json(assistant_reply)
|
||||
logger.debug("Assistant reply JSON: %s", str(assistant_reply_json))
|
||||
if assistant_reply_json == {}:
|
||||
assistant_reply_json = attempt_to_fix_json_by_finding_outermost_brackets(
|
||||
assistant_reply
|
||||
)
|
||||
|
||||
logger.debug("Assistant reply JSON 2: %s", str(assistant_reply_json))
|
||||
if assistant_reply_json != {}:
|
||||
return assistant_reply_json
|
||||
|
||||
logger.error(
|
||||
"Error: The following AI output couldn't be converted to a JSON:\n",
|
||||
assistant_reply,
|
||||
)
|
||||
if CFG.speak_mode:
|
||||
say_text("I have received an invalid JSON response from the OpenAI API.")
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
def correct_json(json_to_load: str) -> str:
|
||||
"""
|
||||
Correct common JSON errors.
|
||||
Args:
|
||||
json_to_load (str): The JSON string.
|
||||
"""
|
||||
|
||||
try:
|
||||
logger.debug("json", json_to_load)
|
||||
json.loads(json_to_load)
|
||||
return json_to_load
|
||||
except json.JSONDecodeError as e:
|
||||
logger.debug("json loads error", e)
|
||||
error_message = str(e)
|
||||
if error_message.startswith("Invalid \\escape"):
|
||||
json_to_load = fix_invalid_escape(json_to_load, error_message)
|
||||
if error_message.startswith(
|
||||
"Expecting property name enclosed in double quotes"
|
||||
):
|
||||
json_to_load = add_quotes_to_property_names(json_to_load)
|
||||
try:
|
||||
json.loads(json_to_load)
|
||||
return json_to_load
|
||||
except json.JSONDecodeError as e:
|
||||
logger.debug("json loads error - add quotes", e)
|
||||
error_message = str(e)
|
||||
if balanced_str := balance_braces(json_to_load):
|
||||
return balanced_str
|
||||
return json_to_load
|
||||
|
||||
|
||||
def attempt_to_fix_json_by_finding_outermost_brackets(json_string: str):
|
||||
if CFG.speak_mode and CFG.debug_mode:
|
||||
say_text(
|
||||
"I have received an invalid JSON response from the OpenAI API. "
|
||||
"Trying to fix it now."
|
||||
)
|
||||
logger.error("Attempting to fix JSON by finding outermost brackets\n")
|
||||
|
||||
try:
|
||||
json_pattern = regex.compile(r"\{(?:[^{}]|(?R))*\}")
|
||||
json_match = json_pattern.search(json_string)
|
||||
|
||||
if json_match:
|
||||
# Extract the valid JSON object from the string
|
||||
json_string = json_match.group(0)
|
||||
logger.typewriter_log(
|
||||
title="Apparently json was fixed.", title_color=Fore.GREEN
|
||||
)
|
||||
if CFG.speak_mode and CFG.debug_mode:
|
||||
say_text("Apparently json was fixed.")
|
||||
else:
|
||||
return {}
|
||||
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
if CFG.debug_mode:
|
||||
logger.error(f"Error: Invalid JSON: {json_string}\n")
|
||||
if CFG.speak_mode:
|
||||
say_text("Didn't work. I will have to ignore this response then.")
|
||||
logger.error("Error: Invalid JSON, setting it to empty JSON now.\n")
|
||||
json_string = {}
|
||||
|
||||
return fix_and_parse_json(json_string)
|
@ -1,36 +1,153 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import functools
|
||||
import importlib
|
||||
import inspect
|
||||
from typing import Any, Callable, Optional
|
||||
from pilot.prompts.generator import PromptGenerator
|
||||
from typing import Dict, List, NoReturn, Union
|
||||
from pilot.configs.config import Config
|
||||
|
||||
class Command:
|
||||
"""A class representing a command.
|
||||
from pilot.speech import say_text
|
||||
|
||||
Attributes:
|
||||
name (str): The name of the command.
|
||||
description (str): A brief description of what the command does.
|
||||
signature (str): The signature of the function that the command executes. Default to None.
|
||||
from pilot.agent.json_fix_llm import fix_json_using_multiple_techniques
|
||||
from pilot.commands.exception_not_commands import NotCommands
|
||||
import json
|
||||
|
||||
|
||||
def _resolve_pathlike_command_args(command_args):
|
||||
if "directory" in command_args and command_args["directory"] in {"", "/"}:
|
||||
# todo
|
||||
command_args["directory"] = ""
|
||||
else:
|
||||
for pathlike in ["filename", "directory", "clone_path"]:
|
||||
if pathlike in command_args:
|
||||
# todo
|
||||
command_args[pathlike] = ""
|
||||
return command_args
|
||||
|
||||
|
||||
def execute_ai_response_json(
|
||||
prompt: PromptGenerator,
|
||||
ai_response: str,
|
||||
user_input: str = None,
|
||||
) -> str:
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
name: str,
|
||||
description: str,
|
||||
method: Callable[..., Any],
|
||||
signature: str = "",
|
||||
enabled: bool = True,
|
||||
disabled_reason: Optional[str] = None,
|
||||
) -> None:
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.method = method
|
||||
self.signature = signature if signature else str(inspect.signature(self.method))
|
||||
self.enabled = enabled
|
||||
self.disabled_reason = disabled_reason
|
||||
Args:
|
||||
command_registry:
|
||||
ai_response:
|
||||
prompt:
|
||||
|
||||
def __call__(self, *args: Any, **kwds: Any) -> Any:
|
||||
if not self.enabled:
|
||||
return f"Command '{self.name}' is disabled: {self.disabled_reason}"
|
||||
return self.method(*args, **kwds)
|
||||
Returns:
|
||||
|
||||
"""
|
||||
cfg = Config()
|
||||
try:
|
||||
assistant_reply_json = fix_json_using_multiple_techniques(ai_response)
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
raise NotCommands("非可执行命令结构")
|
||||
command_name, arguments = get_command(assistant_reply_json)
|
||||
if cfg.speak_mode:
|
||||
say_text(f"I want to execute {command_name}")
|
||||
|
||||
arguments = _resolve_pathlike_command_args(arguments)
|
||||
# Execute command
|
||||
if command_name is not None and command_name.lower().startswith("error"):
|
||||
result = (
|
||||
f"Command {command_name} threw the following error: {arguments}"
|
||||
)
|
||||
elif command_name == "human_feedback":
|
||||
result = f"Human feedback: {user_input}"
|
||||
else:
|
||||
for plugin in cfg.plugins:
|
||||
if not plugin.can_handle_pre_command():
|
||||
continue
|
||||
command_name, arguments = plugin.pre_command(
|
||||
command_name, arguments
|
||||
)
|
||||
command_result = execute_command(
|
||||
command_name,
|
||||
arguments,
|
||||
prompt,
|
||||
)
|
||||
result = f"Command {command_name} returned: " f"{command_result}"
|
||||
return result
|
||||
|
||||
|
||||
def execute_command(
|
||||
command_name: str,
|
||||
arguments,
|
||||
prompt: PromptGenerator,
|
||||
):
|
||||
"""Execute the command and return the result
|
||||
|
||||
Args:
|
||||
command_name (str): The name of the command to execute
|
||||
arguments (dict): The arguments for the command
|
||||
|
||||
Returns:
|
||||
str: The result of the command
|
||||
"""
|
||||
|
||||
cmd = prompt.command_registry.commands.get(command_name)
|
||||
|
||||
# If the command is found, call it with the provided arguments
|
||||
if cmd:
|
||||
try:
|
||||
return cmd(**arguments)
|
||||
except Exception as e:
|
||||
return f"Error: {str(e)}"
|
||||
# TODO: Change these to take in a file rather than pasted code, if
|
||||
# non-file is given, return instructions "Input should be a python
|
||||
# filepath, write your code to file and try again
|
||||
else:
|
||||
for command in prompt.commands:
|
||||
if (
|
||||
command_name == command["label"].lower()
|
||||
or command_name == command["name"].lower()
|
||||
):
|
||||
try:
|
||||
|
||||
return command["function"](**arguments)
|
||||
except Exception as e:
|
||||
return f"Error: {str(e)}"
|
||||
raise NotCommands("非可用命令" + command)
|
||||
|
||||
|
||||
def get_command(response_json: Dict):
|
||||
"""Parse the response and return the command name and arguments
|
||||
|
||||
Args:
|
||||
response_json (json): The response from the AI
|
||||
|
||||
Returns:
|
||||
tuple: The command name and arguments
|
||||
|
||||
Raises:
|
||||
json.decoder.JSONDecodeError: If the response is not valid JSON
|
||||
|
||||
Exception: If any other error occurs
|
||||
"""
|
||||
try:
|
||||
if "command" not in response_json:
|
||||
return "Error:", "Missing 'command' object in JSON"
|
||||
|
||||
if not isinstance(response_json, dict):
|
||||
return "Error:", f"'response_json' object is not dictionary {response_json}"
|
||||
|
||||
command = response_json["command"]
|
||||
if not isinstance(command, dict):
|
||||
return "Error:", "'command' object is not a dictionary"
|
||||
|
||||
if "name" not in command:
|
||||
return "Error:", "Missing 'name' field in 'command' object"
|
||||
|
||||
command_name = command["name"]
|
||||
|
||||
# Use an empty dictionary if 'args' field is not present in 'command' object
|
||||
arguments = command.get("args", {})
|
||||
|
||||
return command_name, arguments
|
||||
except json.decoder.JSONDecodeError:
|
||||
return "Error:", "Invalid JSON"
|
||||
# All other errors, return "Error: + error message"
|
||||
except Exception as e:
|
||||
return "Error:", str(e)
|
||||
|
4
pilot/commands/exception_not_commands.py
Normal file
4
pilot/commands/exception_not_commands.py
Normal file
@ -0,0 +1,4 @@
|
||||
class NotCommands(Exception):
|
||||
def __init__(self, message, error_code):
|
||||
super().__init__(message)
|
||||
self.error_code = error_code
|
@ -42,7 +42,12 @@ class Config(metaclass=Singleton):
|
||||
self.milvus_collection = os.getenv("MILVUS_COLLECTION", "dbgpt")
|
||||
self.milvus_secure = os.getenv("MILVUS_SECURE") == "True"
|
||||
|
||||
self.plugins_dir = os.getenv("PLUGINS_DIR", "plugins")
|
||||
self.authorise_key = os.getenv("AUTHORISE_COMMAND_KEY", "y")
|
||||
self.exit_key = os.getenv("EXIT_KEY", "n")
|
||||
self.image_provider = bool(os.getenv("IMAGE_PROVIDER", True))
|
||||
self.image_size = int(os.getenv("IMAGE_SIZE", 256))
|
||||
|
||||
self.plugins_dir = os.getenv("PLUGINS_DIR", "../../plugins")
|
||||
self.plugins: List[AutoGPTPluginTemplate] = []
|
||||
self.plugins_openai = []
|
||||
|
||||
@ -57,6 +62,7 @@ class Config(metaclass=Singleton):
|
||||
self.huggingface_audio_to_text_model = os.getenv(
|
||||
"HUGGINGFACE_AUDIO_TO_TEXT_MODEL"
|
||||
)
|
||||
self.speak_mode = False
|
||||
|
||||
disabled_command_categories = os.getenv("DISABLED_COMMAND_CATEGORIES")
|
||||
if disabled_command_categories:
|
||||
@ -92,4 +98,6 @@ class Config(metaclass=Singleton):
|
||||
"""Set the temperature value."""
|
||||
self.temperature = value
|
||||
|
||||
|
||||
def set_speak_mode(self, value: bool) -> None:
|
||||
"""Set the speak mode value."""
|
||||
self.speak_mode = value
|
||||
|
0
pilot/json_utils/__init__.py
Normal file
0
pilot/json_utils/__init__.py
Normal file
121
pilot/json_utils/json_fix_general.py
Normal file
121
pilot/json_utils/json_fix_general.py
Normal file
@ -0,0 +1,121 @@
|
||||
"""This module contains functions to fix JSON strings using general programmatic approaches, suitable for addressing
|
||||
common JSON formatting issues."""
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import json
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.logs import logger
|
||||
from pilot.json_utils.utilities import extract_char_position
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
def fix_invalid_escape(json_to_load: str, error_message: str) -> str:
|
||||
"""Fix invalid escape sequences in JSON strings.
|
||||
|
||||
Args:
|
||||
json_to_load (str): The JSON string.
|
||||
error_message (str): The error message from the JSONDecodeError
|
||||
exception.
|
||||
|
||||
Returns:
|
||||
str: The JSON string with invalid escape sequences fixed.
|
||||
"""
|
||||
while error_message.startswith("Invalid \\escape"):
|
||||
bad_escape_location = extract_char_position(error_message)
|
||||
json_to_load = (
|
||||
json_to_load[:bad_escape_location] + json_to_load[bad_escape_location + 1 :]
|
||||
)
|
||||
try:
|
||||
json.loads(json_to_load)
|
||||
return json_to_load
|
||||
except json.JSONDecodeError as e:
|
||||
logger.debug("json loads error - fix invalid escape", e)
|
||||
error_message = str(e)
|
||||
return json_to_load
|
||||
|
||||
|
||||
def balance_braces(json_string: str) -> Optional[str]:
|
||||
"""
|
||||
Balance the braces in a JSON string.
|
||||
|
||||
Args:
|
||||
json_string (str): The JSON string.
|
||||
|
||||
Returns:
|
||||
str: The JSON string with braces balanced.
|
||||
"""
|
||||
|
||||
open_braces_count = json_string.count("{")
|
||||
close_braces_count = json_string.count("}")
|
||||
|
||||
while open_braces_count > close_braces_count:
|
||||
json_string += "}"
|
||||
close_braces_count += 1
|
||||
|
||||
while close_braces_count > open_braces_count:
|
||||
json_string = json_string.rstrip("}")
|
||||
close_braces_count -= 1
|
||||
|
||||
with contextlib.suppress(json.JSONDecodeError):
|
||||
json.loads(json_string)
|
||||
return json_string
|
||||
|
||||
|
||||
def add_quotes_to_property_names(json_string: str) -> str:
|
||||
"""
|
||||
Add quotes to property names in a JSON string.
|
||||
|
||||
Args:
|
||||
json_string (str): The JSON string.
|
||||
|
||||
Returns:
|
||||
str: The JSON string with quotes added to property names.
|
||||
"""
|
||||
|
||||
def replace_func(match: re.Match) -> str:
|
||||
return f'"{match[1]}":'
|
||||
|
||||
property_name_pattern = re.compile(r"(\w+):")
|
||||
corrected_json_string = property_name_pattern.sub(replace_func, json_string)
|
||||
|
||||
try:
|
||||
json.loads(corrected_json_string)
|
||||
return corrected_json_string
|
||||
except json.JSONDecodeError as e:
|
||||
raise e
|
||||
|
||||
|
||||
def correct_json(json_to_load: str) -> str:
|
||||
"""
|
||||
Correct common JSON errors.
|
||||
Args:
|
||||
json_to_load (str): The JSON string.
|
||||
"""
|
||||
|
||||
try:
|
||||
logger.debug("json", json_to_load)
|
||||
json.loads(json_to_load)
|
||||
return json_to_load
|
||||
except json.JSONDecodeError as e:
|
||||
logger.debug("json loads error", e)
|
||||
error_message = str(e)
|
||||
if error_message.startswith("Invalid \\escape"):
|
||||
json_to_load = fix_invalid_escape(json_to_load, error_message)
|
||||
if error_message.startswith(
|
||||
"Expecting property name enclosed in double quotes"
|
||||
):
|
||||
json_to_load = add_quotes_to_property_names(json_to_load)
|
||||
try:
|
||||
json.loads(json_to_load)
|
||||
return json_to_load
|
||||
except json.JSONDecodeError as e:
|
||||
logger.debug("json loads error - add quotes", e)
|
||||
error_message = str(e)
|
||||
if balanced_str := balance_braces(json_to_load):
|
||||
return balanced_str
|
||||
return json_to_load
|
81
pilot/json_utils/utilities.py
Normal file
81
pilot/json_utils/utilities.py
Normal file
@ -0,0 +1,81 @@
|
||||
"""Utilities for the json_fixes package."""
|
||||
import json
|
||||
import os.path
|
||||
import re
|
||||
|
||||
from jsonschema import Draft7Validator
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.logs import logger
|
||||
|
||||
CFG = Config()
|
||||
LLM_DEFAULT_RESPONSE_FORMAT = "llm_response_format_1"
|
||||
|
||||
|
||||
def extract_char_position(error_message: str) -> int:
|
||||
"""Extract the character position from the JSONDecodeError message.
|
||||
|
||||
Args:
|
||||
error_message (str): The error message from the JSONDecodeError
|
||||
exception.
|
||||
|
||||
Returns:
|
||||
int: The character position.
|
||||
"""
|
||||
|
||||
char_pattern = re.compile(r"\(char (\d+)\)")
|
||||
if match := char_pattern.search(error_message):
|
||||
return int(match[1])
|
||||
else:
|
||||
raise ValueError("Character position not found in the error message.")
|
||||
|
||||
|
||||
def validate_json(json_object: object, schema_name: str) -> dict | None:
|
||||
"""
|
||||
:type schema_name: object
|
||||
:param schema_name: str
|
||||
:type json_object: object
|
||||
"""
|
||||
scheme_file = os.path.join(os.path.dirname(__file__), f"{schema_name}.json")
|
||||
with open(scheme_file, "r") as f:
|
||||
schema = json.load(f)
|
||||
validator = Draft7Validator(schema)
|
||||
|
||||
if errors := sorted(validator.iter_errors(json_object), key=lambda e: e.path):
|
||||
logger.error("The JSON object is invalid.")
|
||||
if CFG.debug_mode:
|
||||
logger.error(
|
||||
json.dumps(json_object, indent=4)
|
||||
) # Replace 'json_object' with the variable containing the JSON data
|
||||
logger.error("The following issues were found:")
|
||||
|
||||
for error in errors:
|
||||
logger.error(f"Error: {error.message}")
|
||||
else:
|
||||
logger.debug("The JSON object is valid.")
|
||||
|
||||
return json_object
|
||||
|
||||
|
||||
def validate_json_string(json_string: str, schema_name: str) -> dict | None:
|
||||
"""
|
||||
:type schema_name: object
|
||||
:param schema_name: str
|
||||
:type json_object: object
|
||||
"""
|
||||
|
||||
try:
|
||||
json_loaded = json.loads(json_string)
|
||||
return validate_json(json_loaded, schema_name)
|
||||
except:
|
||||
return None
|
||||
|
||||
|
||||
def is_string_valid_json(json_string: str, schema_name: str) -> bool:
|
||||
"""
|
||||
:type schema_name: object
|
||||
:param schema_name: str
|
||||
:type json_object: object
|
||||
"""
|
||||
|
||||
return validate_json_string(json_string, schema_name) is not None
|
@ -72,7 +72,7 @@ def generate_stream(model, tokenizer, params, device,
|
||||
def generate_output(model, tokenizer, params, device, context_len=2048, stream_interval=2):
|
||||
prompt = params["prompt"]
|
||||
temperature = float(params.get("temperature", 1.0))
|
||||
max_new_tokens = int(params.get("max_new_tokens", 1024))
|
||||
max_new_tokens = int(params.get("max_new_tokens", 2048))
|
||||
stop_parameter = params.get("stop", None)
|
||||
if stop_parameter == tokenizer.eos_token:
|
||||
stop_parameter = None
|
||||
|
@ -207,6 +207,8 @@ def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate
|
||||
List[Tuple[str, Path]]: List of plugins.
|
||||
"""
|
||||
loaded_plugins = []
|
||||
current_dir = os.getcwd()
|
||||
print(current_dir)
|
||||
# Generic plugins
|
||||
plugins_path_path = Path(cfg.plugins_dir)
|
||||
|
||||
|
@ -38,7 +38,6 @@ class FirstPrompt:
|
||||
|
||||
def construct_first_prompt(
|
||||
self,
|
||||
command_registry: [] = None,
|
||||
fisrt_message: [str]=[],
|
||||
prompt_generator: Optional[PromptGenerator] = None
|
||||
) -> str:
|
||||
@ -64,7 +63,7 @@ class FirstPrompt:
|
||||
if prompt_generator is None:
|
||||
prompt_generator = build_default_prompt_generator()
|
||||
prompt_generator.goals = fisrt_message
|
||||
prompt_generator.command_registry = command_registry
|
||||
prompt_generator.command_registry = self.command_registry
|
||||
# 加载插件中可用命令
|
||||
cfg = Config()
|
||||
for plugin in cfg.plugins:
|
||||
|
@ -0,0 +1,14 @@
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
if "pytest" in sys.argv or "pytest" in sys.modules or os.getenv("CI"):
|
||||
print("Setting random seed to 42")
|
||||
random.seed(42)
|
||||
|
||||
# Load the users .env file into environment variables
|
||||
load_dotenv(verbose=True, override=True)
|
||||
|
||||
del load_dotenv
|
@ -20,8 +20,6 @@ from pilot.configs.model_config import LOGDIR, VICUNA_MODEL_SERVER, LLM_MODEL, D
|
||||
from pilot.plugins import scan_plugins
|
||||
from pilot.configs.config import Config
|
||||
from pilot.commands.command_mange import CommandRegistry
|
||||
from pilot.prompts.prompt import build_default_prompt_generator
|
||||
|
||||
from pilot.prompts.first_conversation_prompt import FirstPrompt
|
||||
|
||||
from pilot.conversation import (
|
||||
@ -39,6 +37,8 @@ from pilot.utils import (
|
||||
from pilot.server.gradio_css import code_highlight_css
|
||||
from pilot.server.gradio_patch import Chatbot as grChatbot
|
||||
|
||||
from pilot.commands.command import execute_ai_response_json
|
||||
|
||||
logger = build_logger("webserver", LOGDIR + "webserver.log")
|
||||
headers = {"User-Agent": "dbgpt Client"}
|
||||
|
||||
@ -172,12 +172,11 @@ def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.
|
||||
if len(state.messages) == state.offset + 2:
|
||||
query = state.messages[-2][1]
|
||||
# 第一轮对话需要加入提示Prompt
|
||||
cfg = Config()
|
||||
first_prompt = FirstPrompt()
|
||||
first_prompt.command_registry = cfg.command_registry
|
||||
if(autogpt):
|
||||
# autogpt模式的第一轮对话需要 构建专属prompt
|
||||
cfg = Config()
|
||||
first_prompt = FirstPrompt()
|
||||
first_prompt.command_registry = cfg.command_registry
|
||||
|
||||
system_prompt = first_prompt.construct_first_prompt(fisrt_message=[query])
|
||||
logger.info("[TEST]:" + system_prompt)
|
||||
template_name = "auto_dbgpt_one_shot"
|
||||
@ -456,7 +455,7 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
logger.info(f"args: {args}")
|
||||
|
||||
dbs = get_database_list()
|
||||
# dbs = get_database_list()
|
||||
|
||||
# 加载插件
|
||||
cfg = Config()
|
||||
@ -464,7 +463,6 @@ if __name__ == "__main__":
|
||||
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
|
||||
|
||||
# 加载插件可执行命令
|
||||
command_registry = CommandRegistry()
|
||||
command_categories = [
|
||||
"pilot.commands.audio_text",
|
||||
"pilot.commands.image_gen",
|
||||
@ -473,11 +471,11 @@ if __name__ == "__main__":
|
||||
command_categories = [
|
||||
x for x in command_categories if x not in cfg.disabled_command_categories
|
||||
]
|
||||
command_registry = CommandRegistry()
|
||||
for command_category in command_categories:
|
||||
command_registry.import_commands(command_category)
|
||||
|
||||
cfg.command_registry =command_category
|
||||
|
||||
cfg.command_registry =command_registry
|
||||
|
||||
logger.info(args)
|
||||
demo = build_webdemo()
|
||||
|
BIN
plugins/Db-GPT-SimpleChart-Plugin.zip
Normal file
BIN
plugins/Db-GPT-SimpleChart-Plugin.zip
Normal file
Binary file not shown.
Loading…
Reference in New Issue
Block a user