diff --git a/pilot/connections/conn_spark.py b/pilot/connections/conn_spark.py new file mode 100644 index 000000000..ac404f5c9 --- /dev/null +++ b/pilot/connections/conn_spark.py @@ -0,0 +1,123 @@ +from typing import Optional, Any +from pyspark.sql import SparkSession, DataFrame +from sqlalchemy import text + +from pilot.connections.base import BaseConnect + + +class SparkConnect(BaseConnect): + """Spark Connect + Args: + Usage: + """ + + """db type""" + db_type: str = "spark" + """db driver""" + driver: str = "spark" + """db dialect""" + dialect: str = "sparksql" + + def __init__( + self, + file_path: str, + spark_session: Optional[SparkSession] = None, + engine_args: Optional[dict]= None, + **kwargs: Any + ) -> None: + """Initialize the Spark DataFrame from Datasource path + return: Spark DataFrame + """ + self.spark_session = ( + spark_session or SparkSession.builder.appName("dbgpt").getOrCreate() + ) + self.path = file_path + self.table_name = "temp" + self.df = self.create_df(self.path) + + @classmethod + def from_file_path( + cls, file_path: str, engine_args: Optional[dict] = None, **kwargs: Any + ): + try: + return cls(file_path=file_path, engine_args=engine_args) + + except Exception as e: + print("load spark datasource error" + str(e)) + + def create_df(self, path) -> DataFrame: + """Create a Spark DataFrame from Datasource path + return: Spark DataFrame + """ + return self.spark_session.read.option("header", "true").csv(path) + + def run(self, sql): + # self.log(f"llm ingestion sql query is :\n{sql}") + # self.df = self.create_df(self.path) + self.df.createOrReplaceTempView(self.table_name) + df = self.spark_session.sql(sql) + first_row = df.first() + rows = [first_row.asDict().keys()] + for row in df.collect(): + rows.append(row) + return rows + + def query_ex(self, sql): + rows = self.run(sql) + field_names = rows[0] + return field_names, rows + + def get_indexes(self, table_name): + """Get table indexes about specified table.""" + return "" + + def get_show_create_table(self, table_name): + """Get table show create table about specified table.""" + + return "ans" + + def get_fields(self): + """Get column meta about dataframe.""" + return ",".join([f"({name}: {dtype})" for name, dtype in self.df.dtypes]) + + def get_users(self): + return [] + + def get_grants(self): + return [] + + def get_collation(self): + """Get collation.""" + return "UTF-8" + + def get_charset(self): + return "UTF-8" + + def get_db_list(self): + return ["default"] + + def get_db_names(self): + return ["default"] + + def get_database_list(self): + return [] + + def get_database_names(self): + return [] + + def table_simple_info(self): + return f"{self.table_name}{self.get_fields()}" + + def get_table_comments(self, db_name): + session = self._db_sessions() + cursor = session.execute( + text( + f"""SELECT table, comment FROM system.tables WHERE database = '{db_name}'""".format( + db_name + ) + ) + ) + table_comments = cursor.fetchall() + return [ + (table_comment[0], table_comment[1]) for table_comment in table_comments + ] diff --git a/pilot/connections/rdbms/conn_spark.py b/pilot/connections/rdbms/conn_spark.py deleted file mode 100644 index aca5f9f4a..000000000 --- a/pilot/connections/rdbms/conn_spark.py +++ /dev/null @@ -1,132 +0,0 @@ -import re -from typing import Optional, Any - -from sqlalchemy import text - -from pyspark.sql import SparkSession, DataFrame - - - -class SparkConnect: - """Connect Spark - Args: - Usage: - """ - - """db type""" - # db_type: str = "spark" - """db driver""" - driver: str = "spark" - """db dialect""" - # db_dialect: str = "spark" - def __init__( - self, - spark_session: Optional[SparkSession] = None, - ) -> None: - self.spark_session = spark_session or SparkSession.builder.appName("dbgpt").master("local[*]").config("spark.sql.catalogImplementation", "hive").getOrCreate() - - def create_df(self, path)-> DataFrame: - """load path into spark""" - path = "/Users/chenketing/Downloads/Warehouse_and_Retail_Sales.csv" - return self.spark_session.read.csv(path) - - def run(self, sql): - self.spark_session.sql(sql) - return sql.show() - - - def get_indexes(self, table_name): - """Get table indexes about specified table.""" - return "" - - def get_show_create_table(self, table_name): - """Get table show create table about specified table.""" - session = self._db_sessions() - cursor = session.execute(text(f"SHOW CREATE TABLE {table_name}")) - ans = cursor.fetchall() - ans = ans[0][0] - ans = re.sub(r"\s*ENGINE\s*=\s*MergeTree\s*", " ", ans, flags=re.IGNORECASE) - ans = re.sub( - r"\s*DEFAULT\s*CHARSET\s*=\s*\w+\s*", " ", ans, flags=re.IGNORECASE - ) - ans = re.sub(r"\s*SETTINGS\s*\s*\w+\s*", " ", ans, flags=re.IGNORECASE) - return ans - - def get_fields(self, df: DataFrame): - """Get column fields about specified table.""" - return "\n".join([f"{name}: {dtype}" for name, dtype in df.dtypes]) - - def get_users(self): - return [] - - def get_grants(self): - return [] - - def get_collation(self): - """Get collation.""" - return "UTF-8" - - def get_charset(self): - return "UTF-8" - - def get_database_list(self): - return [] - - def get_database_names(self): - return [] - - def get_table_comments(self, db_name): - session = self._db_sessions() - cursor = session.execute( - text( - f"""SELECT table, comment FROM system.tables WHERE database = '{db_name}'""".format( - db_name - ) - ) - ) - table_comments = cursor.fetchall() - return [ - (table_comment[0], table_comment[1]) for table_comment in table_comments - ] - - -if __name__ == "__main__": - # spark = SparkSession.builder \ - # .appName("Spark-SQL and sqlalchemy") \ - # .getOrCreate() - db_url = "spark://B-V0ECMD6R-0244.local:7077" - # engine = create_engine(db_url) - # engine = create_engine("sparksql:///?Server=127.0.0.1") - # sqlContext = SQLContext(spark) - # df = sqlContext.read.format("jdbc").option("url", db_url).option("dbtable", - # "person").option( - # "user", "username").option("password", "password").load() - spark = ( - SparkSession.builder.appName("ckt") - .master("local[*]") - # .config("hive.metastore.uris", "thrift://localhost:9083") - # .config("spark.sql.warehouse.dir", "/Users/chenketing/myhive/") - .config("spark.sql.catalogImplementation", "hive") - .enableHiveSupport() - .getOrCreate() - ) - # sqlContext.read.jdbc(url='jdbc:hive2://127.0.0.1:10000/default', table='pokes', properties=connProps) - - # df = spark.read.format("jdbc").option("url", "jdbc:hive2://localhost:10000/dbgpt_test").option("dbtable", "dbgpt_test.dbgpt_table").option("driver", "org.apache.hive.jdbc.HiveDriver").load() - - path = "/Users/chenketing/Downloads/Warehouse_and_Retail_Sales.csv" - df = spark.read.csv(path) - df.createOrReplaceTempView("warehouse1") - # df = spark.sql("show databases") - # spark.sql("CREATE TABLE IF NOT EXISTS default.db_test (id INT, name STRING)") - # df = spark.sql("DESC default.db_test") - # spark.sql("INSERT INTO default.db_test VALUES (1, 'Alice')") - # spark.sql("INSERT INTO default.db_test VALUES (2, 'Bob')") - # spark.sql("INSERT INTO default.db_test VALUES (3, 'Charlie')") - # df = spark.sql("select * from default.db_test") - print(spark.sql("SELECT * FROM warehouse1 limit 5").show()) - - # connection = engine.connect() - - # 执行Spark SQL查询 - # result = connection.execute("SELECT * FROM warehouse") diff --git a/pilot/server/llm_manage/api.py b/pilot/server/llm_manage/api.py index e1b45282c..4c0adeabc 100644 --- a/pilot/server/llm_manage/api.py +++ b/pilot/server/llm_manage/api.py @@ -1,10 +1,11 @@ +from typing import List + from fastapi import APIRouter from pilot.component import ComponentType from pilot.configs.config import Config -from pilot.model.base import ModelInstance, WorkerApplyType -from pilot.model.cluster import WorkerStartupRequest, WorkerManager +from pilot.model.cluster import WorkerStartupRequest from pilot.openapi.api_view_model import Result from pilot.server.llm_manage.request.request import ModelResponse @@ -21,15 +22,21 @@ async def model_params(): worker_manager = CFG.SYSTEM_APP.get_component( ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory ).create() - return Result.succ(await worker_manager.supported_models()) + params = [] + workers = await worker_manager.supported_models() + for worker in workers: + for model in worker.models: + model_dict = model.__dict__ + model_dict["host"] = worker.host + model_dict["port"] = worker.port + params.append(model_dict) + return Result.succ(params) if not worker_instance: return Result.faild(code="E000X", msg=f"can not find worker manager") except Exception as e: return Result.faild(code="E000X", msg=f"model stop failed {e}") - - @router.get("/v1/worker/model/list") async def model_list(): print(f"/worker/model/list") @@ -81,6 +88,7 @@ async def model_start(request: WorkerStartupRequest): instances = await controller.get_all_instances( model_name="WorkerManager@service", healthy_only=True ) + request.params = {} worker_instance = None for instance in instances: if instance.host == request.host and instance.port == request.port: