fix(conn): database conn fix special symbols (#898)

This commit is contained in:
magic.chen
2023-12-06 16:51:47 +08:00
committed by GitHub
parent 54e2aa1dbd
commit afad4ffd32
7 changed files with 13 additions and 75 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 196 KiB

After

Width:  |  Height:  |  Size: 92 KiB

View File

@@ -11,6 +11,7 @@ from alembic import command
from alembic.config import Config as AlembicConfig from alembic.config import Config as AlembicConfig
from urllib.parse import quote from urllib.parse import quote
from pilot.configs.config import Config from pilot.configs.config import Config
from urllib.parse import quote_plus as urlquote
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -29,14 +30,7 @@ connection = sqlite3.connect(db_path)
if CFG.LOCAL_DB_TYPE == "mysql": if CFG.LOCAL_DB_TYPE == "mysql":
engine_temp = create_engine( engine_temp = create_engine(
f"mysql+pymysql://" f"mysql+pymysql://{quote(CFG.LOCAL_DB_USER)}:{urlquote(CFG.LOCAL_DB_PASSWORD)}@{CFG.LOCAL_DB_HOST}:{str(CFG.LOCAL_DB_PORT)}"
+ quote(CFG.LOCAL_DB_USER)
+ ":"
+ quote(CFG.LOCAL_DB_PASSWORD)
+ "@"
+ CFG.LOCAL_DB_HOST
+ ":"
+ str(CFG.LOCAL_DB_PORT)
) )
# check and auto create mysqldatabase # check and auto create mysqldatabase
try: try:
@@ -51,15 +45,7 @@ if CFG.LOCAL_DB_TYPE == "mysql":
logger.error(f"{db_name} not connect success!") logger.error(f"{db_name} not connect success!")
engine = create_engine( engine = create_engine(
f"mysql+pymysql://" f"mysql+pymysql://{quote(CFG.LOCAL_DB_USER)}:{urlquote(CFG.LOCAL_DB_PASSWORD)}@{CFG.LOCAL_DB_HOST}:{str(CFG.LOCAL_DB_PORT)}/{db_name}"
+ quote(CFG.LOCAL_DB_USER)
+ ":"
+ quote(CFG.LOCAL_DB_PASSWORD)
+ "@"
+ CFG.LOCAL_DB_HOST
+ ":"
+ str(CFG.LOCAL_DB_PORT)
+ f"/{db_name}"
) )
else: else:
engine = create_engine(f"sqlite:///{db_path}") engine = create_engine(f"sqlite:///{db_path}")

View File

@@ -1,9 +1,10 @@
from __future__ import annotations from __future__ import annotations
from urllib.parse import quote
import warnings import warnings
import sqlparse import sqlparse
import regex as re import regex as re
import pandas as pd import pandas as pd
from urllib.parse import quote
from urllib.parse import quote_plus as urlquote
from typing import Any, Iterable, List, Optional from typing import Any, Iterable, List, Optional
from pydantic import BaseModel, Field, root_validator, validator, Extra from pydantic import BaseModel, Field, root_validator, validator, Extra
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
@@ -113,17 +114,7 @@ class RDBMSDatabase(BaseConnect):
engine_args (Optional[dict]):other engine_args. engine_args (Optional[dict]):other engine_args.
""" """
db_url: str = ( db_url: str = (
cls.driver f"{cls.driver}://{quote(user)}:{urlquote(pwd)}@{host}:{str(port)}/{db_name}"
+ "://"
+ quote(user)
+ ":"
+ quote(pwd)
+ "@"
+ host
+ ":"
+ str(port)
+ "/"
+ db_name
) )
return cls.from_uri(db_url, engine_args, **kwargs) return cls.from_uri(db_url, engine_args, **kwargs)

View File

@@ -1,6 +1,8 @@
import re import re
from typing import Optional, Any from typing import Optional, Any
from sqlalchemy import text from sqlalchemy import text
from urllib.parse import quote
from urllib.parse import quote_plus as urlquote
from pilot.connections.rdbms.base import RDBMSDatabase from pilot.connections.rdbms.base import RDBMSDatabase
@@ -30,17 +32,7 @@ class ClickhouseConnect(RDBMSDatabase):
**kwargs: Any, **kwargs: Any,
) -> RDBMSDatabase: ) -> RDBMSDatabase:
db_url: str = ( db_url: str = (
cls.driver f"{cls.driver}://{quote(user)}:{urlquote(pwd)}@{host}:{str(port)}/{db_name}"
+ "://"
+ user
+ ":"
+ pwd
+ "@"
+ host
+ ":"
+ str(port)
+ "/"
+ db_name
) )
return cls.from_uri(db_url, engine_args, **kwargs) return cls.from_uri(db_url, engine_args, **kwargs)

