"""Spark Connector.""" import logging from typing import TYPE_CHECKING, Any, Optional from .base import BaseConnector if TYPE_CHECKING: from pyspark.sql import SparkSession logger = logging.getLogger(__name__) class SparkConnector(BaseConnector): """Spark Connector. Spark Connect supports operating on a variety of data sources through the DataFrame interface. A DataFrame can be operated on using relational transformations and can also be used to create a temporary view.Registering a DataFrame as a temporary view allows you to run SQL queries over its data. Datasource now support parquet, jdbc, orc, libsvm, csv, text, json. """ """db type""" db_type: str = "spark" """db driver""" driver: str = "spark" """db dialect""" dialect: str = "sparksql" def __init__( self, file_path: str, spark_session: Optional["SparkSession"] = None, **kwargs: Any, ) -> None: """Create a Spark Connector. Args: file_path: file path spark_session: spark session kwargs: other args """ from pyspark.sql import SparkSession self.spark_session = ( spark_session or SparkSession.builder.appName("dbgpt_spark").getOrCreate() ) self.path = file_path self.table_name = "temp" self.df = self.create_df(self.path) @classmethod def from_file_path( cls, file_path: str, engine_args: Optional[dict] = None, **kwargs: Any ) -> "SparkConnector": """Create a new SparkConnector from file path.""" try: return cls(file_path=file_path, engine_args=engine_args, **kwargs) except Exception as e: logger.error("load spark datasource error" + str(e)) raise e def create_df(self, path): """Create a Spark DataFrame. Create a Spark DataFrame from Datasource path(now support parquet, jdbc, orc, libsvm, csv, text, json.). return: Spark DataFrame reference:https://spark.apache.org/docs/latest/sql-data-sources-load-save-functions.html """ extension = ( "text" if path.rsplit(".", 1)[-1] == "txt" else path.rsplit(".", 1)[-1] ) return self.spark_session.read.load( path, format=extension, inferSchema="true", header="true" ) def run(self, sql: str, fetch: str = "all"): """Execute sql command.""" logger.info(f"spark sql to run is {sql}") self.df.createOrReplaceTempView(self.table_name) df = self.spark_session.sql(sql) first_row = df.first() rows = [first_row.asDict().keys()] for row in df.collect(): rows.append(row) return rows def query_ex(self, sql: str): """Execute sql command.""" rows = self.run(sql) field_names = rows[0] return field_names, rows def get_indexes(self, table_name): """Get table indexes about specified table.""" return "" def get_show_create_table(self, table_name): """Get table show create table about specified table.""" return "ans" def get_fields(self, table_name: str): """Get column meta about dataframe. TODO: Support table_name. """ return ",".join([f"({name}: {dtype})" for name, dtype in self.df.dtypes]) def get_collation(self): """Get collation.""" return "UTF-8" def get_db_names(self): """Get database names.""" return ["default"] def get_database_names(self): """Get database names.""" return [] def table_simple_info(self): """Get table simple info.""" return f"{self.table_name}{self.get_fields()}" def get_table_comments(self, db_name): """Get table comments.""" return ""