mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-13 21:21:08 +00:00
fix(conn): database conn fix special symbols (#898)
This commit is contained in:
Binary file not shown.
Before Width: | Height: | Size: 196 KiB After Width: | Height: | Size: 92 KiB |
@@ -11,6 +11,7 @@ from alembic import command
|
|||||||
from alembic.config import Config as AlembicConfig
|
from alembic.config import Config as AlembicConfig
|
||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
|
from urllib.parse import quote_plus as urlquote
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -29,14 +30,7 @@ connection = sqlite3.connect(db_path)
|
|||||||
|
|
||||||
if CFG.LOCAL_DB_TYPE == "mysql":
|
if CFG.LOCAL_DB_TYPE == "mysql":
|
||||||
engine_temp = create_engine(
|
engine_temp = create_engine(
|
||||||
f"mysql+pymysql://"
|
f"mysql+pymysql://{quote(CFG.LOCAL_DB_USER)}:{urlquote(CFG.LOCAL_DB_PASSWORD)}@{CFG.LOCAL_DB_HOST}:{str(CFG.LOCAL_DB_PORT)}"
|
||||||
+ quote(CFG.LOCAL_DB_USER)
|
|
||||||
+ ":"
|
|
||||||
+ quote(CFG.LOCAL_DB_PASSWORD)
|
|
||||||
+ "@"
|
|
||||||
+ CFG.LOCAL_DB_HOST
|
|
||||||
+ ":"
|
|
||||||
+ str(CFG.LOCAL_DB_PORT)
|
|
||||||
)
|
)
|
||||||
# check and auto create mysqldatabase
|
# check and auto create mysqldatabase
|
||||||
try:
|
try:
|
||||||
@@ -51,15 +45,7 @@ if CFG.LOCAL_DB_TYPE == "mysql":
|
|||||||
logger.error(f"{db_name} not connect success!")
|
logger.error(f"{db_name} not connect success!")
|
||||||
|
|
||||||
engine = create_engine(
|
engine = create_engine(
|
||||||
f"mysql+pymysql://"
|
f"mysql+pymysql://{quote(CFG.LOCAL_DB_USER)}:{urlquote(CFG.LOCAL_DB_PASSWORD)}@{CFG.LOCAL_DB_HOST}:{str(CFG.LOCAL_DB_PORT)}/{db_name}"
|
||||||
+ quote(CFG.LOCAL_DB_USER)
|
|
||||||
+ ":"
|
|
||||||
+ quote(CFG.LOCAL_DB_PASSWORD)
|
|
||||||
+ "@"
|
|
||||||
+ CFG.LOCAL_DB_HOST
|
|
||||||
+ ":"
|
|
||||||
+ str(CFG.LOCAL_DB_PORT)
|
|
||||||
+ f"/{db_name}"
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
engine = create_engine(f"sqlite:///{db_path}")
|
engine = create_engine(f"sqlite:///{db_path}")
|
||||||
|
@@ -1,9 +1,10 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from urllib.parse import quote
|
|
||||||
import warnings
|
import warnings
|
||||||
import sqlparse
|
import sqlparse
|
||||||
import regex as re
|
import regex as re
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
from urllib.parse import quote
|
||||||
|
from urllib.parse import quote_plus as urlquote
|
||||||
from typing import Any, Iterable, List, Optional
|
from typing import Any, Iterable, List, Optional
|
||||||
from pydantic import BaseModel, Field, root_validator, validator, Extra
|
from pydantic import BaseModel, Field, root_validator, validator, Extra
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
@@ -113,17 +114,7 @@ class RDBMSDatabase(BaseConnect):
|
|||||||
engine_args (Optional[dict]):other engine_args.
|
engine_args (Optional[dict]):other engine_args.
|
||||||
"""
|
"""
|
||||||
db_url: str = (
|
db_url: str = (
|
||||||
cls.driver
|
f"{cls.driver}://{quote(user)}:{urlquote(pwd)}@{host}:{str(port)}/{db_name}"
|
||||||
+ "://"
|
|
||||||
+ quote(user)
|
|
||||||
+ ":"
|
|
||||||
+ quote(pwd)
|
|
||||||
+ "@"
|
|
||||||
+ host
|
|
||||||
+ ":"
|
|
||||||
+ str(port)
|
|
||||||
+ "/"
|
|
||||||
+ db_name
|
|
||||||
)
|
)
|
||||||
return cls.from_uri(db_url, engine_args, **kwargs)
|
return cls.from_uri(db_url, engine_args, **kwargs)
|
||||||
|
|
||||||
|
@@ -1,6 +1,8 @@
|
|||||||
import re
|
import re
|
||||||
from typing import Optional, Any
|
from typing import Optional, Any
|
||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
|
from urllib.parse import quote
|
||||||
|
from urllib.parse import quote_plus as urlquote
|
||||||
|
|
||||||
from pilot.connections.rdbms.base import RDBMSDatabase
|
from pilot.connections.rdbms.base import RDBMSDatabase
|
||||||
|
|
||||||
@@ -30,17 +32,7 @@ class ClickhouseConnect(RDBMSDatabase):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> RDBMSDatabase:
|
) -> RDBMSDatabase:
|
||||||
db_url: str = (
|
db_url: str = (
|
||||||
cls.driver
|
f"{cls.driver}://{quote(user)}:{urlquote(pwd)}@{host}:{str(port)}/{db_name}"
|
||||||
+ "://"
|
|
||||||
+ user
|
|
||||||
+ ":"
|
|
||||||
+ pwd
|
|
||||||
+ "@"
|
|
||||||
+ host
|
|
||||||
+ ":"
|
|
||||||
+ str(port)
|
|
||||||
+ "/"
|
|
||||||
+ db_name
|
|
||||||
)
|
)
|
||||||
return cls.from_uri(db_url, engine_args, **kwargs)
|
return cls.from_uri(db_url, engine_args, **kwargs)
|
||||||
|
|
||||||
|
@@ -82,25 +82,3 @@ class DuckDbConnect(RDBMSDatabase):
|
|||||||
|
|
||||||
results.append(f"{table_name}({','.join(table_colums)});")
|
results.append(f"{table_name}({','.join(table_colums)});")
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
engine = create_engine(
|
|
||||||
"duckdb:////Users/tuyang.yhj/Code/PycharmProjects/DB-GPT/pilot/mock_datas/db-gpt-test.db"
|
|
||||||
)
|
|
||||||
metadata = MetaData(engine)
|
|
||||||
|
|
||||||
results = (
|
|
||||||
engine.connect()
|
|
||||||
.execute("SELECT name FROM sqlite_master WHERE type='table'")
|
|
||||||
.fetchall()
|
|
||||||
)
|
|
||||||
|
|
||||||
print(str(results))
|
|
||||||
|
|
||||||
fields = []
|
|
||||||
results2 = engine.connect().execute(f"""PRAGMA table_info(user)""").fetchall()
|
|
||||||
for row_col in results2:
|
|
||||||
field_info = list(row_col)
|
|
||||||
fields.append(field_info[1])
|
|
||||||
print(str(fields))
|
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
from typing import Iterable, Optional, Any
|
from typing import Iterable, Optional, Any
|
||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
|
from urllib.parse import quote_plus as urlquote
|
||||||
from pilot.connections.rdbms.base import RDBMSDatabase
|
from pilot.connections.rdbms.base import RDBMSDatabase
|
||||||
|
|
||||||
|
|
||||||
@@ -21,17 +22,7 @@ class PostgreSQLDatabase(RDBMSDatabase):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> RDBMSDatabase:
|
) -> RDBMSDatabase:
|
||||||
db_url: str = (
|
db_url: str = (
|
||||||
cls.driver
|
f"{cls.driver}://{quote(user)}:{urlquote(pwd)}@{host}:{str(port)}/{db_name}"
|
||||||
+ "://"
|
|
||||||
+ quote(user)
|
|
||||||
+ ":"
|
|
||||||
+ quote(pwd)
|
|
||||||
+ "@"
|
|
||||||
+ host
|
|
||||||
+ ":"
|
|
||||||
+ str(port)
|
|
||||||
+ "/"
|
|
||||||
+ db_name
|
|
||||||
)
|
)
|
||||||
return cls.from_uri(db_url, engine_args, **kwargs)
|
return cls.from_uri(db_url, engine_args, **kwargs)
|
||||||
|
|
||||||
|
@@ -1,7 +1,7 @@
|
|||||||
from typing import Iterable, Optional, Any
|
from typing import Iterable, Optional, Any
|
||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
import re
|
from urllib.parse import quote_plus as urlquote
|
||||||
from pilot.connections.rdbms.base import RDBMSDatabase
|
from pilot.connections.rdbms.base import RDBMSDatabase
|
||||||
from pilot.connections.rdbms.dialect.starrocks.sqlalchemy import *
|
from pilot.connections.rdbms.dialect.starrocks.sqlalchemy import *
|
||||||
|
|
||||||
@@ -23,7 +23,7 @@ class StarRocksConnect(RDBMSDatabase):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> RDBMSDatabase:
|
) -> RDBMSDatabase:
|
||||||
db_url: str = (
|
db_url: str = (
|
||||||
f"{cls.driver}://{quote(user)}:{quote(pwd)}@{host}:{str(port)}/{db_name}"
|
f"{cls.driver}://{quote(user)}:{urlquote(pwd)}@{host}:{str(port)}/{db_name}"
|
||||||
)
|
)
|
||||||
return cls.from_uri(db_url, engine_args, **kwargs)
|
return cls.from_uri(db_url, engine_args, **kwargs)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user