View File

@@ -82,25 +82,3 @@ class DuckDbConnect(RDBMSDatabase):
results.append(f"{table_name}({','.join(table_colums)});") results.append(f"{table_name}({','.join(table_colums)});")
return results return results
if __name__ == "__main__":
engine = create_engine(
"duckdb:////Users/tuyang.yhj/Code/PycharmProjects/DB-GPT/pilot/mock_datas/db-gpt-test.db"
)
metadata = MetaData(engine)
results = (
engine.connect()
.execute("SELECT name FROM sqlite_master WHERE type='table'")
.fetchall()
)
print(str(results))
fields = []
results2 = engine.connect().execute(f"""PRAGMA table_info(user)""").fetchall()
for row_col in results2:
field_info = list(row_col)
fields.append(field_info[1])
print(str(fields))

View File

@@ -1,6 +1,7 @@
from typing import Iterable, Optional, Any from typing import Iterable, Optional, Any
from sqlalchemy import text from sqlalchemy import text
from urllib.parse import quote from urllib.parse import quote
from urllib.parse import quote_plus as urlquote
from pilot.connections.rdbms.base import RDBMSDatabase from pilot.connections.rdbms.base import RDBMSDatabase
@@ -21,17 +22,7 @@ class PostgreSQLDatabase(RDBMSDatabase):
**kwargs: Any, **kwargs: Any,
) -> RDBMSDatabase: ) -> RDBMSDatabase:
db_url: str = ( db_url: str = (
cls.driver f"{cls.driver}://{quote(user)}:{urlquote(pwd)}@{host}:{str(port)}/{db_name}"
+ "://"
+ quote(user)
+ ":"
+ quote(pwd)
+ "@"
+ host
+ ":"
+ str(port)
+ "/"
+ db_name
) )
return cls.from_uri(db_url, engine_args, **kwargs) return cls.from_uri(db_url, engine_args, **kwargs)

View File

@@ -1,7 +1,7 @@
from typing import Iterable, Optional, Any from typing import Iterable, Optional, Any
from sqlalchemy import text from sqlalchemy import text
from urllib.parse import quote from urllib.parse import quote
import re from urllib.parse import quote_plus as urlquote
from pilot.connections.rdbms.base import RDBMSDatabase from pilot.connections.rdbms.base import RDBMSDatabase
from pilot.connections.rdbms.dialect.starrocks.sqlalchemy import * from pilot.connections.rdbms.dialect.starrocks.sqlalchemy import *
@@ -23,7 +23,7 @@ class StarRocksConnect(RDBMSDatabase):
**kwargs: Any, **kwargs: Any,
) -> RDBMSDatabase: ) -> RDBMSDatabase:
db_url: str = ( db_url: str = (
f"{cls.driver}://{quote(user)}:{quote(pwd)}@{host}:{str(port)}/{db_name}" f"{cls.driver}://{quote(user)}:{urlquote(pwd)}@{host}:{str(port)}/{db_name}"
) )
return cls.from_uri(db_url, engine_args, **kwargs) return cls.from_uri(db_url, engine_args, **kwargs)