fix(conn): database conn fix special symbols (#898)

This commit is contained in:
magic.chen
2023-12-06 16:51:47 +08:00
committed by GitHub
parent 54e2aa1dbd
commit afad4ffd32
7 changed files with 13 additions and 75 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 196 KiB

After

Width:  |  Height:  |  Size: 92 KiB

View File

@@ -11,6 +11,7 @@ from alembic import command
from alembic.config import Config as AlembicConfig
from urllib.parse import quote
from pilot.configs.config import Config
from urllib.parse import quote_plus as urlquote
logger = logging.getLogger(__name__)
@@ -29,14 +30,7 @@ connection = sqlite3.connect(db_path)
if CFG.LOCAL_DB_TYPE == "mysql":
engine_temp = create_engine(
f"mysql+pymysql://"
+ quote(CFG.LOCAL_DB_USER)
+ ":"
+ quote(CFG.LOCAL_DB_PASSWORD)
+ "@"
+ CFG.LOCAL_DB_HOST
+ ":"
+ str(CFG.LOCAL_DB_PORT)
f"mysql+pymysql://{quote(CFG.LOCAL_DB_USER)}:{urlquote(CFG.LOCAL_DB_PASSWORD)}@{CFG.LOCAL_DB_HOST}:{str(CFG.LOCAL_DB_PORT)}"
)
# check and auto create mysqldatabase
try:
@@ -51,15 +45,7 @@ if CFG.LOCAL_DB_TYPE == "mysql":
logger.error(f"{db_name} not connect success!")
engine = create_engine(
f"mysql+pymysql://"
+ quote(CFG.LOCAL_DB_USER)
+ ":"
+ quote(CFG.LOCAL_DB_PASSWORD)
+ "@"
+ CFG.LOCAL_DB_HOST
+ ":"
+ str(CFG.LOCAL_DB_PORT)
+ f"/{db_name}"
f"mysql+pymysql://{quote(CFG.LOCAL_DB_USER)}:{urlquote(CFG.LOCAL_DB_PASSWORD)}@{CFG.LOCAL_DB_HOST}:{str(CFG.LOCAL_DB_PORT)}/{db_name}"
)
else:
engine = create_engine(f"sqlite:///{db_path}")

View File

@@ -1,9 +1,10 @@
from __future__ import annotations
from urllib.parse import quote
import warnings
import sqlparse
import regex as re
import pandas as pd
from urllib.parse import quote
from urllib.parse import quote_plus as urlquote
from typing import Any, Iterable, List, Optional
from pydantic import BaseModel, Field, root_validator, validator, Extra
from abc import ABC, abstractmethod
@@ -113,17 +114,7 @@ class RDBMSDatabase(BaseConnect):
engine_args (Optional[dict]):other engine_args.
"""
db_url: str = (
cls.driver
+ "://"
+ quote(user)
+ ":"
+ quote(pwd)
+ "@"
+ host
+ ":"
+ str(port)
+ "/"
+ db_name
f"{cls.driver}://{quote(user)}:{urlquote(pwd)}@{host}:{str(port)}/{db_name}"
)
return cls.from_uri(db_url, engine_args, **kwargs)

View File

@@ -1,6 +1,8 @@
import re
from typing import Optional, Any
from sqlalchemy import text
from urllib.parse import quote
from urllib.parse import quote_plus as urlquote
from pilot.connections.rdbms.base import RDBMSDatabase
@@ -30,17 +32,7 @@ class ClickhouseConnect(RDBMSDatabase):
**kwargs: Any,
) -> RDBMSDatabase:
db_url: str = (
cls.driver
+ "://"
+ user
+ ":"
+ pwd
+ "@"
+ host
+ ":"
+ str(port)
+ "/"
+ db_name
f"{cls.driver}://{quote(user)}:{urlquote(pwd)}@{host}:{str(port)}/{db_name}"
)
return cls.from_uri(db_url, engine_args, **kwargs)

View File

@@ -82,25 +82,3 @@ class DuckDbConnect(RDBMSDatabase):
results.append(f"{table_name}({','.join(table_colums)});")
return results
if __name__ == "__main__":
engine = create_engine(
"duckdb:////Users/tuyang.yhj/Code/PycharmProjects/DB-GPT/pilot/mock_datas/db-gpt-test.db"
)
metadata = MetaData(engine)
results = (
engine.connect()
.execute("SELECT name FROM sqlite_master WHERE type='table'")
.fetchall()
)
print(str(results))
fields = []
results2 = engine.connect().execute(f"""PRAGMA table_info(user)""").fetchall()
for row_col in results2:
field_info = list(row_col)
fields.append(field_info[1])
print(str(fields))

View File

@@ -1,6 +1,7 @@
from typing import Iterable, Optional, Any
from sqlalchemy import text
from urllib.parse import quote
from urllib.parse import quote_plus as urlquote
from pilot.connections.rdbms.base import RDBMSDatabase
@@ -21,17 +22,7 @@ class PostgreSQLDatabase(RDBMSDatabase):
**kwargs: Any,
) -> RDBMSDatabase:
db_url: str = (
cls.driver
+ "://"
+ quote(user)
+ ":"
+ quote(pwd)
+ "@"
+ host
+ ":"
+ str(port)
+ "/"
+ db_name
f"{cls.driver}://{quote(user)}:{urlquote(pwd)}@{host}:{str(port)}/{db_name}"
)
return cls.from_uri(db_url, engine_args, **kwargs)

View File

@@ -1,7 +1,7 @@
from typing import Iterable, Optional, Any
from sqlalchemy import text
from urllib.parse import quote
import re
from urllib.parse import quote_plus as urlquote
from pilot.connections.rdbms.base import RDBMSDatabase
from pilot.connections.rdbms.dialect.starrocks.sqlalchemy import *
@@ -23,7 +23,7 @@ class StarRocksConnect(RDBMSDatabase):
**kwargs: Any,
) -> RDBMSDatabase:
db_url: str = (
f"{cls.driver}://{quote(user)}:{quote(pwd)}@{host}:{str(port)}/{db_name}"
f"{cls.driver}://{quote(user)}:{urlquote(pwd)}@{host}:{str(port)}/{db_name}"
)
return cls.from_uri(db_url, engine_args, **kwargs)