feat(model): Unified Deployment Mode with multi-model and add command line chat with model

This commit is contained in:
FangYin Cheng 2023-09-08 17:52:28 +08:00
parent 1f1da2618c
commit a8846c40aa
10 changed files with 388 additions and 96 deletions

View File

@ -4,7 +4,7 @@ import logging
import os
from typing import Callable, List, Type
from pilot.model.controller.registry import ModelRegistryClient
from pilot.model.controller.controller import ModelRegistryClient
from pilot.configs.model_config import LOGDIR
from pilot.model.base import WorkerApplyType
from pilot.model.parameter import (
@ -26,18 +26,23 @@ logger = logging.getLogger("dbgpt_cli")
@click.option(
"--address",
type=str,
default=MODEL_CONTROLLER_ADDRESS,
default=None,
required=False,
show_default=True,
help=(
"Address of the Model Controller to connect to. "
"Just support light deploy model"
"Just support light deploy model, If the environment variable CONTROLLER_ADDRESS is configured, read from the environment variable"
),
)
def model_cli_group(address: str):
"""Clients that manage model serving"""
global MODEL_CONTROLLER_ADDRESS
MODEL_CONTROLLER_ADDRESS = address
if not address:
from pilot.utils.command_utils import _detect_controller_address
MODEL_CONTROLLER_ADDRESS = _detect_controller_address()
else:
MODEL_CONTROLLER_ADDRESS = address
@model_cli_group.command()
@ -140,6 +145,26 @@ def restart(model_name: str, model_type: str):
)
@model_cli_group.command()
@click.option(
"--model_name",
type=str,
default=None,
required=True,
help=("The name of model"),
)
@click.option(
"--system",
type=str,
default=None,
required=False,
help=("System prompt"),
)
def chat(model_name: str, system: str):
"""Interact with your bot from the command line"""
_cli_chat(MODEL_CONTROLLER_ADDRESS, model_name, system)
# @model_cli_group.command()
# @add_model_options
# def modify(address: str, model_name: str, model_type: str):
@ -147,14 +172,21 @@ def restart(model_name: str, model_type: str):
# worker_apply(address, model_name, model_type, WorkerApplyType.UPDATE_PARAMS)
def _get_worker_manager(address: str):
from pilot.model.worker.manager import RemoteWorkerManager, WorkerApplyRequest
registry = ModelRegistryClient(address)
worker_manager = RemoteWorkerManager(registry)
return worker_manager
def worker_apply(
address: str, model_name: str, model_type: str, apply_type: WorkerApplyType
):
from pilot.model.worker.manager import RemoteWorkerManager, WorkerApplyRequest
from pilot.model.worker.manager import WorkerApplyRequest
loop = get_or_create_event_loop()
registry = ModelRegistryClient(address)
worker_manager = RemoteWorkerManager(registry)
worker_manager = _get_worker_manager(address)
apply_req = WorkerApplyRequest(
model=model_name, worker_type=model_type, apply_type=apply_type
)
@ -162,6 +194,41 @@ def worker_apply(
print(res)
def _cli_chat(address: str, model_name: str, system_prompt: str = None):
loop = get_or_create_event_loop()
worker_manager = worker_manager = _get_worker_manager(address)
loop.run_until_complete(_chat_stream(worker_manager, model_name, system_prompt))
async def _chat_stream(worker_manager, model_name: str, system_prompt: str = None):
from pilot.model.worker.manager import PromptRequest
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
print(f"Chatbot started with model {model_name}. Type 'exit' to leave the chat.")
hist = []
previous_response = ""
if system_prompt:
hist.append(
ModelMessage(role=ModelMessageRoleType.SYSTEM, content=system_prompt)
)
while True:
previous_response = ""
user_input = input("\n\nYou: ")
if user_input.lower().strip() == "exit":
break
hist.append(ModelMessage(role=ModelMessageRoleType.HUMAN, content=user_input))
request = PromptRequest(messages=hist, model=model_name, prompt="", echo=False)
request = request.dict(exclude_none=True)
print("Bot: ", end="")
async for response in worker_manager.generate_stream(request):
incremental_output = response.text[len(previous_response) :]
print(incremental_output, end="", flush=True)
previous_response = response.text
hist.append(
ModelMessage(role=ModelMessageRoleType.AI, content=previous_response)
)
def add_stop_server_options(func):
@click.option(
"--port",
@ -249,3 +316,9 @@ def start_apiserver(**kwargs):
def stop_apiserver(**kwargs):
"""Start apiserver(TODO)"""
raise NotImplementedError
def _stop_all_model_server(**kwargs):
"""Stop all server"""
_stop_service("worker", "ModelWorker")
_stop_service("controller", "ModelController")

View File

@ -1,3 +1,5 @@
from abc import ABC, abstractmethod
import logging
from typing import List
@ -6,9 +8,33 @@ from pilot.model.base import ModelInstance
from pilot.model.parameter import ModelControllerParameters
from pilot.model.controller.registry import EmbeddedModelRegistry, ModelRegistry
from pilot.utils.parameter_utils import EnvArgumentParser
from pilot.utils.api_utils import _api_remote as api_remote
class ModelController:
class BaseModelController(ABC):
@abstractmethod
async def register_instance(self, instance: ModelInstance) -> bool:
"""Register a given model instance"""
@abstractmethod
async def deregister_instance(self, instance: ModelInstance) -> bool:
"""Deregister a given model instance."""
@abstractmethod
async def get_all_instances(
self, model_name: str, healthy_only: bool = False
) -> List[ModelInstance]:
"""Fetch all instances of a given model. Optionally, fetch only the healthy instances."""
@abstractmethod
async def send_heartbeat(self, instance: ModelInstance) -> bool:
"""Send a heartbeat for a given model instance. This can be used to verify if the instance is still alive and functioning."""
async def model_apply(self) -> bool:
raise NotImplementedError
class LocalModelController(BaseModelController):
def __init__(self, registry: ModelRegistry = None) -> None:
if not registry:
registry = EmbeddedModelRegistry()
@ -27,22 +53,87 @@ class ModelController:
logging.info(
f"Get all instances with {model_name}, healthy_only: {healthy_only}"
)
return await self.registry.get_all_instances(model_name, healthy_only)
async def get_all_model_instances(self) -> List[ModelInstance]:
return await self.registry.get_all_model_instances()
if not model_name:
return await self.registry.get_all_model_instances()
else:
return await self.registry.get_all_instances(model_name, healthy_only)
async def send_heartbeat(self, instance: ModelInstance) -> bool:
return await self.registry.send_heartbeat(instance)
class _RemoteModelController(BaseModelController):
def __init__(self, base_url: str) -> None:
self.base_url = base_url
@api_remote(path="/api/controller/models", method="POST")
async def register_instance(self, instance: ModelInstance) -> bool:
pass
@api_remote(path="/api/controller/models", method="DELETE")
async def deregister_instance(self, instance: ModelInstance) -> bool:
pass
@api_remote(path="/api/controller/models")
async def get_all_instances(
self, model_name: str = None, healthy_only: bool = False
) -> List[ModelInstance]:
pass
@api_remote(path="/api/controller/heartbeat", method="POST")
async def send_heartbeat(self, instance: ModelInstance) -> bool:
pass
class ModelRegistryClient(_RemoteModelController, ModelRegistry):
async def get_all_model_instances(self) -> List[ModelInstance]:
return await self.get_all_instances()
class ModelControllerAdapter(BaseModelController):
def __init__(self, backend: BaseModelController = None) -> None:
self.backend = backend
async def register_instance(self, instance: ModelInstance) -> bool:
return await self.backend.register_instance(instance)
async def deregister_instance(self, instance: ModelInstance) -> bool:
return await self.backend.deregister_instance(instance)
async def get_all_instances(
self, model_name: str, healthy_only: bool = False
) -> List[ModelInstance]:
return await self.backend.get_all_instances(model_name, healthy_only)
async def send_heartbeat(self, instance: ModelInstance) -> bool:
return await self.backend.send_heartbeat(instance)
async def model_apply(self) -> bool:
# TODO
raise NotImplementedError
return await self.backend.model_apply()
router = APIRouter()
controller = ModelController()
controller = ModelControllerAdapter()
def initialize_controller(
app=None, remote_controller_addr: str = None, host: str = None, port: int = None
):
global controller
if remote_controller_addr:
controller.backend = _RemoteModelController(remote_controller_addr)
else:
controller.backend = LocalModelController()
if app:
app.include_router(router, prefix="/api")
else:
import uvicorn
app = FastAPI()
app.include_router(router, prefix="/api")
uvicorn.run(app, host=host, port=port, log_level="info")
@router.post("/controller/models")
@ -51,14 +142,13 @@ async def api_register_instance(request: ModelInstance):
@router.delete("/controller/models")
async def api_deregister_instance(request: ModelInstance):
return await controller.deregister_instance(request)
async def api_deregister_instance(model_name: str, host: str, port: int):
instance = ModelInstance(model_name=model_name, host=host, port=port)
return await controller.deregister_instance(instance)
@router.get("/controller/models")
async def api_get_all_instances(model_name: str = None, healthy_only: bool = False):
if not model_name:
return await controller.get_all_model_instances()
return await controller.get_all_instances(model_name, healthy_only=healthy_only)
@ -68,18 +158,12 @@ async def api_model_heartbeat(request: ModelInstance):
def run_model_controller():
import uvicorn
parser = EnvArgumentParser()
env_prefix = "controller_"
controller_params: ModelControllerParameters = parser.parse_args_into_dataclass(
ModelControllerParameters, env_prefix=env_prefix
)
app = FastAPI()
app.include_router(router, prefix="/api")
uvicorn.run(
app, host=controller_params.host, port=controller_params.port, log_level="info"
)
initialize_controller(host=controller_params.host, port=controller_params.port)
if __name__ == "__main__":

View File

@ -178,47 +178,11 @@ class EmbeddedModelRegistry(ModelRegistry):
instance.model_name, instance.host, instance.port, healthy_only=False
)
if not exist_ins:
return False
# register new install from heartbeat
self.register_instance(instance)
return True
ins = exist_ins[0]
ins.last_heartbeat = datetime.now()
ins.healthy = True
return True
from pilot.utils.api_utils import _api_remote as api_remote
class ModelRegistryClient(ModelRegistry):
def __init__(self, base_url: str) -> None:
self.base_url = base_url
@api_remote(path="/api/controller/models", method="POST")
async def register_instance(self, instance: ModelInstance) -> bool:
pass
@api_remote(path="/api/controller/models", method="DELETE")
async def deregister_instance(self, instance: ModelInstance) -> bool:
pass
@api_remote(path="/api/controller/models")
async def get_all_instances(
self, model_name: str, healthy_only: bool = False
) -> List[ModelInstance]:
pass
@api_remote(path="/api/controller/models")
async def get_all_model_instances(self) -> List[ModelInstance]:
pass
@api_remote(path="/api/controller/models")
async def select_one_health_instance(self, model_name: str) -> ModelInstance:
instances = await self.get_all_instances(model_name, healthy_only=True)
instances = [i for i in instances if i.enabled]
if not instances:
return None
return random.choice(instances)
@api_remote(path="/api/controller/heartbeat", method="POST")
async def send_heartbeat(self, instance: ModelInstance) -> bool:
pass

View File

@ -51,7 +51,7 @@ class ModelWorkerParameters(BaseParameters):
)
port: Optional[int] = field(
default=8000, metadata={"help": "Model worker deploy port"}
default=8001, metadata={"help": "Model worker deploy port"}
)
daemon: Optional[bool] = field(
default=False, metadata={"help": "Run Model Worker in background"}

View File

@ -82,6 +82,7 @@ class DefaultModelWorker(ModelWorker):
def stop(self) -> None:
if not self.model:
logger.warn("Model has been stopped!!")
return
del self.model
del self.tokenizer
@ -98,23 +99,24 @@ class DefaultModelWorker(ModelWorker):
params, self.ml.model_path, prompt_template=self.ml.prompt_template
)
previous_response = ""
print("stream output:\n")
for output in self.generate_stream_func(
self.model, self.tokenizer, params, get_device(), self.context_len
):
# Please do not open the output in production!
# The gpt4all thread shares stdout with the parent process,
# and opening it may affect the frontend output.
if "windows" in platform.platform().lower():
# Do not print the model output, because it may contain Emoji, there is a problem with the GBK encoding
pass
else:
print("output: ", output)
incremental_output = output[len(previous_response) :]
# print("output: ", output)
print(incremental_output, end="", flush=True)
previous_response = output
# return some model context to dgt-server
model_output = ModelOutput(
text=output, error_code=0, model_context=model_context
)
yield model_output
print(f"\n\nfull stream output:\n{previous_response}")
except torch.cuda.CudaError:
model_output = ModelOutput(
text="**GPU OutOfMemory, Please Refresh.**", error_code=0

View File

@ -225,7 +225,14 @@ class LocalWorkerManager(WorkerManager):
self, params: Dict, async_wrapper=None, **kwargs
) -> Iterator[ModelOutput]:
"""Generate stream result, chat scene"""
worker_run_data = await self._get_model(params)
try:
worker_run_data = await self._get_model(params)
except Exception as e:
yield ModelOutput(
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
error_code=0,
)
return
async with worker_run_data.semaphore:
if worker_run_data.worker.support_async():
async for outout in worker_run_data.worker.async_generate_stream(
@ -244,7 +251,13 @@ class LocalWorkerManager(WorkerManager):
async def generate(self, params: Dict) -> ModelOutput:
"""Generate non stream result"""
worker_run_data = await self._get_model(params)
try:
worker_run_data = await self._get_model(params)
except Exception as e:
return ModelOutput(
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
error_code=0,
)
async with worker_run_data.semaphore:
if worker_run_data.worker.support_async():
return await worker_run_data.worker.async_generate(params)
@ -256,7 +269,10 @@ class LocalWorkerManager(WorkerManager):
async def embeddings(self, params: Dict) -> List[List[float]]:
"""Embed input"""
worker_run_data = await self._get_model(params, worker_type="text2vec")
try:
worker_run_data = await self._get_model(params, worker_type="text2vec")
except Exception as e:
raise e
async with worker_run_data.semaphore:
if worker_run_data.worker.support_async():
return await worker_run_data.worker.async_embeddings(params)
@ -272,6 +288,8 @@ class LocalWorkerManager(WorkerManager):
apply_func = self._start_all_worker
elif apply_req.apply_type == WorkerApplyType.STOP:
apply_func = self._stop_all_worker
elif apply_req.apply_type == WorkerApplyType.RESTART:
apply_func = self._restart_all_worker
elif apply_req.apply_type == WorkerApplyType.UPDATE_PARAMS:
apply_func = self._update_all_worker_params
else:
@ -298,10 +316,13 @@ class LocalWorkerManager(WorkerManager):
apply_req (WorkerApplyRequest): Worker apply request
apply_func (ApplyFunction): Function to apply to worker instances, now function is async function
"""
logger.info(f"Apply req: {apply_req}, apply_func: {apply_func}")
if apply_req:
worker_type = apply_req.worker_type.value
model_name = apply_req.model
worker_instances = await self.get_model_instances(worker_type, model_name)
worker_instances = await self.get_model_instances(
worker_type, model_name, healthy_only=False
)
if not worker_instances:
raise Exception(
f"No worker instance found for the model {model_name} worker type {worker_type}"
@ -370,6 +391,12 @@ class LocalWorkerManager(WorkerManager):
message=f"Worker stopped successfully", timecost=timecost
)
async def _restart_all_worker(
self, apply_req: WorkerApplyRequest
) -> WorkerApplyOutput:
await self._stop_all_worker(apply_req)
return await self._start_all_worker(apply_req)
async def _update_all_worker_params(
self, apply_req: WorkerApplyRequest
) -> WorkerApplyOutput:
@ -389,8 +416,7 @@ class LocalWorkerManager(WorkerManager):
timecost = time.time() - start_time
if need_restart:
logger.info("Model params update successfully, begin restart worker")
await self._stop_all_worker(apply_req)
await self._start_all_worker(apply_req)
await self._restart_all_worker(apply_req)
timecost = time.time() - start_time
message = f"Update worker params and restart successfully"
return WorkerApplyOutput(message=message, timecost=timecost)
@ -500,8 +526,8 @@ async def generate_json_stream(params):
@router.post("/worker/generate_stream")
async def api_generate_stream(request: Request):
params = await request.json()
async def api_generate_stream(request: PromptRequest):
params = request.dict(exclude_none=True)
generator = generate_json_stream(params)
return StreamingResponse(generator)
@ -534,13 +560,19 @@ async def api_worker_parameter_descs(
return output
def _setup_fastapi(worker_params: ModelWorkerParameters):
app = FastAPI()
def _setup_fastapi(worker_params: ModelWorkerParameters, app=None):
if not app:
app = FastAPI()
if worker_params.standalone:
from pilot.model.controller.controller import router as controller_router
from pilot.model.controller.controller import initialize_controller
if not worker_params.controller_addr:
worker_params.controller_addr = f"http://127.0.0.1:{worker_params.port}"
logger.info(
f"Run WorkerManager with standalone mode, controller_addr: {worker_params.controller_addr}"
)
initialize_controller(app=app)
app.include_router(controller_router, prefix="/api")
@app.on_event("startup")
@ -570,7 +602,7 @@ def _parse_worker_params(
)
worker_params.update_from(new_worker_params)
logger.info(f"Worker params: {worker_params}")
# logger.info(f"Worker params: {worker_params}")
return worker_params
@ -583,7 +615,7 @@ def _create_local_model_manager(
)
return LocalWorkerManager()
else:
from pilot.model.controller.registry import ModelRegistryClient
from pilot.model.controller.controller import ModelRegistryClient
from pilot.utils.net_utils import _get_ip_address
client = ModelRegistryClient(worker_params.controller_addr)
@ -600,6 +632,12 @@ def _create_local_model_manager(
)
return await client.register_instance(instance)
async def deregister_func(worker_run_data: WorkerRunData):
instance = ModelInstance(
model_name=worker_run_data.worker_key, host=host, port=port
)
return await client.deregister_instance(instance)
async def send_heartbeat_func(worker_run_data: WorkerRunData):
instance = ModelInstance(
model_name=worker_run_data.worker_key, host=host, port=port
@ -607,7 +645,9 @@ def _create_local_model_manager(
return await client.send_heartbeat(instance)
return LocalWorkerManager(
register_func=register_func, send_heartbeat_func=send_heartbeat_func
register_func=register_func,
deregister_func=deregister_func,
send_heartbeat_func=send_heartbeat_func,
)
@ -642,30 +682,60 @@ def initialize_worker_manager_in_client(
model_path: str = None,
run_locally: bool = True,
controller_addr: str = None,
local_port: int = 5000,
):
"""Initialize WorkerManager in client.
If run_locally is True:
1. Start ModelController
2. Start LocalWorkerManager
3. Start worker in LocalWorkerManager
4. Register worker to ModelController
otherwise:
1. Build ModelRegistryClient with controller address
2. Start RemoteWorkerManager
"""
global worker_manager
if not app:
raise Exception("app can't be None")
worker_params: ModelWorkerParameters = _parse_worker_params(
model_name=model_name, model_path=model_path, controller_addr=controller_addr
)
logger.info(f"Worker params: {worker_params}")
controller_addr = None
if run_locally:
worker_params.register = False
# TODO start ModelController
worker_params.standalone = True
worker_params.register = True
worker_params.port = local_port
logger.info(f"Worker params: {worker_params}")
_setup_fastapi(worker_params, app)
_start_local_worker(worker_manager, worker_params, True)
loop = asyncio.get_event_loop()
loop.run_until_complete(
worker_manager.worker_manager._start_all_worker(apply_req=None)
)
# loop = asyncio.get_event_loop()
# loop.run_until_complete(
# worker_manager.worker_manager._start_all_worker(apply_req=None)
# )
else:
from pilot.model.controller.registry import ModelRegistryClient
from pilot.model.controller.controller import (
initialize_controller,
ModelRegistryClient,
)
if not worker_params.controller_addr:
raise ValueError("Controller can`t be None")
controller_addr = worker_params.controller_addr
logger.info(f"Worker params: {worker_params}")
client = ModelRegistryClient(worker_params.controller_addr)
worker_manager.worker_manager = RemoteWorkerManager(client)
initialize_controller(
app=app, remote_controller_addr=worker_params.controller_addr
)
if include_router and app:
# mount WorkerManager router
app.include_router(router, prefix="/api")

View File

@ -51,8 +51,26 @@ def stop():
pass
@click.group()
def install():
"""Install dependencies, plugins, etc."""
pass
stop_all_func_list = []
@click.command(name="all")
def stop_all():
"""Stop all servers"""
for stop_func in stop_all_func_list:
stop_func()
cli.add_command(start)
cli.add_command(stop)
cli.add_command(install)
add_command_alias(stop_all, name="all", parent_group=stop)
try:
from pilot.model.cli import (
@ -63,6 +81,7 @@ try:
stop_model_worker,
start_apiserver,
stop_apiserver,
_stop_all_model_server,
)
add_command_alias(model_cli_group, name="model", parent_group=cli)
@ -73,15 +92,21 @@ try:
add_command_alias(stop_model_controller, name="controller", parent_group=stop)
add_command_alias(stop_model_worker, name="worker", parent_group=stop)
add_command_alias(stop_apiserver, name="apiserver", parent_group=stop)
stop_all_func_list.append(_stop_all_model_server)
except ImportError as e:
logging.warning(f"Integrating dbgpt model command line tool failed: {e}")
try:
from pilot.server._cli import start_webserver, stop_webserver
from pilot.server._cli import (
start_webserver,
stop_webserver,
_stop_all_dbgpt_server,
)
add_command_alias(start_webserver, name="webserver", parent_group=start)
add_command_alias(stop_webserver, name="webserver", parent_group=stop)
stop_all_func_list.append(_stop_all_dbgpt_server)
except ImportError as e:
logging.warning(f"Integrating dbgpt webserver command line tool failed: {e}")

View File

@ -30,3 +30,7 @@ def start_webserver(**kwargs):
def stop_webserver(port: int):
"""Stop webserver(dbgpt_server.py)"""
_stop_service("webserver", "WebServer", port=port)
def _stop_all_dbgpt_server():
_stop_service("webserver", "WebServer")

View File

@ -104,7 +104,10 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
if not param.light:
print("Model Unified Deployment Mode!")
initialize_worker_manager_in_client(
app=app, model_name=CFG.LLM_MODEL, model_path=model_path
app=app,
model_name=CFG.LLM_MODEL,
model_path=model_path,
local_port=param.port,
)
CFG.NEW_SERVER_MODE = True
@ -116,6 +119,7 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
model_path=model_path,
run_locally=False,
controller_addr=CFG.MODEL_SERVER,
local_port=param.port,
)
CFG.SERVER_LIGHT_MODE = True

View File

@ -6,15 +6,30 @@ import psutil
import platform
def _get_abspath_of_current_command(command_path: str):
if not command_path.endswith(".py"):
return command_path
# This implementation is very ugly
command_path = os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
"scripts",
"cli_scripts.py",
)
return command_path
def _run_current_with_daemon(name: str, log_file: str):
# Get all arguments except for --daemon
args = [arg for arg in sys.argv if arg != "--daemon" and arg != "-d"]
args[0] = _get_abspath_of_current_command(args[0])
daemon_cmd = [sys.executable] + args
daemon_cmd = " ".join(daemon_cmd)
daemon_cmd += f" > {log_file} 2>&1"
print(f"daemon cmd: {daemon_cmd}")
# Check the platform and set the appropriate flags or functions
if platform.system() == "Windows":
if "windows" in platform.system().lower():
process = subprocess.Popen(
daemon_cmd,
creationflags=subprocess.CREATE_NEW_PROCESS_GROUP,
@ -50,7 +65,7 @@ def _run_current_with_gunicorn(app: str, config_path: str, kwargs: Dict):
app_env = EnvArgumentParser._kwargs_to_env_key_value(kwargs)
env_to_app.update(app_env)
cmd = f"uvicorn {app} --host 0.0.0.0 --port 5000"
if platform.system() == "Windows":
if "windows" in platform.system().lower():
raise Exception("Not support on windows")
else: # macOS, Linux, and other Unix-like systems
process = subprocess.Popen(cmd, shell=True, env=env_to_app)
@ -89,3 +104,54 @@ def _stop_service(
if not_found:
print(f"{fullname} process not found.")
def _get_ports_by_cmdline_part(service_keys: List[str]) -> List[int]:
"""
Return a list of ports that are associated with processes that have all the service_keys in their cmdline.
Args:
service_keys (List[str]): List of strings that should all be present in the process's cmdline.
Returns:
List[int]: List of ports sorted with preference for 8000 and 5000, and then in ascending order.
"""
ports = []
for process in psutil.process_iter(attrs=["pid", "name", "connections", "cmdline"]):
try:
# Convert the cmdline list to a single string for easier checking
cmdline = " ".join(process.info["cmdline"])
# Check if all the service keys are present in the cmdline
if all(fragment in cmdline for fragment in service_keys):
for connection in process.info["connections"]:
if connection.status == psutil.CONN_LISTEN:
ports.append(connection.laddr.port)
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
pass
# Sort ports with preference for 8000 and 5000
ports.sort(key=lambda x: (x != 8000, x != 5000, x))
return ports
def _detect_controller_address() -> str:
controller_addr = os.getenv("CONTROLLER_ADDRESS")
if controller_addr:
return controller_addr
cmdline_fragments = [
["python", "start", "controller"],
["python", "controller"],
["python", "start", "webserver"],
["python", "dbgpt_server"],
]
for fragments in cmdline_fragments:
ports = _get_ports_by_cmdline_part(fragments)
if ports:
return f"http://127.0.0.1:{ports[0]}"
return f"http://127.0.0.1:8000"