diff --git a/pilot/connections/conn_spark.py b/pilot/connections/conn_spark.py index d46be65f1..d89ca9399 100644 --- a/pilot/connections/conn_spark.py +++ b/pilot/connections/conn_spark.py @@ -17,12 +17,11 @@ class SparkConnect(BaseConnect): driver: str = "spark" """db dialect""" dialect: str = "sparksql" - from pyspark.sql import SparkSession, DataFrame def __init__( self, file_path: str, - spark_session: Optional[SparkSession] = None, + spark_session: Optional = None, engine_args: Optional[dict] = None, **kwargs: Any, ) -> None: @@ -30,6 +29,7 @@ class SparkConnect(BaseConnect): return: Spark DataFrame """ from pyspark.sql import SparkSession + self.spark_session = ( spark_session or SparkSession.builder.appName("dbgpt_spark").getOrCreate() ) @@ -47,7 +47,7 @@ class SparkConnect(BaseConnect): except Exception as e: print("load spark datasource error" + str(e)) - def create_df(self, path) -> DataFrame: + def create_df(self, path): """Create a Spark DataFrame from Datasource path(now support parquet, jdbc, orc, libsvm, csv, text, json.). return: Spark DataFrame reference:https://spark.apache.org/docs/latest/sql-data-sources-load-save-functions.html