Merge branch 'eosphoros-ai:main' into main

This commit is contained in:
lozzo 2023-09-25 17:03:32 +08:00 committed by GitHub
commit d6a623ac59
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
53 changed files with 489 additions and 146 deletions

View File

@ -73,6 +73,7 @@ body:
- label: Chat Excel
- label: Chat DB
- label: Chat Knowledge
- label: Model Management
- label: Dashboard
- label: Plugins

Binary file not shown.

Before

Width:  |  Height:  |  Size: 102 KiB

After

Width:  |  Height:  |  Size: 226 KiB

View File

@ -26,7 +26,7 @@ extensions = [
"sphinx.ext.autosummary",
"sphinx.ext.napoleon",
"sphinx.ext.viewcode",
"sphinxcontrib.autodoc_pydantic",
# "sphinxcontrib.autodoc_pydantic",
"myst_nb",
"sphinx_copybutton",
"sphinx_panels",
@ -34,14 +34,14 @@ extensions = [
]
source_suffix = [".ipynb", ".html", ".md", ".rst"]
autodoc_pydantic_model_show_json = False
autodoc_pydantic_field_list_validators = False
autodoc_pydantic_config_members = False
autodoc_pydantic_model_show_config_summary = False
autodoc_pydantic_model_show_validator_members = False
autodoc_pydantic_model_show_field_summary = False
autodoc_pydantic_model_members = False
autodoc_pydantic_model_undoc_members = False
# autodoc_pydantic_model_show_json = False
# autodoc_pydantic_field_list_validators = False
# autodoc_pydantic_config_members = False
# autodoc_pydantic_model_show_config_summary = False
# autodoc_pydantic_model_show_validator_members = False
# autodoc_pydantic_model_show_field_summary = False
# autodoc_pydantic_model_members = False
# autodoc_pydantic_model_undoc_members = False
templates_path = ["_templates"]
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]

View File

@ -13,4 +13,5 @@ toml
myst_nb
sphinx_copybutton
pydata-sphinx-theme==0.13.1
pydantic-settings
furo

View File

@ -0,0 +1,112 @@
from typing import Optional, Any
from pyspark.sql import SparkSession, DataFrame
from sqlalchemy import text
from pilot.connections.base import BaseConnect
class SparkConnect(BaseConnect):
"""Spark Connect
Args:
Usage:
"""
"""db type"""
db_type: str = "spark"
"""db driver"""
driver: str = "spark"
"""db dialect"""
dialect: str = "sparksql"
def __init__(
self,
file_path: str,
spark_session: Optional[SparkSession] = None,
engine_args: Optional[dict] = None,
**kwargs: Any,
) -> None:
"""Initialize the Spark DataFrame from Datasource path
return: Spark DataFrame
"""
self.spark_session = (
spark_session or SparkSession.builder.appName("dbgpt").getOrCreate()
)
self.path = file_path
self.table_name = "temp"
self.df = self.create_df(self.path)
@classmethod
def from_file_path(
cls, file_path: str, engine_args: Optional[dict] = None, **kwargs: Any
):
try:
return cls(file_path=file_path, engine_args=engine_args)
except Exception as e:
print("load spark datasource error" + str(e))
def create_df(self, path) -> DataFrame:
"""Create a Spark DataFrame from Datasource path
return: Spark DataFrame
"""
return self.spark_session.read.option("header", "true").csv(path)
def run(self, sql):
# self.log(f"llm ingestion sql query is :\n{sql}")
# self.df = self.create_df(self.path)
self.df.createOrReplaceTempView(self.table_name)
df = self.spark_session.sql(sql)
first_row = df.first()
rows = [first_row.asDict().keys()]
for row in df.collect():
rows.append(row)
return rows
def query_ex(self, sql):
rows = self.run(sql)
field_names = rows[0]
return field_names, rows
def get_indexes(self, table_name):
"""Get table indexes about specified table."""
return ""
def get_show_create_table(self, table_name):
"""Get table show create table about specified table."""
return "ans"
def get_fields(self):
"""Get column meta about dataframe."""
return ",".join([f"({name}: {dtype})" for name, dtype in self.df.dtypes])
def get_users(self):
return []
def get_grants(self):
return []
def get_collation(self):
"""Get collation."""
return "UTF-8"
def get_charset(self):
return "UTF-8"
def get_db_list(self):
return ["default"]
def get_db_names(self):
return ["default"]
def get_database_list(self):
return []
def get_database_names(self):
return []
def table_simple_info(self):
return f"{self.table_name}{self.get_fields()}"
def get_table_comments(self, db_name):
return ""

