mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-05 02:20:08 +00:00
feat(model): Unified Deployment Mode with multi-model and add command line chat with model
This commit is contained in:
parent
1f1da2618c
commit
a8846c40aa
@ -4,7 +4,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
from typing import Callable, List, Type
|
from typing import Callable, List, Type
|
||||||
|
|
||||||
from pilot.model.controller.registry import ModelRegistryClient
|
from pilot.model.controller.controller import ModelRegistryClient
|
||||||
from pilot.configs.model_config import LOGDIR
|
from pilot.configs.model_config import LOGDIR
|
||||||
from pilot.model.base import WorkerApplyType
|
from pilot.model.base import WorkerApplyType
|
||||||
from pilot.model.parameter import (
|
from pilot.model.parameter import (
|
||||||
@ -26,17 +26,22 @@ logger = logging.getLogger("dbgpt_cli")
|
|||||||
@click.option(
|
@click.option(
|
||||||
"--address",
|
"--address",
|
||||||
type=str,
|
type=str,
|
||||||
default=MODEL_CONTROLLER_ADDRESS,
|
default=None,
|
||||||
required=False,
|
required=False,
|
||||||
show_default=True,
|
show_default=True,
|
||||||
help=(
|
help=(
|
||||||
"Address of the Model Controller to connect to. "
|
"Address of the Model Controller to connect to. "
|
||||||
"Just support light deploy model"
|
"Just support light deploy model, If the environment variable CONTROLLER_ADDRESS is configured, read from the environment variable"
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
def model_cli_group(address: str):
|
def model_cli_group(address: str):
|
||||||
"""Clients that manage model serving"""
|
"""Clients that manage model serving"""
|
||||||
global MODEL_CONTROLLER_ADDRESS
|
global MODEL_CONTROLLER_ADDRESS
|
||||||
|
if not address:
|
||||||
|
from pilot.utils.command_utils import _detect_controller_address
|
||||||
|
|
||||||
|
MODEL_CONTROLLER_ADDRESS = _detect_controller_address()
|
||||||
|
else:
|
||||||
MODEL_CONTROLLER_ADDRESS = address
|
MODEL_CONTROLLER_ADDRESS = address
|
||||||
|
|
||||||
|
|
||||||
@ -140,6 +145,26 @@ def restart(model_name: str, model_type: str):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@model_cli_group.command()
|
||||||
|
@click.option(
|
||||||
|
"--model_name",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
required=True,
|
||||||
|
help=("The name of model"),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--system",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
required=False,
|
||||||
|
help=("System prompt"),
|
||||||
|
)
|
||||||
|
def chat(model_name: str, system: str):
|
||||||
|
"""Interact with your bot from the command line"""
|
||||||
|
_cli_chat(MODEL_CONTROLLER_ADDRESS, model_name, system)
|
||||||
|
|
||||||
|
|
||||||
# @model_cli_group.command()
|
# @model_cli_group.command()
|
||||||
# @add_model_options
|
# @add_model_options
|
||||||
# def modify(address: str, model_name: str, model_type: str):
|
# def modify(address: str, model_name: str, model_type: str):
|
||||||
@ -147,14 +172,21 @@ def restart(model_name: str, model_type: str):
|
|||||||
# worker_apply(address, model_name, model_type, WorkerApplyType.UPDATE_PARAMS)
|
# worker_apply(address, model_name, model_type, WorkerApplyType.UPDATE_PARAMS)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_worker_manager(address: str):
|
||||||
|
from pilot.model.worker.manager import RemoteWorkerManager, WorkerApplyRequest
|
||||||
|
|
||||||
|
registry = ModelRegistryClient(address)
|
||||||
|
worker_manager = RemoteWorkerManager(registry)
|
||||||
|
return worker_manager
|
||||||
|
|
||||||
|
|
||||||
def worker_apply(
|
def worker_apply(
|
||||||
address: str, model_name: str, model_type: str, apply_type: WorkerApplyType
|
address: str, model_name: str, model_type: str, apply_type: WorkerApplyType
|
||||||
):
|
):
|
||||||
from pilot.model.worker.manager import RemoteWorkerManager, WorkerApplyRequest
|
from pilot.model.worker.manager import WorkerApplyRequest
|
||||||
|
|
||||||
loop = get_or_create_event_loop()
|
loop = get_or_create_event_loop()
|
||||||
registry = ModelRegistryClient(address)
|
worker_manager = _get_worker_manager(address)
|
||||||
worker_manager = RemoteWorkerManager(registry)
|
|
||||||
apply_req = WorkerApplyRequest(
|
apply_req = WorkerApplyRequest(
|
||||||
model=model_name, worker_type=model_type, apply_type=apply_type
|
model=model_name, worker_type=model_type, apply_type=apply_type
|
||||||
)
|
)
|
||||||
@ -162,6 +194,41 @@ def worker_apply(
|
|||||||
print(res)
|
print(res)
|
||||||
|
|
||||||
|
|
||||||
|
def _cli_chat(address: str, model_name: str, system_prompt: str = None):
|
||||||
|
loop = get_or_create_event_loop()
|
||||||
|
worker_manager = worker_manager = _get_worker_manager(address)
|
||||||
|
loop.run_until_complete(_chat_stream(worker_manager, model_name, system_prompt))
|
||||||
|
|
||||||
|
|
||||||
|
async def _chat_stream(worker_manager, model_name: str, system_prompt: str = None):
|
||||||
|
from pilot.model.worker.manager import PromptRequest
|
||||||
|
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
||||||
|
|
||||||
|
print(f"Chatbot started with model {model_name}. Type 'exit' to leave the chat.")
|
||||||
|
hist = []
|
||||||
|
previous_response = ""
|
||||||
|
if system_prompt:
|
||||||
|
hist.append(
|
||||||
|
ModelMessage(role=ModelMessageRoleType.SYSTEM, content=system_prompt)
|
||||||
|
)
|
||||||
|
while True:
|
||||||
|
previous_response = ""
|
||||||
|
user_input = input("\n\nYou: ")
|
||||||
|
if user_input.lower().strip() == "exit":
|
||||||
|
break
|
||||||
|
hist.append(ModelMessage(role=ModelMessageRoleType.HUMAN, content=user_input))
|
||||||
|
request = PromptRequest(messages=hist, model=model_name, prompt="", echo=False)
|
||||||
|
request = request.dict(exclude_none=True)
|
||||||
|
print("Bot: ", end="")
|
||||||
|
async for response in worker_manager.generate_stream(request):
|
||||||
|
incremental_output = response.text[len(previous_response) :]
|
||||||
|
print(incremental_output, end="", flush=True)
|
||||||
|
previous_response = response.text
|
||||||
|
hist.append(
|
||||||
|
ModelMessage(role=ModelMessageRoleType.AI, content=previous_response)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def add_stop_server_options(func):
|
def add_stop_server_options(func):
|
||||||
@click.option(
|
@click.option(
|
||||||
"--port",
|
"--port",
|
||||||
@ -249,3 +316,9 @@ def start_apiserver(**kwargs):
|
|||||||
def stop_apiserver(**kwargs):
|
def stop_apiserver(**kwargs):
|
||||||
"""Start apiserver(TODO)"""
|
"""Start apiserver(TODO)"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def _stop_all_model_server(**kwargs):
|
||||||
|
"""Stop all server"""
|
||||||
|
_stop_service("worker", "ModelWorker")
|
||||||
|
_stop_service("controller", "ModelController")
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
@ -6,9 +8,33 @@ from pilot.model.base import ModelInstance
|
|||||||
from pilot.model.parameter import ModelControllerParameters
|
from pilot.model.parameter import ModelControllerParameters
|
||||||
from pilot.model.controller.registry import EmbeddedModelRegistry, ModelRegistry
|
from pilot.model.controller.registry import EmbeddedModelRegistry, ModelRegistry
|
||||||
from pilot.utils.parameter_utils import EnvArgumentParser
|
from pilot.utils.parameter_utils import EnvArgumentParser
|
||||||
|
from pilot.utils.api_utils import _api_remote as api_remote
|
||||||
|
|
||||||
|
|
||||||
class ModelController:
|
class BaseModelController(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
async def register_instance(self, instance: ModelInstance) -> bool:
|
||||||
|
"""Register a given model instance"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def deregister_instance(self, instance: ModelInstance) -> bool:
|
||||||
|
"""Deregister a given model instance."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_all_instances(
|
||||||
|
self, model_name: str, healthy_only: bool = False
|
||||||
|
) -> List[ModelInstance]:
|
||||||
|
"""Fetch all instances of a given model. Optionally, fetch only the healthy instances."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def send_heartbeat(self, instance: ModelInstance) -> bool:
|
||||||
|
"""Send a heartbeat for a given model instance. This can be used to verify if the instance is still alive and functioning."""
|
||||||
|
|
||||||
|
async def model_apply(self) -> bool:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class LocalModelController(BaseModelController):
|
||||||
def __init__(self, registry: ModelRegistry = None) -> None:
|
def __init__(self, registry: ModelRegistry = None) -> None:
|
||||||
if not registry:
|
if not registry:
|
||||||
registry = EmbeddedModelRegistry()
|
registry = EmbeddedModelRegistry()
|
||||||
@ -27,22 +53,87 @@ class ModelController:
|
|||||||
logging.info(
|
logging.info(
|
||||||
f"Get all instances with {model_name}, healthy_only: {healthy_only}"
|
f"Get all instances with {model_name}, healthy_only: {healthy_only}"
|
||||||
)
|
)
|
||||||
return await self.registry.get_all_instances(model_name, healthy_only)
|
if not model_name:
|
||||||
|
|
||||||
async def get_all_model_instances(self) -> List[ModelInstance]:
|
|
||||||
return await self.registry.get_all_model_instances()
|
return await self.registry.get_all_model_instances()
|
||||||
|
else:
|
||||||
|
return await self.registry.get_all_instances(model_name, healthy_only)
|
||||||
|
|
||||||
async def send_heartbeat(self, instance: ModelInstance) -> bool:
|
async def send_heartbeat(self, instance: ModelInstance) -> bool:
|
||||||
return await self.registry.send_heartbeat(instance)
|
return await self.registry.send_heartbeat(instance)
|
||||||
|
|
||||||
|
|
||||||
|
class _RemoteModelController(BaseModelController):
|
||||||
|
def __init__(self, base_url: str) -> None:
|
||||||
|
self.base_url = base_url
|
||||||
|
|
||||||
|
@api_remote(path="/api/controller/models", method="POST")
|
||||||
|
async def register_instance(self, instance: ModelInstance) -> bool:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@api_remote(path="/api/controller/models", method="DELETE")
|
||||||
|
async def deregister_instance(self, instance: ModelInstance) -> bool:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@api_remote(path="/api/controller/models")
|
||||||
|
async def get_all_instances(
|
||||||
|
self, model_name: str = None, healthy_only: bool = False
|
||||||
|
) -> List[ModelInstance]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@api_remote(path="/api/controller/heartbeat", method="POST")
|
||||||
|
async def send_heartbeat(self, instance: ModelInstance) -> bool:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ModelRegistryClient(_RemoteModelController, ModelRegistry):
|
||||||
|
async def get_all_model_instances(self) -> List[ModelInstance]:
|
||||||
|
return await self.get_all_instances()
|
||||||
|
|
||||||
|
|
||||||
|
class ModelControllerAdapter(BaseModelController):
|
||||||
|
def __init__(self, backend: BaseModelController = None) -> None:
|
||||||
|
self.backend = backend
|
||||||
|
|
||||||
|
async def register_instance(self, instance: ModelInstance) -> bool:
|
||||||
|
return await self.backend.register_instance(instance)
|
||||||
|
|
||||||
|
async def deregister_instance(self, instance: ModelInstance) -> bool:
|
||||||
|
return await self.backend.deregister_instance(instance)
|
||||||
|
|
||||||
|
async def get_all_instances(
|
||||||
|
self, model_name: str, healthy_only: bool = False
|
||||||
|
) -> List[ModelInstance]:
|
||||||
|
return await self.backend.get_all_instances(model_name, healthy_only)
|
||||||
|
|
||||||
|
async def send_heartbeat(self, instance: ModelInstance) -> bool:
|
||||||
|
return await self.backend.send_heartbeat(instance)
|
||||||
|
|
||||||
async def model_apply(self) -> bool:
|
async def model_apply(self) -> bool:
|
||||||
# TODO
|
return await self.backend.model_apply()
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
controller = ModelController()
|
controller = ModelControllerAdapter()
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_controller(
|
||||||
|
app=None, remote_controller_addr: str = None, host: str = None, port: int = None
|
||||||
|
):
|
||||||
|
global controller
|
||||||
|
if remote_controller_addr:
|
||||||
|
controller.backend = _RemoteModelController(remote_controller_addr)
|
||||||
|
else:
|
||||||
|
controller.backend = LocalModelController()
|
||||||
|
|
||||||
|
if app:
|
||||||
|
app.include_router(router, prefix="/api")
|
||||||
|
else:
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
app.include_router(router, prefix="/api")
|
||||||
|
uvicorn.run(app, host=host, port=port, log_level="info")
|
||||||
|
|
||||||
|
|
||||||
@router.post("/controller/models")
|
@router.post("/controller/models")
|
||||||
@ -51,14 +142,13 @@ async def api_register_instance(request: ModelInstance):
|
|||||||
|
|
||||||
|
|
||||||
@router.delete("/controller/models")
|
@router.delete("/controller/models")
|
||||||
async def api_deregister_instance(request: ModelInstance):
|
async def api_deregister_instance(model_name: str, host: str, port: int):
|
||||||
return await controller.deregister_instance(request)
|
instance = ModelInstance(model_name=model_name, host=host, port=port)
|
||||||
|
return await controller.deregister_instance(instance)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/controller/models")
|
@router.get("/controller/models")
|
||||||
async def api_get_all_instances(model_name: str = None, healthy_only: bool = False):
|
async def api_get_all_instances(model_name: str = None, healthy_only: bool = False):
|
||||||
if not model_name:
|
|
||||||
return await controller.get_all_model_instances()
|
|
||||||
return await controller.get_all_instances(model_name, healthy_only=healthy_only)
|
return await controller.get_all_instances(model_name, healthy_only=healthy_only)
|
||||||
|
|
||||||
|
|
||||||
@ -68,18 +158,12 @@ async def api_model_heartbeat(request: ModelInstance):
|
|||||||
|
|
||||||
|
|
||||||
def run_model_controller():
|
def run_model_controller():
|
||||||
import uvicorn
|
|
||||||
|
|
||||||
parser = EnvArgumentParser()
|
parser = EnvArgumentParser()
|
||||||
env_prefix = "controller_"
|
env_prefix = "controller_"
|
||||||
controller_params: ModelControllerParameters = parser.parse_args_into_dataclass(
|
controller_params: ModelControllerParameters = parser.parse_args_into_dataclass(
|
||||||
ModelControllerParameters, env_prefix=env_prefix
|
ModelControllerParameters, env_prefix=env_prefix
|
||||||
)
|
)
|
||||||
app = FastAPI()
|
initialize_controller(host=controller_params.host, port=controller_params.port)
|
||||||
app.include_router(router, prefix="/api")
|
|
||||||
uvicorn.run(
|
|
||||||
app, host=controller_params.host, port=controller_params.port, log_level="info"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -178,47 +178,11 @@ class EmbeddedModelRegistry(ModelRegistry):
|
|||||||
instance.model_name, instance.host, instance.port, healthy_only=False
|
instance.model_name, instance.host, instance.port, healthy_only=False
|
||||||
)
|
)
|
||||||
if not exist_ins:
|
if not exist_ins:
|
||||||
return False
|
# register new install from heartbeat
|
||||||
|
self.register_instance(instance)
|
||||||
|
return True
|
||||||
|
|
||||||
ins = exist_ins[0]
|
ins = exist_ins[0]
|
||||||
ins.last_heartbeat = datetime.now()
|
ins.last_heartbeat = datetime.now()
|
||||||
ins.healthy = True
|
ins.healthy = True
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
from pilot.utils.api_utils import _api_remote as api_remote
|
|
||||||
|
|
||||||
|
|
||||||
class ModelRegistryClient(ModelRegistry):
|
|
||||||
def __init__(self, base_url: str) -> None:
|
|
||||||
self.base_url = base_url
|
|
||||||
|
|
||||||
@api_remote(path="/api/controller/models", method="POST")
|
|
||||||
async def register_instance(self, instance: ModelInstance) -> bool:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@api_remote(path="/api/controller/models", method="DELETE")
|
|
||||||
async def deregister_instance(self, instance: ModelInstance) -> bool:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@api_remote(path="/api/controller/models")
|
|
||||||
async def get_all_instances(
|
|
||||||
self, model_name: str, healthy_only: bool = False
|
|
||||||
) -> List[ModelInstance]:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@api_remote(path="/api/controller/models")
|
|
||||||
async def get_all_model_instances(self) -> List[ModelInstance]:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@api_remote(path="/api/controller/models")
|
|
||||||
async def select_one_health_instance(self, model_name: str) -> ModelInstance:
|
|
||||||
instances = await self.get_all_instances(model_name, healthy_only=True)
|
|
||||||
instances = [i for i in instances if i.enabled]
|
|
||||||
if not instances:
|
|
||||||
return None
|
|
||||||
return random.choice(instances)
|
|
||||||
|
|
||||||
@api_remote(path="/api/controller/heartbeat", method="POST")
|
|
||||||
async def send_heartbeat(self, instance: ModelInstance) -> bool:
|
|
||||||
pass
|
|
||||||
|
@ -51,7 +51,7 @@ class ModelWorkerParameters(BaseParameters):
|
|||||||
)
|
)
|
||||||
|
|
||||||
port: Optional[int] = field(
|
port: Optional[int] = field(
|
||||||
default=8000, metadata={"help": "Model worker deploy port"}
|
default=8001, metadata={"help": "Model worker deploy port"}
|
||||||
)
|
)
|
||||||
daemon: Optional[bool] = field(
|
daemon: Optional[bool] = field(
|
||||||
default=False, metadata={"help": "Run Model Worker in background"}
|
default=False, metadata={"help": "Run Model Worker in background"}
|
||||||
|
@ -82,6 +82,7 @@ class DefaultModelWorker(ModelWorker):
|
|||||||
|
|
||||||
def stop(self) -> None:
|
def stop(self) -> None:
|
||||||
if not self.model:
|
if not self.model:
|
||||||
|
logger.warn("Model has been stopped!!")
|
||||||
return
|
return
|
||||||
del self.model
|
del self.model
|
||||||
del self.tokenizer
|
del self.tokenizer
|
||||||
@ -98,23 +99,24 @@ class DefaultModelWorker(ModelWorker):
|
|||||||
params, self.ml.model_path, prompt_template=self.ml.prompt_template
|
params, self.ml.model_path, prompt_template=self.ml.prompt_template
|
||||||
)
|
)
|
||||||
|
|
||||||
|
previous_response = ""
|
||||||
|
print("stream output:\n")
|
||||||
for output in self.generate_stream_func(
|
for output in self.generate_stream_func(
|
||||||
self.model, self.tokenizer, params, get_device(), self.context_len
|
self.model, self.tokenizer, params, get_device(), self.context_len
|
||||||
):
|
):
|
||||||
# Please do not open the output in production!
|
# Please do not open the output in production!
|
||||||
# The gpt4all thread shares stdout with the parent process,
|
# The gpt4all thread shares stdout with the parent process,
|
||||||
# and opening it may affect the frontend output.
|
# and opening it may affect the frontend output.
|
||||||
if "windows" in platform.platform().lower():
|
incremental_output = output[len(previous_response) :]
|
||||||
# Do not print the model output, because it may contain Emoji, there is a problem with the GBK encoding
|
# print("output: ", output)
|
||||||
pass
|
print(incremental_output, end="", flush=True)
|
||||||
else:
|
previous_response = output
|
||||||
print("output: ", output)
|
|
||||||
# return some model context to dgt-server
|
# return some model context to dgt-server
|
||||||
model_output = ModelOutput(
|
model_output = ModelOutput(
|
||||||
text=output, error_code=0, model_context=model_context
|
text=output, error_code=0, model_context=model_context
|
||||||
)
|
)
|
||||||
yield model_output
|
yield model_output
|
||||||
|
print(f"\n\nfull stream output:\n{previous_response}")
|
||||||
except torch.cuda.CudaError:
|
except torch.cuda.CudaError:
|
||||||
model_output = ModelOutput(
|
model_output = ModelOutput(
|
||||||
text="**GPU OutOfMemory, Please Refresh.**", error_code=0
|
text="**GPU OutOfMemory, Please Refresh.**", error_code=0
|
||||||
|
@ -225,7 +225,14 @@ class LocalWorkerManager(WorkerManager):
|
|||||||
self, params: Dict, async_wrapper=None, **kwargs
|
self, params: Dict, async_wrapper=None, **kwargs
|
||||||
) -> Iterator[ModelOutput]:
|
) -> Iterator[ModelOutput]:
|
||||||
"""Generate stream result, chat scene"""
|
"""Generate stream result, chat scene"""
|
||||||
|
try:
|
||||||
worker_run_data = await self._get_model(params)
|
worker_run_data = await self._get_model(params)
|
||||||
|
except Exception as e:
|
||||||
|
yield ModelOutput(
|
||||||
|
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
|
||||||
|
error_code=0,
|
||||||
|
)
|
||||||
|
return
|
||||||
async with worker_run_data.semaphore:
|
async with worker_run_data.semaphore:
|
||||||
if worker_run_data.worker.support_async():
|
if worker_run_data.worker.support_async():
|
||||||
async for outout in worker_run_data.worker.async_generate_stream(
|
async for outout in worker_run_data.worker.async_generate_stream(
|
||||||
@ -244,7 +251,13 @@ class LocalWorkerManager(WorkerManager):
|
|||||||
|
|
||||||
async def generate(self, params: Dict) -> ModelOutput:
|
async def generate(self, params: Dict) -> ModelOutput:
|
||||||
"""Generate non stream result"""
|
"""Generate non stream result"""
|
||||||
|
try:
|
||||||
worker_run_data = await self._get_model(params)
|
worker_run_data = await self._get_model(params)
|
||||||
|
except Exception as e:
|
||||||
|
return ModelOutput(
|
||||||
|
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
|
||||||
|
error_code=0,
|
||||||
|
)
|
||||||
async with worker_run_data.semaphore:
|
async with worker_run_data.semaphore:
|
||||||
if worker_run_data.worker.support_async():
|
if worker_run_data.worker.support_async():
|
||||||
return await worker_run_data.worker.async_generate(params)
|
return await worker_run_data.worker.async_generate(params)
|
||||||
@ -256,7 +269,10 @@ class LocalWorkerManager(WorkerManager):
|
|||||||
|
|
||||||
async def embeddings(self, params: Dict) -> List[List[float]]:
|
async def embeddings(self, params: Dict) -> List[List[float]]:
|
||||||
"""Embed input"""
|
"""Embed input"""
|
||||||
|
try:
|
||||||
worker_run_data = await self._get_model(params, worker_type="text2vec")
|
worker_run_data = await self._get_model(params, worker_type="text2vec")
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
async with worker_run_data.semaphore:
|
async with worker_run_data.semaphore:
|
||||||
if worker_run_data.worker.support_async():
|
if worker_run_data.worker.support_async():
|
||||||
return await worker_run_data.worker.async_embeddings(params)
|
return await worker_run_data.worker.async_embeddings(params)
|
||||||
@ -272,6 +288,8 @@ class LocalWorkerManager(WorkerManager):
|
|||||||
apply_func = self._start_all_worker
|
apply_func = self._start_all_worker
|
||||||
elif apply_req.apply_type == WorkerApplyType.STOP:
|
elif apply_req.apply_type == WorkerApplyType.STOP:
|
||||||
apply_func = self._stop_all_worker
|
apply_func = self._stop_all_worker
|
||||||
|
elif apply_req.apply_type == WorkerApplyType.RESTART:
|
||||||
|
apply_func = self._restart_all_worker
|
||||||
elif apply_req.apply_type == WorkerApplyType.UPDATE_PARAMS:
|
elif apply_req.apply_type == WorkerApplyType.UPDATE_PARAMS:
|
||||||
apply_func = self._update_all_worker_params
|
apply_func = self._update_all_worker_params
|
||||||
else:
|
else:
|
||||||
@ -298,10 +316,13 @@ class LocalWorkerManager(WorkerManager):
|
|||||||
apply_req (WorkerApplyRequest): Worker apply request
|
apply_req (WorkerApplyRequest): Worker apply request
|
||||||
apply_func (ApplyFunction): Function to apply to worker instances, now function is async function
|
apply_func (ApplyFunction): Function to apply to worker instances, now function is async function
|
||||||
"""
|
"""
|
||||||
|
logger.info(f"Apply req: {apply_req}, apply_func: {apply_func}")
|
||||||
if apply_req:
|
if apply_req:
|
||||||
worker_type = apply_req.worker_type.value
|
worker_type = apply_req.worker_type.value
|
||||||
model_name = apply_req.model
|
model_name = apply_req.model
|
||||||
worker_instances = await self.get_model_instances(worker_type, model_name)
|
worker_instances = await self.get_model_instances(
|
||||||
|
worker_type, model_name, healthy_only=False
|
||||||
|
)
|
||||||
if not worker_instances:
|
if not worker_instances:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"No worker instance found for the model {model_name} worker type {worker_type}"
|
f"No worker instance found for the model {model_name} worker type {worker_type}"
|
||||||
@ -370,6 +391,12 @@ class LocalWorkerManager(WorkerManager):
|
|||||||
message=f"Worker stopped successfully", timecost=timecost
|
message=f"Worker stopped successfully", timecost=timecost
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _restart_all_worker(
|
||||||
|
self, apply_req: WorkerApplyRequest
|
||||||
|
) -> WorkerApplyOutput:
|
||||||
|
await self._stop_all_worker(apply_req)
|
||||||
|
return await self._start_all_worker(apply_req)
|
||||||
|
|
||||||
async def _update_all_worker_params(
|
async def _update_all_worker_params(
|
||||||
self, apply_req: WorkerApplyRequest
|
self, apply_req: WorkerApplyRequest
|
||||||
) -> WorkerApplyOutput:
|
) -> WorkerApplyOutput:
|
||||||
@ -389,8 +416,7 @@ class LocalWorkerManager(WorkerManager):
|
|||||||
timecost = time.time() - start_time
|
timecost = time.time() - start_time
|
||||||
if need_restart:
|
if need_restart:
|
||||||
logger.info("Model params update successfully, begin restart worker")
|
logger.info("Model params update successfully, begin restart worker")
|
||||||
await self._stop_all_worker(apply_req)
|
await self._restart_all_worker(apply_req)
|
||||||
await self._start_all_worker(apply_req)
|
|
||||||
timecost = time.time() - start_time
|
timecost = time.time() - start_time
|
||||||
message = f"Update worker params and restart successfully"
|
message = f"Update worker params and restart successfully"
|
||||||
return WorkerApplyOutput(message=message, timecost=timecost)
|
return WorkerApplyOutput(message=message, timecost=timecost)
|
||||||
@ -500,8 +526,8 @@ async def generate_json_stream(params):
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/worker/generate_stream")
|
@router.post("/worker/generate_stream")
|
||||||
async def api_generate_stream(request: Request):
|
async def api_generate_stream(request: PromptRequest):
|
||||||
params = await request.json()
|
params = request.dict(exclude_none=True)
|
||||||
generator = generate_json_stream(params)
|
generator = generate_json_stream(params)
|
||||||
return StreamingResponse(generator)
|
return StreamingResponse(generator)
|
||||||
|
|
||||||
@ -534,13 +560,19 @@ async def api_worker_parameter_descs(
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def _setup_fastapi(worker_params: ModelWorkerParameters):
|
def _setup_fastapi(worker_params: ModelWorkerParameters, app=None):
|
||||||
|
if not app:
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
if worker_params.standalone:
|
if worker_params.standalone:
|
||||||
from pilot.model.controller.controller import router as controller_router
|
from pilot.model.controller.controller import router as controller_router
|
||||||
|
from pilot.model.controller.controller import initialize_controller
|
||||||
|
|
||||||
if not worker_params.controller_addr:
|
if not worker_params.controller_addr:
|
||||||
worker_params.controller_addr = f"http://127.0.0.1:{worker_params.port}"
|
worker_params.controller_addr = f"http://127.0.0.1:{worker_params.port}"
|
||||||
|
logger.info(
|
||||||
|
f"Run WorkerManager with standalone mode, controller_addr: {worker_params.controller_addr}"
|
||||||
|
)
|
||||||
|
initialize_controller(app=app)
|
||||||
app.include_router(controller_router, prefix="/api")
|
app.include_router(controller_router, prefix="/api")
|
||||||
|
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
@ -570,7 +602,7 @@ def _parse_worker_params(
|
|||||||
)
|
)
|
||||||
worker_params.update_from(new_worker_params)
|
worker_params.update_from(new_worker_params)
|
||||||
|
|
||||||
logger.info(f"Worker params: {worker_params}")
|
# logger.info(f"Worker params: {worker_params}")
|
||||||
return worker_params
|
return worker_params
|
||||||
|
|
||||||
|
|
||||||
@ -583,7 +615,7 @@ def _create_local_model_manager(
|
|||||||
)
|
)
|
||||||
return LocalWorkerManager()
|
return LocalWorkerManager()
|
||||||
else:
|
else:
|
||||||
from pilot.model.controller.registry import ModelRegistryClient
|
from pilot.model.controller.controller import ModelRegistryClient
|
||||||
from pilot.utils.net_utils import _get_ip_address
|
from pilot.utils.net_utils import _get_ip_address
|
||||||
|
|
||||||
client = ModelRegistryClient(worker_params.controller_addr)
|
client = ModelRegistryClient(worker_params.controller_addr)
|
||||||
@ -600,6 +632,12 @@ def _create_local_model_manager(
|
|||||||
)
|
)
|
||||||
return await client.register_instance(instance)
|
return await client.register_instance(instance)
|
||||||
|
|
||||||
|
async def deregister_func(worker_run_data: WorkerRunData):
|
||||||
|
instance = ModelInstance(
|
||||||
|
model_name=worker_run_data.worker_key, host=host, port=port
|
||||||
|
)
|
||||||
|
return await client.deregister_instance(instance)
|
||||||
|
|
||||||
async def send_heartbeat_func(worker_run_data: WorkerRunData):
|
async def send_heartbeat_func(worker_run_data: WorkerRunData):
|
||||||
instance = ModelInstance(
|
instance = ModelInstance(
|
||||||
model_name=worker_run_data.worker_key, host=host, port=port
|
model_name=worker_run_data.worker_key, host=host, port=port
|
||||||
@ -607,7 +645,9 @@ def _create_local_model_manager(
|
|||||||
return await client.send_heartbeat(instance)
|
return await client.send_heartbeat(instance)
|
||||||
|
|
||||||
return LocalWorkerManager(
|
return LocalWorkerManager(
|
||||||
register_func=register_func, send_heartbeat_func=send_heartbeat_func
|
register_func=register_func,
|
||||||
|
deregister_func=deregister_func,
|
||||||
|
send_heartbeat_func=send_heartbeat_func,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -642,30 +682,60 @@ def initialize_worker_manager_in_client(
|
|||||||
model_path: str = None,
|
model_path: str = None,
|
||||||
run_locally: bool = True,
|
run_locally: bool = True,
|
||||||
controller_addr: str = None,
|
controller_addr: str = None,
|
||||||
|
local_port: int = 5000,
|
||||||
):
|
):
|
||||||
|
"""Initialize WorkerManager in client.
|
||||||
|
If run_locally is True:
|
||||||
|
1. Start ModelController
|
||||||
|
2. Start LocalWorkerManager
|
||||||
|
3. Start worker in LocalWorkerManager
|
||||||
|
4. Register worker to ModelController
|
||||||
|
|
||||||
|
otherwise:
|
||||||
|
1. Build ModelRegistryClient with controller address
|
||||||
|
2. Start RemoteWorkerManager
|
||||||
|
|
||||||
|
"""
|
||||||
global worker_manager
|
global worker_manager
|
||||||
|
|
||||||
|
if not app:
|
||||||
|
raise Exception("app can't be None")
|
||||||
|
|
||||||
worker_params: ModelWorkerParameters = _parse_worker_params(
|
worker_params: ModelWorkerParameters = _parse_worker_params(
|
||||||
model_name=model_name, model_path=model_path, controller_addr=controller_addr
|
model_name=model_name, model_path=model_path, controller_addr=controller_addr
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Worker params: {worker_params}")
|
controller_addr = None
|
||||||
if run_locally:
|
if run_locally:
|
||||||
worker_params.register = False
|
# TODO start ModelController
|
||||||
|
worker_params.standalone = True
|
||||||
|
worker_params.register = True
|
||||||
|
worker_params.port = local_port
|
||||||
|
logger.info(f"Worker params: {worker_params}")
|
||||||
|
_setup_fastapi(worker_params, app)
|
||||||
_start_local_worker(worker_manager, worker_params, True)
|
_start_local_worker(worker_manager, worker_params, True)
|
||||||
loop = asyncio.get_event_loop()
|
# loop = asyncio.get_event_loop()
|
||||||
loop.run_until_complete(
|
# loop.run_until_complete(
|
||||||
worker_manager.worker_manager._start_all_worker(apply_req=None)
|
# worker_manager.worker_manager._start_all_worker(apply_req=None)
|
||||||
)
|
# )
|
||||||
else:
|
else:
|
||||||
from pilot.model.controller.registry import ModelRegistryClient
|
from pilot.model.controller.controller import (
|
||||||
|
initialize_controller,
|
||||||
|
ModelRegistryClient,
|
||||||
|
)
|
||||||
|
|
||||||
if not worker_params.controller_addr:
|
if not worker_params.controller_addr:
|
||||||
raise ValueError("Controller can`t be None")
|
raise ValueError("Controller can`t be None")
|
||||||
|
controller_addr = worker_params.controller_addr
|
||||||
|
logger.info(f"Worker params: {worker_params}")
|
||||||
client = ModelRegistryClient(worker_params.controller_addr)
|
client = ModelRegistryClient(worker_params.controller_addr)
|
||||||
worker_manager.worker_manager = RemoteWorkerManager(client)
|
worker_manager.worker_manager = RemoteWorkerManager(client)
|
||||||
|
initialize_controller(
|
||||||
|
app=app, remote_controller_addr=worker_params.controller_addr
|
||||||
|
)
|
||||||
|
|
||||||
if include_router and app:
|
if include_router and app:
|
||||||
|
# mount WorkerManager router
|
||||||
app.include_router(router, prefix="/api")
|
app.include_router(router, prefix="/api")
|
||||||
|
|
||||||
|
|
||||||
|
@ -51,8 +51,26 @@ def stop():
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@click.group()
|
||||||
|
def install():
|
||||||
|
"""Install dependencies, plugins, etc."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
stop_all_func_list = []
|
||||||
|
|
||||||
|
|
||||||
|
@click.command(name="all")
|
||||||
|
def stop_all():
|
||||||
|
"""Stop all servers"""
|
||||||
|
for stop_func in stop_all_func_list:
|
||||||
|
stop_func()
|
||||||
|
|
||||||
|
|
||||||
cli.add_command(start)
|
cli.add_command(start)
|
||||||
cli.add_command(stop)
|
cli.add_command(stop)
|
||||||
|
cli.add_command(install)
|
||||||
|
add_command_alias(stop_all, name="all", parent_group=stop)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from pilot.model.cli import (
|
from pilot.model.cli import (
|
||||||
@ -63,6 +81,7 @@ try:
|
|||||||
stop_model_worker,
|
stop_model_worker,
|
||||||
start_apiserver,
|
start_apiserver,
|
||||||
stop_apiserver,
|
stop_apiserver,
|
||||||
|
_stop_all_model_server,
|
||||||
)
|
)
|
||||||
|
|
||||||
add_command_alias(model_cli_group, name="model", parent_group=cli)
|
add_command_alias(model_cli_group, name="model", parent_group=cli)
|
||||||
@ -73,15 +92,21 @@ try:
|
|||||||
add_command_alias(stop_model_controller, name="controller", parent_group=stop)
|
add_command_alias(stop_model_controller, name="controller", parent_group=stop)
|
||||||
add_command_alias(stop_model_worker, name="worker", parent_group=stop)
|
add_command_alias(stop_model_worker, name="worker", parent_group=stop)
|
||||||
add_command_alias(stop_apiserver, name="apiserver", parent_group=stop)
|
add_command_alias(stop_apiserver, name="apiserver", parent_group=stop)
|
||||||
|
stop_all_func_list.append(_stop_all_model_server)
|
||||||
|
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logging.warning(f"Integrating dbgpt model command line tool failed: {e}")
|
logging.warning(f"Integrating dbgpt model command line tool failed: {e}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from pilot.server._cli import start_webserver, stop_webserver
|
from pilot.server._cli import (
|
||||||
|
start_webserver,
|
||||||
|
stop_webserver,
|
||||||
|
_stop_all_dbgpt_server,
|
||||||
|
)
|
||||||
|
|
||||||
add_command_alias(start_webserver, name="webserver", parent_group=start)
|
add_command_alias(start_webserver, name="webserver", parent_group=start)
|
||||||
add_command_alias(stop_webserver, name="webserver", parent_group=stop)
|
add_command_alias(stop_webserver, name="webserver", parent_group=stop)
|
||||||
|
stop_all_func_list.append(_stop_all_dbgpt_server)
|
||||||
|
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logging.warning(f"Integrating dbgpt webserver command line tool failed: {e}")
|
logging.warning(f"Integrating dbgpt webserver command line tool failed: {e}")
|
||||||
|
@ -30,3 +30,7 @@ def start_webserver(**kwargs):
|
|||||||
def stop_webserver(port: int):
|
def stop_webserver(port: int):
|
||||||
"""Stop webserver(dbgpt_server.py)"""
|
"""Stop webserver(dbgpt_server.py)"""
|
||||||
_stop_service("webserver", "WebServer", port=port)
|
_stop_service("webserver", "WebServer", port=port)
|
||||||
|
|
||||||
|
|
||||||
|
def _stop_all_dbgpt_server():
|
||||||
|
_stop_service("webserver", "WebServer")
|
||||||
|
@ -104,7 +104,10 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
|
|||||||
if not param.light:
|
if not param.light:
|
||||||
print("Model Unified Deployment Mode!")
|
print("Model Unified Deployment Mode!")
|
||||||
initialize_worker_manager_in_client(
|
initialize_worker_manager_in_client(
|
||||||
app=app, model_name=CFG.LLM_MODEL, model_path=model_path
|
app=app,
|
||||||
|
model_name=CFG.LLM_MODEL,
|
||||||
|
model_path=model_path,
|
||||||
|
local_port=param.port,
|
||||||
)
|
)
|
||||||
|
|
||||||
CFG.NEW_SERVER_MODE = True
|
CFG.NEW_SERVER_MODE = True
|
||||||
@ -116,6 +119,7 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
|
|||||||
model_path=model_path,
|
model_path=model_path,
|
||||||
run_locally=False,
|
run_locally=False,
|
||||||
controller_addr=CFG.MODEL_SERVER,
|
controller_addr=CFG.MODEL_SERVER,
|
||||||
|
local_port=param.port,
|
||||||
)
|
)
|
||||||
CFG.SERVER_LIGHT_MODE = True
|
CFG.SERVER_LIGHT_MODE = True
|
||||||
|
|
||||||
|
@ -6,15 +6,30 @@ import psutil
|
|||||||
import platform
|
import platform
|
||||||
|
|
||||||
|
|
||||||
|
def _get_abspath_of_current_command(command_path: str):
|
||||||
|
if not command_path.endswith(".py"):
|
||||||
|
return command_path
|
||||||
|
# This implementation is very ugly
|
||||||
|
command_path = os.path.join(
|
||||||
|
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
|
||||||
|
"scripts",
|
||||||
|
"cli_scripts.py",
|
||||||
|
)
|
||||||
|
return command_path
|
||||||
|
|
||||||
|
|
||||||
def _run_current_with_daemon(name: str, log_file: str):
|
def _run_current_with_daemon(name: str, log_file: str):
|
||||||
# Get all arguments except for --daemon
|
# Get all arguments except for --daemon
|
||||||
args = [arg for arg in sys.argv if arg != "--daemon" and arg != "-d"]
|
args = [arg for arg in sys.argv if arg != "--daemon" and arg != "-d"]
|
||||||
|
args[0] = _get_abspath_of_current_command(args[0])
|
||||||
|
|
||||||
daemon_cmd = [sys.executable] + args
|
daemon_cmd = [sys.executable] + args
|
||||||
daemon_cmd = " ".join(daemon_cmd)
|
daemon_cmd = " ".join(daemon_cmd)
|
||||||
daemon_cmd += f" > {log_file} 2>&1"
|
daemon_cmd += f" > {log_file} 2>&1"
|
||||||
|
|
||||||
|
print(f"daemon cmd: {daemon_cmd}")
|
||||||
# Check the platform and set the appropriate flags or functions
|
# Check the platform and set the appropriate flags or functions
|
||||||
if platform.system() == "Windows":
|
if "windows" in platform.system().lower():
|
||||||
process = subprocess.Popen(
|
process = subprocess.Popen(
|
||||||
daemon_cmd,
|
daemon_cmd,
|
||||||
creationflags=subprocess.CREATE_NEW_PROCESS_GROUP,
|
creationflags=subprocess.CREATE_NEW_PROCESS_GROUP,
|
||||||
@ -50,7 +65,7 @@ def _run_current_with_gunicorn(app: str, config_path: str, kwargs: Dict):
|
|||||||
app_env = EnvArgumentParser._kwargs_to_env_key_value(kwargs)
|
app_env = EnvArgumentParser._kwargs_to_env_key_value(kwargs)
|
||||||
env_to_app.update(app_env)
|
env_to_app.update(app_env)
|
||||||
cmd = f"uvicorn {app} --host 0.0.0.0 --port 5000"
|
cmd = f"uvicorn {app} --host 0.0.0.0 --port 5000"
|
||||||
if platform.system() == "Windows":
|
if "windows" in platform.system().lower():
|
||||||
raise Exception("Not support on windows")
|
raise Exception("Not support on windows")
|
||||||
else: # macOS, Linux, and other Unix-like systems
|
else: # macOS, Linux, and other Unix-like systems
|
||||||
process = subprocess.Popen(cmd, shell=True, env=env_to_app)
|
process = subprocess.Popen(cmd, shell=True, env=env_to_app)
|
||||||
@ -89,3 +104,54 @@ def _stop_service(
|
|||||||
|
|
||||||
if not_found:
|
if not_found:
|
||||||
print(f"{fullname} process not found.")
|
print(f"{fullname} process not found.")
|
||||||
|
|
||||||
|
|
||||||
|
def _get_ports_by_cmdline_part(service_keys: List[str]) -> List[int]:
|
||||||
|
"""
|
||||||
|
Return a list of ports that are associated with processes that have all the service_keys in their cmdline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
service_keys (List[str]): List of strings that should all be present in the process's cmdline.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[int]: List of ports sorted with preference for 8000 and 5000, and then in ascending order.
|
||||||
|
"""
|
||||||
|
ports = []
|
||||||
|
|
||||||
|
for process in psutil.process_iter(attrs=["pid", "name", "connections", "cmdline"]):
|
||||||
|
try:
|
||||||
|
# Convert the cmdline list to a single string for easier checking
|
||||||
|
cmdline = " ".join(process.info["cmdline"])
|
||||||
|
|
||||||
|
# Check if all the service keys are present in the cmdline
|
||||||
|
if all(fragment in cmdline for fragment in service_keys):
|
||||||
|
for connection in process.info["connections"]:
|
||||||
|
if connection.status == psutil.CONN_LISTEN:
|
||||||
|
ports.append(connection.laddr.port)
|
||||||
|
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Sort ports with preference for 8000 and 5000
|
||||||
|
ports.sort(key=lambda x: (x != 8000, x != 5000, x))
|
||||||
|
|
||||||
|
return ports
|
||||||
|
|
||||||
|
|
||||||
|
def _detect_controller_address() -> str:
|
||||||
|
controller_addr = os.getenv("CONTROLLER_ADDRESS")
|
||||||
|
if controller_addr:
|
||||||
|
return controller_addr
|
||||||
|
|
||||||
|
cmdline_fragments = [
|
||||||
|
["python", "start", "controller"],
|
||||||
|
["python", "controller"],
|
||||||
|
["python", "start", "webserver"],
|
||||||
|
["python", "dbgpt_server"],
|
||||||
|
]
|
||||||
|
|
||||||
|
for fragments in cmdline_fragments:
|
||||||
|
ports = _get_ports_by_cmdline_part(fragments)
|
||||||
|
if ports:
|
||||||
|
return f"http://127.0.0.1:{ports[0]}"
|
||||||
|
|
||||||
|
return f"http://127.0.0.1:8000"
|
||||||
|
Loading…
Reference in New Issue
Block a user