refactor: The first refactored version for sdk release (#907)

Co-authored-by: chengfangyin2 <chengfangyin3@jd.com>
This commit is contained in:
FangYin Cheng
2023-12-08 14:45:59 +08:00
committed by GitHub
parent e7e4aff667
commit cd725db1fb
573 changed files with 2094 additions and 3571 deletions

View File

@@ -0,0 +1,119 @@
from typing import Optional, Any
from dbgpt.datasource.base import BaseConnect
class SparkConnect(BaseConnect):
"""
Spark Connect supports operating on a variety of data sources through the DataFrame interface.
A DataFrame can be operated on using relational transformations and can also be used to create a temporary view.
Registering a DataFrame as a temporary view allows you to run SQL queries over its data.
Datasource now support parquet, jdbc, orc, libsvm, csv, text, json.
"""
"""db type"""
db_type: str = "spark"
"""db driver"""
driver: str = "spark"
"""db dialect"""
dialect: str = "sparksql"
def __init__(
self,
file_path: str,
spark_session: Optional = None,
engine_args: Optional[dict] = None,
**kwargs: Any,
) -> None:
"""Initialize the Spark DataFrame from Datasource path
return: Spark DataFrame
"""
from pyspark.sql import SparkSession
self.spark_session = (
spark_session or SparkSession.builder.appName("dbgpt_spark").getOrCreate()
)
self.path = file_path
self.table_name = "temp"
self.df = self.create_df(self.path)
@classmethod
def from_file_path(
cls, file_path: str, engine_args: Optional[dict] = None, **kwargs: Any
):
try:
return cls(file_path=file_path, engine_args=engine_args)
except Exception as e:
print("load spark datasource error" + str(e))
def create_df(self, path):
"""Create a Spark DataFrame from Datasource path(now support parquet, jdbc, orc, libsvm, csv, text, json.).
return: Spark DataFrame
reference:https://spark.apache.org/docs/latest/sql-data-sources-load-save-functions.html
"""
extension = (
"text" if path.rsplit(".", 1)[-1] == "txt" else path.rsplit(".", 1)[-1]
)
return self.spark_session.read.load(
path, format=extension, inferSchema="true", header="true"
)
def run(self, sql):
print(f"spark sql to run is {sql}")
self.df.createOrReplaceTempView(self.table_name)
df = self.spark_session.sql(sql)
first_row = df.first()
rows = [first_row.asDict().keys()]
for row in df.collect():
rows.append(row)
return rows
def query_ex(self, sql):
rows = self.run(sql)
field_names = rows[0]
return field_names, rows
def get_indexes(self, table_name):
"""Get table indexes about specified table."""
return ""
def get_show_create_table(self, table_name):
"""Get table show create table about specified table."""
return "ans"
def get_fields(self):
"""Get column meta about dataframe."""
return ",".join([f"({name}: {dtype})" for name, dtype in self.df.dtypes])
def get_users(self):
return []
def get_grants(self):
return []
def get_collation(self):
"""Get collation."""
return "UTF-8"
def get_charset(self):
return "UTF-8"
def get_db_list(self):
return ["default"]
def get_db_names(self):
return ["default"]
def get_database_list(self):
return []
def get_database_names(self):
return []
def table_simple_info(self):
return f"{self.table_name}{self.get_fields()}"
def get_table_comments(self, db_name):
return ""