View File

@ -8,7 +8,7 @@ from pilot.model.loader import ModelLoader, _get_model_real_path
from pilot.model.parameter import ModelParameters
from pilot.model.cluster.worker_base import ModelWorker
from pilot.server.chat_adapter import get_llm_chat_adapter, BaseChatAdpter
from pilot.utils.model_utils import _clear_torch_cache
from pilot.utils.model_utils import _clear_model_cache
from pilot.utils.parameter_utils import EnvArgumentParser
logger = logging.getLogger(__name__)
@ -87,7 +87,7 @@ class DefaultModelWorker(ModelWorker):
del self.tokenizer
self.model = None
self.tokenizer = None
_clear_torch_cache(self._model_params.device)
_clear_model_cache(self._model_params.device)
def generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
torch_imported = False

View File

@ -11,7 +11,7 @@ from pilot.model.parameter import (
)
from pilot.model.cluster.worker_base import ModelWorker
from pilot.model.cluster.embedding.loader import EmbeddingLoader
from pilot.utils.model_utils import _clear_torch_cache
from pilot.utils.model_utils import _clear_model_cache
from pilot.utils.parameter_utils import EnvArgumentParser
logger = logging.getLogger(__name__)
@ -79,7 +79,7 @@ class EmbeddingsModelWorker(ModelWorker):
return
del self._embeddings_impl
self._embeddings_impl = None
_clear_torch_cache(self._model_params.device)
_clear_model_cache(self._model_params.device)
def generate_stream(self, params: Dict):
"""Generate stream result, chat scene"""

View File

