mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-10-22 17:39:02 +00:00
Co-authored-by: Fangyin Cheng <staneyffer@gmail.com> Co-authored-by: lcx01800250 <lcx01800250@alibaba-inc.com> Co-authored-by: licunxing <864255598@qq.com> Co-authored-by: Aralhi <xiaoping0501@gmail.com> Co-authored-by: xuyuan23 <643854343@qq.com> Co-authored-by: aries_ckt <916701291@qq.com> Co-authored-by: hzh97 <2976151305@qq.com>
110 lines
3.6 KiB
Python
110 lines
3.6 KiB
Python
import logging
|
|
from contextlib import contextmanager
|
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
|
|
from dbgpt.agent.resource.resource_api import AgentResource
|
|
|
|
from .resource_api import ResourceClient, ResourceType
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ResourceDbClient(ResourceClient):
|
|
@property
|
|
def type(self):
|
|
return ResourceType.DB
|
|
|
|
def get_data_type(self, resource: AgentResource) -> str:
|
|
return super().get_data_type(resource)
|
|
|
|
async def get_data_introduce(
|
|
self, resource: AgentResource, question: Optional[str] = None
|
|
) -> str:
|
|
return await self.a_get_schema_link(resource.value, question)
|
|
|
|
async def a_get_schema_link(self, db: str, question: Optional[str] = None) -> str:
|
|
raise NotImplementedError("The run method should be implemented in a subclass.")
|
|
|
|
async def a_query_to_df(self, dbe: str, sql: str):
|
|
raise NotImplementedError("The run method should be implemented in a subclass.")
|
|
|
|
async def a_query(self, db: str, sql: str):
|
|
raise NotImplementedError("The run method should be implemented in a subclass.")
|
|
|
|
async def a_run_sql(self, db: str, sql: str):
|
|
raise NotImplementedError("The run method should be implemented in a subclass.")
|
|
|
|
|
|
class SqliteLoadClient(ResourceDbClient):
|
|
from sqlalchemy.orm.session import Session
|
|
|
|
def __init__(self):
|
|
super(SqliteLoadClient, self).__init__()
|
|
|
|
def get_data_type(self, resource: AgentResource) -> str:
|
|
return "sqlite"
|
|
|
|
@contextmanager
|
|
def connect(self, db) -> Session:
|
|
from sqlalchemy import create_engine, text
|
|
from sqlalchemy.orm import sessionmaker
|
|
|
|
engine = create_engine("sqlite:///" + db, echo=True)
|
|
Session = sessionmaker(bind=engine)
|
|
session = Session()
|
|
try:
|
|
yield session
|
|
session.commit()
|
|
except:
|
|
session.rollback()
|
|
raise
|
|
finally:
|
|
session.close()
|
|
|
|
async def a_get_schema_link(self, db: str, question: Optional[str] = None) -> str:
|
|
from sqlalchemy import text
|
|
|
|
with self.connect(db) as connect:
|
|
_tables_sql = f"""
|
|
SELECT name FROM sqlite_master WHERE type='table'
|
|
"""
|
|
cursor = connect.execute(text(_tables_sql))
|
|
tables_results = cursor.fetchall()
|
|
results = []
|
|
for row in tables_results:
|
|
table_name = row[0]
|
|
_sql = f"""
|
|
PRAGMA table_info({table_name})
|
|
"""
|
|
cursor_colums = connect.execute(text(_sql))
|
|
colum_results = cursor_colums.fetchall()
|
|
table_colums = []
|
|
for row_col in colum_results:
|
|
field_info = list(row_col)
|
|
table_colums.append(field_info[1])
|
|
|
|
results.append(f"{table_name}({','.join(table_colums)});")
|
|
return results
|
|
|
|
async def a_query_to_df(self, db: str, sql: str):
|
|
import pandas as pd
|
|
|
|
field_names, result = await self.a_query(db, sql)
|
|
return pd.DataFrame(result, columns=field_names)
|
|
|
|
async def a_query(self, db: str, sql: str):
|
|
from sqlalchemy import text
|
|
|
|
with self.connect(db) as connect:
|
|
logger.info(f"Query[{sql}]")
|
|
if not sql:
|
|
return []
|
|
cursor = connect.execute(text(sql))
|
|
if cursor.returns_rows:
|
|
result = cursor.fetchall()
|
|
field_names = tuple(i[0:] for i in cursor.keys())
|
|
return field_names, result
|
|
|
|
async def a_run_sql(self, db: str, sql: str):
|
|
pass
|