feat:update params

This commit is contained in:
aries_ckt
2023-09-20 20:22:14 +08:00
parent 7c879e64ea
commit f3caba7d72
3 changed files with 136 additions and 137 deletions

View File

@@ -0,0 +1,123 @@
from typing import Optional, Any
from pyspark.sql import SparkSession, DataFrame
from sqlalchemy import text
from pilot.connections.base import BaseConnect
class SparkConnect(BaseConnect):
"""Spark Connect
Args:
Usage:
"""
"""db type"""
db_type: str = "spark"
"""db driver"""
driver: str = "spark"
"""db dialect"""
dialect: str = "sparksql"
def __init__(
self,
file_path: str,
spark_session: Optional[SparkSession] = None,
engine_args: Optional[dict]= None,
**kwargs: Any
) -> None:
"""Initialize the Spark DataFrame from Datasource path
return: Spark DataFrame
"""
self.spark_session = (
spark_session or SparkSession.builder.appName("dbgpt").getOrCreate()
)
self.path = file_path
self.table_name = "temp"
self.df = self.create_df(self.path)
@classmethod
def from_file_path(
cls, file_path: str, engine_args: Optional[dict] = None, **kwargs: Any
):
try:
return cls(file_path=file_path, engine_args=engine_args)
except Exception as e:
print("load spark datasource error" + str(e))
def create_df(self, path) -> DataFrame:
"""Create a Spark DataFrame from Datasource path
return: Spark DataFrame
"""
return self.spark_session.read.option("header", "true").csv(path)
def run(self, sql):
# self.log(f"llm ingestion sql query is :\n{sql}")
# self.df = self.create_df(self.path)
self.df.createOrReplaceTempView(self.table_name)
df = self.spark_session.sql(sql)
first_row = df.first()
rows = [first_row.asDict().keys()]
for row in df.collect():
rows.append(row)
return rows
def query_ex(self, sql):
rows = self.run(sql)
field_names = rows[0]
return field_names, rows
def get_indexes(self, table_name):
"""Get table indexes about specified table."""
return ""
def get_show_create_table(self, table_name):
"""Get table show create table about specified table."""
return "ans"
def get_fields(self):
"""Get column meta about dataframe."""
return ",".join([f"({name}: {dtype})" for name, dtype in self.df.dtypes])
def get_users(self):
return []
def get_grants(self):
return []
def get_collation(self):
"""Get collation."""
return "UTF-8"
def get_charset(self):
return "UTF-8"
def get_db_list(self):
return ["default"]
def get_db_names(self):
return ["default"]
def get_database_list(self):
return []
def get_database_names(self):
return []
def table_simple_info(self):
return f"{self.table_name}{self.get_fields()}"
def get_table_comments(self, db_name):
session = self._db_sessions()
cursor = session.execute(
text(
f"""SELECT table, comment FROM system.tables WHERE database = '{db_name}'""".format(
db_name
)
)
)
table_comments = cursor.fetchall()
return [
(table_comment[0], table_comment[1]) for table_comment in table_comments
]

View File

