mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-31 15:47:05 +00:00
feat(model): Unified Deployment Mode with multi-model and add command line chat with model
This commit is contained in:
parent
1f1da2618c
commit
a8846c40aa
@ -4,7 +4,7 @@ import logging
|
||||
import os
|
||||
from typing import Callable, List, Type
|
||||
|
||||
from pilot.model.controller.registry import ModelRegistryClient
|
||||
from pilot.model.controller.controller import ModelRegistryClient
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
from pilot.model.base import WorkerApplyType
|
||||
from pilot.model.parameter import (
|
||||
@ -26,18 +26,23 @@ logger = logging.getLogger("dbgpt_cli")
|
||||
@click.option(
|
||||
"--address",
|
||||
type=str,
|
||||
default=MODEL_CONTROLLER_ADDRESS,
|
||||
default=None,
|
||||
required=False,
|
||||
show_default=True,
|
||||
help=(
|
||||
"Address of the Model Controller to connect to. "
|
||||
"Just support light deploy model"
|
||||
"Just support light deploy model, If the environment variable CONTROLLER_ADDRESS is configured, read from the environment variable"
|
||||
),
|
||||
)
|
||||
def model_cli_group(address: str):
|
||||
"""Clients that manage model serving"""
|
||||
global MODEL_CONTROLLER_ADDRESS
|
||||
MODEL_CONTROLLER_ADDRESS = address
|
||||
if not address:
|
||||
from pilot.utils.command_utils import _detect_controller_address
|
||||
|
||||
MODEL_CONTROLLER_ADDRESS = _detect_controller_address()
|
||||
else:
|
||||
MODEL_CONTROLLER_ADDRESS = address
|
||||
|
||||
|
||||
@model_cli_group.command()
|
||||
@ -140,6 +145,26 @@ def restart(model_name: str, model_type: str):
|
||||
)
|
||||
|
||||
|
||||
@model_cli_group.command()
|
||||
@click.option(
|
||||
"--model_name",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help=("The name of model"),
|
||||
)
|
||||
@click.option(
|
||||
"--system",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help=("System prompt"),
|
||||
)
|
||||
def chat(model_name: str, system: str):
|
||||
"""Interact with your bot from the command line"""
|
||||
_cli_chat(MODEL_CONTROLLER_ADDRESS, model_name, system)
|
||||
|
||||
|
||||
# @model_cli_group.command()
|
||||
# @add_model_options
|
||||
# def modify(address: str, model_name: str, model_type: str):
|
||||
@ -147,14 +172,21 @@ def restart(model_name: str, model_type: str):
|
||||
# worker_apply(address, model_name, model_type, WorkerApplyType.UPDATE_PARAMS)
|
||||
|
||||
|
||||
def _get_worker_manager(address: str):
|
||||
from pilot.model.worker.manager import RemoteWorkerManager, WorkerApplyRequest
|
||||
|
||||
registry = ModelRegistryClient(address)
|
||||
worker_manager = RemoteWorkerManager(registry)
|
||||
return worker_manager
|
||||
|
||||
|
||||
def worker_apply(
|
||||
address: str, model_name: str, model_type: str, apply_type: WorkerApplyType
|
||||
):
|
||||
from pilot.model.worker.manager import RemoteWorkerManager, WorkerApplyRequest
|
||||
from pilot.model.worker.manager import WorkerApplyRequest
|
||||
|
||||
loop = get_or_create_event_loop()
|
||||
registry = ModelRegistryClient(address)
|
||||
worker_manager = RemoteWorkerManager(registry)
|
||||
worker_manager = _get_worker_manager(address)
|
||||
apply_req = WorkerApplyRequest(
|
||||
model=model_name, worker_type=model_type, apply_type=apply_type
|
||||
)
|
||||
@ -162,6 +194,41 @@ def worker_apply(
|
||||
print(res)
|
||||
|
||||
|
||||
def _cli_chat(address: str, model_name: str, system_prompt: str = None):
|
||||
loop = get_or_create_event_loop()
|
||||
worker_manager = worker_manager = _get_worker_manager(address)
|
||||
loop.run_until_complete(_chat_stream(worker_manager, model_name, system_prompt))
|
||||
|
||||
|
||||
async def _chat_stream(worker_manager, model_name: str, system_prompt: str = None):
|
||||
from pilot.model.worker.manager import PromptRequest
|
||||
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
||||
|
||||
print(f"Chatbot started with model {model_name}. Type 'exit' to leave the chat.")
|
||||
hist = []
|
||||
previous_response = ""
|
||||
if system_prompt:
|
||||
hist.append(
|
||||
ModelMessage(role=ModelMessageRoleType.SYSTEM, content=system_prompt)
|
||||
)
|
||||
while True:
|
||||
previous_response = ""
|
||||
user_input = input("\n\nYou: ")
|
||||
if user_input.lower().strip() == "exit":
|
||||
break
|
||||
hist.append(ModelMessage(role=ModelMessageRoleType.HUMAN, content=user_input))
|
||||
request = PromptRequest(messages=hist, model=model_name, prompt="", echo=False)
|
||||
request = request.dict(exclude_none=True)
|
||||
print("Bot: ", end="")
|
||||
async for response in worker_manager.generate_stream(request):
|
||||
incremental_output = response.text[len(previous_response) :]
|
||||
print(incremental_output, end="", flush=True)
|
||||
previous_response = response.text
|
||||
hist.append(
|
||||
ModelMessage(role=ModelMessageRoleType.AI, content=previous_response)
|
||||
)
|
||||
|
||||
|
||||
def add_stop_server_options(func):
|
||||
@click.option(
|
||||
"--port",
|
||||
@ -249,3 +316,9 @@ def start_apiserver(**kwargs):
|
||||
def stop_apiserver(**kwargs):
|
||||
"""Start apiserver(TODO)"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _stop_all_model_server(**kwargs):
|
||||
"""Stop all server"""
|
||||
_stop_service("worker", "ModelWorker")
|
||||
_stop_service("controller", "ModelController")
|
||||
|
@ -1,3 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
@ -6,9 +8,33 @@ from pilot.model.base import ModelInstance
|
||||
from pilot.model.parameter import ModelControllerParameters
|
||||
from pilot.model.controller.registry import EmbeddedModelRegistry, ModelRegistry
|
||||
from pilot.utils.parameter_utils import EnvArgumentParser
|
||||
from pilot.utils.api_utils import _api_remote as api_remote
|
||||
|
||||
|
||||
class ModelController:
|
||||
class BaseModelController(ABC):
|
||||
@abstractmethod
|
||||
async def register_instance(self, instance: ModelInstance) -> bool:
|
||||
"""Register a given model instance"""
|
||||
|
||||
@abstractmethod
|
||||
async def deregister_instance(self, instance: ModelInstance) -> bool:
|
||||
"""Deregister a given model instance."""
|
||||
|
||||
@abstractmethod
|
||||
async def get_all_instances(
|
||||
self, model_name: str, healthy_only: bool = False
|
||||
) -> List[ModelInstance]:
|
||||
"""Fetch all instances of a given model. Optionally, fetch only the healthy instances."""
|
||||
|
||||
@abstractmethod
|
||||
async def send_heartbeat(self, instance: ModelInstance) -> bool:
|
||||
"""Send a heartbeat for a given model instance. This can be used to verify if the instance is still alive and functioning."""
|
||||
|
||||
async def model_apply(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class LocalModelController(BaseModelController):
|
||||
def __init__(self, registry: ModelRegistry = None) -> None:
|
||||
if not registry:
|
||||
registry = EmbeddedModelRegistry()
|
||||
@ -27,22 +53,87 @@ class ModelController:
|
||||
logging.info(
|
||||
f"Get all instances with {model_name}, healthy_only: {healthy_only}"
|
||||
)
|
||||
return await self.registry.get_all_instances(model_name, healthy_only)
|
||||
|
||||
async def get_all_model_instances(self) -> List[ModelInstance]:
|
||||
return await self.registry.get_all_model_instances()
|
||||
if not model_name:
|
||||
return await self.registry.get_all_model_instances()
|
||||
else:
|
||||
return await self.registry.get_all_instances(model_name, healthy_only)
|
||||
|
||||
async def send_heartbeat(self, instance: ModelInstance) -> bool:
|
||||
return await self.registry.send_heartbeat(instance)
|
||||
|
||||
|
||||
class _RemoteModelController(BaseModelController):
|
||||
def __init__(self, base_url: str) -> None:
|
||||
self.base_url = base_url
|
||||
|
||||
@api_remote(path="/api/controller/models", method="POST")
|
||||
async def register_instance(self, instance: ModelInstance) -> bool:
|
||||
pass
|
||||
|
||||
@api_remote(path="/api/controller/models", method="DELETE")
|
||||
async def deregister_instance(self, instance: ModelInstance) -> bool:
|
||||
pass
|
||||
|
||||
@api_remote(path="/api/controller/models")
|
||||
async def get_all_instances(
|
||||
self, model_name: str = None, healthy_only: bool = False
|
||||
) -> List[ModelInstance]:
|
||||
pass
|
||||
|
||||
@api_remote(path="/api/controller/heartbeat", method="POST")
|
||||
async def send_heartbeat(self, instance: ModelInstance) -> bool:
|
||||
pass
|
||||
|
||||
|
||||
class ModelRegistryClient(_RemoteModelController, ModelRegistry):
|
||||
async def get_all_model_instances(self) -> List[ModelInstance]:
|
||||
return await self.get_all_instances()
|
||||
|
||||
|
||||
class ModelControllerAdapter(BaseModelController):
|
||||
def __init__(self, backend: BaseModelController = None) -> None:
|
||||
self.backend = backend
|
||||
|
||||
async def register_instance(self, instance: ModelInstance) -> bool:
|
||||
return await self.backend.register_instance(instance)
|
||||
|
||||
async def deregister_instance(self, instance: ModelInstance) -> bool:
|
||||
return await self.backend.deregister_instance(instance)
|
||||
|
||||
async def get_all_instances(
|
||||
self, model_name: str, healthy_only: bool = False
|
||||
) -> List[ModelInstance]:
|
||||
return await self.backend.get_all_instances(model_name, healthy_only)
|
||||
|
||||
async def send_heartbeat(self, instance: ModelInstance) -> bool:
|
||||
return await self.backend.send_heartbeat(instance)
|
||||
|
||||
async def model_apply(self) -> bool:
|
||||
# TODO
|
||||
raise NotImplementedError
|
||||
return await self.backend.model_apply()
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
controller = ModelController()
|
||||
controller = ModelControllerAdapter()
|
||||
|
||||
|
||||
def initialize_controller(
|
||||
app=None, remote_controller_addr: str = None, host: str = None, port: int = None
|
||||
):
|
||||
global controller
|
||||
if remote_controller_addr:
|
||||
controller.backend = _RemoteModelController(remote_controller_addr)
|
||||
else:
|
||||
controller.backend = LocalModelController()
|
||||
|
||||
if app:
|
||||
app.include_router(router, prefix="/api")
|
||||
else:
|
||||
import uvicorn
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(router, prefix="/api")
|
||||
uvicorn.run(app, host=host, port=port, log_level="info")
|
||||
|
||||
|
||||
@router.post("/controller/models")
|
||||
@ -51,14 +142,13 @@ async def api_register_instance(request: ModelInstance):
|
||||
|
||||
|
||||
@router.delete("/controller/models")
|
||||
async def api_deregister_instance(request: ModelInstance):
|
||||
return await controller.deregister_instance(request)
|
||||
async def api_deregister_instance(model_name: str, host: str, port: int):
|
||||
instance = ModelInstance(model_name=model_name, host=host, port=port)
|
||||
return await controller.deregister_instance(instance)
|
||||
|
||||
|
||||
@router.get("/controller/models")
|
||||
async def api_get_all_instances(model_name: str = None, healthy_only: bool = False):
|
||||
if not model_name:
|
||||
return await controller.get_all_model_instances()
|
||||
return await controller.get_all_instances(model_name, healthy_only=healthy_only)
|
||||
|
||||
|
||||
@ -68,18 +158,12 @@ async def api_model_heartbeat(request: ModelInstance):
|
||||
|
||||
|
||||
def run_model_controller():
|
||||
import uvicorn
|
||||
|
||||
parser = EnvArgumentParser()
|
||||
env_prefix = "controller_"
|
||||
controller_params: ModelControllerParameters = parser.parse_args_into_dataclass(
|
||||
ModelControllerParameters, env_prefix=env_prefix
|
||||
)
|
||||
app = FastAPI()
|
||||
app.include_router(router, prefix="/api")
|
||||
uvicorn.run(
|
||||
app, host=controller_params.host, port=controller_params.port, log_level="info"
|
||||
)
|
||||
initialize_controller(host=controller_params.host, port=controller_params.port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -178,47 +178,11 @@ class EmbeddedModelRegistry(ModelRegistry):
|
||||
instance.model_name, instance.host, instance.port, healthy_only=False
|
||||
)
|
||||
if not exist_ins:
|
||||
return False
|
||||
# register new install from heartbeat
|
||||
self.register_instance(instance)
|
||||
return True
|
||||
|
||||
ins = exist_ins[0]
|
||||
ins.last_heartbeat = datetime.now()
|
||||
ins.healthy = True
|
||||
return True
|
||||
|
||||
|
||||
from pilot.utils.api_utils import _api_remote as api_remote
|
||||
|
||||
|
||||
class ModelRegistryClient(ModelRegistry):
|
||||
def __init__(self, base_url: str) -> None:
|
||||
self.base_url = base_url
|
||||
|
||||
@api_remote(path="/api/controller/models", method="POST")
|
||||
async def register_instance(self, instance: ModelInstance) -> bool:
|
||||
pass
|
||||
|
||||
@api_remote(path="/api/controller/models", method="DELETE")
|
||||
async def deregister_instance(self, instance: ModelInstance) -> bool:
|
||||
pass
|
||||
|
||||
@api_remote(path="/api/controller/models")
|
||||
async def get_all_instances(
|
||||
self, model_name: str, healthy_only: bool = False
|
||||
) -> List[ModelInstance]:
|
||||
pass
|
||||
|
||||
@api_remote(path="/api/controller/models")
|
||||
async def get_all_model_instances(self) -> List[ModelInstance]:
|
||||
pass
|
||||
|
||||
@api_remote(path="/api/controller/models")
|
||||
async def select_one_health_instance(self, model_name: str) -> ModelInstance:
|
||||
instances = await self.get_all_instances(model_name, healthy_only=True)
|
||||
instances = [i for i in instances if i.enabled]
|
||||
if not instances:
|
||||
return None
|
||||
return random.choice(instances)
|
||||
|
||||
@api_remote(path="/api/controller/heartbeat", method="POST")
|
||||
async def send_heartbeat(self, instance: ModelInstance) -> bool:
|
||||
pass
|
||||
|
@ -51,7 +51,7 @@ class ModelWorkerParameters(BaseParameters):
|
||||
)
|
||||
|
||||
port: Optional[int] = field(
|
||||
default=8000, metadata={"help": "Model worker deploy port"}
|
||||
default=8001, metadata={"help": "Model worker deploy port"}
|
||||
)
|
||||
daemon: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Run Model Worker in background"}
|
||||
|
@ -82,6 +82,7 @@ class DefaultModelWorker(ModelWorker):
|
||||
|
||||
def stop(self) -> None:
|
||||
if not self.model:
|
||||
logger.warn("Model has been stopped!!")
|
||||
return
|
||||
del self.model
|
||||
del self.tokenizer
|
||||
@ -98,23 +99,24 @@ class DefaultModelWorker(ModelWorker):
|
||||
params, self.ml.model_path, prompt_template=self.ml.prompt_template
|
||||
)
|
||||
|
||||
previous_response = ""
|
||||
print("stream output:\n")
|
||||
for output in self.generate_stream_func(
|
||||
self.model, self.tokenizer, params, get_device(), self.context_len
|
||||
):
|
||||
# Please do not open the output in production!
|
||||
# The gpt4all thread shares stdout with the parent process,
|
||||
# and opening it may affect the frontend output.
|
||||
if "windows" in platform.platform().lower():
|
||||
# Do not print the model output, because it may contain Emoji, there is a problem with the GBK encoding
|
||||
pass
|
||||
else:
|
||||
print("output: ", output)
|
||||
incremental_output = output[len(previous_response) :]
|
||||
# print("output: ", output)
|
||||
print(incremental_output, end="", flush=True)
|
||||
previous_response = output
|
||||
# return some model context to dgt-server
|
||||
model_output = ModelOutput(
|
||||
text=output, error_code=0, model_context=model_context
|
||||
)
|
||||
yield model_output
|
||||
|
||||
print(f"\n\nfull stream output:\n{previous_response}")
|
||||
except torch.cuda.CudaError:
|
||||
model_output = ModelOutput(
|
||||
text="**GPU OutOfMemory, Please Refresh.**", error_code=0
|
||||
|
@ -225,7 +225,14 @@ class LocalWorkerManager(WorkerManager):
|
||||
self, params: Dict, async_wrapper=None, **kwargs
|
||||
) -> Iterator[ModelOutput]:
|
||||
"""Generate stream result, chat scene"""
|
||||
worker_run_data = await self._get_model(params)
|
||||
try:
|
||||
worker_run_data = await self._get_model(params)
|
||||
except Exception as e:
|
||||
yield ModelOutput(
|
||||
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
|
||||
error_code=0,
|
||||
)
|
||||
return
|
||||
async with worker_run_data.semaphore:
|
||||
if worker_run_data.worker.support_async():
|
||||
async for outout in worker_run_data.worker.async_generate_stream(
|
||||
@ -244,7 +251,13 @@ class LocalWorkerManager(WorkerManager):
|
||||
|
||||
async def generate(self, params: Dict) -> ModelOutput:
|
||||
"""Generate non stream result"""
|
||||
worker_run_data = await self._get_model(params)
|
||||
try:
|
||||
worker_run_data = await self._get_model(params)
|
||||
except Exception as e:
|
||||
return ModelOutput(
|
||||
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
|
||||
error_code=0,
|
||||
)
|
||||
async with worker_run_data.semaphore:
|
||||
if worker_run_data.worker.support_async():
|
||||
return await worker_run_data.worker.async_generate(params)
|
||||
@ -256,7 +269,10 @@ class LocalWorkerManager(WorkerManager):
|
||||
|
||||
async def embeddings(self, params: Dict) -> List[List[float]]:
|
||||
"""Embed input"""
|
||||
worker_run_data = await self._get_model(params, worker_type="text2vec")
|
||||
try:
|
||||
worker_run_data = await self._get_model(params, worker_type="text2vec")
|
||||
except Exception as e:
|
||||
raise e
|
||||
async with worker_run_data.semaphore:
|
||||
if worker_run_data.worker.support_async():
|
||||
return await worker_run_data.worker.async_embeddings(params)
|
||||
@ -272,6 +288,8 @@ class LocalWorkerManager(WorkerManager):
|
||||
apply_func = self._start_all_worker
|
||||
elif apply_req.apply_type == WorkerApplyType.STOP:
|
||||
apply_func = self._stop_all_worker
|
||||
elif apply_req.apply_type == WorkerApplyType.RESTART:
|
||||
apply_func = self._restart_all_worker
|
||||
elif apply_req.apply_type == WorkerApplyType.UPDATE_PARAMS:
|
||||
apply_func = self._update_all_worker_params
|
||||
else:
|
||||
@ -298,10 +316,13 @@ class LocalWorkerManager(WorkerManager):
|
||||
apply_req (WorkerApplyRequest): Worker apply request
|
||||
apply_func (ApplyFunction): Function to apply to worker instances, now function is async function
|
||||
"""
|
||||
logger.info(f"Apply req: {apply_req}, apply_func: {apply_func}")
|
||||
if apply_req:
|
||||
worker_type = apply_req.worker_type.value
|
||||
model_name = apply_req.model
|
||||
worker_instances = await self.get_model_instances(worker_type, model_name)
|
||||
worker_instances = await self.get_model_instances(
|
||||
worker_type, model_name, healthy_only=False
|
||||
)
|
||||
if not worker_instances:
|
||||
raise Exception(
|
||||
f"No worker instance found for the model {model_name} worker type {worker_type}"
|
||||
@ -370,6 +391,12 @@ class LocalWorkerManager(WorkerManager):
|
||||
message=f"Worker stopped successfully", timecost=timecost
|
||||
)
|
||||
|
||||
async def _restart_all_worker(
|
||||
self, apply_req: WorkerApplyRequest
|
||||
) -> WorkerApplyOutput:
|
||||
await self._stop_all_worker(apply_req)
|
||||
return await self._start_all_worker(apply_req)
|
||||
|
||||
async def _update_all_worker_params(
|
||||
self, apply_req: WorkerApplyRequest
|
||||
) -> WorkerApplyOutput:
|
||||
@ -389,8 +416,7 @@ class LocalWorkerManager(WorkerManager):
|
||||
timecost = time.time() - start_time
|
||||
if need_restart:
|
||||
logger.info("Model params update successfully, begin restart worker")
|
||||
await self._stop_all_worker(apply_req)
|
||||
await self._start_all_worker(apply_req)
|
||||
await self._restart_all_worker(apply_req)
|
||||
timecost = time.time() - start_time
|
||||
message = f"Update worker params and restart successfully"
|
||||
return WorkerApplyOutput(message=message, timecost=timecost)
|
||||
@ -500,8 +526,8 @@ async def generate_json_stream(params):
|
||||
|
||||
|
||||
@router.post("/worker/generate_stream")
|
||||
async def api_generate_stream(request: Request):
|
||||
params = await request.json()
|
||||
async def api_generate_stream(request: PromptRequest):
|
||||
params = request.dict(exclude_none=True)
|
||||
generator = generate_json_stream(params)
|
||||
return StreamingResponse(generator)
|
||||
|
||||
@ -534,13 +560,19 @@ async def api_worker_parameter_descs(
|
||||
return output
|
||||
|
||||
|
||||
def _setup_fastapi(worker_params: ModelWorkerParameters):
|
||||
app = FastAPI()
|
||||
def _setup_fastapi(worker_params: ModelWorkerParameters, app=None):
|
||||
if not app:
|
||||
app = FastAPI()
|
||||
if worker_params.standalone:
|
||||
from pilot.model.controller.controller import router as controller_router
|
||||
from pilot.model.controller.controller import initialize_controller
|
||||
|
||||
if not worker_params.controller_addr:
|
||||
worker_params.controller_addr = f"http://127.0.0.1:{worker_params.port}"
|
||||
logger.info(
|
||||
f"Run WorkerManager with standalone mode, controller_addr: {worker_params.controller_addr}"
|
||||
)
|
||||
initialize_controller(app=app)
|
||||
app.include_router(controller_router, prefix="/api")
|
||||
|
||||
@app.on_event("startup")
|
||||
@ -570,7 +602,7 @@ def _parse_worker_params(
|
||||
)
|
||||
worker_params.update_from(new_worker_params)
|
||||
|
||||
logger.info(f"Worker params: {worker_params}")
|
||||
# logger.info(f"Worker params: {worker_params}")
|
||||
return worker_params
|
||||
|
||||
|
||||
@ -583,7 +615,7 @@ def _create_local_model_manager(
|
||||
)
|
||||
return LocalWorkerManager()
|
||||
else:
|
||||
from pilot.model.controller.registry import ModelRegistryClient
|
||||
from pilot.model.controller.controller import ModelRegistryClient
|
||||
from pilot.utils.net_utils import _get_ip_address
|
||||
|
||||
client = ModelRegistryClient(worker_params.controller_addr)
|
||||
@ -600,6 +632,12 @@ def _create_local_model_manager(
|
||||
)
|
||||
return await client.register_instance(instance)
|
||||
|
||||
async def deregister_func(worker_run_data: WorkerRunData):
|
||||
instance = ModelInstance(
|
||||
model_name=worker_run_data.worker_key, host=host, port=port
|
||||
)
|
||||
return await client.deregister_instance(instance)
|
||||
|
||||
async def send_heartbeat_func(worker_run_data: WorkerRunData):
|
||||
instance = ModelInstance(
|
||||
model_name=worker_run_data.worker_key, host=host, port=port
|
||||
@ -607,7 +645,9 @@ def _create_local_model_manager(
|
||||
return await client.send_heartbeat(instance)
|
||||
|
||||
return LocalWorkerManager(
|
||||
register_func=register_func, send_heartbeat_func=send_heartbeat_func
|
||||
register_func=register_func,
|
||||
deregister_func=deregister_func,
|
||||
send_heartbeat_func=send_heartbeat_func,
|
||||
)
|
||||
|
||||
|
||||
@ -642,30 +682,60 @@ def initialize_worker_manager_in_client(
|
||||
model_path: str = None,
|
||||
run_locally: bool = True,
|
||||
controller_addr: str = None,
|
||||
local_port: int = 5000,
|
||||
):
|
||||
"""Initialize WorkerManager in client.
|
||||
If run_locally is True:
|
||||
1. Start ModelController
|
||||
2. Start LocalWorkerManager
|
||||
3. Start worker in LocalWorkerManager
|
||||
4. Register worker to ModelController
|
||||
|
||||
otherwise:
|
||||
1. Build ModelRegistryClient with controller address
|
||||
2. Start RemoteWorkerManager
|
||||
|
||||
"""
|
||||
global worker_manager
|
||||
|
||||
if not app:
|
||||
raise Exception("app can't be None")
|
||||
|
||||
worker_params: ModelWorkerParameters = _parse_worker_params(
|
||||
model_name=model_name, model_path=model_path, controller_addr=controller_addr
|
||||
)
|
||||
|
||||
logger.info(f"Worker params: {worker_params}")
|
||||
controller_addr = None
|
||||
if run_locally:
|
||||
worker_params.register = False
|
||||
# TODO start ModelController
|
||||
worker_params.standalone = True
|
||||
worker_params.register = True
|
||||
worker_params.port = local_port
|
||||
logger.info(f"Worker params: {worker_params}")
|
||||
_setup_fastapi(worker_params, app)
|
||||
_start_local_worker(worker_manager, worker_params, True)
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(
|
||||
worker_manager.worker_manager._start_all_worker(apply_req=None)
|
||||
)
|
||||
# loop = asyncio.get_event_loop()
|
||||
# loop.run_until_complete(
|
||||
# worker_manager.worker_manager._start_all_worker(apply_req=None)
|
||||
# )
|
||||
else:
|
||||
from pilot.model.controller.registry import ModelRegistryClient
|
||||
from pilot.model.controller.controller import (
|
||||
initialize_controller,
|
||||
ModelRegistryClient,
|
||||
)
|
||||
|
||||
if not worker_params.controller_addr:
|
||||
raise ValueError("Controller can`t be None")
|
||||
controller_addr = worker_params.controller_addr
|
||||
logger.info(f"Worker params: {worker_params}")
|
||||
client = ModelRegistryClient(worker_params.controller_addr)
|
||||
worker_manager.worker_manager = RemoteWorkerManager(client)
|
||||
initialize_controller(
|
||||
app=app, remote_controller_addr=worker_params.controller_addr
|
||||
)
|
||||
|
||||
if include_router and app:
|
||||
# mount WorkerManager router
|
||||
app.include_router(router, prefix="/api")
|
||||
|
||||
|
||||
|
@ -51,8 +51,26 @@ def stop():
|
||||
pass
|
||||
|
||||
|
||||
@click.group()
|
||||
def install():
|
||||
"""Install dependencies, plugins, etc."""
|
||||
pass
|
||||
|
||||
|
||||
stop_all_func_list = []
|
||||
|
||||
|
||||
@click.command(name="all")
|
||||
def stop_all():
|
||||
"""Stop all servers"""
|
||||
for stop_func in stop_all_func_list:
|
||||
stop_func()
|
||||
|
||||
|
||||
cli.add_command(start)
|
||||
cli.add_command(stop)
|
||||
cli.add_command(install)
|
||||
add_command_alias(stop_all, name="all", parent_group=stop)
|
||||
|
||||
try:
|
||||
from pilot.model.cli import (
|
||||
@ -63,6 +81,7 @@ try:
|
||||
stop_model_worker,
|
||||
start_apiserver,
|
||||
stop_apiserver,
|
||||
_stop_all_model_server,
|
||||
)
|
||||
|
||||
add_command_alias(model_cli_group, name="model", parent_group=cli)
|
||||
@ -73,15 +92,21 @@ try:
|
||||
add_command_alias(stop_model_controller, name="controller", parent_group=stop)
|
||||
add_command_alias(stop_model_worker, name="worker", parent_group=stop)
|
||||
add_command_alias(stop_apiserver, name="apiserver", parent_group=stop)
|
||||
stop_all_func_list.append(_stop_all_model_server)
|
||||
|
||||
except ImportError as e:
|
||||
logging.warning(f"Integrating dbgpt model command line tool failed: {e}")
|
||||
|
||||
try:
|
||||
from pilot.server._cli import start_webserver, stop_webserver
|
||||
from pilot.server._cli import (
|
||||
start_webserver,
|
||||
stop_webserver,
|
||||
_stop_all_dbgpt_server,
|
||||
)
|
||||
|
||||
add_command_alias(start_webserver, name="webserver", parent_group=start)
|
||||
add_command_alias(stop_webserver, name="webserver", parent_group=stop)
|
||||
stop_all_func_list.append(_stop_all_dbgpt_server)
|
||||
|
||||
except ImportError as e:
|
||||
logging.warning(f"Integrating dbgpt webserver command line tool failed: {e}")
|
||||
|
@ -30,3 +30,7 @@ def start_webserver(**kwargs):
|
||||
def stop_webserver(port: int):
|
||||
"""Stop webserver(dbgpt_server.py)"""
|
||||
_stop_service("webserver", "WebServer", port=port)
|
||||
|
||||
|
||||
def _stop_all_dbgpt_server():
|
||||
_stop_service("webserver", "WebServer")
|
||||
|
@ -104,7 +104,10 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
|
||||
if not param.light:
|
||||
print("Model Unified Deployment Mode!")
|
||||
initialize_worker_manager_in_client(
|
||||
app=app, model_name=CFG.LLM_MODEL, model_path=model_path
|
||||
app=app,
|
||||
model_name=CFG.LLM_MODEL,
|
||||
model_path=model_path,
|
||||
local_port=param.port,
|
||||
)
|
||||
|
||||
CFG.NEW_SERVER_MODE = True
|
||||
@ -116,6 +119,7 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
|
||||
model_path=model_path,
|
||||
run_locally=False,
|
||||
controller_addr=CFG.MODEL_SERVER,
|
||||
local_port=param.port,
|
||||
)
|
||||
CFG.SERVER_LIGHT_MODE = True
|
||||
|
||||
|
@ -6,15 +6,30 @@ import psutil
|
||||
import platform
|
||||
|
||||
|
||||
def _get_abspath_of_current_command(command_path: str):
|
||||
if not command_path.endswith(".py"):
|
||||
return command_path
|
||||
# This implementation is very ugly
|
||||
command_path = os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
|
||||
"scripts",
|
||||
"cli_scripts.py",
|
||||
)
|
||||
return command_path
|
||||
|
||||
|
||||
def _run_current_with_daemon(name: str, log_file: str):
|
||||
# Get all arguments except for --daemon
|
||||
args = [arg for arg in sys.argv if arg != "--daemon" and arg != "-d"]
|
||||
args[0] = _get_abspath_of_current_command(args[0])
|
||||
|
||||
daemon_cmd = [sys.executable] + args
|
||||
daemon_cmd = " ".join(daemon_cmd)
|
||||
daemon_cmd += f" > {log_file} 2>&1"
|
||||
|
||||
print(f"daemon cmd: {daemon_cmd}")
|
||||
# Check the platform and set the appropriate flags or functions
|
||||
if platform.system() == "Windows":
|
||||
if "windows" in platform.system().lower():
|
||||
process = subprocess.Popen(
|
||||
daemon_cmd,
|
||||
creationflags=subprocess.CREATE_NEW_PROCESS_GROUP,
|
||||
@ -50,7 +65,7 @@ def _run_current_with_gunicorn(app: str, config_path: str, kwargs: Dict):
|
||||
app_env = EnvArgumentParser._kwargs_to_env_key_value(kwargs)
|
||||
env_to_app.update(app_env)
|
||||
cmd = f"uvicorn {app} --host 0.0.0.0 --port 5000"
|
||||
if platform.system() == "Windows":
|
||||
if "windows" in platform.system().lower():
|
||||
raise Exception("Not support on windows")
|
||||
else: # macOS, Linux, and other Unix-like systems
|
||||
process = subprocess.Popen(cmd, shell=True, env=env_to_app)
|
||||
@ -89,3 +104,54 @@ def _stop_service(
|
||||
|
||||
if not_found:
|
||||
print(f"{fullname} process not found.")
|
||||
|
||||
|
||||
def _get_ports_by_cmdline_part(service_keys: List[str]) -> List[int]:
|
||||
"""
|
||||
Return a list of ports that are associated with processes that have all the service_keys in their cmdline.
|
||||
|
||||
Args:
|
||||
service_keys (List[str]): List of strings that should all be present in the process's cmdline.
|
||||
|
||||
Returns:
|
||||
List[int]: List of ports sorted with preference for 8000 and 5000, and then in ascending order.
|
||||
"""
|
||||
ports = []
|
||||
|
||||
for process in psutil.process_iter(attrs=["pid", "name", "connections", "cmdline"]):
|
||||
try:
|
||||
# Convert the cmdline list to a single string for easier checking
|
||||
cmdline = " ".join(process.info["cmdline"])
|
||||
|
||||
# Check if all the service keys are present in the cmdline
|
||||
if all(fragment in cmdline for fragment in service_keys):
|
||||
for connection in process.info["connections"]:
|
||||
if connection.status == psutil.CONN_LISTEN:
|
||||
ports.append(connection.laddr.port)
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
|
||||
pass
|
||||
|
||||
# Sort ports with preference for 8000 and 5000
|
||||
ports.sort(key=lambda x: (x != 8000, x != 5000, x))
|
||||
|
||||
return ports
|
||||
|
||||
|
||||
def _detect_controller_address() -> str:
|
||||
controller_addr = os.getenv("CONTROLLER_ADDRESS")
|
||||
if controller_addr:
|
||||
return controller_addr
|
||||
|
||||
cmdline_fragments = [
|
||||
["python", "start", "controller"],
|
||||
["python", "controller"],
|
||||
["python", "start", "webserver"],
|
||||
["python", "dbgpt_server"],
|
||||
]
|
||||
|
||||
for fragments in cmdline_fragments:
|
||||
ports = _get_ports_by_cmdline_part(fragments)
|
||||
if ports:
|
||||
return f"http://127.0.0.1:{ports[0]}"
|
||||
|
||||
return f"http://127.0.0.1:8000"
|
||||
|
Loading…
Reference in New Issue
Block a user