mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-05 11:01:09 +00:00
feat:add spark
This commit is contained in:
135
pilot/connections/rdbms/conn_spark.py
Normal file
135
pilot/connections/rdbms/conn_spark.py
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
import re
|
||||||
|
from typing import Optional, Any
|
||||||
|
|
||||||
|
from pyspark import SQLContext
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
from pilot.connections.rdbms.base import RDBMSDatabase
|
||||||
|
from pyspark.sql import SparkSession, DataFrame
|
||||||
|
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
|
||||||
|
|
||||||
|
class SparkConnect:
|
||||||
|
"""Connect Spark
|
||||||
|
Args:
|
||||||
|
Usage:
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""db type"""
|
||||||
|
# db_type: str = "spark"
|
||||||
|
"""db driver"""
|
||||||
|
driver: str = "spark"
|
||||||
|
"""db dialect"""
|
||||||
|
# db_dialect: str = "spark"
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
spark_session: Optional[SparkSession] = None,
|
||||||
|
) -> None:
|
||||||
|
self.spark_session = spark_session or SparkSession.builder.appName("dbgpt").master("local[*]").config("spark.sql.catalogImplementation", "hive").getOrCreate()
|
||||||
|
|
||||||
|
def create_df(self, path)-> DataFrame:
|
||||||
|
"""load path into spark"""
|
||||||
|
path = "/Users/chenketing/Downloads/Warehouse_and_Retail_Sales.csv"
|
||||||
|
return self.spark_session.read.csv(path)
|
||||||
|
|
||||||
|
def run(self, sql):
|
||||||
|
self.spark_session.sql(sql)
|
||||||
|
return sql.show()
|
||||||
|
|
||||||
|
|
||||||
|
def get_indexes(self, table_name):
|
||||||
|
"""Get table indexes about specified table."""
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def get_show_create_table(self, table_name):
|
||||||
|
"""Get table show create table about specified table."""
|
||||||
|
session = self._db_sessions()
|
||||||
|
cursor = session.execute(text(f"SHOW CREATE TABLE {table_name}"))
|
||||||
|
ans = cursor.fetchall()
|
||||||
|
ans = ans[0][0]
|
||||||
|
ans = re.sub(r"\s*ENGINE\s*=\s*MergeTree\s*", " ", ans, flags=re.IGNORECASE)
|
||||||
|
ans = re.sub(
|
||||||
|
r"\s*DEFAULT\s*CHARSET\s*=\s*\w+\s*", " ", ans, flags=re.IGNORECASE
|
||||||
|
)
|
||||||
|
ans = re.sub(r"\s*SETTINGS\s*\s*\w+\s*", " ", ans, flags=re.IGNORECASE)
|
||||||
|
return ans
|
||||||
|
|
||||||
|
def get_fields(self, df: DataFrame):
|
||||||
|
"""Get column fields about specified table."""
|
||||||
|
return "\n".join([f"{name}: {dtype}" for name, dtype in df.dtypes])
|
||||||
|
|
||||||
|
def get_users(self):
|
||||||
|
return []
|
||||||
|
|
||||||
|
def get_grants(self):
|
||||||
|
return []
|
||||||
|
|
||||||
|
def get_collation(self):
|
||||||
|
"""Get collation."""
|
||||||
|
return "UTF-8"
|
||||||
|
|
||||||
|
def get_charset(self):
|
||||||
|
return "UTF-8"
|
||||||
|
|
||||||
|
def get_database_list(self):
|
||||||
|
return []
|
||||||
|
|
||||||
|
def get_database_names(self):
|
||||||
|
return []
|
||||||
|
|
||||||
|
def get_table_comments(self, db_name):
|
||||||
|
session = self._db_sessions()
|
||||||
|
cursor = session.execute(
|
||||||
|
text(
|
||||||
|
f"""SELECT table, comment FROM system.tables WHERE database = '{db_name}'""".format(
|
||||||
|
db_name
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
table_comments = cursor.fetchall()
|
||||||
|
return [
|
||||||
|
(table_comment[0], table_comment[1]) for table_comment in table_comments
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# spark = SparkSession.builder \
|
||||||
|
# .appName("Spark-SQL and sqlalchemy") \
|
||||||
|
# .getOrCreate()
|
||||||
|
db_url = "spark://B-V0ECMD6R-0244.local:7077"
|
||||||
|
# engine = create_engine(db_url)
|
||||||
|
# engine = create_engine("sparksql:///?Server=127.0.0.1")
|
||||||
|
# sqlContext = SQLContext(spark)
|
||||||
|
# df = sqlContext.read.format("jdbc").option("url", db_url).option("dbtable",
|
||||||
|
# "person").option(
|
||||||
|
# "user", "username").option("password", "password").load()
|
||||||
|
spark = (
|
||||||
|
SparkSession.builder.appName("ckt")
|
||||||
|
.master("local[*]")
|
||||||
|
# .config("hive.metastore.uris", "thrift://localhost:9083")
|
||||||
|
# .config("spark.sql.warehouse.dir", "/Users/chenketing/myhive/")
|
||||||
|
.config("spark.sql.catalogImplementation", "hive")
|
||||||
|
.enableHiveSupport()
|
||||||
|
.getOrCreate()
|
||||||
|
)
|
||||||
|
# sqlContext.read.jdbc(url='jdbc:hive2://127.0.0.1:10000/default', table='pokes', properties=connProps)
|
||||||
|
|
||||||
|
# df = spark.read.format("jdbc").option("url", "jdbc:hive2://localhost:10000/dbgpt_test").option("dbtable", "dbgpt_test.dbgpt_table").option("driver", "org.apache.hive.jdbc.HiveDriver").load()
|
||||||
|
|
||||||
|
path = "/Users/chenketing/Downloads/Warehouse_and_Retail_Sales.csv"
|
||||||
|
df = spark.read.csv(path)
|
||||||
|
df.createOrReplaceTempView("warehouse1")
|
||||||
|
# df = spark.sql("show databases")
|
||||||
|
# spark.sql("CREATE TABLE IF NOT EXISTS default.db_test (id INT, name STRING)")
|
||||||
|
# df = spark.sql("DESC default.db_test")
|
||||||
|
# spark.sql("INSERT INTO default.db_test VALUES (1, 'Alice')")
|
||||||
|
# spark.sql("INSERT INTO default.db_test VALUES (2, 'Bob')")
|
||||||
|
# spark.sql("INSERT INTO default.db_test VALUES (3, 'Charlie')")
|
||||||
|
# df = spark.sql("select * from default.db_test")
|
||||||
|
print(spark.sql("SELECT * FROM warehouse1 limit 5").show())
|
||||||
|
|
||||||
|
# connection = engine.connect()
|
||||||
|
|
||||||
|
# 执行Spark SQL查询
|
||||||
|
# result = connection.execute("SELECT * FROM warehouse")
|
Reference in New Issue
Block a user