mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-13 05:01:25 +00:00
refactor: The first refactored version for sdk release (#907)
Co-authored-by: chengfangyin2 <chengfangyin3@jd.com>
This commit is contained in:
119
dbgpt/datasource/conn_spark.py
Normal file
119
dbgpt/datasource/conn_spark.py
Normal file
@@ -0,0 +1,119 @@
|
||||
from typing import Optional, Any
|
||||
|
||||
from dbgpt.datasource.base import BaseConnect
|
||||
|
||||
|
||||
class SparkConnect(BaseConnect):
|
||||
"""
|
||||
Spark Connect supports operating on a variety of data sources through the DataFrame interface.
|
||||
A DataFrame can be operated on using relational transformations and can also be used to create a temporary view.
|
||||
Registering a DataFrame as a temporary view allows you to run SQL queries over its data.
|
||||
Datasource now support parquet, jdbc, orc, libsvm, csv, text, json.
|
||||
"""
|
||||
|
||||
"""db type"""
|
||||
db_type: str = "spark"
|
||||
"""db driver"""
|
||||
driver: str = "spark"
|
||||
"""db dialect"""
|
||||
dialect: str = "sparksql"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
spark_session: Optional = None,
|
||||
engine_args: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the Spark DataFrame from Datasource path
|
||||
return: Spark DataFrame
|
||||
"""
|
||||
from pyspark.sql import SparkSession
|
||||
|
||||
self.spark_session = (
|
||||
spark_session or SparkSession.builder.appName("dbgpt_spark").getOrCreate()
|
||||
)
|
||||
self.path = file_path
|
||||
self.table_name = "temp"
|
||||
self.df = self.create_df(self.path)
|
||||
|
||||
@classmethod
|
||||
def from_file_path(
|
||||
cls, file_path: str, engine_args: Optional[dict] = None, **kwargs: Any
|
||||
):
|
||||
try:
|
||||
return cls(file_path=file_path, engine_args=engine_args)
|
||||
|
||||
except Exception as e:
|
||||
print("load spark datasource error" + str(e))
|
||||
|
||||
def create_df(self, path):
|
||||
"""Create a Spark DataFrame from Datasource path(now support parquet, jdbc, orc, libsvm, csv, text, json.).
|
||||
return: Spark DataFrame
|
||||
reference:https://spark.apache.org/docs/latest/sql-data-sources-load-save-functions.html
|
||||
"""
|
||||
extension = (
|
||||
"text" if path.rsplit(".", 1)[-1] == "txt" else path.rsplit(".", 1)[-1]
|
||||
)
|
||||
return self.spark_session.read.load(
|
||||
path, format=extension, inferSchema="true", header="true"
|
||||
)
|
||||
|
||||
def run(self, sql):
|
||||
print(f"spark sql to run is {sql}")
|
||||
self.df.createOrReplaceTempView(self.table_name)
|
||||
df = self.spark_session.sql(sql)
|
||||
first_row = df.first()
|
||||
rows = [first_row.asDict().keys()]
|
||||
for row in df.collect():
|
||||
rows.append(row)
|
||||
return rows
|
||||
|
||||
def query_ex(self, sql):
|
||||
rows = self.run(sql)
|
||||
field_names = rows[0]
|
||||
return field_names, rows
|
||||
|
||||
def get_indexes(self, table_name):
|
||||
"""Get table indexes about specified table."""
|
||||
return ""
|
||||
|
||||
def get_show_create_table(self, table_name):
|
||||
"""Get table show create table about specified table."""
|
||||
|
||||
return "ans"
|
||||
|
||||
def get_fields(self):
|
||||
"""Get column meta about dataframe."""
|
||||
return ",".join([f"({name}: {dtype})" for name, dtype in self.df.dtypes])
|
||||
|
||||
def get_users(self):
|
||||
return []
|
||||
|
||||
def get_grants(self):
|
||||
return []
|
||||
|
||||
def get_collation(self):
|
||||
"""Get collation."""
|
||||
return "UTF-8"
|
||||
|
||||
def get_charset(self):
|
||||
return "UTF-8"
|
||||
|
||||
def get_db_list(self):
|
||||
return ["default"]
|
||||
|
||||
def get_db_names(self):
|
||||
return ["default"]
|
||||
|
||||
def get_database_list(self):
|
||||
return []
|
||||
|
||||
def get_database_names(self):
|
||||
return []
|
||||
|
||||
def table_simple_info(self):
|
||||
return f"{self.table_name}{self.get_fields()}"
|
||||
|
||||
def get_table_comments(self, db_name):
|
||||
return ""
|
Reference in New Issue
Block a user