mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-08 12:30:14 +00:00
合并
This commit is contained in:
@@ -1,107 +0,0 @@
|
||||
import json
|
||||
|
||||
|
||||
def is_valid_int(value: str) -> bool:
|
||||
"""Check if the value is a valid integer
|
||||
|
||||
Args:
|
||||
value (str): The value to check
|
||||
|
||||
Returns:
|
||||
bool: True if the value is a valid integer, False otherwise
|
||||
"""
|
||||
try:
|
||||
int(value)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def get_command(response_json: Dict):
|
||||
"""Parse the response and return the command name and arguments
|
||||
|
||||
Args:
|
||||
response_json (json): The response from the AI
|
||||
|
||||
Returns:
|
||||
tuple: The command name and arguments
|
||||
|
||||
Raises:
|
||||
json.decoder.JSONDecodeError: If the response is not valid JSON
|
||||
|
||||
Exception: If any other error occurs
|
||||
"""
|
||||
try:
|
||||
if "command" not in response_json:
|
||||
return "Error:", "Missing 'command' object in JSON"
|
||||
|
||||
if not isinstance(response_json, dict):
|
||||
return "Error:", f"'response_json' object is not dictionary {response_json}"
|
||||
|
||||
command = response_json["command"]
|
||||
if not isinstance(command, dict):
|
||||
return "Error:", "'command' object is not a dictionary"
|
||||
|
||||
if "name" not in command:
|
||||
return "Error:", "Missing 'name' field in 'command' object"
|
||||
|
||||
command_name = command["name"]
|
||||
|
||||
# Use an empty dictionary if 'args' field is not present in 'command' object
|
||||
arguments = command.get("args", {})
|
||||
|
||||
return command_name, arguments
|
||||
except json.decoder.JSONDecodeError:
|
||||
return "Error:", "Invalid JSON"
|
||||
# All other errors, return "Error: + error message"
|
||||
except Exception as e:
|
||||
return "Error:", str(e)
|
||||
|
||||
|
||||
|
||||
|
||||
def execute_command(
|
||||
command_registry: CommandRegistry,
|
||||
command_name: str,
|
||||
arguments,
|
||||
prompt: PromptGenerator,
|
||||
):
|
||||
"""Execute the command and return the result
|
||||
|
||||
Args:
|
||||
command_name (str): The name of the command to execute
|
||||
arguments (dict): The arguments for the command
|
||||
|
||||
Returns:
|
||||
str: The result of the command
|
||||
"""
|
||||
try:
|
||||
cmd = command_registry.commands.get(command_name)
|
||||
|
||||
# If the command is found, call it with the provided arguments
|
||||
if cmd:
|
||||
return cmd(**arguments)
|
||||
|
||||
# TODO: Remove commands below after they are moved to the command registry.
|
||||
command_name = map_command_synonyms(command_name.lower())
|
||||
|
||||
if command_name == "memory_add":
|
||||
return get_memory(CFG).add(arguments["string"])
|
||||
|
||||
# TODO: Change these to take in a file rather than pasted code, if
|
||||
# non-file is given, return instructions "Input should be a python
|
||||
# filepath, write your code to file and try again
|
||||
else:
|
||||
for command in prompt.commands:
|
||||
if (
|
||||
command_name == command["label"].lower()
|
||||
or command_name == command["name"].lower()
|
||||
):
|
||||
return command["function"](**arguments)
|
||||
return (
|
||||
f"Unknown command '{command_name}'. Please refer to the 'COMMANDS'"
|
||||
" list for available commands and only respond in the specified JSON"
|
||||
" format."
|
||||
)
|
||||
except Exception as e:
|
||||
return f"Error: {str(e)}"
|
@@ -1,6 +1,5 @@
|
||||
from pilot.configs.config import Config
|
||||
from pilot.prompts.generator import PromptGenerator
|
||||
from __future__ import annotations
|
||||
from typing import Any, Optional, Type
|
||||
from pilot.prompts.prompt import build_default_prompt_generator
|
||||
|
||||
|
@@ -3,7 +3,6 @@ import io
|
||||
import uuid
|
||||
from base64 import b64decode
|
||||
|
||||
import openai
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
@@ -26,12 +25,9 @@ def generate_image(prompt: str, size: int = 256) -> str:
|
||||
str: The filename of the image
|
||||
"""
|
||||
filename = f"{CFG.workspace_path}/{str(uuid.uuid4())}.jpg"
|
||||
|
||||
# DALL-E
|
||||
if CFG.image_provider == "dalle":
|
||||
return generate_image_with_dalle(prompt, filename, size)
|
||||
|
||||
# HuggingFace
|
||||
elif CFG.image_provider == "huggingface":
|
||||
if CFG.image_provider == "huggingface":
|
||||
return generate_image_with_hf(prompt, filename)
|
||||
# SD WebUI
|
||||
elif CFG.image_provider == "sdwebui":
|
||||
@@ -76,45 +72,6 @@ def generate_image_with_hf(prompt: str, filename: str) -> str:
|
||||
|
||||
return f"Saved to disk:{filename}"
|
||||
|
||||
|
||||
def generate_image_with_dalle(prompt: str, filename: str, size: int) -> str:
|
||||
"""Generate an image with DALL-E.
|
||||
|
||||
Args:
|
||||
prompt (str): The prompt to use
|
||||
filename (str): The filename to save the image to
|
||||
size (int): The size of the image
|
||||
|
||||
Returns:
|
||||
str: The filename of the image
|
||||
"""
|
||||
|
||||
# Check for supported image sizes
|
||||
if size not in [256, 512, 1024]:
|
||||
closest = min([256, 512, 1024], key=lambda x: abs(x - size))
|
||||
logger.info(
|
||||
f"DALL-E only supports image sizes of 256x256, 512x512, or 1024x1024. Setting to {closest}, was {size}."
|
||||
)
|
||||
size = closest
|
||||
|
||||
response = openai.Image.create(
|
||||
prompt=prompt,
|
||||
n=1,
|
||||
size=f"{size}x{size}",
|
||||
response_format="b64_json",
|
||||
api_key=CFG.openai_api_key,
|
||||
)
|
||||
|
||||
logger.info(f"Image Generated for prompt:{prompt}")
|
||||
|
||||
image_data = b64decode(response["data"][0]["b64_json"])
|
||||
|
||||
with open(filename, mode="wb") as png:
|
||||
png.write(image_data)
|
||||
|
||||
return f"Saved to disk:{filename}"
|
||||
|
||||
|
||||
def generate_image_with_sd_webui(
|
||||
prompt: str,
|
||||
filename: str,
|
||||
|
Reference in New Issue
Block a user