@@ -1,132 +0,0 @@
import re
from typing import Optional, Any
from sqlalchemy import text
from pyspark.sql import SparkSession, DataFrame
class SparkConnect:
"""Connect Spark
Args:
Usage:
"""
"""db type"""
# db_type: str = "spark"
"""db driver"""
driver: str = "spark"
"""db dialect"""
# db_dialect: str = "spark"
def __init__(
self,
spark_session: Optional[SparkSession] = None,
) -> None:
self.spark_session = spark_session or SparkSession.builder.appName("dbgpt").master("local[*]").config("spark.sql.catalogImplementation", "hive").getOrCreate()
def create_df(self, path)-> DataFrame:
"""load path into spark"""
path = "/Users/chenketing/Downloads/Warehouse_and_Retail_Sales.csv"
return self.spark_session.read.csv(path)
def run(self, sql):
self.spark_session.sql(sql)
return sql.show()
def get_indexes(self, table_name):
"""Get table indexes about specified table."""
return ""
def get_show_create_table(self, table_name):
"""Get table show create table about specified table."""
session = self._db_sessions()
cursor = session.execute(text(f"SHOW CREATE TABLE {table_name}"))
ans = cursor.fetchall()
ans = ans[0][0]
ans = re.sub(r"\s*ENGINE\s*=\s*MergeTree\s*", " ", ans, flags=re.IGNORECASE)
ans = re.sub(
r"\s*DEFAULT\s*CHARSET\s*=\s*\w+\s*", " ", ans, flags=re.IGNORECASE
)
ans = re.sub(r"\s*SETTINGS\s*\s*\w+\s*", " ", ans, flags=re.IGNORECASE)
return ans
def get_fields(self, df: DataFrame):
"""Get column fields about specified table."""
return "\n".join([f"{name}: {dtype}" for name, dtype in df.dtypes])
def get_users(self):
return []
def get_grants(self):
return []
def get_collation(self):
"""Get collation."""
return "UTF-8"
def get_charset(self):
return "UTF-8"
def get_database_list(self):
return []
def get_database_names(self):
return []
def get_table_comments(self, db_name):
session = self._db_sessions()
cursor = session.execute(
text(
f"""SELECT table, comment FROM system.tables WHERE database = '{db_name}'""".format(
db_name
)
)
)
table_comments = cursor.fetchall()
return [
(table_comment[0], table_comment[1]) for table_comment in table_comments
]
if __name__ == "__main__":
# spark = SparkSession.builder \
# .appName("Spark-SQL and sqlalchemy") \
# .getOrCreate()
db_url = "spark://B-V0ECMD6R-0244.local:7077"
# engine = create_engine(db_url)
# engine = create_engine("sparksql:///?Server=127.0.0.1")
# sqlContext = SQLContext(spark)
# df = sqlContext.read.format("jdbc").option("url", db_url).option("dbtable",
# "person").option(
# "user", "username").option("password", "password").load()
spark = (
SparkSession.builder.appName("ckt")
.master("local[*]")
# .config("hive.metastore.uris", "thrift://localhost:9083")
# .config("spark.sql.warehouse.dir", "/Users/chenketing/myhive/")
.config("spark.sql.catalogImplementation", "hive")
.enableHiveSupport()
.getOrCreate()
)
# sqlContext.read.jdbc(url='jdbc:hive2://127.0.0.1:10000/default', table='pokes', properties=connProps)
# df = spark.read.format("jdbc").option("url", "jdbc:hive2://localhost:10000/dbgpt_test").option("dbtable", "dbgpt_test.dbgpt_table").option("driver", "org.apache.hive.jdbc.HiveDriver").load()
path = "/Users/chenketing/Downloads/Warehouse_and_Retail_Sales.csv"
df = spark.read.csv(path)
df.createOrReplaceTempView("warehouse1")
# df = spark.sql("show databases")
# spark.sql("CREATE TABLE IF NOT EXISTS default.db_test (id INT, name STRING)")
# df = spark.sql("DESC default.db_test")
# spark.sql("INSERT INTO default.db_test VALUES (1, 'Alice')")
# spark.sql("INSERT INTO default.db_test VALUES (2, 'Bob')")
# spark.sql("INSERT INTO default.db_test VALUES (3, 'Charlie')")
# df = spark.sql("select * from default.db_test")
print(spark.sql("SELECT * FROM warehouse1 limit 5").show())
# connection = engine.connect()
# 执行Spark SQL查询
# result = connection.execute("SELECT * FROM warehouse")

View File

@@ -1,10 +1,11 @@
from typing import List
from fastapi import APIRouter
from pilot.component import ComponentType
from pilot.configs.config import Config
from pilot.model.base import ModelInstance, WorkerApplyType
from pilot.model.cluster import WorkerStartupRequest, WorkerManager
from pilot.model.cluster import WorkerStartupRequest
from pilot.openapi.api_view_model import Result
from pilot.server.llm_manage.request.request import ModelResponse
@@ -21,15 +22,21 @@ async def model_params():
worker_manager = CFG.SYSTEM_APP.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
return Result.succ(await worker_manager.supported_models())
params = []
workers = await worker_manager.supported_models()
for worker in workers:
for model in worker.models:
model_dict = model.__dict__
model_dict["host"] = worker.host
model_dict["port"] = worker.port
params.append(model_dict)
return Result.succ(params)
if not worker_instance:
return Result.faild(code="E000X", msg=f"can not find worker manager")
except Exception as e:
return Result.faild(code="E000X", msg=f"model stop failed {e}")
@router.get("/v1/worker/model/list")
async def model_list():
print(f"/worker/model/list")
@@ -81,6 +88,7 @@ async def model_start(request: WorkerStartupRequest):
instances = await controller.get_all_instances(
model_name="WorkerManager@service", healthy_only=True
)
request.params = {}
worker_instance = None
for instance in instances:
if instance.host == request.host and instance.port == request.port: