mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-15 22:19:28 +00:00
Merge branch 'main' into TY_08_DEV_NEW
This commit is contained in:
@@ -29,6 +29,7 @@ class DBType(Enum):
|
||||
Oracle = DbInfo("oracle")
|
||||
MSSQL = DbInfo("mssql")
|
||||
Postgresql = DbInfo("postgresql")
|
||||
Clickhouse = DbInfo("clickhouse")
|
||||
|
||||
def value(self):
|
||||
return self._value_.name
|
||||
|
@@ -41,6 +41,12 @@ LLM_MODEL_CONFIG = {
|
||||
"m3e-base": os.path.join(MODEL_PATH, "m3e-base"),
|
||||
# https://huggingface.co/moka-ai/m3e-base
|
||||
"m3e-large": os.path.join(MODEL_PATH, "m3e-large"),
|
||||
# https://huggingface.co/BAAI/bge-large-en
|
||||
"bge-large-en": os.path.join(MODEL_PATH, "bge-large-en"),
|
||||
"bge-base-en": os.path.join(MODEL_PATH, "bge-base-en"),
|
||||
# https://huggingface.co/BAAI/bge-large-zh
|
||||
"bge-large-zh": os.path.join(MODEL_PATH, "bge-large-zh"),
|
||||
"bge-base-zh": os.path.join(MODEL_PATH, "bge-base-zh"),
|
||||
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"),
|
||||
"codegen2-1b": os.path.join(MODEL_PATH, "codegen2-1B"),
|
||||
"codet5p-2b": os.path.join(MODEL_PATH, "codet5p-2b"),
|
||||
|
@@ -11,6 +11,7 @@ from pilot.connections.rdbms.conn_duckdb import DuckDbConnect
|
||||
from pilot.connections.rdbms.conn_sqlite import SQLiteConnect
|
||||
from pilot.connections.rdbms.conn_mssql import MSSQLConnect
|
||||
from pilot.connections.rdbms.base import RDBMSDatabase
|
||||
from pilot.connections.rdbms.conn_clickhouse import ClickhouseConnect
|
||||
from pilot.singleton import Singleton
|
||||
from pilot.common.sql_database import Database
|
||||
from pilot.connections.db_conn_info import DBConfig
|
||||
|
108
pilot/connections/rdbms/conn_clickhouse.py
Normal file
108
pilot/connections/rdbms/conn_clickhouse.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import re
|
||||
from typing import Optional, Any
|
||||
from sqlalchemy import text
|
||||
|
||||
from pilot.connections.rdbms.base import RDBMSDatabase
|
||||
|
||||
|
||||
class ClickhouseConnect(RDBMSDatabase):
|
||||
"""Connect Clickhouse Database fetch MetaData
|
||||
Args:
|
||||
Usage:
|
||||
"""
|
||||
|
||||
"""db type"""
|
||||
db_type: str = "clickhouse"
|
||||
"""db driver"""
|
||||
driver: str = "clickhouse"
|
||||
"""db dialect"""
|
||||
db_dialect: str = "clickhouse"
|
||||
|
||||
@classmethod
|
||||
def from_uri_db(
|
||||
cls,
|
||||
host: str,
|
||||
port: int,
|
||||
user: str,
|
||||
pwd: str,
|
||||
db_name: str,
|
||||
engine_args: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> RDBMSDatabase:
|
||||
db_url: str = (
|
||||
cls.driver
|
||||
+ "://"
|
||||
+ user
|
||||
+ ":"
|
||||
+ pwd
|
||||
+ "@"
|
||||
+ host
|
||||
+ ":"
|
||||
+ str(port)
|
||||
+ "/"
|
||||
+ db_name
|
||||
)
|
||||
return cls.from_uri(db_url, engine_args, **kwargs)
|
||||
|
||||
def get_indexes(self, table_name):
|
||||
"""Get table indexes about specified table."""
|
||||
return ""
|
||||
|
||||
def get_show_create_table(self, table_name):
|
||||
"""Get table show create table about specified table."""
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(text(f"SHOW CREATE TABLE {table_name}"))
|
||||
ans = cursor.fetchall()
|
||||
ans = ans[0][0]
|
||||
ans = re.sub(r"\s*ENGINE\s*=\s*MergeTree\s*", " ", ans, flags=re.IGNORECASE)
|
||||
ans = re.sub(
|
||||
r"\s*DEFAULT\s*CHARSET\s*=\s*\w+\s*", " ", ans, flags=re.IGNORECASE
|
||||
)
|
||||
ans = re.sub(r"\s*SETTINGS\s*\s*\w+\s*", " ", ans, flags=re.IGNORECASE)
|
||||
return ans
|
||||
|
||||
def get_fields(self, table_name):
|
||||
"""Get column fields about specified table."""
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(
|
||||
text(
|
||||
f"SELECT name, type, default_expression, is_in_primary_key, comment from system.columns where table='{table_name}'".format(
|
||||
table_name
|
||||
)
|
||||
)
|
||||
)
|
||||
fields = cursor.fetchall()
|
||||
return [(field[0], field[1], field[2], field[3], field[4]) for field in fields]
|
||||
|
||||
def get_users(self):
|
||||
return []
|
||||
|
||||
def get_grants(self):
|
||||
return []
|
||||
|
||||
def get_collation(self):
|
||||
"""Get collation."""
|
||||
return "UTF-8"
|
||||
|
||||
def get_charset(self):
|
||||
return "UTF-8"
|
||||
|
||||
def get_database_list(self):
|
||||
return []
|
||||
|
||||
def get_database_names(self):
|
||||
return []
|
||||
|
||||
def get_table_comments(self, db_name):
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(
|
||||
text(
|
||||
f"""SELECT table, comment FROM system.tables WHERE database = '{db_name}'""".format(
|
||||
db_name
|
||||
)
|
||||
)
|
||||
)
|
||||
table_comments = cursor.fetchall()
|
||||
return [
|
||||
(table_comment[0], table_comment[1]) for table_comment in table_comments
|
||||
]
|
@@ -1,6 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
from typing import Optional, Any, Iterable
|
||||
from sqlalchemy import create_engine, text
|
||||
|
||||
@@ -24,6 +25,9 @@ class SQLiteConnect(RDBMSDatabase):
|
||||
_engine_args = engine_args or {}
|
||||
_engine_args["connect_args"] = {"check_same_thread": False}
|
||||
# _engine_args["echo"] = True
|
||||
directory = os.path.dirname(file_path)
|
||||
if not os.path.exists(directory):
|
||||
os.makedirs(directory)
|
||||
return cls(create_engine("sqlite:///" + file_path, **_engine_args), **kwargs)
|
||||
|
||||
def get_indexes(self, table_name):
|
||||
|
@@ -233,18 +233,10 @@ class GorillaAdapter(BaseLLMAdaper):
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
class CodeGenAdapter(BaseLLMAdaper):
|
||||
pass
|
||||
|
||||
|
||||
class StarCoderAdapter(BaseLLMAdaper):
|
||||
pass
|
||||
|
||||
|
||||
class T5CodeAdapter(BaseLLMAdaper):
|
||||
pass
|
||||
|
||||
|
||||
class KoalaLLMAdapter(BaseLLMAdaper):
|
||||
"""Koala LLM Adapter which Based LLaMA"""
|
||||
|
||||
@@ -270,7 +262,7 @@ class GPT4AllAdapter(BaseLLMAdaper):
|
||||
"""
|
||||
|
||||
def match(self, model_path: str):
|
||||
return "gpt4all" in model_path
|
||||
return "gptj-6b" in model_path
|
||||
|
||||
def loader(self, model_path: str, from_pretrained_kwargs: dict):
|
||||
import gpt4all
|
||||
|
@@ -335,6 +335,20 @@ register_conv_template(
|
||||
)
|
||||
)
|
||||
|
||||
# Alpaca default template
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
name="alpaca",
|
||||
system="Below is an instruction that describes a task. Write a response that appropriately completes the request.",
|
||||
roles=("### Instruction", "### Response"),
|
||||
messages=(),
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.ADD_COLON_TWO,
|
||||
sep="\n\n",
|
||||
sep2="</s>",
|
||||
)
|
||||
)
|
||||
|
||||
# Baichuan-13B-Chat template
|
||||
register_conv_template(
|
||||
# source: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/f5f47be2adbbdceb784f334d6fa1ca2c73e65097/modeling_baichuan.py#L507
|
||||
|
@@ -1,23 +1,10 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
import threading
|
||||
import sys
|
||||
import time
|
||||
|
||||
|
||||
def gpt4all_generate_stream(model, tokenizer, params, device, max_position_embeddings):
|
||||
stop = params.get("stop", "###")
|
||||
prompt = params["prompt"]
|
||||
role, query = prompt.split(stop)[1].split(":")
|
||||
role, query = prompt.split(stop)[0].split(":")
|
||||
print(f"gpt4all, role: {role}, query: {query}")
|
||||
|
||||
def worker():
|
||||
model.generate(prompt=query, streaming=True)
|
||||
|
||||
t = threading.Thread(target=worker)
|
||||
t.start()
|
||||
|
||||
while t.is_alive():
|
||||
yield sys.stdout.output
|
||||
time.sleep(0.01)
|
||||
t.join()
|
||||
yield model.generate(prompt=query, streaming=True)
|
||||
|
@@ -110,7 +110,13 @@ async def db_connect_delete(db_name: str = None):
|
||||
|
||||
@router.get("/v1/chat/db/support/type", response_model=Result[DbTypeInfo])
|
||||
async def db_support_types():
|
||||
support_types = [DBType.Mysql, DBType.MSSQL, DBType.DuckDb]
|
||||
support_types = [
|
||||
DBType.Mysql,
|
||||
DBType.MSSQL,
|
||||
DBType.DuckDb,
|
||||
DBType.SQLite,
|
||||
DBType.Clickhouse,
|
||||
]
|
||||
db_type_infos = []
|
||||
for type in support_types:
|
||||
db_type_infos.append(
|
||||
@@ -315,6 +321,7 @@ async def no_stream_generator(chat):
|
||||
|
||||
async def stream_generator(chat):
|
||||
model_response = chat.stream_call()
|
||||
msg = "[LLM_ERROR]: llm server has no output, maybe your prompt template is wrong."
|
||||
if not CFG.NEW_SERVER_MODE:
|
||||
for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||
if chunk:
|
||||
|
@@ -148,28 +148,6 @@ class ChatGLMChatAdapter(BaseChatAdpter):
|
||||
return chatglm_generate_stream
|
||||
|
||||
|
||||
class CodeT5ChatAdapter(BaseChatAdpter):
|
||||
"""Model chat adapter for CodeT5"""
|
||||
|
||||
def match(self, model_path: str):
|
||||
return "codet5" in model_path
|
||||
|
||||
def get_generate_stream_func(self, model_path: str):
|
||||
# TODO
|
||||
pass
|
||||
|
||||
|
||||
class CodeGenChatAdapter(BaseChatAdpter):
|
||||
"""Model chat adapter for CodeGen"""
|
||||
|
||||
def match(self, model_path: str):
|
||||
return "codegen" in model_path
|
||||
|
||||
def get_generate_stream_func(self, model_path: str):
|
||||
# TODO
|
||||
pass
|
||||
|
||||
|
||||
class GuanacoChatAdapter(BaseChatAdpter):
|
||||
"""Model chat adapter for Guanaco"""
|
||||
|
||||
@@ -216,7 +194,7 @@ class GorillaChatAdapter(BaseChatAdpter):
|
||||
|
||||
class GPT4AllChatAdapter(BaseChatAdpter):
|
||||
def match(self, model_path: str):
|
||||
return "gpt4all" in model_path
|
||||
return "gptj-6b" in model_path
|
||||
|
||||
def get_generate_stream_func(self, model_path: str):
|
||||
from pilot.model.llm_out.gpt4all_llm import gpt4all_generate_stream
|
||||
|
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
from fastapi import APIRouter, File, UploadFile, Form
|
||||
@@ -130,29 +131,28 @@ async def document_upload(
|
||||
if doc_file:
|
||||
if not os.path.exists(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name)):
|
||||
os.makedirs(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name))
|
||||
with NamedTemporaryFile(
|
||||
dir=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name), delete=False
|
||||
) as tmp:
|
||||
# We can not move temp file in windows system when we open file in context of `with`
|
||||
tmp_fd, tmp_path = tempfile.mkstemp(
|
||||
dir=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name)
|
||||
)
|
||||
with os.fdopen(tmp_fd, "wb") as tmp:
|
||||
tmp.write(await doc_file.read())
|
||||
tmp_path = tmp.name
|
||||
shutil.move(
|
||||
tmp_path,
|
||||
os.path.join(
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH, space_name, doc_file.filename
|
||||
),
|
||||
shutil.move(
|
||||
tmp_path,
|
||||
os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name, doc_file.filename),
|
||||
)
|
||||
request = KnowledgeDocumentRequest()
|
||||
request.doc_name = doc_name
|
||||
request.doc_type = doc_type
|
||||
request.content = os.path.join(
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH, space_name, doc_file.filename
|
||||
)
|
||||
return Result.succ(
|
||||
knowledge_space_service.create_knowledge_document(
|
||||
space=space_name, request=request
|
||||
)
|
||||
request = KnowledgeDocumentRequest()
|
||||
request.doc_name = doc_name
|
||||
request.doc_type = doc_type
|
||||
request.content = os.path.join(
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH, space_name, doc_file.filename
|
||||
)
|
||||
return Result.succ(
|
||||
knowledge_space_service.create_knowledge_document(
|
||||
space=space_name, request=request
|
||||
)
|
||||
)
|
||||
# return Result.succ([])
|
||||
)
|
||||
# return Result.succ([])
|
||||
return Result.faild(code="E000X", msg=f"doc_file is None")
|
||||
except Exception as e:
|
||||
return Result.faild(code="E000X", msg=f"document add error {e}")
|
||||
|
@@ -6,6 +6,7 @@ import json
|
||||
import os
|
||||
import sys
|
||||
from typing import List
|
||||
import platform
|
||||
|
||||
import uvicorn
|
||||
from fastapi import BackgroundTasks, FastAPI, Request
|
||||
@@ -89,13 +90,18 @@ class ModelWorker:
|
||||
params, model_context = self.llm_chat_adapter.model_adaptation(
|
||||
params, self.ml.model_path, prompt_template=self.ml.prompt_template
|
||||
)
|
||||
|
||||
for output in self.generate_stream_func(
|
||||
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
|
||||
):
|
||||
# Please do not open the output in production!
|
||||
# The gpt4all thread shares stdout with the parent process,
|
||||
# and opening it may affect the frontend output.
|
||||
print("output: ", output)
|
||||
if "windows" in platform.platform().lower():
|
||||
# Do not print the model output, because it may contain Emoji, there is a problem with the GBK encoding
|
||||
pass
|
||||
else:
|
||||
print("output: ", output)
|
||||
# return some model context to dgt-server
|
||||
ret = {"text": output, "error_code": 0, "model_context": model_context}
|
||||
yield json.dumps(ret).encode() + b"\0"
|
||||
|
@@ -1,8 +1,8 @@
|
||||
from pilot.vector_store.chroma_store import ChromaStore
|
||||
|
||||
from pilot.vector_store.weaviate_store import WeaviateStore
|
||||
# from pilot.vector_store.weaviate_store import WeaviateStore
|
||||
|
||||
connector = {"Chroma": ChromaStore, "Weaviate": WeaviateStore}
|
||||
connector = {"Chroma": ChromaStore}
|
||||
|
||||
try:
|
||||
from pilot.vector_store.milvus_store import MilvusStore
|
||||
|
Reference in New Issue
Block a user