mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-15 14:11:14 +00:00
feat:update params
This commit is contained in:
123
pilot/connections/conn_spark.py
Normal file
123
pilot/connections/conn_spark.py
Normal file
@@ -0,0 +1,123 @@
|
||||
from typing import Optional, Any
|
||||
from pyspark.sql import SparkSession, DataFrame
|
||||
from sqlalchemy import text
|
||||
|
||||
from pilot.connections.base import BaseConnect
|
||||
|
||||
|
||||
class SparkConnect(BaseConnect):
|
||||
"""Spark Connect
|
||||
Args:
|
||||
Usage:
|
||||
"""
|
||||
|
||||
"""db type"""
|
||||
db_type: str = "spark"
|
||||
"""db driver"""
|
||||
driver: str = "spark"
|
||||
"""db dialect"""
|
||||
dialect: str = "sparksql"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
spark_session: Optional[SparkSession] = None,
|
||||
engine_args: Optional[dict]= None,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
"""Initialize the Spark DataFrame from Datasource path
|
||||
return: Spark DataFrame
|
||||
"""
|
||||
self.spark_session = (
|
||||
spark_session or SparkSession.builder.appName("dbgpt").getOrCreate()
|
||||
)
|
||||
self.path = file_path
|
||||
self.table_name = "temp"
|
||||
self.df = self.create_df(self.path)
|
||||
|
||||
@classmethod
|
||||
def from_file_path(
|
||||
cls, file_path: str, engine_args: Optional[dict] = None, **kwargs: Any
|
||||
):
|
||||
try:
|
||||
return cls(file_path=file_path, engine_args=engine_args)
|
||||
|
||||
except Exception as e:
|
||||
print("load spark datasource error" + str(e))
|
||||
|
||||
def create_df(self, path) -> DataFrame:
|
||||
"""Create a Spark DataFrame from Datasource path
|
||||
return: Spark DataFrame
|
||||
"""
|
||||
return self.spark_session.read.option("header", "true").csv(path)
|
||||
|
||||
def run(self, sql):
|
||||
# self.log(f"llm ingestion sql query is :\n{sql}")
|
||||
# self.df = self.create_df(self.path)
|
||||
self.df.createOrReplaceTempView(self.table_name)
|
||||
df = self.spark_session.sql(sql)
|
||||
first_row = df.first()
|
||||
rows = [first_row.asDict().keys()]
|
||||
for row in df.collect():
|
||||
rows.append(row)
|
||||
return rows
|
||||
|
||||
def query_ex(self, sql):
|
||||
rows = self.run(sql)
|
||||
field_names = rows[0]
|
||||
return field_names, rows
|
||||
|
||||
def get_indexes(self, table_name):
|
||||
"""Get table indexes about specified table."""
|
||||
return ""
|
||||
|
||||
def get_show_create_table(self, table_name):
|
||||
"""Get table show create table about specified table."""
|
||||
|
||||
return "ans"
|
||||
|
||||
def get_fields(self):
|
||||
"""Get column meta about dataframe."""
|
||||
return ",".join([f"({name}: {dtype})" for name, dtype in self.df.dtypes])
|
||||
|
||||
def get_users(self):
|
||||
return []
|
||||
|
||||
def get_grants(self):
|
||||
return []
|
||||
|
||||
def get_collation(self):
|
||||
"""Get collation."""
|
||||
return "UTF-8"
|
||||
|
||||
def get_charset(self):
|
||||
return "UTF-8"
|
||||
|
||||
def get_db_list(self):
|
||||
return ["default"]
|
||||
|
||||
def get_db_names(self):
|
||||
return ["default"]
|
||||
|
||||
def get_database_list(self):
|
||||
return []
|
||||
|
||||
def get_database_names(self):
|
||||
return []
|
||||
|
||||
def table_simple_info(self):
|
||||
return f"{self.table_name}{self.get_fields()}"
|
||||
|
||||
def get_table_comments(self, db_name):
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(
|
||||
text(
|
||||
f"""SELECT table, comment FROM system.tables WHERE database = '{db_name}'""".format(
|
||||
db_name
|
||||
)
|
||||
)
|
||||
)
|
||||
table_comments = cursor.fetchall()
|
||||
return [
|
||||
(table_comment[0], table_comment[1]) for table_comment in table_comments
|
||||
]
|
@@ -1,132 +0,0 @@
|
||||
import re
|
||||
from typing import Optional, Any
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
from pyspark.sql import SparkSession, DataFrame
|
||||
|
||||
|
||||
|
||||
class SparkConnect:
|
||||
"""Connect Spark
|
||||
Args:
|
||||
Usage:
|
||||
"""
|
||||
|
||||
"""db type"""
|
||||
# db_type: str = "spark"
|
||||
"""db driver"""
|
||||
driver: str = "spark"
|
||||
"""db dialect"""
|
||||
# db_dialect: str = "spark"
|
||||
def __init__(
|
||||
self,
|
||||
spark_session: Optional[SparkSession] = None,
|
||||
) -> None:
|
||||
self.spark_session = spark_session or SparkSession.builder.appName("dbgpt").master("local[*]").config("spark.sql.catalogImplementation", "hive").getOrCreate()
|
||||
|
||||
def create_df(self, path)-> DataFrame:
|
||||
"""load path into spark"""
|
||||
path = "/Users/chenketing/Downloads/Warehouse_and_Retail_Sales.csv"
|
||||
return self.spark_session.read.csv(path)
|
||||
|
||||
def run(self, sql):
|
||||
self.spark_session.sql(sql)
|
||||
return sql.show()
|
||||
|
||||
|
||||
def get_indexes(self, table_name):
|
||||
"""Get table indexes about specified table."""
|
||||
return ""
|
||||
|
||||
def get_show_create_table(self, table_name):
|
||||
"""Get table show create table about specified table."""
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(text(f"SHOW CREATE TABLE {table_name}"))
|
||||
ans = cursor.fetchall()
|
||||
ans = ans[0][0]
|
||||
ans = re.sub(r"\s*ENGINE\s*=\s*MergeTree\s*", " ", ans, flags=re.IGNORECASE)
|
||||
ans = re.sub(
|
||||
r"\s*DEFAULT\s*CHARSET\s*=\s*\w+\s*", " ", ans, flags=re.IGNORECASE
|
||||
)
|
||||
ans = re.sub(r"\s*SETTINGS\s*\s*\w+\s*", " ", ans, flags=re.IGNORECASE)
|
||||
return ans
|
||||
|
||||
def get_fields(self, df: DataFrame):
|
||||
"""Get column fields about specified table."""
|
||||
return "\n".join([f"{name}: {dtype}" for name, dtype in df.dtypes])
|
||||
|
||||
def get_users(self):
|
||||
return []
|
||||
|
||||
def get_grants(self):
|
||||
return []
|
||||
|
||||
def get_collation(self):
|
||||
"""Get collation."""
|
||||
return "UTF-8"
|
||||
|
||||
def get_charset(self):
|
||||
return "UTF-8"
|
||||
|
||||
def get_database_list(self):
|
||||
return []
|
||||
|
||||
def get_database_names(self):
|
||||
return []
|
||||
|
||||
def get_table_comments(self, db_name):
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(
|
||||
text(
|
||||
f"""SELECT table, comment FROM system.tables WHERE database = '{db_name}'""".format(
|
||||
db_name
|
||||
)
|
||||
)
|
||||
)
|
||||
table_comments = cursor.fetchall()
|
||||
return [
|
||||
(table_comment[0], table_comment[1]) for table_comment in table_comments
|
||||
]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# spark = SparkSession.builder \
|
||||
# .appName("Spark-SQL and sqlalchemy") \
|
||||
# .getOrCreate()
|
||||
db_url = "spark://B-V0ECMD6R-0244.local:7077"
|
||||
# engine = create_engine(db_url)
|
||||
# engine = create_engine("sparksql:///?Server=127.0.0.1")
|
||||
# sqlContext = SQLContext(spark)
|
||||
# df = sqlContext.read.format("jdbc").option("url", db_url).option("dbtable",
|
||||
# "person").option(
|
||||
# "user", "username").option("password", "password").load()
|
||||
spark = (
|
||||
SparkSession.builder.appName("ckt")
|
||||
.master("local[*]")
|
||||
# .config("hive.metastore.uris", "thrift://localhost:9083")
|
||||
# .config("spark.sql.warehouse.dir", "/Users/chenketing/myhive/")
|
||||
.config("spark.sql.catalogImplementation", "hive")
|
||||
.enableHiveSupport()
|
||||
.getOrCreate()
|
||||
)
|
||||
# sqlContext.read.jdbc(url='jdbc:hive2://127.0.0.1:10000/default', table='pokes', properties=connProps)
|
||||
|
||||
# df = spark.read.format("jdbc").option("url", "jdbc:hive2://localhost:10000/dbgpt_test").option("dbtable", "dbgpt_test.dbgpt_table").option("driver", "org.apache.hive.jdbc.HiveDriver").load()
|
||||
|
||||
path = "/Users/chenketing/Downloads/Warehouse_and_Retail_Sales.csv"
|
||||
df = spark.read.csv(path)
|
||||
df.createOrReplaceTempView("warehouse1")
|
||||
# df = spark.sql("show databases")
|
||||
# spark.sql("CREATE TABLE IF NOT EXISTS default.db_test (id INT, name STRING)")
|
||||
# df = spark.sql("DESC default.db_test")
|
||||
# spark.sql("INSERT INTO default.db_test VALUES (1, 'Alice')")
|
||||
# spark.sql("INSERT INTO default.db_test VALUES (2, 'Bob')")
|
||||
# spark.sql("INSERT INTO default.db_test VALUES (3, 'Charlie')")
|
||||
# df = spark.sql("select * from default.db_test")
|
||||
print(spark.sql("SELECT * FROM warehouse1 limit 5").show())
|
||||
|
||||
# connection = engine.connect()
|
||||
|
||||
# 执行Spark SQL查询
|
||||
# result = connection.execute("SELECT * FROM warehouse")
|
@@ -1,10 +1,11 @@
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from pilot.component import ComponentType
|
||||
from pilot.configs.config import Config
|
||||
from pilot.model.base import ModelInstance, WorkerApplyType
|
||||
|
||||
from pilot.model.cluster import WorkerStartupRequest, WorkerManager
|
||||
from pilot.model.cluster import WorkerStartupRequest
|
||||
from pilot.openapi.api_view_model import Result
|
||||
|
||||
from pilot.server.llm_manage.request.request import ModelResponse
|
||||
@@ -21,15 +22,21 @@ async def model_params():
|
||||
worker_manager = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||
).create()
|
||||
return Result.succ(await worker_manager.supported_models())
|
||||
params = []
|
||||
workers = await worker_manager.supported_models()
|
||||
for worker in workers:
|
||||
for model in worker.models:
|
||||
model_dict = model.__dict__
|
||||
model_dict["host"] = worker.host
|
||||
model_dict["port"] = worker.port
|
||||
params.append(model_dict)
|
||||
return Result.succ(params)
|
||||
if not worker_instance:
|
||||
return Result.faild(code="E000X", msg=f"can not find worker manager")
|
||||
except Exception as e:
|
||||
return Result.faild(code="E000X", msg=f"model stop failed {e}")
|
||||
|
||||
|
||||
|
||||
|
||||
@router.get("/v1/worker/model/list")
|
||||
async def model_list():
|
||||
print(f"/worker/model/list")
|
||||
@@ -81,6 +88,7 @@ async def model_start(request: WorkerStartupRequest):
|
||||
instances = await controller.get_all_instances(
|
||||
model_name="WorkerManager@service", healthy_only=True
|
||||
)
|
||||
request.params = {}
|
||||
worker_instance = None
|
||||
for instance in instances:
|
||||
if instance.host == request.host and instance.port == request.port:
|
||||
|
Reference in New Issue
Block a user