@ -1,17 +1,18 @@
import asyncio
import itertools
import json
import os
import sys
import random
import time
import logging
import os
import random
import sys
import time
from concurrent.futures import ThreadPoolExecutor
from dataclasses import asdict
from typing import Awaitable, Callable, Dict, Iterator, List, Optional
from fastapi import APIRouter, FastAPI
from fastapi.responses import StreamingResponse
from pilot.component import SystemApp
from pilot.model.base import (
ModelInstance,
@ -20,16 +21,16 @@ from pilot.model.base import (
WorkerApplyType,
WorkerSupportedModel,
)
from pilot.model.cluster.registry import ModelRegistry
from pilot.model.llm_utils import list_supported_models
from pilot.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType
from pilot.model.cluster.worker_base import ModelWorker
from pilot.model.cluster.base import *
from pilot.model.cluster.manager_base import (
WorkerManager,
WorkerRunData,
WorkerManagerFactory,
WorkerRunData,
)
from pilot.model.cluster.base import *
from pilot.model.cluster.registry import ModelRegistry
from pilot.model.cluster.worker_base import ModelWorker
from pilot.model.llm_utils import list_supported_models
from pilot.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType
from pilot.utils.parameter_utils import (
EnvArgumentParser,
ParameterDescription,
@ -639,6 +640,10 @@ def _setup_fastapi(worker_params: ModelWorkerParameters, app=None):
)
if not worker_params.controller_addr:
# if we have http_proxy or https_proxy in env, the server can not start
# so set it to empty here
os.environ["http_proxy"] = ""
os.environ["https_proxy"] = ""
worker_params.controller_addr = f"http://127.0.0.1:{worker_params.port}"
logger.info(
f"Run WorkerManager with standalone mode, controller_addr: {worker_params.controller_addr}"

View File

@ -18,6 +18,7 @@ from pilot.logs import logger
def _check_multi_gpu_or_4bit_quantization(model_params: ModelParameters):
# TODO: vicuna-v1.5 8-bit quantization info is slow
# TODO: support wizardlm quantization, see: https://huggingface.co/WizardLM/WizardLM-13B-V1.2/discussions/5
# TODO: support internlm quantization
model_name = model_params.model_name.lower()
supported_models = ["llama", "baichuan", "vicuna"]
return any(m in model_name for m in supported_models)

View File

@ -1,5 +1,6 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, Optional
@ -246,6 +247,11 @@ class ProxyModelParameters(BaseModelParameters):
proxy_api_key: str = field(
metadata={"tags": "privacy", "help": "The api key of current proxy LLM"},
)
http_proxy: Optional[str] = field(
default=os.environ.get("http_proxy") or os.environ.get("https_proxy"),
metadata={"help": "The http or https proxy to use openai"},
)
proxyllm_backend: Optional[str] = field(
default=None,
metadata={

View File

@ -26,6 +26,9 @@ from pilot.openapi.api_view_model import (
ConversationVo,
MessageVo,
ChatSceneVo,
ChatCompletionResponseStreamChoice,
DeltaMessage,
ChatCompletionStreamResponse,
)
from pilot.connections.db_conn_info import DBConfig, DbTypeInfo
from pilot.configs.config import Config
@ -383,7 +386,7 @@ async def chat_completions(dialogue: ConversationVo = Body()):
)
else:
return StreamingResponse(
stream_generator(chat),
stream_generator(chat, dialogue.incremental, dialogue.model_name),
headers=headers,
media_type="text/plain",
)
@ -421,19 +424,48 @@ async def no_stream_generator(chat):
yield f"data: {msg}\n\n"
async def stream_generator(chat):
async def stream_generator(chat, incremental: bool, model_name: str):
"""Generate streaming responses
Our goal is to generate an openai-compatible streaming responses.
Currently, the incremental response is compatible, and the full response will be transformed in the future.
Args:
chat (BaseChat): Chat instance.
incremental (bool): Used to control whether the content is returned incrementally or in full each time.
model_name (str): The model name
Yields:
_type_: streaming responses
"""
msg = "[LLM_ERROR]: llm server has no output, maybe your prompt template is wrong."
stream_id = f"chatcmpl-{str(uuid.uuid1())}"
previous_response = ""
async for chunk in chat.stream_call():
if chunk:
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
chunk, chat.skip_echo_len
)
msg = msg.replace("\n", "\\n")
yield f"data:{msg}\n\n"
msg = msg.replace("\ufffd", "")
if incremental:
incremental_output = msg[len(previous_response) :]
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant", content=incremental_output),
)
chunk = ChatCompletionStreamResponse(
id=stream_id, choices=[choice_data], model=model_name
)
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
else:
# TODO generate an openai-compatible streaming responses
msg = msg.replace("\n", "\\n")
yield f"data:{msg}\n\n"
previous_response = msg
await asyncio.sleep(0.02)
if incremental:
yield "data: [DONE]\n\n"
chat.current_message.add_ai_message(msg)
chat.current_message.add_view_message(msg)
chat.memory.append(chat.current_message)

View File

@ -1,5 +1,7 @@
from pydantic import BaseModel, Field
from typing import TypeVar, Generic, Any
from typing import TypeVar, Generic, Any, Optional, Literal, List
import uuid
import time
T = TypeVar("T")
@ -59,6 +61,11 @@ class ConversationVo(BaseModel):
"""
model_name: str = None
"""Used to control whether the content is returned incrementally or in full each time.
If this parameter is not provided, the default is full return.
"""
incremental: bool = False
class MessageVo(BaseModel):
"""
@ -83,3 +90,21 @@ class MessageVo(BaseModel):
model_name
"""
model_name: str
class DeltaMessage(BaseModel):
role: Optional[str] = None
content: Optional[str] = None
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length"]] = None
class ChatCompletionStreamResponse(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{str(uuid.uuid1())}")
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseStreamChoice]

View File

@ -24,6 +24,7 @@ from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from pilot.server.knowledge.api import router as knowledge_router
from pilot.server.prompt.api import router as prompt_router
from pilot.server.llm_manage.api import router as llm_manage_api
from pilot.openapi.api_v1.api_v1 import router as api_v1
@ -73,6 +74,7 @@ app.add_middleware(
app.include_router(api_v1, prefix="/api")
app.include_router(knowledge_router, prefix="/api")
app.include_router(api_editor_route_v1, prefix="/api")
app.include_router(llm_manage_api, prefix="/api")
app.include_router(api_fb_v1, prefix="/api")
# app.include_router(api_v1)

View File

@ -0,0 +1,108 @@
from typing import List
from fastapi import APIRouter
from pilot.component import ComponentType
from pilot.configs.config import Config
from pilot.model.cluster import WorkerStartupRequest, WorkerManagerFactory
from pilot.openapi.api_view_model import Result
from pilot.server.llm_manage.request.request import ModelResponse
CFG = Config()
router = APIRouter()
@router.get("/v1/worker/model/params")
async def model_params():
print(f"/worker/model/params")
try:
from pilot.model.cluster import WorkerManagerFactory
worker_manager = CFG.SYSTEM_APP.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
params = []
workers = await worker_manager.supported_models()
for worker in workers:
for model in worker.models:
model_dict = model.__dict__
model_dict["host"] = worker.host
model_dict["port"] = worker.port
params.append(model_dict)
return Result.succ(params)
if not worker_instance:
return Result.faild(code="E000X", msg=f"can not find worker manager")
except Exception as e:
return Result.faild(code="E000X", msg=f"model stop failed {e}")
@router.get("/v1/worker/model/list")
async def model_list():
print(f"/worker/model/list")
try:
from pilot.model.cluster.controller.controller import BaseModelController
controller = CFG.SYSTEM_APP.get_component(
ComponentType.MODEL_CONTROLLER, BaseModelController
)
responses = []
managers = await controller.get_all_instances(
model_name="WorkerManager@service", healthy_only=True
)
manager_map = dict(map(lambda manager: (manager.host, manager), managers))
models = await controller.get_all_instances()
for model in models:
worker_name, worker_type = model.model_name.split("@")
if worker_type == "llm" or worker_type == "text2vec":
response = ModelResponse(
model_name=worker_name,
model_type=worker_type,
host=model.host,
port=model.port,
healthy=model.healthy,
check_healthy=model.check_healthy,
last_heartbeat=model.last_heartbeat,
prompt_template=model.prompt_template,
)
response.manager_host = model.host if manager_map[model.host] else None
response.manager_port = (
manager_map[model.host].port if manager_map[model.host] else None
)
responses.append(response)
return Result.succ(responses)
except Exception as e:
return Result.faild(code="E000X", msg=f"space list error {e}")
@router.post("/v1/worker/model/stop")
async def model_stop(request: WorkerStartupRequest):
print(f"/v1/worker/model/stop:")
try:
from pilot.model.cluster.controller.controller import BaseModelController
worker_manager = CFG.SYSTEM_APP.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
if not worker_manager:
return Result.faild(code="E000X", msg=f"can not find worker manager")
request.params = {}
return Result.succ(await worker_manager.model_shutdown(request))
except Exception as e:
return Result.faild(code="E000X", msg=f"model stop failed {e}")
@router.post("/v1/worker/model/start")
async def model_start(request: WorkerStartupRequest):
print(f"/v1/worker/model/start:")
try:
worker_manager = CFG.SYSTEM_APP.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
if not worker_manager:
return Result.faild(code="E000X", msg=f"can not find worker manager")
return Result.succ(await worker_manager.model_startup(request))
except Exception as e:
return Result.faild(code="E000X", msg=f"model start failed {e}")

View File

@ -0,0 +1,28 @@
from dataclasses import dataclass
@dataclass
class ModelResponse:
"""ModelRequest"""
"""model_name: model_name"""
model_name: str = None
"""model_type: model_type"""
model_type: str = None
"""host: host"""
host: str = None
"""port: port"""
port: int = None
"""manager_host: manager_host"""
manager_host: str = None
"""manager_port: manager_port"""
manager_port: int = None
"""healthy: healthy"""
healthy: bool = True
"""check_healthy: check_healthy"""
check_healthy: bool = True
prompt_template: str = None
last_heartbeat: str = None
stream_api: str = None
nostream_api: str = None

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -1 +0,0 @@
self.__BUILD_MANIFEST=function(s,a,t,c,e,d,f,n,u,i,b){return{__rewrites:{beforeFiles:[],afterFiles:[],fallback:[]},"/":[a,"static/chunks/66-791bb03098dc9265.js","static/chunks/707-109d4fec9e26030d.js","static/chunks/pages/index-d5aba6bbbc1d8aaa.js"],"/_error":["static/chunks/pages/_error-dee72aff9b2e2c12.js"],"/chat":["static/chunks/pages/chat-a9adfc18f61cb676.js"],"/database":[s,t,d,c,"static/chunks/46-2a716444a56f6f08.js","static/chunks/847-4335b5938375e331.js","static/chunks/pages/database-ddf0a72485646c52.js"],"/datastores":[e,s,f,a,n,u,"static/chunks/241-4117dd68a591b7fa.js","static/chunks/pages/datastores-4fb48131988df037.js"],"/datastores/documents":[e,"static/chunks/75fc9c18-a784766a129ec5fb.js",s,f,t,a,d,c,n,i,b,u,"static/chunks/749-f876c99e30a851b8.js","static/chunks/pages/datastores/documents-7312ed2d9409617f.js"],"/datastores/documents/chunklist":[e,s,t,c,i,b,"static/chunks/pages/datastores/documents/chunklist-4ae606926d192018.js"],sortedPages:["/","/_app","/_error","/chat","/database","/datastores","/datastores/documents","/datastores/documents/chunklist"]}}("static/chunks/566-31b5bf29f3e84615.js","static/chunks/913-b5bc9815149e2ad5.js","static/chunks/902-c56acea399c45e57.js","static/chunks/455-5c8f2c8bda9b4b83.js","static/chunks/29107295-90b90cb30c825230.js","static/chunks/625-63aa85328eed0b3e.js","static/chunks/556-26ffce13383f774a.js","static/chunks/939-126a01b0d827f3b4.js","static/chunks/589-8dfb35868cafc00b.js","static/chunks/289-06c0d9f538f77a71.js","static/chunks/34-4756f8547fff0eaf.js"),self.__BUILD_MANIFEST_CB&&self.__BUILD_MANIFEST_CB();

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1 @@
self.__BUILD_MANIFEST=function(s,a,c,t,e,d,n,f,u,i,h,k,b,j,r){return{__rewrites:{beforeFiles:[],afterFiles:[],fallback:[]},"/":[d,t,"static/chunks/707-109d4fec9e26030d.js","static/chunks/pages/index-899dec6259d37a62.js"],"/_error":["static/chunks/pages/_error-dee72aff9b2e2c12.js"],"/chat":["static/chunks/pages/chat-2df8214a2d24741c.js"],"/database":[s,a,c,n,f,u,"static/chunks/184-0a4f2ea3f379be28.js","static/chunks/pages/database-ebf51747fd365187.js"],"/datastores":[e,s,i,t,h,k,"static/chunks/241-4117dd68a591b7fa.js","static/chunks/pages/datastores-4fb48131988df037.js"],"/datastores/documents":[e,b,s,a,c,i,t,n,h,j,r,k,"static/chunks/749-f876c99e30a851b8.js","static/chunks/pages/datastores/documents-20cfa4fca7908a8d.js"],"/datastores/documents/chunklist":[e,s,a,c,j,r,"static/chunks/pages/datastores/documents/chunklist-4ae606926d192018.js"],"/models":[b,s,a,c,d,f,u,"static/chunks/147-1c86c44f1f0eb632.js","static/chunks/pages/models-11145708f29a00e1.js"],sortedPages:["/","/_app","/_error","/chat","/database","/datastores","/datastores/documents","/datastores/documents/chunklist","/models"]}}("static/chunks/566-31b5bf29f3e84615.js","static/chunks/902-c56acea399c45e57.js","static/chunks/455-ca34b6460502160b.js","static/chunks/913-b5bc9815149e2ad5.js","static/chunks/29107295-90b90cb30c825230.js","static/chunks/66-791bb03098dc9265.js","static/chunks/625-63aa85328eed0b3e.js","static/chunks/46-2a716444a56f6f08.js","static/chunks/631-b73b692b8c702e06.js","static/chunks/556-26ffce13383f774a.js","static/chunks/939-126a01b0d827f3b4.js","static/chunks/589-8dfb35868cafc00b.js","static/chunks/75fc9c18-a784766a129ec5fb.js","static/chunks/289-06c0d9f538f77a71.js","static/chunks/34-cd5c494fe56733f7.js"),self.__BUILD_MANIFEST_CB&&self.__BUILD_MANIFEST_CB();

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

Binary file not shown.

After

Width:  |  Height:  |  Size: 12 KiB

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.8 KiB

View File

@ -1,10 +1,22 @@
import logging
logger = logging.getLogger(__name__)
def _clear_model_cache(device="cuda"):
try:
# clear torch cache
import torch
_clear_torch_cache(device)
except ImportError:
logger.warn("Torch not installed, skip clear torch cache")
# TODO clear other cache
def _clear_torch_cache(device="cuda"):
import gc
import torch
import gc
gc.collect()
if device != "cpu":
@ -14,14 +26,14 @@ def _clear_torch_cache(device="cuda"):
empty_cache()
except Exception as e:
logging.warn(f"Clear mps torch cache error, {str(e)}")
logger.warn(f"Clear mps torch cache error, {str(e)}")
elif torch.has_cuda:
device_count = torch.cuda.device_count()
for device_id in range(device_count):
cuda_device = f"cuda:{device_id}"
logging.info(f"Clear torch cache of device: {cuda_device}")
logger.info(f"Clear torch cache of device: {cuda_device}")
with torch.cuda.device(cuda_device):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
else:
logging.info("No cuda or mps, not support clear torch cache yet")
logger.info("No cuda or mps, not support clear torch cache yet")