diff --git a/pilot/model/cli.py b/pilot/model/cli.py index 5f8a23def..4b981e504 100644 --- a/pilot/model/cli.py +++ b/pilot/model/cli.py @@ -4,7 +4,7 @@ import logging import os from typing import Callable, List, Type -from pilot.model.controller.registry import ModelRegistryClient +from pilot.model.controller.controller import ModelRegistryClient from pilot.configs.model_config import LOGDIR from pilot.model.base import WorkerApplyType from pilot.model.parameter import ( @@ -26,18 +26,23 @@ logger = logging.getLogger("dbgpt_cli") @click.option( "--address", type=str, - default=MODEL_CONTROLLER_ADDRESS, + default=None, required=False, show_default=True, help=( "Address of the Model Controller to connect to. " - "Just support light deploy model" + "Just support light deploy model, If the environment variable CONTROLLER_ADDRESS is configured, read from the environment variable" ), ) def model_cli_group(address: str): """Clients that manage model serving""" global MODEL_CONTROLLER_ADDRESS - MODEL_CONTROLLER_ADDRESS = address + if not address: + from pilot.utils.command_utils import _detect_controller_address + + MODEL_CONTROLLER_ADDRESS = _detect_controller_address() + else: + MODEL_CONTROLLER_ADDRESS = address @model_cli_group.command() @@ -140,6 +145,26 @@ def restart(model_name: str, model_type: str): ) +@model_cli_group.command() +@click.option( + "--model_name", + type=str, + default=None, + required=True, + help=("The name of model"), +) +@click.option( + "--system", + type=str, + default=None, + required=False, + help=("System prompt"), +) +def chat(model_name: str, system: str): + """Interact with your bot from the command line""" + _cli_chat(MODEL_CONTROLLER_ADDRESS, model_name, system) + + # @model_cli_group.command() # @add_model_options # def modify(address: str, model_name: str, model_type: str): @@ -147,14 +172,21 @@ def restart(model_name: str, model_type: str): # worker_apply(address, model_name, model_type, WorkerApplyType.UPDATE_PARAMS) +def _get_worker_manager(address: str): + from pilot.model.worker.manager import RemoteWorkerManager, WorkerApplyRequest + + registry = ModelRegistryClient(address) + worker_manager = RemoteWorkerManager(registry) + return worker_manager + + def worker_apply( address: str, model_name: str, model_type: str, apply_type: WorkerApplyType ): - from pilot.model.worker.manager import RemoteWorkerManager, WorkerApplyRequest + from pilot.model.worker.manager import WorkerApplyRequest loop = get_or_create_event_loop() - registry = ModelRegistryClient(address) - worker_manager = RemoteWorkerManager(registry) + worker_manager = _get_worker_manager(address) apply_req = WorkerApplyRequest( model=model_name, worker_type=model_type, apply_type=apply_type ) @@ -162,6 +194,41 @@ def worker_apply( print(res) +def _cli_chat(address: str, model_name: str, system_prompt: str = None): + loop = get_or_create_event_loop() + worker_manager = worker_manager = _get_worker_manager(address) + loop.run_until_complete(_chat_stream(worker_manager, model_name, system_prompt)) + + +async def _chat_stream(worker_manager, model_name: str, system_prompt: str = None): + from pilot.model.worker.manager import PromptRequest + from pilot.scene.base_message import ModelMessage, ModelMessageRoleType + + print(f"Chatbot started with model {model_name}. Type 'exit' to leave the chat.") + hist = [] + previous_response = "" + if system_prompt: + hist.append( + ModelMessage(role=ModelMessageRoleType.SYSTEM, content=system_prompt) + ) + while True: + previous_response = "" + user_input = input("\n\nYou: ") + if user_input.lower().strip() == "exit": + break + hist.append(ModelMessage(role=ModelMessageRoleType.HUMAN, content=user_input)) + request = PromptRequest(messages=hist, model=model_name, prompt="", echo=False) + request = request.dict(exclude_none=True) + print("Bot: ", end="") + async for response in worker_manager.generate_stream(request): + incremental_output = response.text[len(previous_response) :] + print(incremental_output, end="", flush=True) + previous_response = response.text + hist.append( + ModelMessage(role=ModelMessageRoleType.AI, content=previous_response) + ) + + def add_stop_server_options(func): @click.option( "--port", @@ -249,3 +316,9 @@ def start_apiserver(**kwargs): def stop_apiserver(**kwargs): """Start apiserver(TODO)""" raise NotImplementedError + + +def _stop_all_model_server(**kwargs): + """Stop all server""" + _stop_service("worker", "ModelWorker") + _stop_service("controller", "ModelController") diff --git a/pilot/model/controller/controller.py b/pilot/model/controller/controller.py index 51b61826e..bb19e31df 100644 --- a/pilot/model/controller/controller.py +++ b/pilot/model/controller/controller.py @@ -1,3 +1,5 @@ +from abc import ABC, abstractmethod + import logging from typing import List @@ -6,9 +8,33 @@ from pilot.model.base import ModelInstance from pilot.model.parameter import ModelControllerParameters from pilot.model.controller.registry import EmbeddedModelRegistry, ModelRegistry from pilot.utils.parameter_utils import EnvArgumentParser +from pilot.utils.api_utils import _api_remote as api_remote -class ModelController: +class BaseModelController(ABC): + @abstractmethod + async def register_instance(self, instance: ModelInstance) -> bool: + """Register a given model instance""" + + @abstractmethod + async def deregister_instance(self, instance: ModelInstance) -> bool: + """Deregister a given model instance.""" + + @abstractmethod + async def get_all_instances( + self, model_name: str, healthy_only: bool = False + ) -> List[ModelInstance]: + """Fetch all instances of a given model. Optionally, fetch only the healthy instances.""" + + @abstractmethod + async def send_heartbeat(self, instance: ModelInstance) -> bool: + """Send a heartbeat for a given model instance. This can be used to verify if the instance is still alive and functioning.""" + + async def model_apply(self) -> bool: + raise NotImplementedError + + +class LocalModelController(BaseModelController): def __init__(self, registry: ModelRegistry = None) -> None: if not registry: registry = EmbeddedModelRegistry() @@ -27,22 +53,87 @@ class ModelController: logging.info( f"Get all instances with {model_name}, healthy_only: {healthy_only}" ) - return await self.registry.get_all_instances(model_name, healthy_only) - - async def get_all_model_instances(self) -> List[ModelInstance]: - return await self.registry.get_all_model_instances() + if not model_name: + return await self.registry.get_all_model_instances() + else: + return await self.registry.get_all_instances(model_name, healthy_only) async def send_heartbeat(self, instance: ModelInstance) -> bool: return await self.registry.send_heartbeat(instance) + +class _RemoteModelController(BaseModelController): + def __init__(self, base_url: str) -> None: + self.base_url = base_url + + @api_remote(path="/api/controller/models", method="POST") + async def register_instance(self, instance: ModelInstance) -> bool: + pass + + @api_remote(path="/api/controller/models", method="DELETE") + async def deregister_instance(self, instance: ModelInstance) -> bool: + pass + + @api_remote(path="/api/controller/models") + async def get_all_instances( + self, model_name: str = None, healthy_only: bool = False + ) -> List[ModelInstance]: + pass + + @api_remote(path="/api/controller/heartbeat", method="POST") + async def send_heartbeat(self, instance: ModelInstance) -> bool: + pass + + +class ModelRegistryClient(_RemoteModelController, ModelRegistry): + async def get_all_model_instances(self) -> List[ModelInstance]: + return await self.get_all_instances() + + +class ModelControllerAdapter(BaseModelController): + def __init__(self, backend: BaseModelController = None) -> None: + self.backend = backend + + async def register_instance(self, instance: ModelInstance) -> bool: + return await self.backend.register_instance(instance) + + async def deregister_instance(self, instance: ModelInstance) -> bool: + return await self.backend.deregister_instance(instance) + + async def get_all_instances( + self, model_name: str, healthy_only: bool = False + ) -> List[ModelInstance]: + return await self.backend.get_all_instances(model_name, healthy_only) + + async def send_heartbeat(self, instance: ModelInstance) -> bool: + return await self.backend.send_heartbeat(instance) + async def model_apply(self) -> bool: - # TODO - raise NotImplementedError + return await self.backend.model_apply() router = APIRouter() -controller = ModelController() +controller = ModelControllerAdapter() + + +def initialize_controller( + app=None, remote_controller_addr: str = None, host: str = None, port: int = None +): + global controller + if remote_controller_addr: + controller.backend = _RemoteModelController(remote_controller_addr) + else: + controller.backend = LocalModelController() + + if app: + app.include_router(router, prefix="/api") + else: + import uvicorn + + app = FastAPI() + app.include_router(router, prefix="/api") + uvicorn.run(app, host=host, port=port, log_level="info") @router.post("/controller/models") @@ -51,14 +142,13 @@ async def api_register_instance(request: ModelInstance): @router.delete("/controller/models") -async def api_deregister_instance(request: ModelInstance): - return await controller.deregister_instance(request) +async def api_deregister_instance(model_name: str, host: str, port: int): + instance = ModelInstance(model_name=model_name, host=host, port=port) + return await controller.deregister_instance(instance) @router.get("/controller/models") async def api_get_all_instances(model_name: str = None, healthy_only: bool = False): - if not model_name: - return await controller.get_all_model_instances() return await controller.get_all_instances(model_name, healthy_only=healthy_only) @@ -68,18 +158,12 @@ async def api_model_heartbeat(request: ModelInstance): def run_model_controller(): - import uvicorn - parser = EnvArgumentParser() env_prefix = "controller_" controller_params: ModelControllerParameters = parser.parse_args_into_dataclass( ModelControllerParameters, env_prefix=env_prefix ) - app = FastAPI() - app.include_router(router, prefix="/api") - uvicorn.run( - app, host=controller_params.host, port=controller_params.port, log_level="info" - ) + initialize_controller(host=controller_params.host, port=controller_params.port) if __name__ == "__main__": diff --git a/pilot/model/controller/registry.py b/pilot/model/controller/registry.py index 445e68c26..e5e9b0618 100644 --- a/pilot/model/controller/registry.py +++ b/pilot/model/controller/registry.py @@ -178,47 +178,11 @@ class EmbeddedModelRegistry(ModelRegistry): instance.model_name, instance.host, instance.port, healthy_only=False ) if not exist_ins: - return False + # register new install from heartbeat + self.register_instance(instance) + return True ins = exist_ins[0] ins.last_heartbeat = datetime.now() ins.healthy = True return True - - -from pilot.utils.api_utils import _api_remote as api_remote - - -class ModelRegistryClient(ModelRegistry): - def __init__(self, base_url: str) -> None: - self.base_url = base_url - - @api_remote(path="/api/controller/models", method="POST") - async def register_instance(self, instance: ModelInstance) -> bool: - pass - - @api_remote(path="/api/controller/models", method="DELETE") - async def deregister_instance(self, instance: ModelInstance) -> bool: - pass - - @api_remote(path="/api/controller/models") - async def get_all_instances( - self, model_name: str, healthy_only: bool = False - ) -> List[ModelInstance]: - pass - - @api_remote(path="/api/controller/models") - async def get_all_model_instances(self) -> List[ModelInstance]: - pass - - @api_remote(path="/api/controller/models") - async def select_one_health_instance(self, model_name: str) -> ModelInstance: - instances = await self.get_all_instances(model_name, healthy_only=True) - instances = [i for i in instances if i.enabled] - if not instances: - return None - return random.choice(instances) - - @api_remote(path="/api/controller/heartbeat", method="POST") - async def send_heartbeat(self, instance: ModelInstance) -> bool: - pass diff --git a/pilot/model/parameter.py b/pilot/model/parameter.py index c140f7e35..a4067dab0 100644 --- a/pilot/model/parameter.py +++ b/pilot/model/parameter.py @@ -51,7 +51,7 @@ class ModelWorkerParameters(BaseParameters): ) port: Optional[int] = field( - default=8000, metadata={"help": "Model worker deploy port"} + default=8001, metadata={"help": "Model worker deploy port"} ) daemon: Optional[bool] = field( default=False, metadata={"help": "Run Model Worker in background"} diff --git a/pilot/model/worker/default_worker.py b/pilot/model/worker/default_worker.py index 331be53ec..157745ffe 100644 --- a/pilot/model/worker/default_worker.py +++ b/pilot/model/worker/default_worker.py @@ -82,6 +82,7 @@ class DefaultModelWorker(ModelWorker): def stop(self) -> None: if not self.model: + logger.warn("Model has been stopped!!") return del self.model del self.tokenizer @@ -98,23 +99,24 @@ class DefaultModelWorker(ModelWorker): params, self.ml.model_path, prompt_template=self.ml.prompt_template ) + previous_response = "" + print("stream output:\n") for output in self.generate_stream_func( self.model, self.tokenizer, params, get_device(), self.context_len ): # Please do not open the output in production! # The gpt4all thread shares stdout with the parent process, # and opening it may affect the frontend output. - if "windows" in platform.platform().lower(): - # Do not print the model output, because it may contain Emoji, there is a problem with the GBK encoding - pass - else: - print("output: ", output) + incremental_output = output[len(previous_response) :] + # print("output: ", output) + print(incremental_output, end="", flush=True) + previous_response = output # return some model context to dgt-server model_output = ModelOutput( text=output, error_code=0, model_context=model_context ) yield model_output - + print(f"\n\nfull stream output:\n{previous_response}") except torch.cuda.CudaError: model_output = ModelOutput( text="**GPU OutOfMemory, Please Refresh.**", error_code=0 diff --git a/pilot/model/worker/manager.py b/pilot/model/worker/manager.py index cb76f2529..76367cc79 100644 --- a/pilot/model/worker/manager.py +++ b/pilot/model/worker/manager.py @@ -225,7 +225,14 @@ class LocalWorkerManager(WorkerManager): self, params: Dict, async_wrapper=None, **kwargs ) -> Iterator[ModelOutput]: """Generate stream result, chat scene""" - worker_run_data = await self._get_model(params) + try: + worker_run_data = await self._get_model(params) + except Exception as e: + yield ModelOutput( + text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}", + error_code=0, + ) + return async with worker_run_data.semaphore: if worker_run_data.worker.support_async(): async for outout in worker_run_data.worker.async_generate_stream( @@ -244,7 +251,13 @@ class LocalWorkerManager(WorkerManager): async def generate(self, params: Dict) -> ModelOutput: """Generate non stream result""" - worker_run_data = await self._get_model(params) + try: + worker_run_data = await self._get_model(params) + except Exception as e: + return ModelOutput( + text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}", + error_code=0, + ) async with worker_run_data.semaphore: if worker_run_data.worker.support_async(): return await worker_run_data.worker.async_generate(params) @@ -256,7 +269,10 @@ class LocalWorkerManager(WorkerManager): async def embeddings(self, params: Dict) -> List[List[float]]: """Embed input""" - worker_run_data = await self._get_model(params, worker_type="text2vec") + try: + worker_run_data = await self._get_model(params, worker_type="text2vec") + except Exception as e: + raise e async with worker_run_data.semaphore: if worker_run_data.worker.support_async(): return await worker_run_data.worker.async_embeddings(params) @@ -272,6 +288,8 @@ class LocalWorkerManager(WorkerManager): apply_func = self._start_all_worker elif apply_req.apply_type == WorkerApplyType.STOP: apply_func = self._stop_all_worker + elif apply_req.apply_type == WorkerApplyType.RESTART: + apply_func = self._restart_all_worker elif apply_req.apply_type == WorkerApplyType.UPDATE_PARAMS: apply_func = self._update_all_worker_params else: @@ -298,10 +316,13 @@ class LocalWorkerManager(WorkerManager): apply_req (WorkerApplyRequest): Worker apply request apply_func (ApplyFunction): Function to apply to worker instances, now function is async function """ + logger.info(f"Apply req: {apply_req}, apply_func: {apply_func}") if apply_req: worker_type = apply_req.worker_type.value model_name = apply_req.model - worker_instances = await self.get_model_instances(worker_type, model_name) + worker_instances = await self.get_model_instances( + worker_type, model_name, healthy_only=False + ) if not worker_instances: raise Exception( f"No worker instance found for the model {model_name} worker type {worker_type}" @@ -370,6 +391,12 @@ class LocalWorkerManager(WorkerManager): message=f"Worker stopped successfully", timecost=timecost ) + async def _restart_all_worker( + self, apply_req: WorkerApplyRequest + ) -> WorkerApplyOutput: + await self._stop_all_worker(apply_req) + return await self._start_all_worker(apply_req) + async def _update_all_worker_params( self, apply_req: WorkerApplyRequest ) -> WorkerApplyOutput: @@ -389,8 +416,7 @@ class LocalWorkerManager(WorkerManager): timecost = time.time() - start_time if need_restart: logger.info("Model params update successfully, begin restart worker") - await self._stop_all_worker(apply_req) - await self._start_all_worker(apply_req) + await self._restart_all_worker(apply_req) timecost = time.time() - start_time message = f"Update worker params and restart successfully" return WorkerApplyOutput(message=message, timecost=timecost) @@ -500,8 +526,8 @@ async def generate_json_stream(params): @router.post("/worker/generate_stream") -async def api_generate_stream(request: Request): - params = await request.json() +async def api_generate_stream(request: PromptRequest): + params = request.dict(exclude_none=True) generator = generate_json_stream(params) return StreamingResponse(generator) @@ -534,13 +560,19 @@ async def api_worker_parameter_descs( return output -def _setup_fastapi(worker_params: ModelWorkerParameters): - app = FastAPI() +def _setup_fastapi(worker_params: ModelWorkerParameters, app=None): + if not app: + app = FastAPI() if worker_params.standalone: from pilot.model.controller.controller import router as controller_router + from pilot.model.controller.controller import initialize_controller if not worker_params.controller_addr: worker_params.controller_addr = f"http://127.0.0.1:{worker_params.port}" + logger.info( + f"Run WorkerManager with standalone mode, controller_addr: {worker_params.controller_addr}" + ) + initialize_controller(app=app) app.include_router(controller_router, prefix="/api") @app.on_event("startup") @@ -570,7 +602,7 @@ def _parse_worker_params( ) worker_params.update_from(new_worker_params) - logger.info(f"Worker params: {worker_params}") + # logger.info(f"Worker params: {worker_params}") return worker_params @@ -583,7 +615,7 @@ def _create_local_model_manager( ) return LocalWorkerManager() else: - from pilot.model.controller.registry import ModelRegistryClient + from pilot.model.controller.controller import ModelRegistryClient from pilot.utils.net_utils import _get_ip_address client = ModelRegistryClient(worker_params.controller_addr) @@ -600,6 +632,12 @@ def _create_local_model_manager( ) return await client.register_instance(instance) + async def deregister_func(worker_run_data: WorkerRunData): + instance = ModelInstance( + model_name=worker_run_data.worker_key, host=host, port=port + ) + return await client.deregister_instance(instance) + async def send_heartbeat_func(worker_run_data: WorkerRunData): instance = ModelInstance( model_name=worker_run_data.worker_key, host=host, port=port @@ -607,7 +645,9 @@ def _create_local_model_manager( return await client.send_heartbeat(instance) return LocalWorkerManager( - register_func=register_func, send_heartbeat_func=send_heartbeat_func + register_func=register_func, + deregister_func=deregister_func, + send_heartbeat_func=send_heartbeat_func, ) @@ -642,30 +682,60 @@ def initialize_worker_manager_in_client( model_path: str = None, run_locally: bool = True, controller_addr: str = None, + local_port: int = 5000, ): + """Initialize WorkerManager in client. + If run_locally is True: + 1. Start ModelController + 2. Start LocalWorkerManager + 3. Start worker in LocalWorkerManager + 4. Register worker to ModelController + + otherwise: + 1. Build ModelRegistryClient with controller address + 2. Start RemoteWorkerManager + + """ global worker_manager + if not app: + raise Exception("app can't be None") + worker_params: ModelWorkerParameters = _parse_worker_params( model_name=model_name, model_path=model_path, controller_addr=controller_addr ) - logger.info(f"Worker params: {worker_params}") + controller_addr = None if run_locally: - worker_params.register = False + # TODO start ModelController + worker_params.standalone = True + worker_params.register = True + worker_params.port = local_port + logger.info(f"Worker params: {worker_params}") + _setup_fastapi(worker_params, app) _start_local_worker(worker_manager, worker_params, True) - loop = asyncio.get_event_loop() - loop.run_until_complete( - worker_manager.worker_manager._start_all_worker(apply_req=None) - ) + # loop = asyncio.get_event_loop() + # loop.run_until_complete( + # worker_manager.worker_manager._start_all_worker(apply_req=None) + # ) else: - from pilot.model.controller.registry import ModelRegistryClient + from pilot.model.controller.controller import ( + initialize_controller, + ModelRegistryClient, + ) if not worker_params.controller_addr: raise ValueError("Controller can`t be None") + controller_addr = worker_params.controller_addr + logger.info(f"Worker params: {worker_params}") client = ModelRegistryClient(worker_params.controller_addr) worker_manager.worker_manager = RemoteWorkerManager(client) + initialize_controller( + app=app, remote_controller_addr=worker_params.controller_addr + ) if include_router and app: + # mount WorkerManager router app.include_router(router, prefix="/api") diff --git a/pilot/scripts/cli_scripts.py b/pilot/scripts/cli_scripts.py index 6a0bcbb5e..a0a7f029e 100644 --- a/pilot/scripts/cli_scripts.py +++ b/pilot/scripts/cli_scripts.py @@ -51,8 +51,26 @@ def stop(): pass +@click.group() +def install(): + """Install dependencies, plugins, etc.""" + pass + + +stop_all_func_list = [] + + +@click.command(name="all") +def stop_all(): + """Stop all servers""" + for stop_func in stop_all_func_list: + stop_func() + + cli.add_command(start) cli.add_command(stop) +cli.add_command(install) +add_command_alias(stop_all, name="all", parent_group=stop) try: from pilot.model.cli import ( @@ -63,6 +81,7 @@ try: stop_model_worker, start_apiserver, stop_apiserver, + _stop_all_model_server, ) add_command_alias(model_cli_group, name="model", parent_group=cli) @@ -73,15 +92,21 @@ try: add_command_alias(stop_model_controller, name="controller", parent_group=stop) add_command_alias(stop_model_worker, name="worker", parent_group=stop) add_command_alias(stop_apiserver, name="apiserver", parent_group=stop) + stop_all_func_list.append(_stop_all_model_server) except ImportError as e: logging.warning(f"Integrating dbgpt model command line tool failed: {e}") try: - from pilot.server._cli import start_webserver, stop_webserver + from pilot.server._cli import ( + start_webserver, + stop_webserver, + _stop_all_dbgpt_server, + ) add_command_alias(start_webserver, name="webserver", parent_group=start) add_command_alias(stop_webserver, name="webserver", parent_group=stop) + stop_all_func_list.append(_stop_all_dbgpt_server) except ImportError as e: logging.warning(f"Integrating dbgpt webserver command line tool failed: {e}") diff --git a/pilot/server/_cli.py b/pilot/server/_cli.py index 90993d624..f9f135808 100644 --- a/pilot/server/_cli.py +++ b/pilot/server/_cli.py @@ -30,3 +30,7 @@ def start_webserver(**kwargs): def stop_webserver(port: int): """Stop webserver(dbgpt_server.py)""" _stop_service("webserver", "WebServer", port=port) + + +def _stop_all_dbgpt_server(): + _stop_service("webserver", "WebServer") diff --git a/pilot/server/dbgpt_server.py b/pilot/server/dbgpt_server.py index 440dd20e3..809743518 100644 --- a/pilot/server/dbgpt_server.py +++ b/pilot/server/dbgpt_server.py @@ -104,7 +104,10 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None): if not param.light: print("Model Unified Deployment Mode!") initialize_worker_manager_in_client( - app=app, model_name=CFG.LLM_MODEL, model_path=model_path + app=app, + model_name=CFG.LLM_MODEL, + model_path=model_path, + local_port=param.port, ) CFG.NEW_SERVER_MODE = True @@ -116,6 +119,7 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None): model_path=model_path, run_locally=False, controller_addr=CFG.MODEL_SERVER, + local_port=param.port, ) CFG.SERVER_LIGHT_MODE = True diff --git a/pilot/utils/command_utils.py b/pilot/utils/command_utils.py index f29aff3e1..815d13d5b 100644 --- a/pilot/utils/command_utils.py +++ b/pilot/utils/command_utils.py @@ -6,15 +6,30 @@ import psutil import platform +def _get_abspath_of_current_command(command_path: str): + if not command_path.endswith(".py"): + return command_path + # This implementation is very ugly + command_path = os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))), + "scripts", + "cli_scripts.py", + ) + return command_path + + def _run_current_with_daemon(name: str, log_file: str): # Get all arguments except for --daemon args = [arg for arg in sys.argv if arg != "--daemon" and arg != "-d"] + args[0] = _get_abspath_of_current_command(args[0]) + daemon_cmd = [sys.executable] + args daemon_cmd = " ".join(daemon_cmd) daemon_cmd += f" > {log_file} 2>&1" + print(f"daemon cmd: {daemon_cmd}") # Check the platform and set the appropriate flags or functions - if platform.system() == "Windows": + if "windows" in platform.system().lower(): process = subprocess.Popen( daemon_cmd, creationflags=subprocess.CREATE_NEW_PROCESS_GROUP, @@ -50,7 +65,7 @@ def _run_current_with_gunicorn(app: str, config_path: str, kwargs: Dict): app_env = EnvArgumentParser._kwargs_to_env_key_value(kwargs) env_to_app.update(app_env) cmd = f"uvicorn {app} --host 0.0.0.0 --port 5000" - if platform.system() == "Windows": + if "windows" in platform.system().lower(): raise Exception("Not support on windows") else: # macOS, Linux, and other Unix-like systems process = subprocess.Popen(cmd, shell=True, env=env_to_app) @@ -89,3 +104,54 @@ def _stop_service( if not_found: print(f"{fullname} process not found.") + + +def _get_ports_by_cmdline_part(service_keys: List[str]) -> List[int]: + """ + Return a list of ports that are associated with processes that have all the service_keys in their cmdline. + + Args: + service_keys (List[str]): List of strings that should all be present in the process's cmdline. + + Returns: + List[int]: List of ports sorted with preference for 8000 and 5000, and then in ascending order. + """ + ports = [] + + for process in psutil.process_iter(attrs=["pid", "name", "connections", "cmdline"]): + try: + # Convert the cmdline list to a single string for easier checking + cmdline = " ".join(process.info["cmdline"]) + + # Check if all the service keys are present in the cmdline + if all(fragment in cmdline for fragment in service_keys): + for connection in process.info["connections"]: + if connection.status == psutil.CONN_LISTEN: + ports.append(connection.laddr.port) + except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess): + pass + + # Sort ports with preference for 8000 and 5000 + ports.sort(key=lambda x: (x != 8000, x != 5000, x)) + + return ports + + +def _detect_controller_address() -> str: + controller_addr = os.getenv("CONTROLLER_ADDRESS") + if controller_addr: + return controller_addr + + cmdline_fragments = [ + ["python", "start", "controller"], + ["python", "controller"], + ["python", "start", "webserver"], + ["python", "dbgpt_server"], + ] + + for fragments in cmdline_fragments: + ports = _get_ports_by_cmdline_part(fragments) + if ports: + return f"http://127.0.0.1:{ports[0]}" + + return f"http://127.0.0.1:8000"