mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-31 15:47:05 +00:00
Merge branch 'eosphoros-ai:main' into main
This commit is contained in:
commit
d6a623ac59
1
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
1
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@ -73,6 +73,7 @@ body:
|
||||
- label: Chat Excel
|
||||
- label: Chat DB
|
||||
- label: Chat Knowledge
|
||||
- label: Model Management
|
||||
- label: Dashboard
|
||||
- label: Plugins
|
||||
|
||||
|
Binary file not shown.
Before Width: | Height: | Size: 102 KiB After Width: | Height: | Size: 226 KiB |
18
docs/conf.py
18
docs/conf.py
@ -26,7 +26,7 @@ extensions = [
|
||||
"sphinx.ext.autosummary",
|
||||
"sphinx.ext.napoleon",
|
||||
"sphinx.ext.viewcode",
|
||||
"sphinxcontrib.autodoc_pydantic",
|
||||
# "sphinxcontrib.autodoc_pydantic",
|
||||
"myst_nb",
|
||||
"sphinx_copybutton",
|
||||
"sphinx_panels",
|
||||
@ -34,14 +34,14 @@ extensions = [
|
||||
]
|
||||
source_suffix = [".ipynb", ".html", ".md", ".rst"]
|
||||
|
||||
autodoc_pydantic_model_show_json = False
|
||||
autodoc_pydantic_field_list_validators = False
|
||||
autodoc_pydantic_config_members = False
|
||||
autodoc_pydantic_model_show_config_summary = False
|
||||
autodoc_pydantic_model_show_validator_members = False
|
||||
autodoc_pydantic_model_show_field_summary = False
|
||||
autodoc_pydantic_model_members = False
|
||||
autodoc_pydantic_model_undoc_members = False
|
||||
# autodoc_pydantic_model_show_json = False
|
||||
# autodoc_pydantic_field_list_validators = False
|
||||
# autodoc_pydantic_config_members = False
|
||||
# autodoc_pydantic_model_show_config_summary = False
|
||||
# autodoc_pydantic_model_show_validator_members = False
|
||||
# autodoc_pydantic_model_show_field_summary = False
|
||||
# autodoc_pydantic_model_members = False
|
||||
# autodoc_pydantic_model_undoc_members = False
|
||||
|
||||
templates_path = ["_templates"]
|
||||
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
|
||||
|
@ -13,4 +13,5 @@ toml
|
||||
myst_nb
|
||||
sphinx_copybutton
|
||||
pydata-sphinx-theme==0.13.1
|
||||
pydantic-settings
|
||||
furo
|
112
pilot/connections/conn_spark.py
Normal file
112
pilot/connections/conn_spark.py
Normal file
@ -0,0 +1,112 @@
|
||||
from typing import Optional, Any
|
||||
from pyspark.sql import SparkSession, DataFrame
|
||||
from sqlalchemy import text
|
||||
|
||||
from pilot.connections.base import BaseConnect
|
||||
|
||||
|
||||
class SparkConnect(BaseConnect):
|
||||
"""Spark Connect
|
||||
Args:
|
||||
Usage:
|
||||
"""
|
||||
|
||||
"""db type"""
|
||||
db_type: str = "spark"
|
||||
"""db driver"""
|
||||
driver: str = "spark"
|
||||
"""db dialect"""
|
||||
dialect: str = "sparksql"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
spark_session: Optional[SparkSession] = None,
|
||||
engine_args: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the Spark DataFrame from Datasource path
|
||||
return: Spark DataFrame
|
||||
"""
|
||||
self.spark_session = (
|
||||
spark_session or SparkSession.builder.appName("dbgpt").getOrCreate()
|
||||
)
|
||||
self.path = file_path
|
||||
self.table_name = "temp"
|
||||
self.df = self.create_df(self.path)
|
||||
|
||||
@classmethod
|
||||
def from_file_path(
|
||||
cls, file_path: str, engine_args: Optional[dict] = None, **kwargs: Any
|
||||
):
|
||||
try:
|
||||
return cls(file_path=file_path, engine_args=engine_args)
|
||||
|
||||
except Exception as e:
|
||||
print("load spark datasource error" + str(e))
|
||||
|
||||
def create_df(self, path) -> DataFrame:
|
||||
"""Create a Spark DataFrame from Datasource path
|
||||
return: Spark DataFrame
|
||||
"""
|
||||
return self.spark_session.read.option("header", "true").csv(path)
|
||||
|
||||
def run(self, sql):
|
||||
# self.log(f"llm ingestion sql query is :\n{sql}")
|
||||
# self.df = self.create_df(self.path)
|
||||
self.df.createOrReplaceTempView(self.table_name)
|
||||
df = self.spark_session.sql(sql)
|
||||
first_row = df.first()
|
||||
rows = [first_row.asDict().keys()]
|
||||
for row in df.collect():
|
||||
rows.append(row)
|
||||
return rows
|
||||
|
||||
def query_ex(self, sql):
|
||||
rows = self.run(sql)
|
||||
field_names = rows[0]
|
||||
return field_names, rows
|
||||
|
||||
def get_indexes(self, table_name):
|
||||
"""Get table indexes about specified table."""
|
||||
return ""
|
||||
|
||||
def get_show_create_table(self, table_name):
|
||||
"""Get table show create table about specified table."""
|
||||
|
||||
return "ans"
|
||||
|
||||
def get_fields(self):
|
||||
"""Get column meta about dataframe."""
|
||||
return ",".join([f"({name}: {dtype})" for name, dtype in self.df.dtypes])
|
||||
|
||||
def get_users(self):
|
||||
return []
|
||||
|
||||
def get_grants(self):
|
||||
return []
|
||||
|
||||
def get_collation(self):
|
||||
"""Get collation."""
|
||||
return "UTF-8"
|
||||
|
||||
def get_charset(self):
|
||||
return "UTF-8"
|
||||
|
||||
def get_db_list(self):
|
||||
return ["default"]
|
||||
|
||||
def get_db_names(self):
|
||||
return ["default"]
|
||||
|
||||
def get_database_list(self):
|
||||
return []
|
||||
|
||||
def get_database_names(self):
|
||||
return []
|
||||
|
||||
def table_simple_info(self):
|
||||
return f"{self.table_name}{self.get_fields()}"
|
||||
|
||||
def get_table_comments(self, db_name):
|
||||
return ""
|
@ -8,7 +8,7 @@ from pilot.model.loader import ModelLoader, _get_model_real_path
|
||||
from pilot.model.parameter import ModelParameters
|
||||
from pilot.model.cluster.worker_base import ModelWorker
|
||||
from pilot.server.chat_adapter import get_llm_chat_adapter, BaseChatAdpter
|
||||
from pilot.utils.model_utils import _clear_torch_cache
|
||||
from pilot.utils.model_utils import _clear_model_cache
|
||||
from pilot.utils.parameter_utils import EnvArgumentParser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -87,7 +87,7 @@ class DefaultModelWorker(ModelWorker):
|
||||
del self.tokenizer
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
_clear_torch_cache(self._model_params.device)
|
||||
_clear_model_cache(self._model_params.device)
|
||||
|
||||
def generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
|
||||
torch_imported = False
|
||||
|
@ -11,7 +11,7 @@ from pilot.model.parameter import (
|
||||
)
|
||||
from pilot.model.cluster.worker_base import ModelWorker
|
||||
from pilot.model.cluster.embedding.loader import EmbeddingLoader
|
||||
from pilot.utils.model_utils import _clear_torch_cache
|
||||
from pilot.utils.model_utils import _clear_model_cache
|
||||
from pilot.utils.parameter_utils import EnvArgumentParser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -79,7 +79,7 @@ class EmbeddingsModelWorker(ModelWorker):
|
||||
return
|
||||
del self._embeddings_impl
|
||||
self._embeddings_impl = None
|
||||
_clear_torch_cache(self._model_params.device)
|
||||
_clear_model_cache(self._model_params.device)
|
||||
|
||||
def generate_stream(self, params: Dict):
|
||||
"""Generate stream result, chat scene"""
|
||||
|
@ -1,17 +1,18 @@
|
||||
import asyncio
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import random
|
||||
import time
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import asdict
|
||||
from typing import Awaitable, Callable, Dict, Iterator, List, Optional
|
||||
|
||||
from fastapi import APIRouter, FastAPI
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from pilot.component import SystemApp
|
||||
from pilot.model.base import (
|
||||
ModelInstance,
|
||||
@ -20,16 +21,16 @@ from pilot.model.base import (
|
||||
WorkerApplyType,
|
||||
WorkerSupportedModel,
|
||||
)
|
||||
from pilot.model.cluster.registry import ModelRegistry
|
||||
from pilot.model.llm_utils import list_supported_models
|
||||
from pilot.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType
|
||||
from pilot.model.cluster.worker_base import ModelWorker
|
||||
from pilot.model.cluster.base import *
|
||||
from pilot.model.cluster.manager_base import (
|
||||
WorkerManager,
|
||||
WorkerRunData,
|
||||
WorkerManagerFactory,
|
||||
WorkerRunData,
|
||||
)
|
||||
from pilot.model.cluster.base import *
|
||||
from pilot.model.cluster.registry import ModelRegistry
|
||||
from pilot.model.cluster.worker_base import ModelWorker
|
||||
from pilot.model.llm_utils import list_supported_models
|
||||
from pilot.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType
|
||||
from pilot.utils.parameter_utils import (
|
||||
EnvArgumentParser,
|
||||
ParameterDescription,
|
||||
@ -639,6 +640,10 @@ def _setup_fastapi(worker_params: ModelWorkerParameters, app=None):
|
||||
)
|
||||
|
||||
if not worker_params.controller_addr:
|
||||
# if we have http_proxy or https_proxy in env, the server can not start
|
||||
# so set it to empty here
|
||||
os.environ["http_proxy"] = ""
|
||||
os.environ["https_proxy"] = ""
|
||||
worker_params.controller_addr = f"http://127.0.0.1:{worker_params.port}"
|
||||
logger.info(
|
||||
f"Run WorkerManager with standalone mode, controller_addr: {worker_params.controller_addr}"
|
||||
|
@ -18,6 +18,7 @@ from pilot.logs import logger
|
||||
def _check_multi_gpu_or_4bit_quantization(model_params: ModelParameters):
|
||||
# TODO: vicuna-v1.5 8-bit quantization info is slow
|
||||
# TODO: support wizardlm quantization, see: https://huggingface.co/WizardLM/WizardLM-13B-V1.2/discussions/5
|
||||
# TODO: support internlm quantization
|
||||
model_name = model_params.model_name.lower()
|
||||
supported_models = ["llama", "baichuan", "vicuna"]
|
||||
return any(m in model_name for m in supported_models)
|
||||
|
@ -1,5 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Dict, Optional
|
||||
@ -246,6 +247,11 @@ class ProxyModelParameters(BaseModelParameters):
|
||||
proxy_api_key: str = field(
|
||||
metadata={"tags": "privacy", "help": "The api key of current proxy LLM"},
|
||||
)
|
||||
http_proxy: Optional[str] = field(
|
||||
default=os.environ.get("http_proxy") or os.environ.get("https_proxy"),
|
||||
metadata={"help": "The http or https proxy to use openai"},
|
||||
)
|
||||
|
||||
proxyllm_backend: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
|
@ -26,6 +26,9 @@ from pilot.openapi.api_view_model import (
|
||||
ConversationVo,
|
||||
MessageVo,
|
||||
ChatSceneVo,
|
||||
ChatCompletionResponseStreamChoice,
|
||||
DeltaMessage,
|
||||
ChatCompletionStreamResponse,
|
||||
)
|
||||
from pilot.connections.db_conn_info import DBConfig, DbTypeInfo
|
||||
from pilot.configs.config import Config
|
||||
@ -383,7 +386,7 @@ async def chat_completions(dialogue: ConversationVo = Body()):
|
||||
)
|
||||
else:
|
||||
return StreamingResponse(
|
||||
stream_generator(chat),
|
||||
stream_generator(chat, dialogue.incremental, dialogue.model_name),
|
||||
headers=headers,
|
||||
media_type="text/plain",
|
||||
)
|
||||
@ -421,19 +424,48 @@ async def no_stream_generator(chat):
|
||||
yield f"data: {msg}\n\n"
|
||||
|
||||
|
||||
async def stream_generator(chat):
|
||||
async def stream_generator(chat, incremental: bool, model_name: str):
|
||||
"""Generate streaming responses
|
||||
|
||||
Our goal is to generate an openai-compatible streaming responses.
|
||||
Currently, the incremental response is compatible, and the full response will be transformed in the future.
|
||||
|
||||
Args:
|
||||
chat (BaseChat): Chat instance.
|
||||
incremental (bool): Used to control whether the content is returned incrementally or in full each time.
|
||||
model_name (str): The model name
|
||||
|
||||
Yields:
|
||||
_type_: streaming responses
|
||||
"""
|
||||
msg = "[LLM_ERROR]: llm server has no output, maybe your prompt template is wrong."
|
||||
|
||||
stream_id = f"chatcmpl-{str(uuid.uuid1())}"
|
||||
previous_response = ""
|
||||
async for chunk in chat.stream_call():
|
||||
if chunk:
|
||||
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
|
||||
chunk, chat.skip_echo_len
|
||||
)
|
||||
|
||||
msg = msg.replace("\n", "\\n")
|
||||
yield f"data:{msg}\n\n"
|
||||
msg = msg.replace("\ufffd", "")
|
||||
if incremental:
|
||||
incremental_output = msg[len(previous_response) :]
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=DeltaMessage(role="assistant", content=incremental_output),
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=stream_id, choices=[choice_data], model=model_name
|
||||
)
|
||||
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
|
||||
else:
|
||||
# TODO generate an openai-compatible streaming responses
|
||||
msg = msg.replace("\n", "\\n")
|
||||
yield f"data:{msg}\n\n"
|
||||
previous_response = msg
|
||||
await asyncio.sleep(0.02)
|
||||
|
||||
if incremental:
|
||||
yield "data: [DONE]\n\n"
|
||||
chat.current_message.add_ai_message(msg)
|
||||
chat.current_message.add_view_message(msg)
|
||||
chat.memory.append(chat.current_message)
|
||||
|
@ -1,5 +1,7 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import TypeVar, Generic, Any
|
||||
from typing import TypeVar, Generic, Any, Optional, Literal, List
|
||||
import uuid
|
||||
import time
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@ -59,6 +61,11 @@ class ConversationVo(BaseModel):
|
||||
"""
|
||||
model_name: str = None
|
||||
|
||||
"""Used to control whether the content is returned incrementally or in full each time.
|
||||
If this parameter is not provided, the default is full return.
|
||||
"""
|
||||
incremental: bool = False
|
||||
|
||||
|
||||
class MessageVo(BaseModel):
|
||||
"""
|
||||
@ -83,3 +90,21 @@ class MessageVo(BaseModel):
|
||||
model_name
|
||||
"""
|
||||
model_name: str
|
||||
|
||||
|
||||
class DeltaMessage(BaseModel):
|
||||
role: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
|
||||
|
||||
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||
index: int
|
||||
delta: DeltaMessage
|
||||
finish_reason: Optional[Literal["stop", "length"]] = None
|
||||
|
||||
|
||||
class ChatCompletionStreamResponse(BaseModel):
|
||||
id: str = Field(default_factory=lambda: f"chatcmpl-{str(uuid.uuid1())}")
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[ChatCompletionResponseStreamChoice]
|
||||
|
@ -24,6 +24,7 @@ from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pilot.server.knowledge.api import router as knowledge_router
|
||||
from pilot.server.prompt.api import router as prompt_router
|
||||
from pilot.server.llm_manage.api import router as llm_manage_api
|
||||
|
||||
|
||||
from pilot.openapi.api_v1.api_v1 import router as api_v1
|
||||
@ -73,6 +74,7 @@ app.add_middleware(
|
||||
app.include_router(api_v1, prefix="/api")
|
||||
app.include_router(knowledge_router, prefix="/api")
|
||||
app.include_router(api_editor_route_v1, prefix="/api")
|
||||
app.include_router(llm_manage_api, prefix="/api")
|
||||
app.include_router(api_fb_v1, prefix="/api")
|
||||
|
||||
# app.include_router(api_v1)
|
||||
|
108
pilot/server/llm_manage/api.py
Normal file
108
pilot/server/llm_manage/api.py
Normal file
@ -0,0 +1,108 @@
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from pilot.component import ComponentType
|
||||
from pilot.configs.config import Config
|
||||
|
||||
from pilot.model.cluster import WorkerStartupRequest, WorkerManagerFactory
|
||||
from pilot.openapi.api_view_model import Result
|
||||
|
||||
from pilot.server.llm_manage.request.request import ModelResponse
|
||||
|
||||
CFG = Config()
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/v1/worker/model/params")
|
||||
async def model_params():
|
||||
print(f"/worker/model/params")
|
||||
try:
|
||||
from pilot.model.cluster import WorkerManagerFactory
|
||||
|
||||
worker_manager = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||
).create()
|
||||
params = []
|
||||
workers = await worker_manager.supported_models()
|
||||
for worker in workers:
|
||||
for model in worker.models:
|
||||
model_dict = model.__dict__
|
||||
model_dict["host"] = worker.host
|
||||
model_dict["port"] = worker.port
|
||||
params.append(model_dict)
|
||||
return Result.succ(params)
|
||||
if not worker_instance:
|
||||
return Result.faild(code="E000X", msg=f"can not find worker manager")
|
||||
except Exception as e:
|
||||
return Result.faild(code="E000X", msg=f"model stop failed {e}")
|
||||
|
||||
|
||||
@router.get("/v1/worker/model/list")
|
||||
async def model_list():
|
||||
print(f"/worker/model/list")
|
||||
try:
|
||||
from pilot.model.cluster.controller.controller import BaseModelController
|
||||
|
||||
controller = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.MODEL_CONTROLLER, BaseModelController
|
||||
)
|
||||
responses = []
|
||||
managers = await controller.get_all_instances(
|
||||
model_name="WorkerManager@service", healthy_only=True
|
||||
)
|
||||
manager_map = dict(map(lambda manager: (manager.host, manager), managers))
|
||||
models = await controller.get_all_instances()
|
||||
for model in models:
|
||||
worker_name, worker_type = model.model_name.split("@")
|
||||
if worker_type == "llm" or worker_type == "text2vec":
|
||||
response = ModelResponse(
|
||||
model_name=worker_name,
|
||||
model_type=worker_type,
|
||||
host=model.host,
|
||||
port=model.port,
|
||||
healthy=model.healthy,
|
||||
check_healthy=model.check_healthy,
|
||||
last_heartbeat=model.last_heartbeat,
|
||||
prompt_template=model.prompt_template,
|
||||
)
|
||||
response.manager_host = model.host if manager_map[model.host] else None
|
||||
response.manager_port = (
|
||||
manager_map[model.host].port if manager_map[model.host] else None
|
||||
)
|
||||
responses.append(response)
|
||||
return Result.succ(responses)
|
||||
|
||||
except Exception as e:
|
||||
return Result.faild(code="E000X", msg=f"space list error {e}")
|
||||
|
||||
|
||||
@router.post("/v1/worker/model/stop")
|
||||
async def model_stop(request: WorkerStartupRequest):
|
||||
print(f"/v1/worker/model/stop:")
|
||||
try:
|
||||
from pilot.model.cluster.controller.controller import BaseModelController
|
||||
|
||||
worker_manager = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||
).create()
|
||||
if not worker_manager:
|
||||
return Result.faild(code="E000X", msg=f"can not find worker manager")
|
||||
request.params = {}
|
||||
return Result.succ(await worker_manager.model_shutdown(request))
|
||||
except Exception as e:
|
||||
return Result.faild(code="E000X", msg=f"model stop failed {e}")
|
||||
|
||||
|
||||
@router.post("/v1/worker/model/start")
|
||||
async def model_start(request: WorkerStartupRequest):
|
||||
print(f"/v1/worker/model/start:")
|
||||
try:
|
||||
worker_manager = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||
).create()
|
||||
if not worker_manager:
|
||||
return Result.faild(code="E000X", msg=f"can not find worker manager")
|
||||
return Result.succ(await worker_manager.model_startup(request))
|
||||
except Exception as e:
|
||||
return Result.faild(code="E000X", msg=f"model start failed {e}")
|
28
pilot/server/llm_manage/request/request.py
Normal file
28
pilot/server/llm_manage/request/request.py
Normal file
@ -0,0 +1,28 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelResponse:
|
||||
"""ModelRequest"""
|
||||
|
||||
"""model_name: model_name"""
|
||||
model_name: str = None
|
||||
"""model_type: model_type"""
|
||||
model_type: str = None
|
||||
"""host: host"""
|
||||
host: str = None
|
||||
"""port: port"""
|
||||
port: int = None
|
||||
"""manager_host: manager_host"""
|
||||
manager_host: str = None
|
||||
"""manager_port: manager_port"""
|
||||
manager_port: int = None
|
||||
"""healthy: healthy"""
|
||||
healthy: bool = True
|
||||
|
||||
"""check_healthy: check_healthy"""
|
||||
check_healthy: bool = True
|
||||
prompt_template: str = None
|
||||
last_heartbeat: str = None
|
||||
stream_api: str = None
|
||||
nostream_api: str = None
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@ -1 +0,0 @@
|
||||
self.__BUILD_MANIFEST=function(s,a,t,c,e,d,f,n,u,i,b){return{__rewrites:{beforeFiles:[],afterFiles:[],fallback:[]},"/":[a,"static/chunks/66-791bb03098dc9265.js","static/chunks/707-109d4fec9e26030d.js","static/chunks/pages/index-d5aba6bbbc1d8aaa.js"],"/_error":["static/chunks/pages/_error-dee72aff9b2e2c12.js"],"/chat":["static/chunks/pages/chat-a9adfc18f61cb676.js"],"/database":[s,t,d,c,"static/chunks/46-2a716444a56f6f08.js","static/chunks/847-4335b5938375e331.js","static/chunks/pages/database-ddf0a72485646c52.js"],"/datastores":[e,s,f,a,n,u,"static/chunks/241-4117dd68a591b7fa.js","static/chunks/pages/datastores-4fb48131988df037.js"],"/datastores/documents":[e,"static/chunks/75fc9c18-a784766a129ec5fb.js",s,f,t,a,d,c,n,i,b,u,"static/chunks/749-f876c99e30a851b8.js","static/chunks/pages/datastores/documents-7312ed2d9409617f.js"],"/datastores/documents/chunklist":[e,s,t,c,i,b,"static/chunks/pages/datastores/documents/chunklist-4ae606926d192018.js"],sortedPages:["/","/_app","/_error","/chat","/database","/datastores","/datastores/documents","/datastores/documents/chunklist"]}}("static/chunks/566-31b5bf29f3e84615.js","static/chunks/913-b5bc9815149e2ad5.js","static/chunks/902-c56acea399c45e57.js","static/chunks/455-5c8f2c8bda9b4b83.js","static/chunks/29107295-90b90cb30c825230.js","static/chunks/625-63aa85328eed0b3e.js","static/chunks/556-26ffce13383f774a.js","static/chunks/939-126a01b0d827f3b4.js","static/chunks/589-8dfb35868cafc00b.js","static/chunks/289-06c0d9f538f77a71.js","static/chunks/34-4756f8547fff0eaf.js"),self.__BUILD_MANIFEST_CB&&self.__BUILD_MANIFEST_CB();
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@ -0,0 +1 @@
|
||||
self.__BUILD_MANIFEST=function(s,a,c,t,e,d,n,f,u,i,h,k,b,j,r){return{__rewrites:{beforeFiles:[],afterFiles:[],fallback:[]},"/":[d,t,"static/chunks/707-109d4fec9e26030d.js","static/chunks/pages/index-899dec6259d37a62.js"],"/_error":["static/chunks/pages/_error-dee72aff9b2e2c12.js"],"/chat":["static/chunks/pages/chat-2df8214a2d24741c.js"],"/database":[s,a,c,n,f,u,"static/chunks/184-0a4f2ea3f379be28.js","static/chunks/pages/database-ebf51747fd365187.js"],"/datastores":[e,s,i,t,h,k,"static/chunks/241-4117dd68a591b7fa.js","static/chunks/pages/datastores-4fb48131988df037.js"],"/datastores/documents":[e,b,s,a,c,i,t,n,h,j,r,k,"static/chunks/749-f876c99e30a851b8.js","static/chunks/pages/datastores/documents-20cfa4fca7908a8d.js"],"/datastores/documents/chunklist":[e,s,a,c,j,r,"static/chunks/pages/datastores/documents/chunklist-4ae606926d192018.js"],"/models":[b,s,a,c,d,f,u,"static/chunks/147-1c86c44f1f0eb632.js","static/chunks/pages/models-11145708f29a00e1.js"],sortedPages:["/","/_app","/_error","/chat","/database","/datastores","/datastores/documents","/datastores/documents/chunklist","/models"]}}("static/chunks/566-31b5bf29f3e84615.js","static/chunks/902-c56acea399c45e57.js","static/chunks/455-ca34b6460502160b.js","static/chunks/913-b5bc9815149e2ad5.js","static/chunks/29107295-90b90cb30c825230.js","static/chunks/66-791bb03098dc9265.js","static/chunks/625-63aa85328eed0b3e.js","static/chunks/46-2a716444a56f6f08.js","static/chunks/631-b73b692b8c702e06.js","static/chunks/556-26ffce13383f774a.js","static/chunks/939-126a01b0d827f3b4.js","static/chunks/589-8dfb35868cafc00b.js","static/chunks/75fc9c18-a784766a129ec5fb.js","static/chunks/289-06c0d9f538f77a71.js","static/chunks/34-cd5c494fe56733f7.js"),self.__BUILD_MANIFEST_CB&&self.__BUILD_MANIFEST_CB();
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
BIN
pilot/server/static/icons/spark.png
Normal file
BIN
pilot/server/static/icons/spark.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 12 KiB |
File diff suppressed because one or more lines are too long
1
pilot/server/static/models/index.html
Normal file
1
pilot/server/static/models/index.html
Normal file
File diff suppressed because one or more lines are too long
BIN
pilot/server/static/models/internlm.png
Normal file
BIN
pilot/server/static/models/internlm.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 4.8 KiB |
@ -1,10 +1,22 @@
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _clear_model_cache(device="cuda"):
|
||||
try:
|
||||
# clear torch cache
|
||||
import torch
|
||||
|
||||
_clear_torch_cache(device)
|
||||
except ImportError:
|
||||
logger.warn("Torch not installed, skip clear torch cache")
|
||||
# TODO clear other cache
|
||||
|
||||
|
||||
def _clear_torch_cache(device="cuda"):
|
||||
import gc
|
||||
|
||||
import torch
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
if device != "cpu":
|
||||
@ -14,14 +26,14 @@ def _clear_torch_cache(device="cuda"):
|
||||
|
||||
empty_cache()
|
||||
except Exception as e:
|
||||
logging.warn(f"Clear mps torch cache error, {str(e)}")
|
||||
logger.warn(f"Clear mps torch cache error, {str(e)}")
|
||||
elif torch.has_cuda:
|
||||
device_count = torch.cuda.device_count()
|
||||
for device_id in range(device_count):
|
||||
cuda_device = f"cuda:{device_id}"
|
||||
logging.info(f"Clear torch cache of device: {cuda_device}")
|
||||
logger.info(f"Clear torch cache of device: {cuda_device}")
|
||||
with torch.cuda.device(cuda_device):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
else:
|
||||
logging.info("No cuda or mps, not support clear torch cache yet")
|
||||
logger.info("No cuda or mps, not support clear torch cache yet")
|
||||
|
Loading…
Reference in New Issue
Block a user