mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-22 20:01:46 +00:00
Merge remote-tracking branch 'origin/plugin_init' into source_embedding
This commit is contained in:
commit
e44dff4170
72
.env.template
Normal file
72
.env.template
Normal file
@ -0,0 +1,72 @@
|
||||
#*******************************************************************#
|
||||
#** DB-GPT - GENERAL SETTINGS **#
|
||||
#*******************************************************************#
|
||||
## DISABLED_COMMAND_CATEGORIES - The list of categories of commands that are disabled. Each of the below are an option:
|
||||
## pilot.commands.query_execute
|
||||
|
||||
## For example, to disable coding related features, uncomment the next line
|
||||
# DISABLED_COMMAND_CATEGORIES=
|
||||
|
||||
|
||||
#*******************************************************************#
|
||||
#*** LLM PROVIDER ***#
|
||||
#*******************************************************************#
|
||||
|
||||
# TEMPERATURE=0
|
||||
|
||||
#*******************************************************************#
|
||||
#** LLM MODELS **#
|
||||
#*******************************************************************#
|
||||
|
||||
## SMART_LLM_MODEL - Smart language model (Default: vicuna-13b)
|
||||
## FAST_LLM_MODEL - Fast language model (Default: chatglm-6b)
|
||||
# SMART_LLM_MODEL=vicuna-13b
|
||||
# FAST_LLM_MODEL=chatglm-6b
|
||||
|
||||
|
||||
### EMBEDDINGS
|
||||
## EMBEDDING_MODEL - Model to use for creating embeddings
|
||||
## EMBEDDING_TOKENIZER - Tokenizer to use for chunking large inputs
|
||||
## EMBEDDING_TOKEN_LIMIT - Chunk size limit for large inputs
|
||||
# EMBEDDING_MODEL=all-MiniLM-L6-v2
|
||||
# EMBEDDING_TOKENIZER=all-MiniLM-L6-v2
|
||||
# EMBEDDING_TOKEN_LIMIT=8191
|
||||
|
||||
|
||||
#*******************************************************************#
|
||||
#** DATABASE SETTINGS **#
|
||||
#*******************************************************************#
|
||||
DB_SETTINGS_MYSQL_USER=root
|
||||
DB_SETTINGS_MYSQL_PASSWORD=password
|
||||
DB_SETTINGS_MYSQL_HOST=localhost
|
||||
DB_SETTINGS_MYSQL_PORT=3306
|
||||
|
||||
|
||||
### MILVUS
|
||||
## MILVUS_ADDR - Milvus remote address (e.g. localhost:19530)
|
||||
## MILVUS_USERNAME - username for your Milvus database
|
||||
## MILVUS_PASSWORD - password for your Milvus database
|
||||
## MILVUS_SECURE - True to enable TLS. (Default: False)
|
||||
## Setting MILVUS_ADDR to a `https://` URL will override this setting.
|
||||
## MILVUS_COLLECTION - Milvus collection, change it if you want to start a new memory and retain the old memory.
|
||||
# MILVUS_ADDR=localhost:19530
|
||||
# MILVUS_USERNAME=
|
||||
# MILVUS_PASSWORD=
|
||||
# MILVUS_SECURE=
|
||||
# MILVUS_COLLECTION=dbgpt
|
||||
|
||||
#*******************************************************************#
|
||||
#** ALLOWLISTED PLUGINS **#
|
||||
#*******************************************************************#
|
||||
|
||||
#ALLOWLISTED_PLUGINS - Sets the listed plugins that are allowed (Example: plugin1,plugin2,plugin3)
|
||||
#DENYLISTED_PLUGINS - Sets the listed plugins that are not allowed (Example: plugin1,plugin2,plugin3)
|
||||
ALLOWLISTED_PLUGINS=
|
||||
DENYLISTED_PLUGINS=
|
||||
|
||||
|
||||
#*******************************************************************#
|
||||
#** CHAT PLUGIN SETTINGS **#
|
||||
#*******************************************************************#
|
||||
# CHAT_MESSAGES_ENABLED - Enable chat messages (Default: False)
|
||||
# CHAT_MESSAGES_ENABLED=False
|
4
.gitignore
vendored
4
.gitignore
vendored
@ -8,6 +8,8 @@ __pycache__/
|
||||
|
||||
.idea
|
||||
.vscode
|
||||
.idea
|
||||
.chroma
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
@ -132,5 +134,5 @@ dmypy.json
|
||||
.pyre/
|
||||
.DS_Store
|
||||
logs
|
||||
|
||||
nltk_data
|
||||
.vectordb
|
25
.vscode/launch.json
vendored
25
.vscode/launch.json
vendored
@ -1,25 +0,0 @@
|
||||
{
|
||||
// Use IntelliSense to learn about possible attributes.
|
||||
// Hover to view descriptions of existing attributes.
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Python: Current File",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "${file}",
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": true,
|
||||
"env": {"PYTHONPATH": "${workspaceFolder}"},
|
||||
"envFile": "${workspaceFolder}/.env"
|
||||
},
|
||||
{
|
||||
"name": "Python: Module",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"module": "pilot",
|
||||
"justMyCode": true,
|
||||
}
|
||||
]
|
||||
}
|
65
CONTRIBUTING.md
Normal file
65
CONTRIBUTING.md
Normal file
@ -0,0 +1,65 @@
|
||||
|
||||
To contribute to this GitHub project, you can follow these steps:
|
||||
|
||||
1. Fork the repository you want to contribute to by clicking the "Fork" button on the project page.
|
||||
|
||||
2. Clone the repository to your local machine using the following command:
|
||||
|
||||
```
|
||||
git clone https://github.com/<YOUR-GITHUB-USERNAME>/DB-GPT
|
||||
```
|
||||
3. Install the project requirements
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
4. Install pre-commit hooks
|
||||
```
|
||||
pre-commit install
|
||||
```
|
||||
5. Create a new branch for your changes using the following command:
|
||||
|
||||
```
|
||||
git checkout -b "branch-name"
|
||||
```
|
||||
6. Make your changes to the code or documentation.
|
||||
- Example: Improve User Interface or Add Documentation.
|
||||
|
||||
|
||||
7. Add the changes to the staging area using the following command:
|
||||
```
|
||||
git add .
|
||||
```
|
||||
|
||||
8. Commit the changes with a meaningful commit message using the following command:
|
||||
```
|
||||
git commit -m "your commit message"
|
||||
```
|
||||
9. Push the changes to your forked repository using the following command:
|
||||
```
|
||||
git push origin branch-name
|
||||
```
|
||||
10. Go to the GitHub website and navigate to your forked repository.
|
||||
|
||||
11. Click the "New pull request" button.
|
||||
|
||||
12. Select the branch you just pushed to and the branch you want to merge into on the original repository.
|
||||
|
||||
13. Add a description of your changes and click the "Create pull request" button.
|
||||
|
||||
14. Wait for the project maintainer to review your changes and provide feedback.
|
||||
|
||||
15. Make any necessary changes based on feedback and repeat steps 5-12 until your changes are accepted and merged into the main project.
|
||||
|
||||
16. Once your changes are merged, you can update your forked repository and local copy of the repository with the following commands:
|
||||
|
||||
```
|
||||
git fetch upstream
|
||||
git checkout master
|
||||
git merge upstream/master
|
||||
```
|
||||
Finally, delete the branch you created with the following command:
|
||||
```
|
||||
git branch -d branch-name
|
||||
```
|
||||
That's it you made it 🐣⭐⭐
|
||||
|
21
README.md
21
README.md
@ -28,15 +28,23 @@ Run on an RTX 4090 GPU (The origin mov not sped up!, [YouTube地址](https://www
|
||||
|
||||
<img src="https://github.com/csunny/DB-GPT/blob/main/asserts/SQLGEN.png" width="600" margin-left="auto" margin-right="auto" >
|
||||
|
||||
The Generated SQL is runable.
|
||||
|
||||
<img src="https://github.com/csunny/DB-GPT/blob/main/asserts/exeable.png" width="600" margin-left="auto" margin-right="auto" >
|
||||
|
||||
- 数据库QA示例
|
||||
|
||||
<img src="https://github.com/csunny/DB-GPT/blob/main/asserts/DB_QA.png" margin-left="auto" margin-right="auto" width="600">
|
||||
|
||||
基于默认内置知识库QA
|
||||
|
||||
<img src="https://github.com/csunny/DB-GPT/blob/main/asserts/VectorDBQA.png" width="600" margin-left="auto" margin-right="auto" >
|
||||
|
||||
# Dependencies
|
||||
1. First you need to install python requirements.
|
||||
```
|
||||
python>=3.9
|
||||
pip install -r requirements
|
||||
python>=3.10
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
or if you use conda envirenment, you can use this command
|
||||
```
|
||||
@ -55,12 +63,12 @@ The password just for test, you can change this if necessary
|
||||
# Install
|
||||
1. 基础模型下载
|
||||
关于基础模型, 可以根据[vicuna](https://github.com/lm-sys/FastChat/blob/main/README.md#model-weights)合成教程进行合成。
|
||||
如果此步有困难的同学,也可以直接使用[Hugging Face](https://huggingface.co/)上的模型进行替代。 替代模型: [vicuna-13b](https://huggingface.co/Tribbiani/vicuna-13b)
|
||||
如果此步有困难的同学,也可以直接使用[Hugging Face](https://huggingface.co/)上的模型进行替代. [替代模型](https://huggingface.co/Tribbiani/vicuna-7b)
|
||||
|
||||
2. Run model server
|
||||
```
|
||||
cd pilot/server
|
||||
python vicuna_server.py
|
||||
python llmserver.py
|
||||
```
|
||||
|
||||
3. Run gradio webui
|
||||
@ -75,6 +83,7 @@ python webserver.py
|
||||
|
||||
总的来说,它是一个用于数据库的复杂且创新的AI工具。如果您对如何在工作中使用或实施DB-GPT有任何具体问题,请联系我, 我会尽力提供帮助, 同时也欢迎大家参与到项目建设中, 做一些有趣的事情。
|
||||
|
||||
|
||||
# Contribute
|
||||
[Contribute](https://github.com/csunny/DB-GPT/blob/main/CONTRIBUTING)
|
||||
# Licence
|
||||
[MIT](https://github.com/csunny/DB-GPT/blob/main/LICENSE)
|
||||
[MIT](https://github.com/csunny/DB-GPT/blob/main/LICENSE)
|
||||
|
BIN
asserts/VectorDBQA.png
Normal file
BIN
asserts/VectorDBQA.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 560 KiB |
@ -7,11 +7,9 @@ dependencies:
|
||||
- python=3.9
|
||||
- cudatoolkit
|
||||
- pip
|
||||
- pytorch=1.12.1
|
||||
- pytorch-mutex=1.0=cuda
|
||||
- torchaudio=0.12.1
|
||||
- torchvision=0.13.1
|
||||
- pip:
|
||||
- pytorch
|
||||
- accelerate==0.16.0
|
||||
- aiohttp==3.8.4
|
||||
- aiosignal==1.3.1
|
||||
@ -30,7 +28,6 @@ dependencies:
|
||||
- kiwisolver==1.4.4
|
||||
- matplotlib==3.7.0
|
||||
- multidict==6.0.4
|
||||
- openai==0.27.0
|
||||
- packaging==23.0
|
||||
- psutil==5.9.4
|
||||
- pycocotools==2.0.6
|
||||
@ -57,11 +54,15 @@ dependencies:
|
||||
- sentence-transformers
|
||||
- umap-learn
|
||||
- notebook
|
||||
- gradio==3.24.1
|
||||
- gradio==3.23
|
||||
- gradio-client==0.0.8
|
||||
- wandb
|
||||
- fschat=0.1.10
|
||||
- llama-index=0.5.27
|
||||
- llama-index==0.5.27
|
||||
- pymysql
|
||||
- unstructured==0.6.3
|
||||
- pytesseract==0.3.10
|
||||
- markdown2
|
||||
- chromadb
|
||||
- colorama
|
||||
- playsound
|
||||
- distro
|
19
examples/gradio_test.py
Normal file
19
examples/gradio_test.py
Normal file
@ -0,0 +1,19 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
import gradio as gr
|
||||
|
||||
def change_tab():
|
||||
return gr.Tabs.update(selected=1)
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
with gr.Tabs() as tabs:
|
||||
with gr.TabItem("Train", id=0):
|
||||
t = gr.Textbox()
|
||||
with gr.TabItem("Inference", id=1):
|
||||
i = gr.Image()
|
||||
|
||||
btn = gr.Button()
|
||||
btn.click(change_tab, None, tabs)
|
||||
|
||||
demo.launch()
|
@ -3,5 +3,10 @@
|
||||
|
||||
|
||||
class Agent:
|
||||
"""Agent class for interacting with DB-GPT """
|
||||
pass
|
||||
"""Agent class for interacting with DB-GPT
|
||||
|
||||
Attributes:
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
@ -1,23 +1,86 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
"""Agent manager for managing GPT agents"""
|
||||
from __future__ import annotations
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.singleton import Singleton
|
||||
from pilot.configs.config import Config
|
||||
from typing import List
|
||||
from pilot.model.base import Message
|
||||
|
||||
|
||||
class AgentManager(metaclass=Singleton):
|
||||
"""Agent manager for managing DB-GPT agents"""
|
||||
"""Agent manager for managing GPT agents"""
|
||||
|
||||
def __init__(self):
|
||||
self.next_key = 0
|
||||
self.agents = {} # key, (task, full_message_history, model)
|
||||
self.cfg = Config()
|
||||
"""Agent manager for managing DB-GPT agents
|
||||
In order to compatible auto gpt plugins,
|
||||
we use the same template with it.
|
||||
|
||||
Args: next_keys
|
||||
agents
|
||||
cfg
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
|
||||
self.next_key = 0
|
||||
self.agents = {} #TODO need to define
|
||||
self.cfg = Config()
|
||||
|
||||
def create_agent(self):
|
||||
pass
|
||||
# Create new GPT agent
|
||||
# TODO: Centralise use of create_chat_completion() to globally enforce token limit
|
||||
|
||||
def message_agent(self):
|
||||
pass
|
||||
def create_agent(self, task: str, prompt: str, model: str) -> tuple[int, str]:
|
||||
"""Create a new agent and return its key
|
||||
|
||||
def list_agents(self):
|
||||
pass
|
||||
Args:
|
||||
task: The task to perform
|
||||
prompt: The prompt to use
|
||||
model: The model to use
|
||||
|
||||
def delete_agent(self):
|
||||
pass
|
||||
Returns:
|
||||
The key of the new agent
|
||||
"""
|
||||
|
||||
|
||||
def message_agent(self, key: str | int, message: str) -> str:
|
||||
"""Send a message to an agent and return its response
|
||||
|
||||
Args:
|
||||
key: The key of the agent to message
|
||||
message: The message to send to the agent
|
||||
|
||||
Returns:
|
||||
The agent's response
|
||||
"""
|
||||
|
||||
|
||||
def list_agents(self) -> list[tuple[str | int, str]]:
|
||||
"""Return a list of all agents
|
||||
|
||||
Returns:
|
||||
A list of tuples of the form (key, task)
|
||||
"""
|
||||
|
||||
# Return a list of agent keys and their tasks
|
||||
return [(key, task) for key, (task, _, _) in self.agents.items()]
|
||||
|
||||
def delete_agent(self, key: str | int) -> bool:
|
||||
"""Delete an agent from the agent manager
|
||||
|
||||
Args:
|
||||
key: The key of the agent to delete
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
|
||||
try:
|
||||
del self.agents[int(key)]
|
||||
return True
|
||||
except KeyError:
|
||||
return False
|
||||
|
0
pilot/commands/__init__.py
Normal file
0
pilot/commands/__init__.py
Normal file
61
pilot/commands/audio_text.py
Normal file
61
pilot/commands/audio_text.py
Normal file
@ -0,0 +1,61 @@
|
||||
"""Commands for converting audio to text."""
|
||||
import json
|
||||
|
||||
import requests
|
||||
|
||||
from pilot.commands.command_mange import command
|
||||
from pilot.configs.config import Config
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
@command(
|
||||
"read_audio_from_file",
|
||||
"Convert Audio to text",
|
||||
'"filename": "<filename>"',
|
||||
CFG.huggingface_audio_to_text_model,
|
||||
"Configure huggingface_audio_to_text_model.",
|
||||
)
|
||||
def read_audio_from_file(filename: str) -> str:
|
||||
"""
|
||||
Convert audio to text.
|
||||
|
||||
Args:
|
||||
filename (str): The path to the audio file
|
||||
|
||||
Returns:
|
||||
str: The text from the audio
|
||||
"""
|
||||
with open(filename, "rb") as audio_file:
|
||||
audio = audio_file.read()
|
||||
return read_audio(audio)
|
||||
|
||||
|
||||
def read_audio(audio: bytes) -> str:
|
||||
"""
|
||||
Convert audio to text.
|
||||
|
||||
Args:
|
||||
audio (bytes): The audio to convert
|
||||
|
||||
Returns:
|
||||
str: The text from the audio
|
||||
"""
|
||||
model = CFG.huggingface_audio_to_text_model
|
||||
api_url = f"https://api-inference.huggingface.co/models/{model}"
|
||||
api_token = CFG.huggingface_api_token
|
||||
headers = {"Authorization": f"Bearer {api_token}"}
|
||||
|
||||
if api_token is None:
|
||||
raise ValueError(
|
||||
"You need to set your Hugging Face API token in the config file."
|
||||
)
|
||||
|
||||
response = requests.post(
|
||||
api_url,
|
||||
headers=headers,
|
||||
data=audio,
|
||||
)
|
||||
|
||||
text = json.loads(response.content.decode("utf-8"))["text"]
|
||||
return f"The audio says: {text}"
|
36
pilot/commands/command.py
Normal file
36
pilot/commands/command.py
Normal file
@ -0,0 +1,36 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import functools
|
||||
import importlib
|
||||
import inspect
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
class Command:
|
||||
"""A class representing a command.
|
||||
|
||||
Attributes:
|
||||
name (str): The name of the command.
|
||||
description (str): A brief description of what the command does.
|
||||
signature (str): The signature of the function that the command executes. Default to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
name: str,
|
||||
description: str,
|
||||
method: Callable[..., Any],
|
||||
signature: str = "",
|
||||
enabled: bool = True,
|
||||
disabled_reason: Optional[str] = None,
|
||||
) -> None:
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.method = method
|
||||
self.signature = signature if signature else str(inspect.signature(self.method))
|
||||
self.enabled = enabled
|
||||
self.disabled_reason = disabled_reason
|
||||
|
||||
def __call__(self, *args: Any, **kwds: Any) -> Any:
|
||||
if not self.enabled:
|
||||
return f"Command '{self.name}' is disabled: {self.disabled_reason}"
|
||||
return self.method(*args, **kwds)
|
156
pilot/commands/command_mange.py
Normal file
156
pilot/commands/command_mange.py
Normal file
@ -0,0 +1,156 @@
|
||||
import functools
|
||||
import importlib
|
||||
import inspect
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
# Unique identifier for auto-gpt commands
|
||||
AUTO_GPT_COMMAND_IDENTIFIER = "auto_gpt_command"
|
||||
|
||||
|
||||
class Command:
|
||||
"""A class representing a command.
|
||||
|
||||
Attributes:
|
||||
name (str): The name of the command.
|
||||
description (str): A brief description of what the command does.
|
||||
signature (str): The signature of the function that the command executes. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
method: Callable[..., Any],
|
||||
signature: str = "",
|
||||
enabled: bool = True,
|
||||
disabled_reason: Optional[str] = None,
|
||||
):
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.method = method
|
||||
self.signature = signature if signature else str(inspect.signature(self.method))
|
||||
self.enabled = enabled
|
||||
self.disabled_reason = disabled_reason
|
||||
|
||||
def __call__(self, *args, **kwargs) -> Any:
|
||||
if not self.enabled:
|
||||
return f"Command '{self.name}' is disabled: {self.disabled_reason}"
|
||||
return self.method(*args, **kwargs)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.name}: {self.description}, args: {self.signature}"
|
||||
|
||||
|
||||
class CommandRegistry:
|
||||
"""
|
||||
The CommandRegistry class is a manager for a collection of Command objects.
|
||||
It allows the registration, modification, and retrieval of Command objects,
|
||||
as well as the scanning and loading of command plugins from a specified
|
||||
directory.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.commands = {}
|
||||
|
||||
def _import_module(self, module_name: str) -> Any:
|
||||
return importlib.import_module(module_name)
|
||||
|
||||
def _reload_module(self, module: Any) -> Any:
|
||||
return importlib.reload(module)
|
||||
|
||||
def register(self, cmd: Command) -> None:
|
||||
self.commands[cmd.name] = cmd
|
||||
|
||||
def unregister(self, command_name: str):
|
||||
if command_name in self.commands:
|
||||
del self.commands[command_name]
|
||||
else:
|
||||
raise KeyError(f"Command '{command_name}' not found in registry.")
|
||||
|
||||
def reload_commands(self) -> None:
|
||||
"""Reloads all loaded command plugins."""
|
||||
for cmd_name in self.commands:
|
||||
cmd = self.commands[cmd_name]
|
||||
module = self._import_module(cmd.__module__)
|
||||
reloaded_module = self._reload_module(module)
|
||||
if hasattr(reloaded_module, "register"):
|
||||
reloaded_module.register(self)
|
||||
|
||||
def get_command(self, name: str) -> Callable[..., Any]:
|
||||
return self.commands[name]
|
||||
|
||||
def call(self, command_name: str, **kwargs) -> Any:
|
||||
if command_name not in self.commands:
|
||||
raise KeyError(f"Command '{command_name}' not found in registry.")
|
||||
command = self.commands[command_name]
|
||||
return command(**kwargs)
|
||||
|
||||
def command_prompt(self) -> str:
|
||||
"""
|
||||
Returns a string representation of all registered `Command` objects for use in a prompt
|
||||
"""
|
||||
commands_list = [
|
||||
f"{idx + 1}. {str(cmd)}" for idx, cmd in enumerate(self.commands.values())
|
||||
]
|
||||
return "\n".join(commands_list)
|
||||
|
||||
def import_commands(self, module_name: str) -> None:
|
||||
"""
|
||||
Imports the specified Python module containing command plugins.
|
||||
|
||||
This method imports the associated module and registers any functions or
|
||||
classes that are decorated with the `AUTO_GPT_COMMAND_IDENTIFIER` attribute
|
||||
as `Command` objects. The registered `Command` objects are then added to the
|
||||
`commands` dictionary of the `CommandRegistry` object.
|
||||
|
||||
Args:
|
||||
module_name (str): The name of the module to import for command plugins.
|
||||
"""
|
||||
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
for attr_name in dir(module):
|
||||
attr = getattr(module, attr_name)
|
||||
# Register decorated functions
|
||||
if hasattr(attr, AUTO_GPT_COMMAND_IDENTIFIER) and getattr(
|
||||
attr, AUTO_GPT_COMMAND_IDENTIFIER
|
||||
):
|
||||
self.register(attr.command)
|
||||
# Register command classes
|
||||
elif (
|
||||
inspect.isclass(attr) and issubclass(attr, Command) and attr != Command
|
||||
):
|
||||
cmd_instance = attr()
|
||||
self.register(cmd_instance)
|
||||
|
||||
|
||||
def command(
|
||||
name: str,
|
||||
description: str,
|
||||
signature: str = "",
|
||||
enabled: bool = True,
|
||||
disabled_reason: Optional[str] = None,
|
||||
) -> Callable[..., Any]:
|
||||
"""The command decorator is used to create Command objects from ordinary functions."""
|
||||
|
||||
def decorator(func: Callable[..., Any]) -> Command:
|
||||
cmd = Command(
|
||||
name=name,
|
||||
description=description,
|
||||
method=func,
|
||||
signature=signature,
|
||||
enabled=enabled,
|
||||
disabled_reason=disabled_reason,
|
||||
)
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs) -> Any:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
wrapper.command = cmd
|
||||
|
||||
setattr(wrapper, AUTO_GPT_COMMAND_IDENTIFIER, True)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
27
pilot/commands/commands_load.py
Normal file
27
pilot/commands/commands_load.py
Normal file
@ -0,0 +1,27 @@
|
||||
from pilot.configs.config import Config
|
||||
from pilot.prompts.generator import PromptGenerator
|
||||
from typing import Any, Optional, Type
|
||||
from pilot.prompts.prompt import build_default_prompt_generator
|
||||
|
||||
|
||||
class CommandsLoad:
|
||||
"""
|
||||
Load Plugins Commands Info , help build system prompt!
|
||||
"""
|
||||
|
||||
def __init__(self)->None:
|
||||
self.command_registry = None
|
||||
|
||||
|
||||
def getCommandInfos(self, prompt_generator: Optional[PromptGenerator] = None)-> str:
|
||||
cfg = Config()
|
||||
if prompt_generator is None:
|
||||
prompt_generator = build_default_prompt_generator()
|
||||
for plugin in cfg.plugins:
|
||||
if not plugin.can_handle_post_prompt():
|
||||
continue
|
||||
prompt_generator = plugin.post_prompt(prompt_generator)
|
||||
self.prompt_generator = prompt_generator
|
||||
command_infos = ""
|
||||
command_infos += f"\n\n{prompt_generator.commands()}"
|
||||
return command_infos
|
122
pilot/commands/image_gen.py
Normal file
122
pilot/commands/image_gen.py
Normal file
@ -0,0 +1,122 @@
|
||||
""" Image Generation Module for AutoGPT."""
|
||||
import io
|
||||
import uuid
|
||||
from base64 import b64decode
|
||||
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
from pilot.commands.command_mange import command
|
||||
from pilot.configs.config import Config
|
||||
from pilot.logs import logger
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
@command("generate_image", "Generate Image", '"prompt": "<prompt>"', CFG.image_provider)
|
||||
def generate_image(prompt: str, size: int = 256) -> str:
|
||||
"""Generate an image from a prompt.
|
||||
|
||||
Args:
|
||||
prompt (str): The prompt to use
|
||||
size (int, optional): The size of the image. Defaults to 256. (Not supported by HuggingFace)
|
||||
|
||||
Returns:
|
||||
str: The filename of the image
|
||||
"""
|
||||
filename = f"{CFG.workspace_path}/{str(uuid.uuid4())}.jpg"
|
||||
|
||||
# HuggingFace
|
||||
if CFG.image_provider == "huggingface":
|
||||
return generate_image_with_hf(prompt, filename)
|
||||
# SD WebUI
|
||||
elif CFG.image_provider == "sdwebui":
|
||||
return generate_image_with_sd_webui(prompt, filename, size)
|
||||
return "No Image Provider Set"
|
||||
|
||||
|
||||
def generate_image_with_hf(prompt: str, filename: str) -> str:
|
||||
"""Generate an image with HuggingFace's API.
|
||||
|
||||
Args:
|
||||
prompt (str): The prompt to use
|
||||
filename (str): The filename to save the image to
|
||||
|
||||
Returns:
|
||||
str: The filename of the image
|
||||
"""
|
||||
API_URL = (
|
||||
f"https://api-inference.huggingface.co/models/{CFG.huggingface_image_model}"
|
||||
)
|
||||
if CFG.huggingface_api_token is None:
|
||||
raise ValueError(
|
||||
"You need to set your Hugging Face API token in the config file."
|
||||
)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {CFG.huggingface_api_token}",
|
||||
"X-Use-Cache": "false",
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
API_URL,
|
||||
headers=headers,
|
||||
json={
|
||||
"inputs": prompt,
|
||||
},
|
||||
)
|
||||
|
||||
image = Image.open(io.BytesIO(response.content))
|
||||
logger.info(f"Image Generated for prompt:{prompt}")
|
||||
|
||||
image.save(filename)
|
||||
|
||||
return f"Saved to disk:{filename}"
|
||||
|
||||
def generate_image_with_sd_webui(
|
||||
prompt: str,
|
||||
filename: str,
|
||||
size: int = 512,
|
||||
negative_prompt: str = "",
|
||||
extra: dict = {},
|
||||
) -> str:
|
||||
"""Generate an image with Stable Diffusion webui.
|
||||
Args:
|
||||
prompt (str): The prompt to use
|
||||
filename (str): The filename to save the image to
|
||||
size (int, optional): The size of the image. Defaults to 256.
|
||||
negative_prompt (str, optional): The negative prompt to use. Defaults to "".
|
||||
extra (dict, optional): Extra parameters to pass to the API. Defaults to {}.
|
||||
Returns:
|
||||
str: The filename of the image
|
||||
"""
|
||||
# Create a session and set the basic auth if needed
|
||||
s = requests.Session()
|
||||
if CFG.sd_webui_auth:
|
||||
username, password = CFG.sd_webui_auth.split(":")
|
||||
s.auth = (username, password or "")
|
||||
|
||||
# Generate the images
|
||||
response = requests.post(
|
||||
f"{CFG.sd_webui_url}/sdapi/v1/txt2img",
|
||||
json={
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"sampler_index": "DDIM",
|
||||
"steps": 20,
|
||||
"cfg_scale": 7.0,
|
||||
"width": size,
|
||||
"height": size,
|
||||
"n_iter": 1,
|
||||
**extra,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f"Image Generated for prompt:{prompt}")
|
||||
|
||||
# Save the image to disk
|
||||
response = response.json()
|
||||
b64 = b64decode(response["images"][0].split(",", 1)[0])
|
||||
image = Image.open(io.BytesIO(b64))
|
||||
image.save(filename)
|
||||
|
||||
return f"Saved to disk:{filename}"
|
10
pilot/commands/times.py
Normal file
10
pilot/commands/times.py
Normal file
@ -0,0 +1,10 @@
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
def get_datetime() -> str:
|
||||
"""Return the current date and time
|
||||
|
||||
Returns:
|
||||
str: The current date and time
|
||||
"""
|
||||
return "Current date and time: " + datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
169
pilot/configs/ai_config.py
Normal file
169
pilot/configs/ai_config.py
Normal file
@ -0,0 +1,169 @@
|
||||
# sourcery skip: do-not-use-staticmethod
|
||||
"""
|
||||
A module that contains the AIConfig class object that contains the configuration
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import platform
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
import distro
|
||||
import yaml
|
||||
|
||||
from pilot.prompts.generator import PromptGenerator
|
||||
from pilot.configs.config import Config
|
||||
from pilot.prompts.prompt import build_default_prompt_generator
|
||||
|
||||
# Soon this will go in a folder where it remembers more stuff about the run(s)
|
||||
SAVE_FILE = str(Path(os.getcwd()) / "ai_settings.yaml")
|
||||
|
||||
|
||||
class AIConfig:
|
||||
"""
|
||||
A class object that contains the configuration information for the AI
|
||||
|
||||
Attributes:
|
||||
ai_name (str): The name of the AI.
|
||||
ai_role (str): The description of the AI's role.
|
||||
ai_goals (list): The list of objectives the AI is supposed to complete.
|
||||
api_budget (float): The maximum dollar value for API calls (0.0 means infinite)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ai_name: str = "",
|
||||
ai_role: str = "",
|
||||
ai_goals: list | None = None,
|
||||
api_budget: float = 0.0,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize a class instance
|
||||
|
||||
Parameters:
|
||||
ai_name (str): The name of the AI.
|
||||
ai_role (str): The description of the AI's role.
|
||||
ai_goals (list): The list of objectives the AI is supposed to complete.
|
||||
api_budget (float): The maximum dollar value for API calls (0.0 means infinite)
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if ai_goals is None:
|
||||
ai_goals = []
|
||||
self.ai_name = ai_name
|
||||
self.ai_role = ai_role
|
||||
self.ai_goals = ai_goals
|
||||
self.api_budget = api_budget
|
||||
self.prompt_generator = None
|
||||
self.command_registry = None
|
||||
|
||||
@staticmethod
|
||||
def load(config_file: str = SAVE_FILE) -> "AIConfig":
|
||||
"""
|
||||
Returns class object with parameters (ai_name, ai_role, ai_goals, api_budget) loaded from
|
||||
yaml file if yaml file exists,
|
||||
else returns class with no parameters.
|
||||
|
||||
Parameters:
|
||||
config_file (int): The path to the config yaml file.
|
||||
DEFAULT: "../ai_settings.yaml"
|
||||
|
||||
Returns:
|
||||
cls (object): An instance of given cls object
|
||||
"""
|
||||
|
||||
try:
|
||||
with open(config_file, encoding="utf-8") as file:
|
||||
config_params = yaml.load(file, Loader=yaml.FullLoader)
|
||||
except FileNotFoundError:
|
||||
config_params = {}
|
||||
|
||||
ai_name = config_params.get("ai_name", "")
|
||||
ai_role = config_params.get("ai_role", "")
|
||||
ai_goals = [
|
||||
str(goal).strip("{}").replace("'", "").replace('"', "")
|
||||
if isinstance(goal, dict)
|
||||
else str(goal)
|
||||
for goal in config_params.get("ai_goals", [])
|
||||
]
|
||||
api_budget = config_params.get("api_budget", 0.0)
|
||||
# type: Type[AIConfig]
|
||||
return AIConfig(ai_name, ai_role, ai_goals, api_budget)
|
||||
|
||||
def save(self, config_file: str = SAVE_FILE) -> None:
|
||||
"""
|
||||
Saves the class parameters to the specified file yaml file path as a yaml file.
|
||||
|
||||
Parameters:
|
||||
config_file(str): The path to the config yaml file.
|
||||
DEFAULT: "../ai_settings.yaml"
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
config = {
|
||||
"ai_name": self.ai_name,
|
||||
"ai_role": self.ai_role,
|
||||
"ai_goals": self.ai_goals,
|
||||
"api_budget": self.api_budget,
|
||||
}
|
||||
with open(config_file, "w", encoding="utf-8") as file:
|
||||
yaml.dump(config, file, allow_unicode=True)
|
||||
|
||||
def construct_full_prompt(
|
||||
self, prompt_generator: Optional[PromptGenerator] = None
|
||||
) -> str:
|
||||
"""
|
||||
Returns a prompt to the user with the class information in an organized fashion.
|
||||
|
||||
Parameters:
|
||||
None
|
||||
|
||||
Returns:
|
||||
full_prompt (str): A string containing the initial prompt for the user
|
||||
including the ai_name, ai_role, ai_goals, and api_budget.
|
||||
"""
|
||||
|
||||
prompt_start = (
|
||||
"Your decisions must always be made independently without"
|
||||
" seeking user assistance. Play to your strengths as an LLM and pursue"
|
||||
" simple strategies with no legal complications."
|
||||
""
|
||||
)
|
||||
|
||||
|
||||
|
||||
cfg = Config()
|
||||
if prompt_generator is None:
|
||||
prompt_generator = build_default_prompt_generator()
|
||||
prompt_generator.goals = self.ai_goals
|
||||
prompt_generator.name = self.ai_name
|
||||
prompt_generator.role = self.ai_role
|
||||
prompt_generator.command_registry = self.command_registry
|
||||
for plugin in cfg.plugins:
|
||||
if not plugin.can_handle_post_prompt():
|
||||
continue
|
||||
prompt_generator = plugin.post_prompt(prompt_generator)
|
||||
|
||||
if cfg.execute_local_commands:
|
||||
# add OS info to prompt
|
||||
os_name = platform.system()
|
||||
os_info = (
|
||||
platform.platform(terse=True)
|
||||
if os_name != "Linux"
|
||||
else distro.name(pretty=True)
|
||||
)
|
||||
|
||||
prompt_start += f"\nThe OS you are running on is: {os_info}"
|
||||
|
||||
# Construct full prompt
|
||||
full_prompt = f"You are {prompt_generator.name}, {prompt_generator.role}\n{prompt_start}\n\nGOALS:\n\n"
|
||||
for i, goal in enumerate(self.ai_goals):
|
||||
full_prompt += f"{i+1}. {goal}\n"
|
||||
if self.api_budget > 0.0:
|
||||
full_prompt += f"\nIt takes money to let you run. Your API budget is ${self.api_budget:.3f}"
|
||||
self.prompt_generator = prompt_generator
|
||||
full_prompt += f"\n\n{prompt_generator.generate_prompt_string()}"
|
||||
return full_prompt
|
@ -1,6 +1,9 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
from auto_gpt_plugin_template import AutoGPTPluginTemplate
|
||||
from pilot.singleton import Singleton
|
||||
|
||||
@ -8,5 +11,89 @@ class Config(metaclass=Singleton):
|
||||
"""Configuration class to store the state of bools for different scripts access"""
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the Config class"""
|
||||
pass
|
||||
|
||||
# TODO change model_config there
|
||||
|
||||
self.debug_mode = False
|
||||
self.skip_reprompt = False
|
||||
|
||||
self.temperature = float(os.getenv("TEMPERATURE", 0.7))
|
||||
|
||||
# TODO change model_config there
|
||||
self.execute_local_commands = (
|
||||
os.getenv("EXECUTE_LOCAL_COMMANDS", "False") == "True"
|
||||
)
|
||||
# User agent header to use when making HTTP requests
|
||||
# Some websites might just completely deny request with an error code if
|
||||
# no user agent was found.
|
||||
self.user_agent = os.getenv(
|
||||
"USER_AGENT",
|
||||
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_4) AppleWebKit/537.36"
|
||||
" (KHTML, like Gecko) Chrome/83.0.4103.97 Safari/537.36",
|
||||
)
|
||||
|
||||
self.elevenlabs_api_key = os.getenv("ELEVENLABS_API_KEY")
|
||||
self.elevenlabs_voice_1_id = os.getenv("ELEVENLABS_VOICE_1_ID")
|
||||
self.elevenlabs_voice_2_id = os.getenv("ELEVENLABS_VOICE_2_ID")
|
||||
|
||||
self.use_mac_os_tts = False
|
||||
self.use_mac_os_tts = os.getenv("USE_MAC_OS_TTS")
|
||||
|
||||
# milvus or zilliz cloud configuration
|
||||
self.milvus_addr = os.getenv("MILVUS_ADDR", "localhost:19530")
|
||||
self.milvus_username = os.getenv("MILVUS_USERNAME")
|
||||
self.milvus_password = os.getenv("MILVUS_PASSWORD")
|
||||
self.milvus_collection = os.getenv("MILVUS_COLLECTION", "dbgpt")
|
||||
self.milvus_secure = os.getenv("MILVUS_SECURE") == "True"
|
||||
|
||||
self.plugins_dir = os.getenv("PLUGINS_DIR", "plugins")
|
||||
self.plugins: List[AutoGPTPluginTemplate] = []
|
||||
self.plugins_openai = []
|
||||
|
||||
self.command_registry = []
|
||||
|
||||
self.huggingface_api_token = os.getenv("HUGGINGFACE_API_TOKEN")
|
||||
self.image_provider = os.getenv("IMAGE_PROVIDER")
|
||||
self.image_size = int(os.getenv("IMAGE_SIZE", 256))
|
||||
self.huggingface_image_model = os.getenv(
|
||||
"HUGGINGFACE_IMAGE_MODEL", "CompVis/stable-diffusion-v1-4"
|
||||
)
|
||||
self.huggingface_audio_to_text_model = os.getenv(
|
||||
"HUGGINGFACE_AUDIO_TO_TEXT_MODEL"
|
||||
)
|
||||
|
||||
disabled_command_categories = os.getenv("DISABLED_COMMAND_CATEGORIES")
|
||||
if disabled_command_categories:
|
||||
self.disabled_command_categories = disabled_command_categories.split(",")
|
||||
else:
|
||||
self.disabled_command_categories = []
|
||||
|
||||
self.execute_local_commands = (
|
||||
os.getenv("EXECUTE_LOCAL_COMMANDS", "False") == "True"
|
||||
)
|
||||
|
||||
plugins_allowlist = os.getenv("ALLOWLISTED_PLUGINS")
|
||||
if plugins_allowlist:
|
||||
self.plugins_allowlist = plugins_allowlist.split(",")
|
||||
else:
|
||||
self.plugins_allowlist = []
|
||||
|
||||
plugins_denylist = os.getenv("DENYLISTED_PLUGINS")
|
||||
if plugins_denylist:
|
||||
self.plugins_denylist = plugins_denylist.split(",")
|
||||
else:
|
||||
self.plugins_denylist = []
|
||||
|
||||
def set_debug_mode(self, value: bool) -> None:
|
||||
"""Set the debug mode value"""
|
||||
self.debug_mode = value
|
||||
|
||||
def set_plugins(self, value: list) -> None:
|
||||
"""Set the plugins value. """
|
||||
self.plugins = value
|
||||
|
||||
def set_templature(self, value: int) -> None:
|
||||
"""Set the temperature value."""
|
||||
self.temperature = value
|
||||
|
||||
|
@ -5,6 +5,7 @@ import torch
|
||||
import os
|
||||
import nltk
|
||||
|
||||
|
||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
MODEL_PATH = os.path.join(ROOT_PATH, "models")
|
||||
PILOT_PATH = os.path.join(ROOT_PATH, "pilot")
|
||||
@ -26,9 +27,8 @@ LLM_MODEL_CONFIG = {
|
||||
VECTOR_SEARCH_TOP_K = 3
|
||||
LLM_MODEL = "vicuna-13b"
|
||||
LIMIT_MODEL_CONCURRENCY = 5
|
||||
MAX_POSITION_EMBEDDINGS = 2048
|
||||
VICUNA_MODEL_SERVER = "http://192.168.31.114:8000"
|
||||
|
||||
MAX_POSITION_EMBEDDINGS = 4096
|
||||
VICUNA_MODEL_SERVER = "http://121.41.167.183:8000"
|
||||
|
||||
# Load model config
|
||||
ISLOAD_8BIT = True
|
||||
@ -37,7 +37,7 @@ ISDEBUG = False
|
||||
|
||||
DB_SETTINGS = {
|
||||
"user": "root",
|
||||
"password": "aa123456",
|
||||
"host": "localhost",
|
||||
"password": "aa12345678",
|
||||
"host": "127.0.0.1",
|
||||
"port": 3306
|
||||
}
|
@ -16,6 +16,7 @@ class MySQLOperator:
|
||||
self.conn = pymysql.connect(
|
||||
host=host,
|
||||
user=user,
|
||||
port=port,
|
||||
passwd=password,
|
||||
charset="utf8mb4",
|
||||
cursorclass=pymysql.cursors.DictCursor
|
||||
|
@ -6,18 +6,20 @@ from enum import auto, Enum
|
||||
from typing import List, Any
|
||||
from pilot.configs.model_config import DB_SETTINGS
|
||||
|
||||
|
||||
class SeparatorStyle(Enum):
|
||||
|
||||
SINGLE = auto()
|
||||
TWO = auto()
|
||||
THREE = auto()
|
||||
FOUR = auto()
|
||||
|
||||
@dataclasses.dataclass
|
||||
@ dataclasses.dataclass
|
||||
class Conversation:
|
||||
"""This class keeps all conversation history. """
|
||||
|
||||
system: str
|
||||
roles: List[str]
|
||||
messages: List[List[str]]
|
||||
messages: List[List[str]]
|
||||
offset: int
|
||||
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
||||
sep: str = "###"
|
||||
@ -32,7 +34,7 @@ class Conversation:
|
||||
ret = self.system + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
ret += role + ": " + message + self.sep
|
||||
ret += role + ": " + message + self.sep
|
||||
else:
|
||||
ret += role + ":"
|
||||
return ret
|
||||
@ -46,14 +48,12 @@ class Conversation:
|
||||
else:
|
||||
ret += role + ":"
|
||||
return ret
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid style: {self.sep_style}")
|
||||
|
||||
|
||||
def append_message(self, role, message):
|
||||
self.messages.append([role, message])
|
||||
|
||||
|
||||
def to_gradio_chatbot(self):
|
||||
ret = []
|
||||
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
||||
@ -101,13 +101,14 @@ def gen_sqlgen_conversation(dbname):
|
||||
message += s["schema_info"] + ";"
|
||||
return f"数据库{dbname}的Schema信息如下: {message}\n"
|
||||
|
||||
|
||||
conv_one_shot = Conversation(
|
||||
system="A chat between a curious human and an artificial intelligence assistant, who very familiar with database related knowledge. "
|
||||
"The assistant gives helpful, detailed, professional and polite answers to the human's questions. ",
|
||||
roles=("Human", "Assistant"),
|
||||
system="A chat between a curious user and an artificial intelligence assistant, who very familiar with database related knowledge. "
|
||||
"The assistant gives helpful, detailed, professional and polite answers to the user's questions. ",
|
||||
roles=("USER", "Assistant"),
|
||||
messages=(
|
||||
(
|
||||
"Human",
|
||||
"USER",
|
||||
"What are the key differences between mysql and postgres?",
|
||||
),
|
||||
(
|
||||
@ -134,10 +135,10 @@ conv_one_shot = Conversation(
|
||||
sep_style=SeparatorStyle.SINGLE,
|
||||
sep="###"
|
||||
)
|
||||
|
||||
|
||||
conv_vicuna_v1 = Conversation(
|
||||
system = "A chat between a curious user and an artificial intelligence assistant. who very familiar with database related knowledge. "
|
||||
"The assistant gives helpful, detailed, professional and polite answers to the user's questions. ",
|
||||
system="A chat between a curious user and an artificial intelligence assistant. who very familiar with database related knowledge. "
|
||||
"The assistant gives helpful, detailed, professional and polite answers to the user's questions. ",
|
||||
roles=("USER", "ASSISTANT"),
|
||||
messages=(),
|
||||
offset=0,
|
||||
@ -146,10 +147,100 @@ conv_vicuna_v1 = Conversation(
|
||||
sep2="</s>",
|
||||
)
|
||||
|
||||
auto_dbgpt_one_shot = Conversation(
|
||||
system="You are DB-GPT, an AI designed to answer questions about HackerNews by query `hackerbews` database in MySQL. "
|
||||
"Your decisions must always be made independently without seeking user assistance. Play to your strengths as an LLM and pursue simple strategies with no legal complications.",
|
||||
roles=("USER", "ASSISTANT"),
|
||||
messages=(
|
||||
(
|
||||
"USER",
|
||||
""" Answer how many users does app_users have by query ob database
|
||||
Constraints:
|
||||
1. If you are unsure how you previously did something or want to recall past events, thinking about similar events will help you remember.
|
||||
2. No user assistance
|
||||
3. Exclusively use the commands listed in double quotes e.g. "command name"
|
||||
|
||||
Commands:
|
||||
1. analyze_code: Analyze Code, args: "code": "<full_code_string>"
|
||||
2. execute_python_file: Execute Python File, args: "filename": "<filename>"
|
||||
3. append_to_file: Append to file, args: "filename": "<filename>", "text": "<text>"
|
||||
4. delete_file: Delete file, args: "filename": "<filename>"
|
||||
5. list_files: List Files in Directory, args: "directory": "<directory>"
|
||||
6. read_file: Read file, args: "filename": "<filename>"
|
||||
7. write_to_file: Write to file, args: "filename": "<filename>", "text": "<text>"
|
||||
8. ob_sql_executor: "Execute SQL in OB Database.", args: "sql": "<sql>"
|
||||
|
||||
Resources:
|
||||
1. Internet access for searches and information gathering.
|
||||
2. Long Term memory management.
|
||||
3. vicuna powered Agents for delegation of simple tasks.
|
||||
|
||||
Performance Evaluation:
|
||||
1. Continuously review and analyze your actions to ensure you are performing to the best of your abilities.
|
||||
2. Constructively self-criticize your big-picture behavior constantly.
|
||||
3. Reflect on past decisions and strategies to refine your approach.
|
||||
4. Every command has a cost, so be smart and efficient. Aim to complete tasks in the least number of steps.
|
||||
5. Write all code to a file.
|
||||
|
||||
You should only respond in JSON format as described below
|
||||
Response Format:
|
||||
{
|
||||
"thoughts": {
|
||||
"text": "thought",
|
||||
"reasoning": "reasoning",
|
||||
"plan": "- short bulleted\n- list that conveys\n- long-term plan",
|
||||
"criticism": "constructive self-criticism",
|
||||
"speak": "thoughts summary to say to user"
|
||||
},
|
||||
"command": {
|
||||
"name": "command name",
|
||||
"args": {
|
||||
"arg name": "value"
|
||||
}
|
||||
}
|
||||
}
|
||||
Ensure the response can be parsed by Python json.loads
|
||||
"""
|
||||
),
|
||||
(
|
||||
"ASSISTANT",
|
||||
"""
|
||||
{
|
||||
"thoughts": {
|
||||
"text": "To answer how many users by query database we need to write SQL query to get the count of the distinct users from the database. We can use ob_sql_executor command to execute the SQL query in database.",
|
||||
"reasoning": "We can use the sql_executor command to execute the SQL query for getting count of distinct users from the users database. We can select the count of the distinct users from the users table.",
|
||||
"plan": "- Write SQL query to get count of distinct users from users database\n- Use ob_sql_executor to execute the SQL query in OB database\n- Parse the SQL result to get the count\n- Respond with the count as the answer",
|
||||
"criticism": "None",
|
||||
"speak": "To get the number of users in users, I will execute an SQL query in OB database using the ob_sql_executor command and respond with the count."
|
||||
},
|
||||
"command": {
|
||||
"name": "ob_sql_executor",
|
||||
"args": {
|
||||
"sql": "SELECT COUNT(DISTINCT(*)) FROM users ;"
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
)
|
||||
),
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.SINGLE,
|
||||
sep="###",
|
||||
)
|
||||
|
||||
conv_qa_prompt_template = """ 基于以下已知的信息, 专业、详细的回答用户的问题。
|
||||
如果无法从提供的恶内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题", 但是你可以给出一些与问题相关答案的建议:
|
||||
|
||||
auto_dbgpt_without_shot = Conversation(
|
||||
system="You are DB-GPT, an AI designed to answer questions about users by query `users` database in MySQL. "
|
||||
"Your decisions must always be made independently without seeking user assistance. Play to your strengths as an LLM and pursue simple strategies with no legal complications.",
|
||||
roles=("USER", "ASSISTANT"),
|
||||
messages=(),
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.SINGLE,
|
||||
sep=" ",
|
||||
sep2="</s>",
|
||||
)
|
||||
|
||||
conv_qa_prompt_template = """ 基于以下已知的信息, 专业、详细的回答用户的问题,
|
||||
如果无法从提供的恶内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题", 但是你可以给出一些与问题相关答案的建议。
|
||||
已知内容:
|
||||
{context}
|
||||
问题:
|
||||
@ -161,15 +252,15 @@ default_conversation = conv_one_shot
|
||||
conversation_types = {
|
||||
"native": "LLM原生对话",
|
||||
"default_knownledge": "默认知识库对话",
|
||||
"custome": "新增知识库对话",
|
||||
"custome": "新增知识库对话",
|
||||
}
|
||||
|
||||
conv_templates = {
|
||||
"conv_one_shot": conv_one_shot,
|
||||
"vicuna_v1": conv_vicuna_v1,
|
||||
"auto_dbgpt_one_shot": auto_dbgpt_one_shot
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
message = gen_sqlgen_conversation("dbgpt")
|
||||
print(message)
|
||||
print(message)
|
||||
|
Binary file not shown.
Binary file not shown.
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
206
pilot/datasets/oceanbase/OceanBase_Introduction.md
Normal file
206
pilot/datasets/oceanbase/OceanBase_Introduction.md
Normal file
@ -0,0 +1,206 @@
|
||||
OceanBase 数据库(OceanBase Database)是一款完全自研的企业级原生分布式数据库,在普通硬件上实现金融级高可用,首创“三地五中心”城市级故障自动无损容灾新标准,刷新 TPC-C 标准测试,单集群规模超过 1500 节点,具有云原生、强一致性、高度兼容 Oracle/MySQL 等特性。
|
||||
|
||||
核心特性
|
||||
高可用
|
||||
独创 “三地五中心” 容灾架构方案,建立金融行业无损容灾新标准。支持同城/异地容灾,可实现多地多活,满足金融行业 6 级容灾标准(RPO=0,RTO< 8s),数据零丢失。
|
||||
高兼容
|
||||
高度兼容 Oracle 和 MySQL,覆盖绝大多数常见功能,支持过程语言、触发器等高级特性,提供自动迁移工具,支持迁移评估和反向同步以保障数据迁移安全,可支撑金融、政府、运营商等关键行业核心场景替代。
|
||||
水平扩展
|
||||
实现透明水平扩展,支持业务快速的扩容缩容,同时通过准内存处理架构实现高性能。支持集群节点超过数千个,单集群最大数据量超过 3PB,最大单表行数达万亿级。
|
||||
低成本
|
||||
基于 LSM-Tree 的高压缩引擎,存储成本降低 70% - 90%;原生支持多租户架构,同集群可为多个独立业务提供服务,租户间数据隔离,降低部署和运维成本。
|
||||
实时 HTAP
|
||||
基于“同一份数据,同一个引擎”,同时支持在线实时交易及实时分析两种场景,“一份数据”的多个副本可以存储成多种形态,用于不同工作负载,从根本上保持数据一致性。
|
||||
安全可靠
|
||||
12 年完全自主研发,代码级可控,自研单机分布式一体化架构,大规模金融核心场景 9 年可靠性验证;完备的角色权限管理体系,数据存储和通信全链路透明加密,支持国密算法,通过等保三级专项合规检测。
|
||||
深入了解 OceanBase 数据库
|
||||
您可以通过以下内容更深入地了解 OceanBase 数据库:
|
||||
|
||||
OceanBase 使用通用服务器硬件,依赖本地存储,分布式部署使用的多个服务器也是对等的,没有特殊的硬件要求。OceanBase 的分布式数据库处理采用 Shared Nothing 架构,数据库内的 SQL 执行引擎具有分布式执行能力。
|
||||
|
||||
OceanBase 在服务器上会运行叫做 observer 的单进程程序作为数据库的运行实例,使用本地的文件存储数据和事务 Redo 日志。
|
||||
|
||||
OceanBase 集群部署需要配置可用区(Zone),由若干个服务器组成。可用区是一个逻辑概念,表示集群内具有相似硬件可用性的一组节点,它在不同的部署模式下代表不同的含义。例如,当整个集群部署在同一个数据中心(IDC)内的时候,一个可用区的节点可以属于同一个机架,同一个交换机等。当集群分布在多个数据中心的时候,每个可用区可以对应于一个数据中心。
|
||||
|
||||
用户存储的数据在分布式集群内部可以存储多个副本,用于故障容灾,也可以用于分散读取压力。在一个可用区内部数据只有一个副本,不同的可用区可以存储同一个数据的多个副本,副本之间由共识协议保证数据的一致性。
|
||||
|
||||
OceanBase 内置多租户特性,每个租户对于使用者是一个独立的数据库,一个租户能够在租户级别设置租户的分布式部署方式。租户之间 CPU、内存和 IO 都是隔离的。
|
||||
|
||||
OceanBase的数据库实例内部由不同的组件相互协作,这些组件从底层向上由存储层、复制层、均衡层、事务层、SQL 层、接入层组成。
|
||||
|
||||
存储层
|
||||
存储层以一张表或者一个分区为粒度提供数据存储与访问,每个分区对应一个用于存储数据的Tablet(分片),用户定义的非分区表也会对应一个 Tablet。
|
||||
|
||||
Tablet 的内部是分层存储的结构,总共有 4 层。DML 操作插入、更新、删除等首先写入 MemTable,等到 MemTable 达到一定大小时转储到磁盘成为 L0 SSTable。L0 SSTable 个数达到阈值后会将多个 L0 SSTable 合并成一个 L1 SSTable。在每天配置的业务低峰期,系统会将所有的 MemTable、L0 SSTable 和 L1 SSTable 合并成一个 Major SSTable。
|
||||
|
||||
每个 SSTable 内部是以 2MB 定长宏块为基本单位,每个宏块内部由多个不定长微块组成。
|
||||
|
||||
Major SSTable 的微块会在合并过程中用编码方式进行格式转换,微块内的数据会按照列维度分别进行列内的编码,编码规则包括字典/游程/常量/差值等,每一列压缩结束后,还会进一步对多列进行列间等值/子串等规则编码。编码能对数据大幅压缩,同时提炼的列内特征信息还能进一步加速后续的查询速度。
|
||||
|
||||
在编码压缩之后,还可以根据用户指定的通用压缩算法进行无损压缩,进一步提升数据压缩率。
|
||||
|
||||
复制层
|
||||
复制层使用日志流(LS、Log Stream)在多副本之间同步状态。每个 Tablet 都会对应一个确定的日志流,DML 操作写入 Tablet 的数据所产生的 Redo 日志会持久化在日志流中。日志流的多个副本会分布在不同的可用区中,多个副本之间维持了共识算法,选择其中一个副本作为主副本,其他的副本皆为从副本。Tablet 的 DML 和强一致性查询只在其对应的日志流的主副本上进行。
|
||||
|
||||
通常情况下,每个租户在每台机器上只会有一个日志流的主副本,可能存在多个其他日志流的从副本。租户的总日志流个数取决于 Primary Zone 和 Locality 的配置。
|
||||
|
||||
日志流使用自研的 Paxos 协议将 Redo 日志在本服务器持久化,同时通过网络发送给日志流的从副本,从副本在完成各自持久化后应答主副本,主副本在确认有多数派副本都持久化成功后确认对应的 Redo 日志持久化成功。从副本利用 Redo 日志的内容实时回放,保证自己的状态与主副本一致。
|
||||
|
||||
日志流的主副本在被选举成为主后会获得租约(Lease),正常工作的主副本在租约有效期内会不停的通过选举协议延长租约期。主副本只会在租约有效时执行主的工作,租约机制保证了数据库异常处理的能力。
|
||||
|
||||
复制层能够自动应对服务器故障,保障数据库服务的持续可用。如果出现少于半数的从副本所在服务器出现问题,因为还有多于半数的副本正常工作,数据库的服务不受影响。如果主副本所在服务器出现问题,其租约会得不到延续,待其租约失效后,其他从副本会通过选举协议选举出新的主副本并授予新的租约,之后即可恢复数据库的服务。
|
||||
|
||||
均衡层
|
||||
新建表和新增分区时,系统会按照均衡原则选择合适的日志流创建 Tablet。当租户的属性发生变更,新增了机器资源,或者经过长时间使用后,Tablet 在各台机器上不再均衡时,均衡层通过日志流的分裂和合并操作,并在这个过程中配合日志流副本的移动,让数据和服务在多个服务器之间再次均衡。
|
||||
|
||||
当租户有扩容操作,获得更多服务器资源时,均衡层会将租户内已有的日志流进行分裂,并选择合适数量的 Tablet 一同分裂到新的日志流中,再将新日志流迁移到新增的服务器上,以充分利用扩容后的资源。当租户有缩容操作时,均衡层会把需要缩减的服务器上的日志流迁移到其他服务器上,并和其他服务器上已有的日志流进行合并,以缩减机器的资源占用。
|
||||
|
||||
当数据库长期使用后,随着持续创建删除表格,并且写入更多的数据,即使没有服务器资源数量变化,原本均衡的情况可能被破坏。最常见的情况是,当用户删除了一批表格后,删除的表格可能原本聚集在某一些机器上,删除后这些机器上的 Tablet 数量就变少了,应该把其他机器的 Tablet 均衡一些到这些少的机器上。均衡层会定期生成均衡计划,将 Tablet 多的服务器上日志流分裂出临时日志流并携带需要移动的 Tablet,临时日志流迁移到目的服务器后再和目的服务器上的日志流进行合并,以达成均衡的效果。
|
||||
|
||||
事务层
|
||||
事务层保证了单个日志流和多个日志流DML操作提交的原子性,也保证了并发事务之间的多版本隔离能力。
|
||||
|
||||
原子性
|
||||
一个日志流上事务的修改,即使涉及多个 Tablet,通过日志流的 write-ahead log 可以保证事务提交的原子性。事务的修改涉及多个日志流时,每个日志流会产生并持久化各自的write-ahead log,事务层通过优化的两阶段提交协议来保证提交的原子性。
|
||||
|
||||
事务层会选择一个事务修改的一个日志流产生协调者状态机,协调者会与事务修改的所有日志流通信,判断 write-ahead log 是否持久化,当所有日志流都完成持久化后,事务进入提交状态,协调者会再驱动所有日志流写下这个事务的 Commit 日志,表示事务最终的提交状态。当从副本回放或者数据库重启时,已经完成提交的事务都会通过 Commit 日志确定各自日志流事务的状态。
|
||||
|
||||
宕机重启场景下,宕机前还未完成的事务,会出现写完 write-ahead log 但是还没有Commit 日志的情况,每个日志流的 write-ahead log 都会包含事务的所有日志流列表,通过此信息可以重新确定哪个日志流是协调者并恢复协调者的状态,再次推进两阶段状态机,直到事务最终的 Commit 或 Abort 状态。
|
||||
|
||||
隔离性
|
||||
GTS 服务是一个租户内产生连续增长的时间戳的服务,其通过多副本保证可用性,底层机制与上面复制层所描述的日志流副本同步机制是一样的。
|
||||
|
||||
每个事务在提交时会从 GTS 获取一个时间戳作为事务的提交版本号并持久化在日志流的write-ahead log 中,事务内所有修改的数据都以此提交版本号标记。
|
||||
|
||||
每个语句开始时(对于 Read Committed 隔离级别)或者每个事务开始时(对于Repeatable Read 和 Serializable 隔离级别)会从 GTS 获取一个时间戳作为语句或事务的读取版本号。在读取数据时,会跳过事务版本号比读取版本号大的数据,通过这种方式为读取操作提供了统一的全局数据快照。
|
||||
|
||||
SQL 层
|
||||
SQL 层将用户的 SQL 请求转化成对一个或多个 Tablet 的数据访问。
|
||||
|
||||
SQL 层组件
|
||||
SQL 层处理一个请求的执行流程是:Parser、Resolver、Transformer、Optimizer、Code Generator、Executor。
|
||||
|
||||
Parser 负责词法/语法解析,Parser 会将用户的 SQL 分成一个个的 "Token",并根据预先设定好的语法规则解析整个请求,转换成语法树(Syntax Tree)。
|
||||
|
||||
Resolver 负责语义解析,将根据数据库元信息将 SQL 请求中的 Token 翻译成对应的对象(例如库、表、列、索引等),生成的数据结构叫做 Statement Tree。
|
||||
|
||||
Transformer 负责逻辑改写,根据内部的规则或代价模型,将 SQL 改写为与之等价的其他形式,并将其提供给后续的优化器做进一步的优化。Transformer 的工作方式是在原Statement Tree 上做等价变换,变换的结果仍然是一棵 Statement Tree。
|
||||
|
||||
Optimizer(优化器)为 SQL 请求生成最佳的执行计划,需要综合考虑 SQL 请求的语义、对象数据特征、对象物理分布等多方面因素,解决访问路径选择、联接顺序选择、联接算法选择、分布式计划生成等问题,最终生成执行计划。
|
||||
|
||||
Code Generator(代码生成器)将执行计划转换为可执行的代码,但是不做任何优化选择。
|
||||
|
||||
Executor(执行器)启动 SQL 的执行过程。
|
||||
|
||||
在标准的 SQL 流程之外,SQL 层还有 Plan Cache 能力,将历史的执行计划缓存在内存中,后续的执行可以反复执行这个计划,避免了重复查询优化的过程。配合 Fast-parser 模块,仅使用词法分析对文本串直接参数化,获取参数化后的文本及常量参数,让 SQL 直接命中 Plan Cache,加速频繁执行的 SQL。
|
||||
|
||||
多种计划
|
||||
SQL 层的执行计划分为本地、远程和分布式三种。本地执行计划只访问本服务器的数据。远程执行计划只访问非本地的一台服务器的数据。分布式计划会访问超过一台服务器的数据,执行计划会分成多个子计划在多个服务器上执行。
|
||||
|
||||
SQL 层并行化执行能力可以将执行计划分解成多个部分,由多个执行线程执行,通过一定的调度的方式,实现执行计划的并行处理。并行化执行可以充分发挥服务器 CPU 和 IO 处理能力,缩短单个查询的响应时间。并行查询技术可以用于分布式执行计划,也可以用于本地执行计划。
|
||||
|
||||
接入层
|
||||
obproxy 是 OceanBase 数据库的接入层,负责将用户的请求转发到合适的 OceanBase 实例上进行处理。
|
||||
|
||||
obproxy 是独立的进程实例,独立于 OceanBase 的数据库实例部署。obproxy 监听网络端口,兼容 MySQL 网络协议,支持使用 MySQL 驱动的应用直接连接 OceanBase。
|
||||
|
||||
obproxy 能够自动发现 OceanBase 集群的数据分布信息,对于代理的每一条 SQL 语句,会尽可能识别出语句将访问的数据,并将语句直接转发到数据所在服务器的 OceanBase 实例。
|
||||
|
||||
obproxy 有两种部署方式,一种是部署在每一个需要访问数据库的应用服务器上,另一种是部署在与 OceanBase 相同的机器上。第一种部署方式下,应用程序直接连接部署在同一台服务器上的 obproxy,所有的请求会由 obproxy 发送到合适的 OceanBase 服务器。第二种部署方式下,需要使用网络负载均衡服务将多个 obproxy 聚合成同一个对应用提供服务的入口地址。
|
||||
|
||||
OceanBase 数据库采用 Shared-Nothing 架构,各个节点之间完全对等,每个节点都有自己的 SQL 引擎、存储引擎、事务引擎,运行在普通 PC 服务器组成的集群之上,具备高可扩展性、高可用性、高性能、低成本、与主流数据库高兼容等核心特性。
|
||||
|
||||
OceanBase 数据库的一个集群由若干个节点组成。这些节点分属于若干个可用区(Zone),每个节点属于一个可用区。可用区是一个逻辑概念,表示集群内具有相似硬件可用性的一组节点,它在不同的部署模式下代表不同的含义。例如,当整个集群部署在同一个数据中心(IDC)内的时候,一个可用区的节点可以属于同一个机架,同一个交换机等。当集群分布在多个数据中心的时候,每个可用区可以对应于一个数据中心。每个可用区具有 IDC 和地域(Region)两个属性,描述该可用区所在的 IDC 及 IDC 所属的地域。一般地,地域指 IDC 所在的城市。可用区的 IDC 和 Region 属性需要反映部署时候的实际情况,以便集群内的自动容灾处理和优化策略能更好地工作。根据业务对数据库系统不同的高可用性需求,OceanBase 集群提供了多种部署模式,参见 高可用架构概述。
|
||||
|
||||
在 OceanBase 数据库中,一个表的数据可以按照某种划分规则水平拆分为多个分片,每个分片叫做一个表分区,简称分区(Partition)。某行数据属于且只属于一个分区。分区的规则由用户在建表的时候指定,包括hash、range、list等类型的分区,还支持二级分区。例如,交易库中的订单表,可以先按照用户 ID 划分为若干一级分区,再按照月份把每个一级分区划分为若干二级分区。对于二级分区表,第二级的每个子分区是一个物理分区,而第一级分区只是逻辑概念。一个表的若干个分区可以分布在一个可用区内的多个节点上。每个物理分区有一个用于存储数据的存储层对象,叫做 Tablet ,用于存储有序的数据记录。
|
||||
|
||||
当用户对 Tablet 中记录进行修改的时候,为了保证数据持久化,需要记录重做日志(REDO)到 Tablet 对应的日志流(Log Stream)里。每个日志流服务了其所在节点上的多个 Tablet。为了能够保护数据,并在节点发生故障的时候不中断服务,每个日志流及其所属的 Tablet 有多个副本。一般来说,多个副本分散在多个不同的可用区里。多个副本中有且只有一个副本接受修改操作,叫做主副本(Leader),其他副本叫做从副本(Follower)。主从副本之间通过基于 Multi-Paxos 的分布式共识协议实现了副本之间数据的一致性。当主副本所在节点发生故障的时候,一个从副本会被选举为新的主副本并继续提供服务。
|
||||
|
||||
在集群的每个节点上会运行一个叫做 observer 的服务进程,它内部包含多个操作系统线程。节点的功能都是对等的。每个服务负责自己所在节点上分区数据的存取,也负责路由到本机的 SQL 语句的解析和执行。这些服务进程之间通过 TCP/IP 协议进行通信。同时,每个服务会监听来自外部应用的连接请求,建立连接和数据库会话,并提供数据库服务。关于 observer 服务进程的更多信息,参见 线程简介。
|
||||
|
||||
为了简化大规模部署多个业务数据库的管理并降低资源成本,OceanBase 数据库提供了独特的多租户特性。在一个 OceanBase 集群内,可以创建很多个互相之间隔离的数据库"实例",叫做一个租户。从应用程序的视角来看,每个租户是一个独立的数据库。不仅如此,每个租户可以选择 MySQL 或 Oracle 兼容模式。应用连接到 MySQL 租户后,可以在租户下创建用户、database,与一个独立的 MySQL 库的使用体验是一样的。同样的,应用连接到 Oracle 租户后,可以在租户下创建 schema、管理角色等,与一个独立的 Oracle 库的使用体验是一样的。一个新的集群初始化之后,就会存在一个特殊的名为 sys 的租户,叫做系统租户。系统租户中保存了集群的元数据,是一个 MySQL 兼容模式的租户。
|
||||
|
||||
为了隔离租户的资源,每个 observer 进程内可以有多个属于不同租户的虚拟容器,叫做资源单元(UNIT)。每个租户在多个节点上的资源单元组成一个资源池。资源单元包括 CPU 和内存资源。
|
||||
|
||||
为了使 OceanBase 数据库对应用程序屏蔽内部分区和副本分布等细节,使应用访问分布式数据库像访问单机数据库一样简单,我们提供了 obproxy 代理服务。应用程序并不会直接与 OBServer 建立连接,而是连接obproxy,然后由 obproxy 转发 SQL 请求到合适的 OBServer 节点。obproxy 是无状态的服务,多个 obproxy 节点通过网络负载均衡(SLB)对应用提供统一的网络地址。
|
||||
|
||||
|
||||
OceanBase 数据库是随着阿里巴巴电商业务的发展孕育而生,随着蚂蚁集团移动支付业务的发展而壮大,经过十多年各类业务的使用和打磨才终于破茧成蝶,推向了外部市场。本章节简述 OceanBase 数据库发展过程中一些里程碑意义的事件。
|
||||
|
||||
诞生
|
||||
|
||||
2010 年,OceanBase 创始人阳振坤博士带领初创团队启动了 OceanBase 项目。第一个应用是淘宝的收藏夹业务。如今收藏夹依然是 OceanBase 的客户。收藏夹单表数据量非常大,OceanBase 用独创的方法解决了其高并发的大表连接小表的需求。
|
||||
|
||||
关系数据库
|
||||
|
||||
早期的版本中,应用通过定制的 API 库访问 OceanBase 数据库。2012 年,OceanBase 数据库发布了支持 SQL 的版本,初步成为一个功能完整的通用关系数据库。
|
||||
|
||||
初试金融业务
|
||||
|
||||
OceanBase 进入支付宝(后来的蚂蚁集团),开始应用于金融级的业务场景。2014 年"双 11"大促活动,OceanBase 开始承担交易库部分流量。此后,新成立的网商银行把所有核心交易库都运行在 OceanBase 数据库上。
|
||||
|
||||
金融级核心库
|
||||
|
||||
2016 年,OceanBase 数据库发布了架构重新设计后的 1.0 版本,支持了分布式事务,提升了高并发写业务中的扩展,同时实现了多租户架构,这个整体架构延续至今。同时,到 2016 年"双 11"时,支付宝全部核心库的业务流量 100% 运行在 OceanBase 数据库上,包括交易、支付、会员和最重要的账务库。
|
||||
|
||||
走向外部市场
|
||||
|
||||
2017 年,OceanBase 数据库开始试点外部业务,成功应用于南京银行。
|
||||
|
||||
商业化加速
|
||||
|
||||
2018 年,OceanBase 数据库发布 2.0 版本,开始支持 Oracle 兼容模式。这一特性降低应用改造适配成本,在外部客户中快速推广开来。
|
||||
|
||||
勇攀高峰
|
||||
|
||||
2019 年,OceanBase 数据库 V2.2 版本参加代表 OLTP 数据库最权威的 TPC-C 评测,以 6000 万 tpmC 的成绩登顶世界第一。随后,在 2020 年,又以 7 亿 tpmC 刷新纪录,截止目前依然稳居第一。这充分证明了 OceanBase 数据库优秀的扩展性和稳定性。OceanBase 数据库是第一个也是截止目前唯一一个上榜 TPC-C 的中国数据库产品。
|
||||
|
||||
HTAP 混合负载
|
||||
|
||||
2021 年,OceanBase 数据库 V3.0 基于全新的向量化执行引擎,在 TPC-H 30000GB 的评测中以 1526 万 QphH 的成绩刷新了评测榜单。这标志着 OceanBase 数据库一套引擎处理 AP 和 TP 混合负载的能力取得了基础性的突破。
|
||||
|
||||
开源开放
|
||||
|
||||
2021 年六一儿童节,OceanBase 数据库宣布全面开源,开放合作,共建生态。
|
||||
|
||||
OceanBase 数据库采用了单集群多租户设计,天然支持云数据库架构,支持公有云、私有云、混合云等多种部署形式。
|
||||
|
||||
架构
|
||||
|
||||
OceanBase 数据库通过租户实现资源隔离,让每个数据库服务的实例不感知其他实例的存在,并通过权限控制确保租户数据的安全性,配合 OceanBase 数据库强大的可扩展性,能够提供安全、灵活的 DBaaS 服务。
|
||||
|
||||
租户是一个逻辑概念。在 OceanBase 数据库中,租户是资源分配的单位,是数据库对象管理和资源管理的基础,对于系统运维,尤其是对于云数据库的运维有着重要的影响。租户在一定程度上相当于传统数据库的"实例"概念。租户之间是完全隔离的。在数据安全方面,OceanBase 数据库不允许跨租户的数据访问,以确保用户的数据资产没有被其他租户窃取的风险。在资源使用方面,OceanBase 数据库表现为租户"独占"其资源配额。总体上来说,租户(tenant)既是各类数据库对象的容器,又是资源(CPU、Memory、IO 等)的容器。
|
||||
|
||||
OceanBase 数据库在一个系统中可同时支持 MySQL 模式和 Oracle 模式两种模式的租户。用户在创建租户时,可选择创建 MySQL 兼容模式的租户或 Oracle 兼容模式的租户,租户的兼容模式一经确定就无法更改,所有数据类型、SQL 功能、视图等相应地与 MySQL 数据库或 Oracle 数据库保持一致。
|
||||
|
||||
|
||||
MySQL 模式
|
||||
MySQL 模式是为降低 MySQL 数据库迁移至 OceanBase 数据库所引发的业务系统改造成本,同时使业务数据库设计人员、开发人员、数据库管理员等可复用积累的 MySQL 数据库技术知识经验,并能快速上手 OceanBase 数据库而支持的一种租户类型功能。OceanBase 数据库的 MySQL 模式兼容 MySQL 5.7 的绝大部分功能和语法,兼容 MySQL 5.7 版本的全量以及 8.0 版本的部分 JSON 函数,基于 MySQL 的应用能够平滑迁移。
|
||||
|
||||
Oracle 模式
|
||||
OceanBase 数据库从 V2.x.x 版本开始支持 Oracle 兼容模式。Oracle 模式是为降低 Oracle 数据库迁移 OceanBase 数据库的业务系统改造成本,同时使业务数据库设计开发人员、数据库管理员等可复用积累的 Oracle 数据库技术知识经验,并能快速上手 OceanBase 数据库而支持的一种租户类型功能。Oracle 模式目前能够支持绝大部分的 Oracle 语法和过程性语言功能,可以做到大部分的 Oracle 业务进行少量修改后的自动迁移。
|
||||
|
||||
OceanBase 数据库是多租户架构。在 V4.0.0 版本之前,仅支持两种类型的租户:系统租户和用户租户。从 V4.0.0 版本开始,引入了 Meta 租户概念。因此,当前版本对用户可见的租户有三种类型:系统租户、用户租户以及 Meta 租户。
|
||||
|
||||
系统租户
|
||||
系统租户是集群默认创建的租户,与集群的生命周期一致,负责管理集群和所有租户的生命周期。系统租户仅有一个 1 号日志流,仅支持单点写入,不具备扩展能力。
|
||||
|
||||
系统租户可以创建用户表,所有的用户表和系统表数据均由 1 号日志流服务。系统租户的数据是集群私有的,不支持主备集群物理同步和物理备份恢复。
|
||||
|
||||
用户租户
|
||||
用户租户是由用户创建的租户,对外提供完整的数据库功能,支持 MySQL 和 Oracle 两种兼容模式。用户租户支持服务能力水平扩展到多台机器上,支持动态扩容和缩容,内部会根据用户的配置自动创建和删除日志流。
|
||||
|
||||
用户租户的数据有更强的数据保护和可用性要求,支持跨集群物理同步和物理备份恢复,典型数据包括:Schema 数据、用户表数据及事务数据等。
|
||||
Meta 租户
|
||||
Meta 租户是 OceanBase 数据库内部自管理的租户,每创建一个用户租户系统就会自动创建一个对应的 Meta 租户,其生命期与用户租户保持一致。
|
||||
|
||||
Meta 租户用于存储和管理用户租户的集群私有数据,这部分数据不需要进行跨库物理同步以及物理备份恢复,这些数据包括:配置项、位置信息、副本信息、日志流状态、备份恢复相关信息、合并信息等。
|
||||
|
||||
租户对比
|
||||
从用户角度来看,系统租户、用户租户和 Meta 租户的差异性如下表所示。
|
||||
OceanBase 数据库是多租户的数据库系统,一个集群内可包含多个相互独立的租户,每个租户提供独立的数据库服务。在 OceanBase 数据库中,使用资源配置(unit_config)、资源池(Resource Pool)和资源单元(Unit)三个概念,对各租户的可用资源进行管理。
|
||||
|
||||
|
||||
创建租户前,需首先确定租户的资源配置、使用资源范围等。租户创建的通用流程如下:
|
||||
|
||||
资源配置是描述资源池的配置信息,用来描述资源池中每个资源单元可用的 CPU、内存、存储空间和 IOPS 等的规格。修改资源配置可动态调整资源单元的规格。这里需要注意,资源配置指定的是对应资源单元能够提供的服务能力,而不是资源单元的实时负载。 创建资源配置的示例语句如下:
|
0
pilot/log/__init__.py
Normal file
0
pilot/log/__init__.py
Normal file
20
pilot/log/json_handler.py
Normal file
20
pilot/log/json_handler.py
Normal file
@ -0,0 +1,20 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
|
||||
class JsonFileHandler(logging.FileHandler):
|
||||
def __init__(self, filename, mode="a", encoding=None, delay=False):
|
||||
super().__init__(filename, mode, encoding, delay)
|
||||
|
||||
def emit(self, record):
|
||||
json_data = json.loads(self.format(record))
|
||||
with open(self.baseFilename, "w", encoding="utf-8") as f:
|
||||
json.dump(json_data, f, ensure_ascii=False, indent=4)
|
||||
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
class JsonFormatter(logging.Formatter):
|
||||
def format(self, record):
|
||||
return record.msg
|
287
pilot/logs.py
Normal file
287
pilot/logs.py
Normal file
@ -0,0 +1,287 @@
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
from logging import LogRecord
|
||||
from typing import Any
|
||||
|
||||
from colorama import Fore, Style
|
||||
|
||||
from pilot.log.json_handler import JsonFileHandler, JsonFormatter
|
||||
from pilot.singleton import Singleton
|
||||
from pilot.speech import say_text
|
||||
|
||||
|
||||
class Logger(metaclass=Singleton):
|
||||
"""
|
||||
Logger that handle titles in different colors.
|
||||
Outputs logs in console, activity.log, and errors.log
|
||||
For console handler: simulates typing
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# create log directory if it doesn't exist
|
||||
this_files_dir_path = os.path.dirname(__file__)
|
||||
log_dir = os.path.join(this_files_dir_path, "../logs")
|
||||
if not os.path.exists(log_dir):
|
||||
os.makedirs(log_dir)
|
||||
|
||||
log_file = "activity.log"
|
||||
error_file = "error.log"
|
||||
|
||||
console_formatter = DbGptFormatter("%(title_color)s %(message)s")
|
||||
|
||||
# Create a handler for console which simulate typing
|
||||
self.typing_console_handler = TypingConsoleHandler()
|
||||
self.typing_console_handler.setLevel(logging.INFO)
|
||||
self.typing_console_handler.setFormatter(console_formatter)
|
||||
|
||||
# Create a handler for console without typing simulation
|
||||
self.console_handler = ConsoleHandler()
|
||||
self.console_handler.setLevel(logging.DEBUG)
|
||||
self.console_handler.setFormatter(console_formatter)
|
||||
|
||||
# Info handler in activity.log
|
||||
self.file_handler = logging.FileHandler(
|
||||
os.path.join(log_dir, log_file), "a", "utf-8"
|
||||
)
|
||||
self.file_handler.setLevel(logging.DEBUG)
|
||||
info_formatter = DbGptFormatter(
|
||||
"%(asctime)s %(levelname)s %(title)s %(message_no_color)s"
|
||||
)
|
||||
self.file_handler.setFormatter(info_formatter)
|
||||
|
||||
# Error handler error.log
|
||||
error_handler = logging.FileHandler(
|
||||
os.path.join(log_dir, error_file), "a", "utf-8"
|
||||
)
|
||||
error_handler.setLevel(logging.ERROR)
|
||||
error_formatter = DbGptFormatter(
|
||||
"%(asctime)s %(levelname)s %(module)s:%(funcName)s:%(lineno)d %(title)s"
|
||||
" %(message_no_color)s"
|
||||
)
|
||||
error_handler.setFormatter(error_formatter)
|
||||
|
||||
self.typing_logger = logging.getLogger("TYPER")
|
||||
self.typing_logger.addHandler(self.typing_console_handler)
|
||||
self.typing_logger.addHandler(self.file_handler)
|
||||
self.typing_logger.addHandler(error_handler)
|
||||
self.typing_logger.setLevel(logging.DEBUG)
|
||||
|
||||
self.logger = logging.getLogger("LOGGER")
|
||||
self.logger.addHandler(self.console_handler)
|
||||
self.logger.addHandler(self.file_handler)
|
||||
self.logger.addHandler(error_handler)
|
||||
self.logger.setLevel(logging.DEBUG)
|
||||
|
||||
self.json_logger = logging.getLogger("JSON_LOGGER")
|
||||
self.json_logger.addHandler(self.file_handler)
|
||||
self.json_logger.addHandler(error_handler)
|
||||
self.json_logger.setLevel(logging.DEBUG)
|
||||
|
||||
self.speak_mode = False
|
||||
self.chat_plugins = []
|
||||
|
||||
def typewriter_log(
|
||||
self, title="", title_color="", content="", speak_text=False, level=logging.INFO
|
||||
):
|
||||
if speak_text and self.speak_mode:
|
||||
say_text(f"{title}. {content}")
|
||||
|
||||
for plugin in self.chat_plugins:
|
||||
plugin.report(f"{title}. {content}")
|
||||
|
||||
if content:
|
||||
if isinstance(content, list):
|
||||
content = " ".join(content)
|
||||
else:
|
||||
content = ""
|
||||
|
||||
self.typing_logger.log(
|
||||
level, content, extra={"title": title, "color": title_color}
|
||||
)
|
||||
|
||||
def debug(
|
||||
self,
|
||||
message,
|
||||
title="",
|
||||
title_color="",
|
||||
):
|
||||
self._log(title, title_color, message, logging.DEBUG)
|
||||
|
||||
def info(
|
||||
self,
|
||||
message,
|
||||
title="",
|
||||
title_color="",
|
||||
):
|
||||
self._log(title, title_color, message, logging.INFO)
|
||||
|
||||
def warn(
|
||||
self,
|
||||
message,
|
||||
title="",
|
||||
title_color="",
|
||||
):
|
||||
self._log(title, title_color, message, logging.WARN)
|
||||
|
||||
def error(self, title, message=""):
|
||||
self._log(title, Fore.RED, message, logging.ERROR)
|
||||
|
||||
def _log(
|
||||
self,
|
||||
title: str = "",
|
||||
title_color: str = "",
|
||||
message: str = "",
|
||||
level=logging.INFO,
|
||||
):
|
||||
if message:
|
||||
if isinstance(message, list):
|
||||
message = " ".join(message)
|
||||
self.logger.log(
|
||||
level, message, extra={"title": str(title), "color": str(title_color)}
|
||||
)
|
||||
|
||||
def set_level(self, level):
|
||||
self.logger.setLevel(level)
|
||||
self.typing_logger.setLevel(level)
|
||||
|
||||
def double_check(self, additionalText=None):
|
||||
if not additionalText:
|
||||
additionalText = (
|
||||
"Please ensure you've setup and configured everything"
|
||||
" correctly. Read https://github.com/Torantulino/Auto-GPT#readme to "
|
||||
"double check. You can also create a github issue or join the discord"
|
||||
" and ask there!"
|
||||
)
|
||||
|
||||
self.typewriter_log("DOUBLE CHECK CONFIGURATION", Fore.YELLOW, additionalText)
|
||||
|
||||
def log_json(self, data: Any, file_name: str) -> None:
|
||||
# Define log directory
|
||||
this_files_dir_path = os.path.dirname(__file__)
|
||||
log_dir = os.path.join(this_files_dir_path, "../logs")
|
||||
|
||||
# Create a handler for JSON files
|
||||
json_file_path = os.path.join(log_dir, file_name)
|
||||
json_data_handler = JsonFileHandler(json_file_path)
|
||||
json_data_handler.setFormatter(JsonFormatter())
|
||||
|
||||
# Log the JSON data using the custom file handler
|
||||
self.json_logger.addHandler(json_data_handler)
|
||||
self.json_logger.debug(data)
|
||||
self.json_logger.removeHandler(json_data_handler)
|
||||
|
||||
def get_log_directory(self):
|
||||
this_files_dir_path = os.path.dirname(__file__)
|
||||
log_dir = os.path.join(this_files_dir_path, "../logs")
|
||||
return os.path.abspath(log_dir)
|
||||
|
||||
"""
|
||||
Output stream to console using simulated typing
|
||||
"""
|
||||
|
||||
class TypingConsoleHandler(logging.StreamHandler):
|
||||
def emit(self, record):
|
||||
min_typing_speed = 0.05
|
||||
max_typing_speed = 0.01
|
||||
|
||||
msg = self.format(record)
|
||||
try:
|
||||
words = msg.split()
|
||||
for i, word in enumerate(words):
|
||||
print(word, end="", flush=True)
|
||||
if i < len(words) - 1:
|
||||
print(" ", end="", flush=True)
|
||||
typing_speed = random.uniform(min_typing_speed, max_typing_speed)
|
||||
time.sleep(typing_speed)
|
||||
# type faster after each word
|
||||
min_typing_speed = min_typing_speed * 0.95
|
||||
max_typing_speed = max_typing_speed * 0.95
|
||||
print()
|
||||
except Exception:
|
||||
self.handleError(record)
|
||||
|
||||
class ConsoleHandler(logging.StreamHandler):
|
||||
def emit(self, record) -> None:
|
||||
msg = self.format(record)
|
||||
try:
|
||||
print(msg)
|
||||
except Exception:
|
||||
self.handleError(record)
|
||||
|
||||
|
||||
class DbGptFormatter(logging.Formatter):
|
||||
"""
|
||||
Allows to handle custom placeholders 'title_color' and 'message_no_color'.
|
||||
To use this formatter, make sure to pass 'color', 'title' as log extras.
|
||||
"""
|
||||
|
||||
def format(self, record: LogRecord) -> str:
|
||||
if hasattr(record, "color"):
|
||||
record.title_color = (
|
||||
getattr(record, "color")
|
||||
+ getattr(record, "title", "")
|
||||
+ " "
|
||||
+ Style.RESET_ALL
|
||||
)
|
||||
else:
|
||||
record.title_color = getattr(record, "title", "")
|
||||
|
||||
# Add this line to set 'title' to an empty string if it doesn't exist
|
||||
record.title = getattr(record, "title", "")
|
||||
|
||||
if hasattr(record, "msg"):
|
||||
record.message_no_color = remove_color_codes(getattr(record, "msg"))
|
||||
else:
|
||||
record.message_no_color = ""
|
||||
return super().format(record)
|
||||
|
||||
|
||||
def remove_color_codes(s: str) -> str:
|
||||
ansi_escape = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])")
|
||||
return ansi_escape.sub("", s)
|
||||
|
||||
|
||||
logger = Logger()
|
||||
|
||||
|
||||
def print_assistant_thoughts(
|
||||
ai_name: object,
|
||||
assistant_reply_json_valid: object,
|
||||
speak_mode: bool = False,
|
||||
) -> None:
|
||||
assistant_thoughts_reasoning = None
|
||||
assistant_thoughts_plan = None
|
||||
assistant_thoughts_speak = None
|
||||
assistant_thoughts_criticism = None
|
||||
|
||||
assistant_thoughts = assistant_reply_json_valid.get("thoughts", {})
|
||||
assistant_thoughts_text = assistant_thoughts.get("text")
|
||||
if assistant_thoughts:
|
||||
assistant_thoughts_reasoning = assistant_thoughts.get("reasoning")
|
||||
assistant_thoughts_plan = assistant_thoughts.get("plan")
|
||||
assistant_thoughts_criticism = assistant_thoughts.get("criticism")
|
||||
assistant_thoughts_speak = assistant_thoughts.get("speak")
|
||||
logger.typewriter_log(
|
||||
f"{ai_name.upper()} THOUGHTS:", Fore.YELLOW, f"{assistant_thoughts_text}"
|
||||
)
|
||||
logger.typewriter_log("REASONING:", Fore.YELLOW, f"{assistant_thoughts_reasoning}")
|
||||
if assistant_thoughts_plan:
|
||||
logger.typewriter_log("PLAN:", Fore.YELLOW, "")
|
||||
# If it's a list, join it into a string
|
||||
if isinstance(assistant_thoughts_plan, list):
|
||||
assistant_thoughts_plan = "\n".join(assistant_thoughts_plan)
|
||||
elif isinstance(assistant_thoughts_plan, dict):
|
||||
assistant_thoughts_plan = str(assistant_thoughts_plan)
|
||||
|
||||
# Split the input_string using the newline character and dashes
|
||||
lines = assistant_thoughts_plan.split("\n")
|
||||
for line in lines:
|
||||
line = line.lstrip("- ")
|
||||
logger.typewriter_log("- ", Fore.GREEN, line.strip())
|
||||
logger.typewriter_log("CRITICISM:", Fore.YELLOW, f"{assistant_thoughts_criticism}")
|
||||
# Speak the assistant's thoughts
|
||||
if speak_mode and assistant_thoughts_speak:
|
||||
say_text(assistant_thoughts_speak)
|
11
pilot/model/base.py
Normal file
11
pilot/model/base.py
Normal file
@ -0,0 +1,11 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from typing import List, TypedDict
|
||||
|
||||
class Message(TypedDict):
|
||||
"""LLM Message object containing usually like (role: content) """
|
||||
|
||||
role: str
|
||||
content: str
|
||||
|
3
pilot/model/chat.py
Normal file
3
pilot/model/chat.py
Normal file
@ -0,0 +1,3 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
|
121
pilot/model/compression.py
Normal file
121
pilot/model/compression.py
Normal file
@ -0,0 +1,121 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
import dataclasses
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CompressionConfig:
|
||||
"""Group-wise quantization."""
|
||||
num_bits: int
|
||||
group_size: int
|
||||
group_dim: int
|
||||
symmetric: bool
|
||||
enabled: bool = True
|
||||
|
||||
|
||||
default_compression_config = CompressionConfig(
|
||||
num_bits=8, group_size=256, group_dim=1, symmetric=True, enabled=True)
|
||||
|
||||
|
||||
class CLinear(nn.Module):
|
||||
"""Compressed Linear Layer."""
|
||||
|
||||
def __init__(self, weight, bias, device):
|
||||
super().__init__()
|
||||
|
||||
self.weight = compress(weight.data.to(device), default_compression_config)
|
||||
self.bias = bias
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
weight = decompress(self.weight, default_compression_config)
|
||||
return F.linear(input, weight, self.bias)
|
||||
|
||||
|
||||
def compress_module(module, target_device):
|
||||
for attr_str in dir(module):
|
||||
target_attr = getattr(module, attr_str)
|
||||
if type(target_attr) == torch.nn.Linear:
|
||||
setattr(module, attr_str,
|
||||
CLinear(target_attr.weight, target_attr.bias, target_device))
|
||||
for name, child in module.named_children():
|
||||
compress_module(child, target_device)
|
||||
|
||||
|
||||
def compress(tensor, config):
|
||||
"""Simulate group-wise quantization."""
|
||||
if not config.enabled:
|
||||
return tensor
|
||||
|
||||
group_size, num_bits, group_dim, symmetric = (
|
||||
config.group_size, config.num_bits, config.group_dim, config.symmetric)
|
||||
assert num_bits <= 8
|
||||
|
||||
original_shape = tensor.shape
|
||||
num_groups = (original_shape[group_dim] + group_size - 1) // group_size
|
||||
new_shape = (original_shape[:group_dim] + (num_groups, group_size) +
|
||||
original_shape[group_dim+1:])
|
||||
|
||||
# Pad
|
||||
pad_len = (group_size - original_shape[group_dim] % group_size) % group_size
|
||||
if pad_len != 0:
|
||||
pad_shape = original_shape[:group_dim] + (pad_len,) + original_shape[group_dim+1:]
|
||||
tensor = torch.cat([
|
||||
tensor,
|
||||
torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device)],
|
||||
dim=group_dim)
|
||||
data = tensor.view(new_shape)
|
||||
|
||||
# Quantize
|
||||
if symmetric:
|
||||
B = 2 ** (num_bits - 1) - 1
|
||||
scale = B / torch.max(data.abs(), dim=group_dim + 1, keepdim=True)[0]
|
||||
data = data * scale
|
||||
data = data.clamp_(-B, B).round_().to(torch.int8)
|
||||
return data, scale, original_shape
|
||||
else:
|
||||
B = 2 ** num_bits - 1
|
||||
mn = torch.min(data, dim=group_dim + 1, keepdim=True)[0]
|
||||
mx = torch.max(data, dim=group_dim + 1, keepdim=True)[0]
|
||||
|
||||
scale = B / (mx - mn)
|
||||
data = data - mn
|
||||
data.mul_(scale)
|
||||
|
||||
data = data.clamp_(0, B).round_().to(torch.uint8)
|
||||
return data, mn, scale, original_shape
|
||||
|
||||
|
||||
def decompress(packed_data, config):
|
||||
"""Simulate group-wise dequantization."""
|
||||
if not config.enabled:
|
||||
return packed_data
|
||||
|
||||
group_size, num_bits, group_dim, symmetric = (
|
||||
config.group_size, config.num_bits, config.group_dim, config.symmetric)
|
||||
|
||||
# Dequantize
|
||||
if symmetric:
|
||||
data, scale, original_shape = packed_data
|
||||
data = data / scale
|
||||
else:
|
||||
data, mn, scale, original_shape = packed_data
|
||||
data = data / scale
|
||||
data.add_(mn)
|
||||
|
||||
# Unpad
|
||||
pad_len = (group_size - original_shape[group_dim] % group_size) % group_size
|
||||
if pad_len:
|
||||
padded_original_shape = (
|
||||
original_shape[:group_dim] +
|
||||
(original_shape[group_dim] + pad_len,) +
|
||||
original_shape[group_dim+1:])
|
||||
data = data.reshape(padded_original_shape)
|
||||
indices = [slice(0, x) for x in original_shape]
|
||||
return data[indices].contiguous()
|
||||
else:
|
||||
return data.view(original_shape)
|
@ -5,13 +5,13 @@ import torch
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate_stream(model, tokenizer, params, device,
|
||||
context_len=2048, stream_interval=2):
|
||||
context_len=4096, stream_interval=2):
|
||||
|
||||
"""Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py """
|
||||
prompt = params["prompt"]
|
||||
l_prompt = len(prompt)
|
||||
temperature = float(params.get("temperature", 1.0))
|
||||
max_new_tokens = int(params.get("max_new_tokens", 256))
|
||||
max_new_tokens = int(params.get("max_new_tokens", 2048))
|
||||
stop_str = params.get("stop", None)
|
||||
|
||||
input_ids = tokenizer(prompt).input_ids
|
||||
|
34
pilot/model/llm/base.py
Normal file
34
pilot/model/llm/base.py
Normal file
@ -0,0 +1,34 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, TypedDict
|
||||
|
||||
|
||||
class Message(TypedDict):
|
||||
"""Vicuna Message object containing a role and the message content """
|
||||
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelInfo:
|
||||
"""Struct for model information.
|
||||
|
||||
Would be lovely to eventually get this directly from APIs
|
||||
"""
|
||||
name: str
|
||||
max_tokens: int
|
||||
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
"""Standard response struct for a response from a LLM model."""
|
||||
model_info = ModelInfo
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatModelResponse(LLMResponse):
|
||||
"""Standard response struct for a response from an LLM model."""
|
||||
|
||||
content: str = None
|
108
pilot/model/llm/llm_utils.py
Normal file
108
pilot/model/llm/llm_utils.py
Normal file
@ -0,0 +1,108 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import abc
|
||||
import time
|
||||
import functools
|
||||
from typing import List, Optional
|
||||
from pilot.model.llm.base import Message
|
||||
from pilot.conversation import conv_templates, Conversation, conv_one_shot, auto_dbgpt_one_shot
|
||||
from pilot.configs.config import Config
|
||||
|
||||
|
||||
# TODO Rewrite this
|
||||
def retry_stream_api(
|
||||
num_retries: int = 10,
|
||||
backoff_base: float = 2.0,
|
||||
warn_user: bool = True
|
||||
):
|
||||
"""Retry an Vicuna Server call.
|
||||
|
||||
Args:
|
||||
num_retries int: Number of retries. Defaults to 10.
|
||||
backoff_base float: Base for exponential backoff. Defaults to 2.
|
||||
warn_user bool: Whether to warn the user. Defaults to True.
|
||||
"""
|
||||
retry_limit_msg = f"Error: Reached rate limit, passing..."
|
||||
backoff_msg = (f"Error: API Bad gateway. Waiting {{backoff}} seconds...")
|
||||
|
||||
def _wrapper(func):
|
||||
@functools.wraps(func)
|
||||
def _wrapped(*args, **kwargs):
|
||||
user_warned = not warn_user
|
||||
num_attempts = num_retries + 1 # +1 for the first attempt
|
||||
for attempt in range(1, num_attempts + 1):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
if (e.http_status != 502) or (attempt == num_attempts):
|
||||
raise
|
||||
|
||||
backoff = backoff_base ** (attempt + 2)
|
||||
time.sleep(backoff)
|
||||
return _wrapped
|
||||
return _wrapper
|
||||
|
||||
# Overly simple abstraction util we create something better
|
||||
# simple retry mechanism when getting a rate error or a bad gateway
|
||||
def create_chat_competion(
|
||||
conv: Conversation,
|
||||
model: Optional[str] = None,
|
||||
temperature: float = None,
|
||||
max_new_tokens: Optional[int] = None,
|
||||
) -> str:
|
||||
"""Create a chat completion using the Vicuna-13b
|
||||
|
||||
Args:
|
||||
messages(List[Message]): The messages to send to the chat completion
|
||||
model (str, optional): The model to use. Default to None.
|
||||
temperature (float, optional): The temperature to use. Defaults to 0.7.
|
||||
max_tokens (int, optional): The max tokens to use. Defaults to None.
|
||||
|
||||
Returns:
|
||||
str: The response from the chat completion
|
||||
"""
|
||||
cfg = Config()
|
||||
if temperature is None:
|
||||
temperature = cfg.temperature
|
||||
|
||||
# TODO request vicuna model get response
|
||||
# convert vicuna message to chat completion.
|
||||
for plugin in cfg.plugins:
|
||||
if plugin.can_handle_chat_completion():
|
||||
pass
|
||||
|
||||
|
||||
class ChatIO(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def prompt_for_input(self, role: str) -> str:
|
||||
"""Prompt for input from a role."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def prompt_for_output(self, role: str) -> str:
|
||||
"""Prompt for output from a role."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def stream_output(self, output_stream, skip_echo_len: int):
|
||||
"""Stream output."""
|
||||
|
||||
|
||||
class SimpleChatIO(ChatIO):
|
||||
def prompt_for_input(self, role: str) -> str:
|
||||
return input(f"{role}: ")
|
||||
|
||||
def prompt_for_output(self, role: str) -> str:
|
||||
print(f"{role}: ", end="", flush=True)
|
||||
|
||||
def stream_output(self, output_stream, skip_echo_len: int):
|
||||
pre = 0
|
||||
for outputs in output_stream:
|
||||
outputs = outputs[skip_echo_len:].strip()
|
||||
now = len(outputs) - 1
|
||||
if now > pre:
|
||||
print(" ".join(outputs[pre:now]), end=" ", flush=True)
|
||||
pre = now
|
||||
|
||||
print(" ".join(outputs[pre:]), flush=True)
|
||||
return " ".join(outputs)
|
||||
|
47
pilot/model/llm_utils.py
Normal file
47
pilot/model/llm_utils.py
Normal file
@ -0,0 +1,47 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
from typing import List, Optional
|
||||
from pilot.model.base import Message
|
||||
from pilot.configs.config import Config
|
||||
from pilot.server.llmserver import generate_output
|
||||
|
||||
def create_chat_completion(
|
||||
messages: List[Message], # type: ignore
|
||||
model: Optional[str] = None,
|
||||
temperature: float = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> str:
|
||||
"""Create a chat completion using the vicuna local model
|
||||
|
||||
Args:
|
||||
messages(List[Message]): The messages to send to the chat completion
|
||||
model (str, optional): The model to use. Defaults to None.
|
||||
temperature (float, optional): The temperature to use. Defaults to 0.7.
|
||||
max_tokens (int, optional): The max tokens to use. Defaults to None
|
||||
|
||||
Returns:
|
||||
str: The response from chat completion
|
||||
"""
|
||||
cfg = Config()
|
||||
if temperature is None:
|
||||
temperature = cfg.temperature
|
||||
|
||||
for plugin in cfg.plugins:
|
||||
if plugin.can_handle_chat_completion(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
):
|
||||
message = plugin.handle_chat_completion(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
if message is not None:
|
||||
return message
|
||||
|
||||
response = None
|
||||
# TODO impl this use vicuna server api
|
@ -2,15 +2,17 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
from pilot.singleton import Singleton
|
||||
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoModelForCausalLM,
|
||||
AutoModel
|
||||
)
|
||||
|
||||
from fastchat.serve.compression import compress_module
|
||||
from pilot.model.compression import compress_module
|
||||
|
||||
class ModelLoader:
|
||||
class ModelLoader(metaclass=Singleton):
|
||||
"""Model loader is a class for model load
|
||||
|
||||
Args: model_path
|
||||
|
134
pilot/plugins.py
Normal file
134
pilot/plugins.py
Normal file
@ -0,0 +1,134 @@
|
||||
"""加载组件"""
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
from zipimport import zipimporter
|
||||
|
||||
import requests
|
||||
from auto_gpt_plugin_template import AutoGPTPluginTemplate
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.logs import logger
|
||||
|
||||
def inspect_zip_for_modules(zip_path: str, debug: bool = False) -> list[str]:
|
||||
"""
|
||||
加载zip文件的插件,完全兼容Auto_gpt_plugin
|
||||
|
||||
Args:
|
||||
zip_path (str): Path to the zipfile.
|
||||
debug (bool, optional): Enable debug logging. Defaults to False.
|
||||
|
||||
Returns:
|
||||
list[str]: The list of module names found or empty list if none were found.
|
||||
"""
|
||||
result = []
|
||||
with zipfile.ZipFile(zip_path, "r") as zfile:
|
||||
for name in zfile.namelist():
|
||||
if name.endswith("__init__.py") and not name.startswith("__MACOSX"):
|
||||
logger.debug(f"Found module '{name}' in the zipfile at: {name}")
|
||||
result.append(name)
|
||||
if len(result) == 0:
|
||||
logger.debug(f"Module '__init__.py' not found in the zipfile @ {zip_path}.")
|
||||
return result
|
||||
|
||||
def write_dict_to_json_file(data: dict, file_path: str) -> None:
|
||||
"""
|
||||
Write a dictionary to a JSON file.
|
||||
Args:
|
||||
data (dict): Dictionary to write.
|
||||
file_path (str): Path to the file.
|
||||
"""
|
||||
with open(file_path, "w") as file:
|
||||
json.dump(data, file, indent=4)
|
||||
|
||||
def create_directory_if_not_exists(directory_path: str) -> bool:
|
||||
"""
|
||||
Create a directory if it does not exist.
|
||||
Args:
|
||||
directory_path (str): Path to the directory.
|
||||
Returns:
|
||||
bool: True if the directory was created, else False.
|
||||
"""
|
||||
if not os.path.exists(directory_path):
|
||||
try:
|
||||
os.makedirs(directory_path)
|
||||
logger.debug(f"Created directory: {directory_path}")
|
||||
return True
|
||||
except OSError as e:
|
||||
logger.warn(f"Error creating directory {directory_path}: {e}")
|
||||
return False
|
||||
else:
|
||||
logger.info(f"Directory {directory_path} already exists")
|
||||
return True
|
||||
|
||||
def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate]:
|
||||
"""Scan the plugins directory for plugins and loads them.
|
||||
|
||||
Args:
|
||||
cfg (Config): Config instance including plugins config
|
||||
debug (bool, optional): Enable debug logging. Defaults to False.
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, Path]]: List of plugins.
|
||||
"""
|
||||
loaded_plugins = []
|
||||
# Generic plugins
|
||||
plugins_path_path = Path(cfg.plugins_dir)
|
||||
|
||||
logger.debug(f"Allowlisted Plugins: {cfg.plugins_allowlist}")
|
||||
logger.debug(f"Denylisted Plugins: {cfg.plugins_denylist}")
|
||||
|
||||
for plugin in plugins_path_path.glob("*.zip"):
|
||||
if moduleList := inspect_zip_for_modules(str(plugin), debug):
|
||||
for module in moduleList:
|
||||
plugin = Path(plugin)
|
||||
module = Path(module)
|
||||
logger.debug(f"Plugin: {plugin} Module: {module}")
|
||||
zipped_package = zipimporter(str(plugin))
|
||||
zipped_module = zipped_package.load_module(str(module.parent))
|
||||
for key in dir(zipped_module):
|
||||
if key.startswith("__"):
|
||||
continue
|
||||
a_module = getattr(zipped_module, key)
|
||||
a_keys = dir(a_module)
|
||||
if (
|
||||
"_abc_impl" in a_keys
|
||||
and a_module.__name__ != "AutoGPTPluginTemplate"
|
||||
and denylist_allowlist_check(a_module.__name__, cfg)
|
||||
):
|
||||
loaded_plugins.append(a_module())
|
||||
|
||||
if loaded_plugins:
|
||||
logger.info(f"\nPlugins found: {len(loaded_plugins)}\n" "--------------------")
|
||||
for plugin in loaded_plugins:
|
||||
logger.info(f"{plugin._name}: {plugin._version} - {plugin._description}")
|
||||
return loaded_plugins
|
||||
|
||||
|
||||
def denylist_allowlist_check(plugin_name: str, cfg: Config) -> bool:
|
||||
"""Check if the plugin is in the allowlist or denylist.
|
||||
|
||||
Args:
|
||||
plugin_name (str): Name of the plugin.
|
||||
cfg (Config): Config object.
|
||||
|
||||
Returns:
|
||||
True or False
|
||||
"""
|
||||
logger.debug(f"Checking if plugin {plugin_name} should be loaded")
|
||||
if plugin_name in cfg.plugins_denylist:
|
||||
logger.debug(f"Not loading plugin {plugin_name} as it was in the denylist.")
|
||||
return False
|
||||
if plugin_name in cfg.plugins_allowlist:
|
||||
logger.debug(f"Loading plugin {plugin_name} as it was in the allowlist.")
|
||||
return True
|
||||
ack = input(
|
||||
f"WARNING: Plugin {plugin_name} found. But not in the"
|
||||
f" allowlist... Load? ({cfg.authorise_key}/{cfg.exit_key}): "
|
||||
)
|
||||
return ack.lower() == cfg.authorise_key
|
0
pilot/prompts/__init__.py
Normal file
0
pilot/prompts/__init__.py
Normal file
96
pilot/prompts/first_conversation_prompt.py
Normal file
96
pilot/prompts/first_conversation_prompt.py
Normal file
@ -0,0 +1,96 @@
|
||||
from pilot.prompts.generator import PromptGenerator
|
||||
from typing import Any, Optional, Type
|
||||
import os
|
||||
import platform
|
||||
from pathlib import Path
|
||||
|
||||
import distro
|
||||
import yaml
|
||||
from pilot.configs.config import Config
|
||||
from pilot.prompts.prompt import build_default_prompt_generator
|
||||
|
||||
|
||||
class FirstPrompt:
|
||||
"""
|
||||
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
ai_goals: list | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize a class instance
|
||||
|
||||
Parameters:
|
||||
ai_name (str): The name of the AI.
|
||||
ai_role (str): The description of the AI's role.
|
||||
ai_goals (list): The list of objectives the AI is supposed to complete.
|
||||
api_budget (float): The maximum dollar value for API calls (0.0 means infinite)
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if ai_goals is None:
|
||||
ai_goals = []
|
||||
self.ai_goals = ai_goals
|
||||
self.prompt_generator = None
|
||||
self.command_registry = None
|
||||
|
||||
|
||||
def construct_first_prompt(
|
||||
self,
|
||||
command_registry: [] = None,
|
||||
fisrt_message: [str]=[],
|
||||
prompt_generator: Optional[PromptGenerator] = None
|
||||
) -> str:
|
||||
"""
|
||||
基于用户输入的初始对话信息构建完整的prompt信息
|
||||
Args:
|
||||
self:
|
||||
prompt_generator:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
prompt_start = (
|
||||
"Your decisions must always be made independently without"
|
||||
" seeking user assistance. Play to your strengths as an LLM and pursue"
|
||||
" simple strategies with no legal complications."
|
||||
""
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
if prompt_generator is None:
|
||||
prompt_generator = build_default_prompt_generator()
|
||||
prompt_generator.goals = fisrt_message
|
||||
prompt_generator.command_registry = command_registry
|
||||
# 加载插件中可用命令
|
||||
cfg = Config()
|
||||
for plugin in cfg.plugins:
|
||||
if not plugin.can_handle_post_prompt():
|
||||
continue
|
||||
prompt_generator = plugin.post_prompt(prompt_generator)
|
||||
if cfg.execute_local_commands:
|
||||
# add OS info to prompt
|
||||
os_name = platform.system()
|
||||
os_info = (
|
||||
platform.platform(terse=True)
|
||||
if os_name != "Linux"
|
||||
else distro.name(pretty=True)
|
||||
)
|
||||
|
||||
prompt_start += f"\nThe OS you are running on is: {os_info}"
|
||||
|
||||
# Construct full prompt
|
||||
full_prompt = f"{prompt_start}\n\nGOALS:\n\n"
|
||||
|
||||
if not self.ai_goals :
|
||||
self.ai_goals = fisrt_message
|
||||
for i, goal in enumerate(self.ai_goals):
|
||||
full_prompt += f"{i+1}. {goal}\n"
|
||||
# if self.api_budget > 0.0:
|
||||
# full_prompt += f"\nIt takes money to let you run. Your API budget is ${self.api_budget:.3f}"
|
||||
self.prompt_generator = prompt_generator
|
||||
full_prompt += f"\n\n{prompt_generator.generate_prompt_string()}"
|
||||
return full_prompt
|
155
pilot/prompts/generator.py
Normal file
155
pilot/prompts/generator.py
Normal file
@ -0,0 +1,155 @@
|
||||
""" A module for generating custom prompt strings."""
|
||||
import json
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
|
||||
class PromptGenerator:
|
||||
"""
|
||||
A class for generating custom prompt strings based on constraints, commands,
|
||||
resources, and performance evaluations.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""
|
||||
Initialize the PromptGenerator object with empty lists of constraints,
|
||||
commands, resources, and performance evaluations.
|
||||
"""
|
||||
self.constraints = []
|
||||
self.commands = []
|
||||
self.resources = []
|
||||
self.performance_evaluation = []
|
||||
self.goals = []
|
||||
self.command_registry = None
|
||||
self.name = "Bob"
|
||||
self.role = "AI"
|
||||
self.response_format = {
|
||||
"thoughts": {
|
||||
"text": "thought",
|
||||
"reasoning": "reasoning",
|
||||
"plan": "- short bulleted\n- list that conveys\n- long-term plan",
|
||||
"criticism": "constructive self-criticism",
|
||||
"speak": "thoughts summary to say to user",
|
||||
},
|
||||
"command": {"name": "command name", "args": {"arg name": "value"}},
|
||||
}
|
||||
|
||||
def add_constraint(self, constraint: str) -> None:
|
||||
"""
|
||||
Add a constraint to the constraints list.
|
||||
|
||||
Args:
|
||||
constraint (str): The constraint to be added.
|
||||
"""
|
||||
self.constraints.append(constraint)
|
||||
|
||||
def add_command(
|
||||
self,
|
||||
command_label: str,
|
||||
command_name: str,
|
||||
args=None,
|
||||
function: Optional[Callable] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Add a command to the commands list with a label, name, and optional arguments.
|
||||
|
||||
Args:
|
||||
command_label (str): The label of the command.
|
||||
command_name (str): The name of the command.
|
||||
args (dict, optional): A dictionary containing argument names and their
|
||||
values. Defaults to None.
|
||||
function (callable, optional): A callable function to be called when
|
||||
the command is executed. Defaults to None.
|
||||
"""
|
||||
if args is None:
|
||||
args = {}
|
||||
|
||||
command_args = {arg_key: arg_value for arg_key, arg_value in args.items()}
|
||||
|
||||
command = {
|
||||
"label": command_label,
|
||||
"name": command_name,
|
||||
"args": command_args,
|
||||
"function": function,
|
||||
}
|
||||
|
||||
self.commands.append(command)
|
||||
|
||||
def _generate_command_string(self, command: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Generate a formatted string representation of a command.
|
||||
|
||||
Args:
|
||||
command (dict): A dictionary containing command information.
|
||||
|
||||
Returns:
|
||||
str: The formatted command string.
|
||||
"""
|
||||
args_string = ", ".join(
|
||||
f'"{key}": "{value}"' for key, value in command["args"].items()
|
||||
)
|
||||
return f'{command["label"]}: "{command["name"]}", args: {args_string}'
|
||||
|
||||
def add_resource(self, resource: str) -> None:
|
||||
"""
|
||||
Add a resource to the resources list.
|
||||
|
||||
Args:
|
||||
resource (str): The resource to be added.
|
||||
"""
|
||||
self.resources.append(resource)
|
||||
|
||||
def add_performance_evaluation(self, evaluation: str) -> None:
|
||||
"""
|
||||
Add a performance evaluation item to the performance_evaluation list.
|
||||
|
||||
Args:
|
||||
evaluation (str): The evaluation item to be added.
|
||||
"""
|
||||
self.performance_evaluation.append(evaluation)
|
||||
|
||||
def _generate_numbered_list(self, items: List[Any], item_type="list") -> str:
|
||||
"""
|
||||
Generate a numbered list from given items based on the item_type.
|
||||
|
||||
Args:
|
||||
items (list): A list of items to be numbered.
|
||||
item_type (str, optional): The type of items in the list.
|
||||
Defaults to 'list'.
|
||||
|
||||
Returns:
|
||||
str: The formatted numbered list.
|
||||
"""
|
||||
if item_type == "command":
|
||||
command_strings = []
|
||||
if self.command_registry:
|
||||
command_strings += [
|
||||
str(item)
|
||||
for item in self.command_registry.commands.values()
|
||||
if item.enabled
|
||||
]
|
||||
# terminate command is added manually
|
||||
command_strings += [self._generate_command_string(item) for item in items]
|
||||
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(command_strings))
|
||||
else:
|
||||
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(items))
|
||||
|
||||
def generate_prompt_string(self) -> str:
|
||||
"""
|
||||
Generate a prompt string based on the constraints, commands, resources,
|
||||
and performance evaluations.
|
||||
|
||||
Returns:
|
||||
str: The generated prompt string.
|
||||
"""
|
||||
formatted_response_format = json.dumps(self.response_format, indent=4)
|
||||
return (
|
||||
f"Constraints:\n{self._generate_numbered_list(self.constraints)}\n\n"
|
||||
"Commands:\n"
|
||||
f"{self._generate_numbered_list(self.commands, item_type='command')}\n\n"
|
||||
f"Resources:\n{self._generate_numbered_list(self.resources)}\n\n"
|
||||
"Performance Evaluation:\n"
|
||||
f"{self._generate_numbered_list(self.performance_evaluation)}\n\n"
|
||||
"You should only respond in JSON format as described below \nResponse"
|
||||
f" Format: \n{formatted_response_format} \nEnsure the response can be"
|
||||
" parsed by Python json.loads"
|
||||
)
|
65
pilot/prompts/prompt.py
Normal file
65
pilot/prompts/prompt.py
Normal file
@ -0,0 +1,65 @@
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.prompts.generator import PromptGenerator
|
||||
|
||||
|
||||
CFG = Config()
|
||||
|
||||
DEFAULT_TRIGGERING_PROMPT = (
|
||||
"Determine which next command to use, and respond using the format specified above:"
|
||||
)
|
||||
|
||||
|
||||
def build_default_prompt_generator() -> PromptGenerator:
|
||||
"""
|
||||
This function generates a prompt string that includes various constraints,
|
||||
commands, resources, and performance evaluations.
|
||||
|
||||
Returns:
|
||||
str: The generated prompt string.
|
||||
"""
|
||||
|
||||
# Initialize the PromptGenerator object
|
||||
prompt_generator = PromptGenerator()
|
||||
|
||||
# Add constraints to the PromptGenerator object
|
||||
prompt_generator.add_constraint(
|
||||
"~4000 word limit for short term memory. Your short term memory is short, so"
|
||||
" immediately save important information to files."
|
||||
)
|
||||
prompt_generator.add_constraint(
|
||||
"If you are unsure how you previously did something or want to recall past"
|
||||
" events, thinking about similar events will help you remember."
|
||||
)
|
||||
prompt_generator.add_constraint("No user assistance")
|
||||
prompt_generator.add_constraint(
|
||||
'Exclusively use the commands listed in double quotes e.g. "command name"'
|
||||
)
|
||||
|
||||
# Add resources to the PromptGenerator object
|
||||
prompt_generator.add_resource(
|
||||
"Internet access for searches and information gathering."
|
||||
)
|
||||
prompt_generator.add_resource("Long Term memory management.")
|
||||
prompt_generator.add_resource(
|
||||
"GPT-3.5 powered Agents for delegation of simple tasks."
|
||||
)
|
||||
# prompt_generator.add_resource("File output.")
|
||||
|
||||
# Add performance evaluations to the PromptGenerator object
|
||||
prompt_generator.add_performance_evaluation(
|
||||
"Continuously review and analyze your actions to ensure you are performing to"
|
||||
" the best of your abilities."
|
||||
)
|
||||
prompt_generator.add_performance_evaluation(
|
||||
"Constructively self-criticize your big-picture behavior constantly."
|
||||
)
|
||||
prompt_generator.add_performance_evaluation(
|
||||
"Reflect on past decisions and strategies to refine your approach."
|
||||
)
|
||||
prompt_generator.add_performance_evaluation(
|
||||
"Every command has a cost, so be smart and efficient. Aim to complete tasks in"
|
||||
" the least number of steps."
|
||||
)
|
||||
# prompt_generator.add_performance_evaluation("Write all code to a file.")
|
||||
return prompt_generator
|
@ -10,8 +10,6 @@ from fastapi.responses import StreamingResponse
|
||||
from pilot.model.inference import generate_stream
|
||||
from pydantic import BaseModel
|
||||
from pilot.model.inference import generate_output, get_embeddings
|
||||
from fastchat.serve.inference import load_model
|
||||
|
||||
|
||||
from pilot.model.loader import ModelLoader
|
||||
from pilot.configs.model_config import *
|
@ -17,6 +17,13 @@ from pilot.vector_store.extract_tovec import get_vector_storelist, load_knownled
|
||||
|
||||
from pilot.configs.model_config import LOGDIR, VICUNA_MODEL_SERVER, LLM_MODEL, DATASETS_DIR
|
||||
|
||||
from pilot.plugins import scan_plugins
|
||||
from pilot.configs.config import Config
|
||||
from pilot.commands.command_mange import CommandRegistry
|
||||
from pilot.prompts.prompt import build_default_prompt_generator
|
||||
|
||||
from pilot.prompts.first_conversation_prompt import FirstPrompt
|
||||
|
||||
from pilot.conversation import (
|
||||
default_conversation,
|
||||
conv_templates,
|
||||
@ -24,11 +31,9 @@ from pilot.conversation import (
|
||||
SeparatorStyle
|
||||
)
|
||||
|
||||
from fastchat.utils import (
|
||||
from pilot.utils import (
|
||||
build_logger,
|
||||
server_error_msg,
|
||||
violates_moderation,
|
||||
moderation_msg
|
||||
)
|
||||
|
||||
from pilot.server.gradio_css import code_highlight_css
|
||||
@ -45,6 +50,7 @@ enable_moderation = False
|
||||
models = []
|
||||
dbs = []
|
||||
vs_list = ["新建知识库"] + get_vector_storelist()
|
||||
autogpt = False
|
||||
|
||||
priority = {
|
||||
"vicuna-13b": "aaa"
|
||||
@ -58,8 +64,6 @@ def get_simlar(q):
|
||||
contents = [dc.page_content for dc, _ in docs]
|
||||
return "\n".join(contents)
|
||||
|
||||
|
||||
|
||||
def gen_sqlgen_conversation(dbname):
|
||||
mo = MySQLOperator(
|
||||
**DB_SETTINGS
|
||||
@ -118,6 +122,8 @@ def regenerate(state, request: gr.Request):
|
||||
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
|
||||
|
||||
def clear_history(request: gr.Request):
|
||||
|
||||
|
||||
logger.info(f"clear_history. ip: {request.client.host}")
|
||||
state = None
|
||||
return (state, [], "") + (disable_btn,) * 5
|
||||
@ -128,14 +134,9 @@ def add_text(state, text, request: gr.Request):
|
||||
if len(text) <= 0:
|
||||
state.skip_next = True
|
||||
return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5
|
||||
if args.moderate:
|
||||
flagged = violates_moderation(text)
|
||||
if flagged:
|
||||
state.skip_next = True
|
||||
return (state, state.to_gradio_chatbot(), moderation_msg) + (
|
||||
no_change_btn,) * 5
|
||||
|
||||
text = text[:1536] # Hard cut-off
|
||||
""" Default support 4000 tokens, if tokens too lang, we will cut off """
|
||||
text = text[:4000]
|
||||
state.append_message(state.roles[0], text)
|
||||
state.append_message(state.roles[1], None)
|
||||
state.skip_next = False
|
||||
@ -152,6 +153,10 @@ def post_process_code(code):
|
||||
return code
|
||||
|
||||
def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.Request):
|
||||
|
||||
# MOCk
|
||||
autogpt = True
|
||||
print("是否是AUTO-GPT模式.", autogpt)
|
||||
start_tstamp = time.time()
|
||||
model_name = LLM_MODEL
|
||||
|
||||
@ -162,27 +167,39 @@ def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.
|
||||
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
||||
return
|
||||
|
||||
|
||||
|
||||
# TODO when tab mode is AUTO_GPT, Prompt need to rebuild.
|
||||
if len(state.messages) == state.offset + 2:
|
||||
# 第一轮对话需要加入提示Prompt
|
||||
|
||||
template_name = "conv_one_shot"
|
||||
new_state = conv_templates[template_name].copy()
|
||||
new_state.conv_id = uuid.uuid4().hex
|
||||
|
||||
query = state.messages[-2][1]
|
||||
# 第一轮对话需要加入提示Prompt
|
||||
if(autogpt):
|
||||
# autogpt模式的第一轮对话需要 构建专属prompt
|
||||
cfg = Config()
|
||||
first_prompt = FirstPrompt()
|
||||
first_prompt.command_registry = cfg.command_registry
|
||||
|
||||
# prompt 中添加上下文提示, 根据已有知识对话, 上下文提示是否也应该放在第一轮, 还是每一轮都添加上下文?
|
||||
# 如果用户侧的问题跨度很大, 应该每一轮都加提示。
|
||||
if db_selector:
|
||||
new_state.append_message(new_state.roles[0], gen_sqlgen_conversation(dbname) + query)
|
||||
new_state.append_message(new_state.roles[1], None)
|
||||
state = new_state
|
||||
system_prompt = first_prompt.construct_first_prompt(fisrt_message=[query])
|
||||
logger.info("[TEST]:" + system_prompt)
|
||||
template_name = "auto_dbgpt_one_shot"
|
||||
new_state = conv_templates[template_name].copy()
|
||||
new_state.append_message(role='USER', message=system_prompt)
|
||||
else:
|
||||
new_state.append_message(new_state.roles[0], query)
|
||||
new_state.append_message(new_state.roles[1], None)
|
||||
state = new_state
|
||||
template_name = "conv_one_shot"
|
||||
new_state = conv_templates[template_name].copy()
|
||||
|
||||
new_state.conv_id = uuid.uuid4().hex
|
||||
|
||||
if not autogpt:
|
||||
# prompt 中添加上下文提示, 根据已有知识对话, 上下文提示是否也应该放在第一轮, 还是每一轮都添加上下文?
|
||||
# 如果用户侧的问题跨度很大, 应该每一轮都加提示。
|
||||
if db_selector:
|
||||
new_state.append_message(new_state.roles[0], gen_sqlgen_conversation(dbname) + query)
|
||||
new_state.append_message(new_state.roles[1], None)
|
||||
else:
|
||||
new_state.append_message(new_state.roles[0], query)
|
||||
new_state.append_message(new_state.roles[1], None)
|
||||
|
||||
state = new_state
|
||||
if mode == conversation_types["default_knownledge"] and not db_selector:
|
||||
query = state.messages[-2][1]
|
||||
knqa = KnownLedgeBaseQA()
|
||||
@ -251,29 +268,28 @@ def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.
|
||||
block_css = (
|
||||
code_highlight_css
|
||||
+ """
|
||||
pre {
|
||||
white-space: pre-wrap; /* Since CSS 2.1 */
|
||||
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
|
||||
white-space: -pre-wrap; /* Opera 4-6 */
|
||||
white-space: -o-pre-wrap; /* Opera 7 */
|
||||
word-wrap: break-word; /* Internet Explorer 5.5+ */
|
||||
}
|
||||
#notice_markdown th {
|
||||
display: none;
|
||||
}
|
||||
"""
|
||||
pre {
|
||||
white-space: pre-wrap; /* Since CSS 2.1 */
|
||||
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
|
||||
white-space: -pre-wrap; /* Opera 4-6 */
|
||||
white-space: -o-pre-wrap; /* Opera 7 */
|
||||
word-wrap: break-word; /* Internet Explorer 5.5+ */
|
||||
}
|
||||
#notice_markdown th {
|
||||
display: none;
|
||||
}
|
||||
"""
|
||||
)
|
||||
|
||||
def change_tab(tab):
|
||||
pass
|
||||
|
||||
def change_mode(mode):
|
||||
if mode in ["默认知识库对话", "LLM原生对话"]:
|
||||
return gr.update(visible=False)
|
||||
else:
|
||||
return gr.update(visible=True)
|
||||
|
||||
|
||||
def change_tab():
|
||||
autogpt = True
|
||||
|
||||
def build_single_model_ui():
|
||||
|
||||
notice_markdown = """
|
||||
@ -301,16 +317,17 @@ def build_single_model_ui():
|
||||
|
||||
max_output_tokens = gr.Slider(
|
||||
minimum=0,
|
||||
maximum=1024,
|
||||
value=512,
|
||||
maximum=4096,
|
||||
value=2048,
|
||||
step=64,
|
||||
interactive=True,
|
||||
label="最大输出Token数",
|
||||
)
|
||||
tabs = gr.Tabs()
|
||||
tabs= gr.Tabs()
|
||||
with tabs:
|
||||
with gr.TabItem("SQL生成与诊断", elem_id="SQL"):
|
||||
# TODO A selector to choose database
|
||||
tab_sql = gr.TabItem("SQL生成与诊断", elem_id="SQL")
|
||||
with tab_sql:
|
||||
# TODO A selector to choose database
|
||||
with gr.Row(elem_id="db_selector"):
|
||||
db_selector = gr.Dropdown(
|
||||
label="请选择数据库",
|
||||
@ -318,9 +335,12 @@ def build_single_model_ui():
|
||||
value=dbs[0] if len(models) > 0 else "",
|
||||
interactive=True,
|
||||
show_label=True).style(container=False)
|
||||
tab_auto = gr.TabItem("AUTO-GPT", elem_id="auto")
|
||||
with tab_auto:
|
||||
gr.Markdown("自动执行模式下, DB-GPT可以具备执行SQL、从网络读取知识自动化存储学习的能力")
|
||||
|
||||
with gr.TabItem("知识问答", elem_id="QA"):
|
||||
|
||||
tab_qa = gr.TabItem("知识问答", elem_id="QA")
|
||||
with tab_qa:
|
||||
mode = gr.Radio(["LLM原生对话", "默认知识库对话", "新增知识库对话"], show_label=False, value="LLM原生对话")
|
||||
vs_setting = gr.Accordion("配置知识库", open=False)
|
||||
mode.change(fn=change_mode, inputs=mode, outputs=vs_setting)
|
||||
@ -360,9 +380,7 @@ def build_single_model_ui():
|
||||
regenerate_btn = gr.Button(value="重新生成", interactive=False)
|
||||
clear_btn = gr.Button(value="清理", interactive=False)
|
||||
|
||||
|
||||
gr.Markdown(learn_more_markdown)
|
||||
|
||||
btn_list = [regenerate_btn, clear_btn]
|
||||
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
|
||||
http_bot,
|
||||
@ -434,13 +452,33 @@ if __name__ == "__main__":
|
||||
"--model-list-mode", type=str, default="once", choices=["once", "reload"]
|
||||
)
|
||||
parser.add_argument("--share", default=False, action="store_true")
|
||||
parser.add_argument(
|
||||
"--moderate", action="store_true", help="Enable content moderation"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
logger.info(f"args: {args}")
|
||||
|
||||
dbs = get_database_list()
|
||||
|
||||
# 加载插件
|
||||
cfg = Config()
|
||||
|
||||
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
|
||||
|
||||
# 加载插件可执行命令
|
||||
command_registry = CommandRegistry()
|
||||
command_categories = [
|
||||
"pilot.commands.audio_text",
|
||||
"pilot.commands.image_gen",
|
||||
]
|
||||
# 排除禁用命令
|
||||
command_categories = [
|
||||
x for x in command_categories if x not in cfg.disabled_command_categories
|
||||
]
|
||||
for command_category in command_categories:
|
||||
command_registry.import_commands(command_category)
|
||||
|
||||
cfg.command_registry =command_category
|
||||
|
||||
|
||||
logger.info(args)
|
||||
demo = build_webdemo()
|
||||
demo.queue(
|
||||
|
3
pilot/speech/__init__.py
Normal file
3
pilot/speech/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from pilot.speech.say import say_text
|
||||
|
||||
__all__ = ["say_text"]
|
50
pilot/speech/base.py
Normal file
50
pilot/speech/base.py
Normal file
@ -0,0 +1,50 @@
|
||||
"""Base class for all voice classes."""
|
||||
import abc
|
||||
from threading import Lock
|
||||
|
||||
from pilot.singleton import AbstractSingleton
|
||||
|
||||
|
||||
class VoiceBase(AbstractSingleton):
|
||||
"""
|
||||
Base class for all voice classes.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialize the voice class.
|
||||
"""
|
||||
self._url = None
|
||||
self._headers = None
|
||||
self._api_key = None
|
||||
self._voices = []
|
||||
self._mutex = Lock()
|
||||
self._setup()
|
||||
|
||||
def say(self, text: str, voice_index: int = 0) -> bool:
|
||||
"""
|
||||
Say the given text.
|
||||
|
||||
Args:
|
||||
text (str): The text to say.
|
||||
voice_index (int): The index of the voice to use.
|
||||
"""
|
||||
with self._mutex:
|
||||
return self._speech(text, voice_index)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _setup(self) -> None:
|
||||
"""
|
||||
Setup the voices, API key, etc.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _speech(self, text: str, voice_index: int = 0) -> bool:
|
||||
"""
|
||||
Play the given text.
|
||||
|
||||
Args:
|
||||
text (str): The text to play.
|
||||
"""
|
||||
pass
|
43
pilot/speech/brian.py
Normal file
43
pilot/speech/brian.py
Normal file
@ -0,0 +1,43 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
import requests
|
||||
from playsound import playsound
|
||||
|
||||
from pilot.speech.base import VoiceBase
|
||||
|
||||
|
||||
class BrianSpeech(VoiceBase):
|
||||
"""Brian speech module for autogpt"""
|
||||
|
||||
def _setup(self) -> None:
|
||||
"""Setup the voices, API key, etc."""
|
||||
pass
|
||||
|
||||
def _speech(self, text: str, _: int = 0) -> bool:
|
||||
"""Speak text using Brian with the streamelements API
|
||||
|
||||
Args:
|
||||
text (str): The text to speak
|
||||
|
||||
Returns:
|
||||
bool: True if the request was successful, False otherwise
|
||||
"""
|
||||
tts_url = (
|
||||
f"https://api.streamelements.com/kappa/v2/speech?voice=Brian&text={text}"
|
||||
)
|
||||
response = requests.get(tts_url)
|
||||
|
||||
if response.status_code == 200:
|
||||
with open("speech.mp3", "wb") as f:
|
||||
f.write(response.content)
|
||||
playsound("speech.mp3")
|
||||
os.remove("speech.mp3")
|
||||
return True
|
||||
else:
|
||||
logging.error(
|
||||
"Request failed with status code: %s, response content: %s",
|
||||
response.status_code,
|
||||
response.content,
|
||||
)
|
||||
return False
|
88
pilot/speech/eleven_labs.py
Normal file
88
pilot/speech/eleven_labs.py
Normal file
@ -0,0 +1,88 @@
|
||||
"""ElevenLabs speech module"""
|
||||
import os
|
||||
|
||||
import requests
|
||||
from playsound import playsound
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.speech.base import VoiceBase
|
||||
|
||||
PLACEHOLDERS = {"your-voice-id"}
|
||||
|
||||
|
||||
class ElevenLabsSpeech(VoiceBase):
|
||||
"""ElevenLabs speech class"""
|
||||
|
||||
def _setup(self) -> None:
|
||||
"""Set up the voices, API key, etc.
|
||||
|
||||
Returns:
|
||||
None: None
|
||||
"""
|
||||
|
||||
cfg = Config()
|
||||
default_voices = ["ErXwobaYiN019PkySvjV", "EXAVITQu4vr4xnSDxMaL"]
|
||||
voice_options = {
|
||||
"Rachel": "21m00Tcm4TlvDq8ikWAM",
|
||||
"Domi": "AZnzlk1XvdvUeBnXmlld",
|
||||
"Bella": "EXAVITQu4vr4xnSDxMaL",
|
||||
"Antoni": "ErXwobaYiN019PkySvjV",
|
||||
"Elli": "MF3mGyEYCl7XYWbV9V6O",
|
||||
"Josh": "TxGEqnHWrfWFTfGW9XjX",
|
||||
"Arnold": "VR6AewLTigWG4xSOukaG",
|
||||
"Adam": "pNInz6obpgDQGcFmaJgB",
|
||||
"Sam": "yoZ06aMxZJJ28mfd3POQ",
|
||||
}
|
||||
self._headers = {
|
||||
"Content-Type": "application/json",
|
||||
"xi-api-key": cfg.elevenlabs_api_key,
|
||||
}
|
||||
self._voices = default_voices.copy()
|
||||
if cfg.elevenlabs_voice_1_id in voice_options:
|
||||
cfg.elevenlabs_voice_1_id = voice_options[cfg.elevenlabs_voice_1_id]
|
||||
if cfg.elevenlabs_voice_2_id in voice_options:
|
||||
cfg.elevenlabs_voice_2_id = voice_options[cfg.elevenlabs_voice_2_id]
|
||||
self._use_custom_voice(cfg.elevenlabs_voice_1_id, 0)
|
||||
self._use_custom_voice(cfg.elevenlabs_voice_2_id, 1)
|
||||
|
||||
def _use_custom_voice(self, voice, voice_index) -> None:
|
||||
"""Use a custom voice if provided and not a placeholder
|
||||
|
||||
Args:
|
||||
voice (str): The voice ID
|
||||
voice_index (int): The voice index
|
||||
|
||||
Returns:
|
||||
None: None
|
||||
"""
|
||||
# Placeholder values that should be treated as empty
|
||||
if voice and voice not in PLACEHOLDERS:
|
||||
self._voices[voice_index] = voice
|
||||
|
||||
def _speech(self, text: str, voice_index: int = 0) -> bool:
|
||||
"""Speak text using elevenlabs.io's API
|
||||
|
||||
Args:
|
||||
text (str): The text to speak
|
||||
voice_index (int, optional): The voice to use. Defaults to 0.
|
||||
|
||||
Returns:
|
||||
bool: True if the request was successful, False otherwise
|
||||
"""
|
||||
from pilot.logs import logger
|
||||
|
||||
tts_url = (
|
||||
f"https://api.elevenlabs.io/v1/text-to-speech/{self._voices[voice_index]}"
|
||||
)
|
||||
response = requests.post(tts_url, headers=self._headers, json={"text": text})
|
||||
|
||||
if response.status_code == 200:
|
||||
with open("speech.mpeg", "wb") as f:
|
||||
f.write(response.content)
|
||||
playsound("speech.mpeg", True)
|
||||
os.remove("speech.mpeg")
|
||||
return True
|
||||
else:
|
||||
logger.warn("Request failed with status code:", response.status_code)
|
||||
logger.info("Response content:", response.content)
|
||||
return False
|
22
pilot/speech/gtts.py
Normal file
22
pilot/speech/gtts.py
Normal file
@ -0,0 +1,22 @@
|
||||
""" GTTS Voice. """
|
||||
import os
|
||||
|
||||
import gtts
|
||||
from playsound import playsound
|
||||
|
||||
from pilot.speech.base import VoiceBase
|
||||
|
||||
|
||||
class GTTSVoice(VoiceBase):
|
||||
"""GTTS Voice."""
|
||||
|
||||
def _setup(self) -> None:
|
||||
pass
|
||||
|
||||
def _speech(self, text: str, _: int = 0) -> bool:
|
||||
"""Play the given text."""
|
||||
tts = gtts.gTTS(text)
|
||||
tts.save("speech.mp3")
|
||||
playsound("speech.mp3", True)
|
||||
os.remove("speech.mp3")
|
||||
return True
|
21
pilot/speech/macos_tts.py
Normal file
21
pilot/speech/macos_tts.py
Normal file
@ -0,0 +1,21 @@
|
||||
""" MacOS TTS Voice. """
|
||||
import os
|
||||
|
||||
from pilot.speech.base import VoiceBase
|
||||
|
||||
|
||||
class MacOSTTS(VoiceBase):
|
||||
"""MacOS TTS Voice."""
|
||||
|
||||
def _setup(self) -> None:
|
||||
pass
|
||||
|
||||
def _speech(self, text: str, voice_index: int = 0) -> bool:
|
||||
"""Play the given text."""
|
||||
if voice_index == 0:
|
||||
os.system(f'say "{text}"')
|
||||
elif voice_index == 1:
|
||||
os.system(f'say -v "Ava (Premium)" "{text}"')
|
||||
else:
|
||||
os.system(f'say -v Samantha "{text}"')
|
||||
return True
|
46
pilot/speech/say.py
Normal file
46
pilot/speech/say.py
Normal file
@ -0,0 +1,46 @@
|
||||
""" Text to speech module """
|
||||
import threading
|
||||
from threading import Semaphore
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.speech.base import VoiceBase
|
||||
from pilot.speech.brian import BrianSpeech
|
||||
from pilot.speech.eleven_labs import ElevenLabsSpeech
|
||||
from pilot.speech.gtts import GTTSVoice
|
||||
from pilot.speech.macos_tts import MacOSTTS
|
||||
|
||||
_QUEUE_SEMAPHORE = Semaphore(
|
||||
1
|
||||
) # The amount of sounds to queue before blocking the main thread
|
||||
|
||||
|
||||
def say_text(text: str, voice_index: int = 0) -> None:
|
||||
"""Speak the given text using the given voice index"""
|
||||
cfg = Config()
|
||||
default_voice_engine, voice_engine = _get_voice_engine(cfg)
|
||||
|
||||
def speak() -> None:
|
||||
success = voice_engine.say(text, voice_index)
|
||||
if not success:
|
||||
default_voice_engine.say(text)
|
||||
|
||||
_QUEUE_SEMAPHORE.release()
|
||||
|
||||
_QUEUE_SEMAPHORE.acquire(True)
|
||||
thread = threading.Thread(target=speak)
|
||||
thread.start()
|
||||
|
||||
|
||||
def _get_voice_engine(config: Config) -> tuple[VoiceBase, VoiceBase]:
|
||||
"""Get the voice engine to use for the given configuration"""
|
||||
default_voice_engine = GTTSVoice()
|
||||
if config.elevenlabs_api_key:
|
||||
voice_engine = ElevenLabsSpeech()
|
||||
elif config.use_mac_os_tts == "True":
|
||||
voice_engine = MacOSTTS()
|
||||
elif config.use_brian_tts == "True":
|
||||
voice_engine = BrianSpeech()
|
||||
else:
|
||||
voice_engine = GTTSVoice()
|
||||
|
||||
return default_voice_engine, voice_engine
|
109
pilot/utils.py
109
pilot/utils.py
@ -3,6 +3,20 @@
|
||||
|
||||
import torch
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
import logging.handlers
|
||||
import os
|
||||
import sys
|
||||
|
||||
import requests
|
||||
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
|
||||
server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
|
||||
|
||||
handler = None
|
||||
|
||||
def get_gpu_memory(max_gpus=None):
|
||||
gpu_memory = []
|
||||
num_gpus = (
|
||||
@ -20,3 +34,98 @@ def get_gpu_memory(max_gpus=None):
|
||||
available_memory = total_memory - allocated_memory
|
||||
gpu_memory.append(available_memory)
|
||||
return gpu_memory
|
||||
|
||||
|
||||
|
||||
def build_logger(logger_name, logger_filename):
|
||||
global handler
|
||||
|
||||
formatter = logging.Formatter(
|
||||
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
# Set the format of root handlers
|
||||
if not logging.getLogger().handlers:
|
||||
logging.basicConfig(level=logging.INFO, encoding='utf-8')
|
||||
logging.getLogger().handlers[0].setFormatter(formatter)
|
||||
|
||||
# Redirect stdout and stderr to loggers
|
||||
stdout_logger = logging.getLogger("stdout")
|
||||
stdout_logger.setLevel(logging.INFO)
|
||||
sl = StreamToLogger(stdout_logger, logging.INFO)
|
||||
sys.stdout = sl
|
||||
|
||||
stderr_logger = logging.getLogger("stderr")
|
||||
stderr_logger.setLevel(logging.ERROR)
|
||||
sl = StreamToLogger(stderr_logger, logging.ERROR)
|
||||
sys.stderr = sl
|
||||
|
||||
# Get logger
|
||||
logger = logging.getLogger(logger_name)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
# Add a file handler for all loggers
|
||||
if handler is None:
|
||||
os.makedirs(LOGDIR, exist_ok=True)
|
||||
filename = os.path.join(LOGDIR, logger_filename)
|
||||
handler = logging.handlers.TimedRotatingFileHandler(
|
||||
filename, when='D', utc=True)
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
for name, item in logging.root.manager.loggerDict.items():
|
||||
if isinstance(item, logging.Logger):
|
||||
item.addHandler(handler)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
class StreamToLogger(object):
|
||||
"""
|
||||
Fake file-like stream object that redirects writes to a logger instance.
|
||||
"""
|
||||
def __init__(self, logger, log_level=logging.INFO):
|
||||
self.terminal = sys.stdout
|
||||
self.logger = logger
|
||||
self.log_level = log_level
|
||||
self.linebuf = ''
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.terminal, attr)
|
||||
|
||||
def write(self, buf):
|
||||
temp_linebuf = self.linebuf + buf
|
||||
self.linebuf = ''
|
||||
for line in temp_linebuf.splitlines(True):
|
||||
# From the io.TextIOWrapper docs:
|
||||
# On output, if newline is None, any '\n' characters written
|
||||
# are translated to the system default line separator.
|
||||
# By default sys.stdout.write() expects '\n' newlines and then
|
||||
# translates them so this is still cross platform.
|
||||
if line[-1] == '\n':
|
||||
encoded_message = line.encode('utf-8', 'ignore').decode('utf-8')
|
||||
self.logger.log(self.log_level, encoded_message.rstrip())
|
||||
else:
|
||||
self.linebuf += line
|
||||
|
||||
def flush(self):
|
||||
if self.linebuf != '':
|
||||
encoded_message = self.linebuf.encode('utf-8', 'ignore').decode('utf-8')
|
||||
self.logger.log(self.log_level, encoded_message.rstrip())
|
||||
self.linebuf = ''
|
||||
|
||||
|
||||
def disable_torch_init():
|
||||
"""
|
||||
Disable the redundant torch default initialization to accelerate model creation.
|
||||
"""
|
||||
import torch
|
||||
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
||||
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
||||
|
||||
|
||||
def pretty_print_semaphore(semaphore):
|
||||
if semaphore is None:
|
||||
return "None"
|
||||
return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
|
||||
|
||||
|
0
plugins/__PUT_PLUGIN_ZIPS_HERE__
Normal file
0
plugins/__PUT_PLUGIN_ZIPS_HERE__
Normal file
@ -1,7 +1,5 @@
|
||||
accelerate==0.16.0
|
||||
torch==2.0.0
|
||||
torchvision==0.13.1
|
||||
torchaudio==0.12.1
|
||||
accelerate==0.16.0
|
||||
aiohttp==3.8.4
|
||||
aiosignal==1.3.1
|
||||
@ -20,7 +18,6 @@ importlib-resources==5.12.0
|
||||
kiwisolver==1.4.4
|
||||
matplotlib==3.7.0
|
||||
multidict==6.0.4
|
||||
openai==0.27.0
|
||||
packaging==23.0
|
||||
psutil==5.9.4
|
||||
pycocotools==2.0.6
|
||||
@ -47,11 +44,32 @@ pycocoevalcap
|
||||
sentence-transformers
|
||||
umap-learn
|
||||
notebook
|
||||
gradio==3.24.1
|
||||
gradio==3.23
|
||||
gradio-client==0.0.8
|
||||
wandb
|
||||
fschat=0.1.10
|
||||
llama-index=0.5.27
|
||||
llama-index==0.5.27
|
||||
pymysql
|
||||
unstructured==0.6.3
|
||||
pytesseract==0.3.10
|
||||
pytesseract==0.3.10
|
||||
|
||||
auto-gpt-plugin-template
|
||||
pymdown-extensions
|
||||
mkdocs
|
||||
requests
|
||||
gTTS==2.3.1
|
||||
|
||||
# Testing dependencies
|
||||
pytest
|
||||
asynctest
|
||||
pytest-asyncio
|
||||
pytest-benchmark
|
||||
pytest-cov
|
||||
pytest-integration
|
||||
pytest-mock
|
||||
vcrpy
|
||||
pytest-recording
|
||||
chromadb
|
||||
markdown2
|
||||
colorama
|
||||
playsound
|
||||
distro
|
25
run.sh
Normal file
25
run.sh
Normal file
@ -0,0 +1,25 @@
|
||||
#!/bin/bash
|
||||
|
||||
function find_python_command() {
|
||||
if command -v python &> /dev/null
|
||||
then
|
||||
echo "python"
|
||||
elif command -v python3 &> /dev/null
|
||||
then
|
||||
echo "python3"
|
||||
else
|
||||
echo "Python not found. Please install python."
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
PYTHONCMD=$(find_python_command)
|
||||
|
||||
nohup PYTHONCMD pilot/server/vicuna_server.py >> /root/server.log 2>&1 &
|
||||
while [ `grep -c "Uvicorn running on" /root/server.log` -eq '0' ];do
|
||||
sleep 1s;
|
||||
echo "wait server running"
|
||||
done
|
||||
echo "server running"
|
||||
|
||||
PYTHONCMD pilot/server/webserver.py
|
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
BIN
tests/unit/data/test_plugins/Auto-GPT-Plugin-Test-master.zip
Normal file
BIN
tests/unit/data/test_plugins/Auto-GPT-Plugin-Test-master.zip
Normal file
Binary file not shown.
135
tests/unit/test_plugins.py
Normal file
135
tests/unit/test_plugins.py
Normal file
@ -0,0 +1,135 @@
|
||||
import pytest
|
||||
import os
|
||||
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.plugins import (
|
||||
denylist_allowlist_check,
|
||||
inspect_zip_for_modules,
|
||||
scan_plugins,
|
||||
)
|
||||
|
||||
PLUGINS_TEST_DIR = "tests/unit/data/test_plugins"
|
||||
PLUGINS_TEST_DIR_TEMP = "data/test_plugins"
|
||||
PLUGIN_TEST_ZIP_FILE = "Auto-GPT-Plugin-Test-master.zip"
|
||||
PLUGIN_TEST_INIT_PY = "Auto-GPT-Plugin-Test-master/src/auto_gpt_vicuna/__init__.py"
|
||||
PLUGIN_TEST_OPENAI = "https://weathergpt.vercel.app/"
|
||||
|
||||
def test_inspect_zip_for_modules():
|
||||
current_dir = os.getcwd()
|
||||
print(current_dir)
|
||||
result = inspect_zip_for_modules(str(f"{current_dir}/{PLUGINS_TEST_DIR_TEMP}/{PLUGIN_TEST_ZIP_FILE}"))
|
||||
assert result == [PLUGIN_TEST_INIT_PY]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config_denylist_allowlist_check():
|
||||
class MockConfig:
|
||||
"""Mock config object for testing the denylist_allowlist_check function"""
|
||||
|
||||
plugins_denylist = ["BadPlugin"]
|
||||
plugins_allowlist = ["GoodPlugin"]
|
||||
authorise_key = "y"
|
||||
exit_key = "n"
|
||||
|
||||
return MockConfig()
|
||||
|
||||
|
||||
def test_denylist_allowlist_check_denylist(
|
||||
mock_config_denylist_allowlist_check, monkeypatch
|
||||
):
|
||||
# Test that the function returns False when the plugin is in the denylist
|
||||
monkeypatch.setattr("builtins.input", lambda _: "y")
|
||||
assert not denylist_allowlist_check(
|
||||
"BadPlugin", mock_config_denylist_allowlist_check
|
||||
)
|
||||
|
||||
|
||||
def test_denylist_allowlist_check_allowlist(
|
||||
mock_config_denylist_allowlist_check, monkeypatch
|
||||
):
|
||||
# Test that the function returns True when the plugin is in the allowlist
|
||||
monkeypatch.setattr("builtins.input", lambda _: "y")
|
||||
assert denylist_allowlist_check("GoodPlugin", mock_config_denylist_allowlist_check)
|
||||
|
||||
|
||||
def test_denylist_allowlist_check_user_input_yes(
|
||||
mock_config_denylist_allowlist_check, monkeypatch
|
||||
):
|
||||
# Test that the function returns True when the user inputs "y"
|
||||
monkeypatch.setattr("builtins.input", lambda _: "y")
|
||||
assert denylist_allowlist_check(
|
||||
"UnknownPlugin", mock_config_denylist_allowlist_check
|
||||
)
|
||||
|
||||
|
||||
def test_denylist_allowlist_check_user_input_no(
|
||||
mock_config_denylist_allowlist_check, monkeypatch
|
||||
):
|
||||
# Test that the function returns False when the user inputs "n"
|
||||
monkeypatch.setattr("builtins.input", lambda _: "n")
|
||||
assert not denylist_allowlist_check(
|
||||
"UnknownPlugin", mock_config_denylist_allowlist_check
|
||||
)
|
||||
|
||||
|
||||
def test_denylist_allowlist_check_user_input_invalid(
|
||||
mock_config_denylist_allowlist_check, monkeypatch
|
||||
):
|
||||
# Test that the function returns False when the user inputs an invalid value
|
||||
monkeypatch.setattr("builtins.input", lambda _: "invalid")
|
||||
assert not denylist_allowlist_check(
|
||||
"UnknownPlugin", mock_config_denylist_allowlist_check
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config_with_plugins():
|
||||
"""Mock config object for testing the scan_plugins function"""
|
||||
# Test that the function returns the correct number of plugins
|
||||
cfg = Config()
|
||||
cfg.plugins_dir = PLUGINS_TEST_DIR
|
||||
cfg.plugins_openai = ["https://weathergpt.vercel.app/"]
|
||||
return cfg
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config_openai_plugin():
|
||||
"""Mock config object for testing the scan_plugins function"""
|
||||
|
||||
class MockConfig:
|
||||
"""Mock config object for testing the scan_plugins function"""
|
||||
current_dir = os.getcwd()
|
||||
plugins_dir = f"{current_dir}/{PLUGINS_TEST_DIR_TEMP}/"
|
||||
plugins_openai = [PLUGIN_TEST_OPENAI]
|
||||
plugins_denylist = ["AutoGPTPVicuna"]
|
||||
plugins_allowlist = [PLUGIN_TEST_OPENAI]
|
||||
|
||||
return MockConfig()
|
||||
|
||||
|
||||
def test_scan_plugins_openai(mock_config_openai_plugin):
|
||||
# Test that the function returns the correct number of plugins
|
||||
result = scan_plugins(mock_config_openai_plugin, debug=True)
|
||||
assert len(result) == 1
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config_generic_plugin():
|
||||
"""Mock config object for testing the scan_plugins function"""
|
||||
|
||||
# Test that the function returns the correct number of plugins
|
||||
class MockConfig:
|
||||
current_dir = os.getcwd()
|
||||
plugins_dir = f"{current_dir}/{PLUGINS_TEST_DIR_TEMP}/"
|
||||
plugins_openai = []
|
||||
plugins_denylist = []
|
||||
plugins_allowlist = ["AutoGPTPVicuna"]
|
||||
|
||||
return MockConfig()
|
||||
|
||||
|
||||
def test_scan_plugins_generic(mock_config_generic_plugin):
|
||||
# Test that the function returns the correct number of plugins
|
||||
result = scan_plugins(mock_config_generic_plugin, debug=True)
|
||||
assert len(result) == 1
|
Loading…
Reference in New Issue
Block a user