diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py index 05f9ffdcb..8ec0075ea 100644 --- a/pilot/model/adapter.py +++ b/pilot/model/adapter.py @@ -12,19 +12,13 @@ from transformers import ( AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, - BitsAndBytesConfig, ) from pilot.model.parameter import ModelParameters, LlamaCppModelParameters from pilot.configs.model_config import DEVICE from pilot.configs.config import Config from pilot.logs import logger -bnb_config = BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_quant_type="nf4", - bnb_4bit_compute_dtype="bfloat16", - bnb_4bit_use_double_quant=False, -) + CFG = Config() @@ -203,6 +197,14 @@ class FalconAdapater(BaseLLMAdaper): tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) if CFG.QLoRA: + from transformers import BitsAndBytesConfig + + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype="bfloat16", + bnb_4bit_use_double_quant=False, + ) model = AutoModelForCausalLM.from_pretrained( model_path, load_in_4bit=True, # quantize diff --git a/pilot/model/base.py b/pilot/model/base.py index ba8190ea3..c81279886 100644 --- a/pilot/model/base.py +++ b/pilot/model/base.py @@ -1,7 +1,10 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -from typing import TypedDict +from enum import Enum +from typing import TypedDict, Optional, Dict +from dataclasses import dataclass +from datetime import datetime class Message(TypedDict): @@ -9,3 +12,39 @@ class Message(TypedDict): role: str content: str + + +@dataclass +class ModelInstance: + """Model instance info""" + + model_name: str + host: str + port: int + weight: Optional[float] = 1.0 + check_healthy: Optional[bool] = True + healthy: Optional[bool] = False + enabled: Optional[bool] = True + prompt_template: Optional[str] = None + last_heartbeat: Optional[datetime] = None + + +class WorkerApplyType(str, Enum): + START = "start" + STOP = "stop" + RESTART = "restart" + UPDATE_PARAMS = "update_params" + + +@dataclass +class ModelOutput: + text: str + error_code: int + model_context: Dict = None + + +@dataclass +class WorkerApplyOutput: + message: str + # The seconds cost to apply some action to worker instances + timecost: Optional[int] = -1 diff --git a/pilot/model/cli.py b/pilot/model/cli.py new file mode 100644 index 000000000..6109d11ab --- /dev/null +++ b/pilot/model/cli.py @@ -0,0 +1,151 @@ +import click +import functools + +from pilot.model.controller.registry import ModelRegistryClient +from pilot.model.worker.manager import ( + RemoteWorkerManager, + WorkerApplyRequest, + WorkerApplyType, +) +from pilot.utils import get_or_create_event_loop + + +@click.group("model") +def model_cli_group(): + pass + + +@model_cli_group.command() +@click.option( + "--address", + type=str, + default="http://127.0.0.1:8000", + required=False, + help=( + "Address of the Model Controller to connect to." + "Just support light deploy model" + ), +) +@click.option( + "--model-name", type=str, default=None, required=False, help=("The name of model") +) +@click.option( + "--model-type", type=str, default="llm", required=False, help=("The type of model") +) +def list(address: str, model_name: str, model_type: str): + """List model instances""" + from prettytable import PrettyTable + + loop = get_or_create_event_loop() + registry = ModelRegistryClient(address) + + if not model_name: + instances = loop.run_until_complete(registry.get_all_model_instances()) + else: + if not model_type: + model_type = "llm" + register_model_name = f"{model_name}@{model_type}" + instances = loop.run_until_complete( + registry.get_all_instances(register_model_name) + ) + table = PrettyTable() + + table.field_names = [ + "Model Name", + "Model Type", + "Host", + "Port", + "Healthy", + "Enabled", + "Prompt Template", + "Last Heartbeat", + ] + for instance in instances: + model_name, model_type = instance.model_name.split("@") + table.add_row( + [ + model_name, + model_type, + instance.host, + instance.port, + instance.healthy, + instance.enabled, + instance.prompt_template, + instance.last_heartbeat, + ] + ) + + print(table) + + +def add_model_options(func): + @click.option( + "--address", + type=str, + default="http://127.0.0.1:8000", + required=False, + help=( + "Address of the Model Controller to connect to." + "Just support light deploy model" + ), + ) + @click.option( + "--model-name", + type=str, + default=None, + required=True, + help=("The name of model"), + ) + @click.option( + "--model-type", + type=str, + default="llm", + required=False, + help=("The type of model"), + ) + @functools.wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + return wrapper + + +@model_cli_group.command() +@add_model_options +def stop(address: str, model_name: str, model_type: str): + """Stop model instances""" + worker_apply(address, model_name, model_type, WorkerApplyType.STOP) + + +@model_cli_group.command() +@add_model_options +def start(address: str, model_name: str, model_type: str): + """Start model instances""" + worker_apply(address, model_name, model_type, WorkerApplyType.START) + + +@model_cli_group.command() +@add_model_options +def restart(address: str, model_name: str, model_type: str): + """Restart model instances""" + worker_apply(address, model_name, model_type, WorkerApplyType.RESTART) + + +# @model_cli_group.command() +# @add_model_options +# def modify(address: str, model_name: str, model_type: str): +# """Restart model instances""" +# worker_apply(address, model_name, model_type, WorkerApplyType.UPDATE_PARAMS) + + +def worker_apply( + address: str, model_name: str, model_type: str, apply_type: WorkerApplyType +): + loop = get_or_create_event_loop() + registry = ModelRegistryClient(address) + worker_manager = RemoteWorkerManager(registry) + apply_req = WorkerApplyRequest( + model=model_name, worker_type=model_type, apply_type=apply_type + ) + res = loop.run_until_complete(worker_manager.worker_apply(apply_req)) + print(res) diff --git a/pilot/model/controller/__init__.py b/pilot/model/controller/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/model/controller/controller.py b/pilot/model/controller/controller.py new file mode 100644 index 000000000..84d4dfb29 --- /dev/null +++ b/pilot/model/controller/controller.py @@ -0,0 +1,65 @@ +import logging +from typing import List + +from fastapi import APIRouter +from pilot.model.base import ModelInstance +from pilot.model.controller.registry import EmbeddedModelRegistry, ModelRegistry + + +class ModelController: + def __init__(self, registry: ModelRegistry = None) -> None: + if not registry: + registry = EmbeddedModelRegistry() + self.registry = registry + self.deployment = None + + async def register_instance(self, instance: ModelInstance) -> bool: + return await self.registry.register_instance(instance) + + async def deregister_instance(self, instance: ModelInstance) -> bool: + return await self.registry.deregister_instance(instance) + + async def get_all_instances( + self, model_name: str, healthy_only: bool = False + ) -> List[ModelInstance]: + logging.info( + f"Get all instances with {model_name}, healthy_only: {healthy_only}" + ) + return await self.registry.get_all_instances(model_name, healthy_only) + + async def get_all_model_instances(self) -> List[ModelInstance]: + return await self.registry.get_all_model_instances() + + async def send_heartbeat(self, instance: ModelInstance) -> bool: + return await self.registry.send_heartbeat(instance) + + async def model_apply(self) -> bool: + # TODO + raise NotImplementedError + + +router = APIRouter() + +controller = ModelController() + + +@router.post("/controller/models") +async def api_register_instance(request: ModelInstance): + return await controller.register_instance(request) + + +@router.delete("/controller/models") +async def api_deregister_instance(request: ModelInstance): + return await controller.deregister_instance(request) + + +@router.get("/controller/models") +async def api_get_all_instances(model_name: str = None, healthy_only: bool = False): + if not model_name: + return await controller.get_all_model_instances() + return await controller.get_all_instances(model_name, healthy_only=healthy_only) + + +@router.post("/controller/heartbeat") +async def api_model_heartbeat(request: ModelInstance): + return await controller.send_heartbeat(request) diff --git a/pilot/model/controller/ray_controller.py b/pilot/model/controller/ray_controller.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/model/controller/registry.py b/pilot/model/controller/registry.py new file mode 100644 index 000000000..445e68c26 --- /dev/null +++ b/pilot/model/controller/registry.py @@ -0,0 +1,224 @@ +import random +import threading +import time +from abc import ABC, abstractmethod +from collections import defaultdict +from datetime import datetime, timedelta +from typing import Dict, List, Tuple +import itertools + +from pilot.model.base import ModelInstance + + +class ModelRegistry(ABC): + """ + Abstract base class for a model registry. It provides an interface + for registering, deregistering, fetching instances, and sending heartbeats + for instances. + """ + + @abstractmethod + async def register_instance(self, instance: ModelInstance) -> bool: + """ + Register a given model instance. + + Args: + - instance (ModelInstance): The instance of the model to register. + + Returns: + - bool: True if registration is successful, False otherwise. + """ + pass + + @abstractmethod + async def deregister_instance(self, instance: ModelInstance) -> bool: + """ + Deregister a given model instance. + + Args: + - instance (ModelInstance): The instance of the model to deregister. + + Returns: + - bool: True if deregistration is successful, False otherwise. + """ + + @abstractmethod + async def get_all_instances( + self, model_name: str, healthy_only: bool = False + ) -> List[ModelInstance]: + """ + Fetch all instances of a given model. Optionally, fetch only the healthy instances. + + Args: + - model_name (str): Name of the model to fetch instances for. + - healthy_only (bool, optional): If set to True, fetches only the healthy instances. + Defaults to False. + + Returns: + - List[ModelInstance]: A list of instances for the given model. + """ + + @abstractmethod + async def get_all_model_instances(self) -> List[ModelInstance]: + """ + Fetch all instances of all models + + Returns: + - List[ModelInstance]: A list of instances for the all models. + """ + + async def select_one_health_instance(self, model_name: str) -> ModelInstance: + """ + Selects one healthy and enabled instance for a given model. + + Args: + - model_name (str): Name of the model. + + Returns: + - ModelInstance: One randomly selected healthy and enabled instance, or None if no such instance exists. + """ + instances = await self.get_all_instances(model_name, healthy_only=True) + instances = [i for i in instances if i.enabled] + if not instances: + return None + return random.choice(instances) + + @abstractmethod + async def send_heartbeat(self, instance: ModelInstance) -> bool: + """ + Send a heartbeat for a given model instance. This can be used to + verify if the instance is still alive and functioning. + + Args: + - instance (ModelInstance): The instance of the model to send a heartbeat for. + + Returns: + - bool: True if heartbeat is successful, False otherwise. + """ + + +class EmbeddedModelRegistry(ModelRegistry): + def __init__( + self, heartbeat_interval_secs: int = 60, heartbeat_timeout_secs: int = 120 + ): + self.registry: Dict[str, List[ModelInstance]] = defaultdict(list) + self.heartbeat_interval_secs = heartbeat_interval_secs + self.heartbeat_timeout_secs = heartbeat_timeout_secs + self.heartbeat_thread = threading.Thread(target=self._heartbeat_checker) + self.heartbeat_thread.daemon = True + self.heartbeat_thread.start() + + def _get_instances( + self, model_name: str, host: str, port: int, healthy_only: bool = False + ) -> Tuple[List[ModelInstance], List[ModelInstance]]: + instances = self.registry[model_name] + if healthy_only: + instances = [ins for ins in instances if ins.healthy == True] + exist_ins = [ins for ins in instances if ins.host == host and ins.port == port] + return instances, exist_ins + + def _heartbeat_checker(self): + while True: + for instances in self.registry.values(): + for instance in instances: + if ( + instance.check_healthy + and datetime.now() - instance.last_heartbeat + > timedelta(seconds=self.heartbeat_timeout_secs) + ): + instance.healthy = False + time.sleep(self.heartbeat_interval_secs) + + async def register_instance(self, instance: ModelInstance) -> bool: + model_name = instance.model_name.strip() + host = instance.host.strip() + port = instance.port + + instances, exist_ins = self._get_instances( + model_name, host, port, healthy_only=False + ) + if exist_ins: + # One exist instance at most + ins = exist_ins[0] + # Update instance + ins.weight = instance.weight + ins.healthy = True + ins.prompt_template = instance.prompt_template + ins.last_heartbeat = datetime.now() + else: + instance.healthy = True + instance.last_heartbeat = datetime.now() + instances.append(instance) + return True + + async def deregister_instance(self, instance: ModelInstance) -> bool: + model_name = instance.model_name.strip() + host = instance.host.strip() + port = instance.port + _, exist_ins = self._get_instances(model_name, host, port, healthy_only=False) + if exist_ins: + ins = exist_ins[0] + ins.healthy = False + return True + + async def get_all_instances( + self, model_name: str, healthy_only: bool = False + ) -> List[ModelInstance]: + instances = self.registry[model_name] + if healthy_only: + instances = [ins for ins in instances if ins.healthy == True] + return instances + + async def get_all_model_instances(self) -> List[ModelInstance]: + print(self.registry) + return list(itertools.chain(*self.registry.values())) + + async def send_heartbeat(self, instance: ModelInstance) -> bool: + _, exist_ins = self._get_instances( + instance.model_name, instance.host, instance.port, healthy_only=False + ) + if not exist_ins: + return False + + ins = exist_ins[0] + ins.last_heartbeat = datetime.now() + ins.healthy = True + return True + + +from pilot.utils.api_utils import _api_remote as api_remote + + +class ModelRegistryClient(ModelRegistry): + def __init__(self, base_url: str) -> None: + self.base_url = base_url + + @api_remote(path="/api/controller/models", method="POST") + async def register_instance(self, instance: ModelInstance) -> bool: + pass + + @api_remote(path="/api/controller/models", method="DELETE") + async def deregister_instance(self, instance: ModelInstance) -> bool: + pass + + @api_remote(path="/api/controller/models") + async def get_all_instances( + self, model_name: str, healthy_only: bool = False + ) -> List[ModelInstance]: + pass + + @api_remote(path="/api/controller/models") + async def get_all_model_instances(self) -> List[ModelInstance]: + pass + + @api_remote(path="/api/controller/models") + async def select_one_health_instance(self, model_name: str) -> ModelInstance: + instances = await self.get_all_instances(model_name, healthy_only=True) + instances = [i for i in instances if i.enabled] + if not instances: + return None + return random.choice(instances) + + @api_remote(path="/api/controller/heartbeat", method="POST") + async def send_heartbeat(self, instance: ModelInstance) -> bool: + pass diff --git a/pilot/model/controller/tests/__init__.py b/pilot/model/controller/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/model/controller/tests/test_registry.py b/pilot/model/controller/tests/test_registry.py new file mode 100644 index 000000000..b69b08508 --- /dev/null +++ b/pilot/model/controller/tests/test_registry.py @@ -0,0 +1,148 @@ +import pytest +from datetime import datetime, timedelta + +import asyncio +from unittest.mock import patch +from pilot.model.base import ModelInstance +from pilot.model.controller.registry import ModelRegistry, EmbeddedModelRegistry + + +@pytest.fixture +def model_registry(): + return EmbeddedModelRegistry() + + +@pytest.fixture +def model_instance(): + return ModelInstance( + model_name="test_model", + ip="192.168.1.1", + port=5000, + ) + + +# Async function to test the registry +@pytest.mark.asyncio +async def test_register_instance(model_registry, model_instance): + """ + Test if an instance can be registered correctly + """ + assert await model_registry.register_instance(model_instance) == True + assert len(model_registry.registry[model_instance.model_name]) == 1 + + +@pytest.mark.asyncio +async def test_deregister_instance(model_registry, model_instance): + """ + Test if an instance can be deregistered correctly + """ + await model_registry.register_instance(model_instance) + assert await model_registry.deregister_instance(model_instance) == True + assert not model_registry.registry[model_instance.model_name][0].healthy + + +@pytest.mark.asyncio +async def test_get_all_instances(model_registry, model_instance): + """ + Test if all instances can be retrieved, with and without the healthy_only filter + """ + await model_registry.register_instance(model_instance) + assert len(await model_registry.get_all_instances(model_instance.model_name)) == 1 + assert ( + len( + await model_registry.get_all_instances( + model_instance.model_name, healthy_only=True + ) + ) + == 1 + ) + model_instance.healthy = False + assert ( + len( + await model_registry.get_all_instances( + model_instance.model_name, healthy_only=True + ) + ) + == 0 + ) + + +@pytest.mark.asyncio +async def test_select_one_health_instance(model_registry, model_instance): + """ + Test if a single healthy instance can be selected + """ + await model_registry.register_instance(model_instance) + selected_instance = await model_registry.select_one_health_instance( + model_instance.model_name + ) + assert selected_instance is not None + assert selected_instance.healthy + assert selected_instance.enabled + + +@pytest.mark.asyncio +async def test_send_heartbeat(model_registry, model_instance): + """ + Test if a heartbeat can be sent and that it correctly updates the last_heartbeat timestamp + """ + await model_registry.register_instance(model_instance) + last_heartbeat = datetime.now() - timedelta(seconds=10) + model_instance.last_heartbeat = last_heartbeat + assert ( + await model_registry.send_heartbeat( + model_instance.model_name, model_instance.ip, model_instance.port + ) + == True + ) + assert ( + model_registry.registry[model_instance.model_name][0].last_heartbeat + > last_heartbeat + ) + assert model_registry.registry[model_instance.model_name][0].healthy == True + + +@pytest.mark.asyncio +async def test_heartbeat_timeout(model_registry, model_instance): + """ + Test if an instance is marked as unhealthy when the heartbeat is not sent within the timeout + """ + model_registry = EmbeddedModelRegistry(1, 1) + await model_registry.register_instance(model_instance) + model_registry.registry[model_instance.model_name][ + 0 + ].last_heartbeat = datetime.now() - timedelta( + seconds=model_registry.heartbeat_timeout_secs + 1 + ) + await asyncio.sleep(model_registry.heartbeat_interval_secs + 1) + assert not model_registry.registry[model_instance.model_name][0].healthy + + +@pytest.mark.asyncio +async def test_multiple_instances(model_registry, model_instance): + """ + Test if multiple instances of the same model are handled correctly + """ + model_instance2 = ModelInstance( + model_name="test_model", + ip="192.168.1.2", + port=5000, + ) + await model_registry.register_instance(model_instance) + await model_registry.register_instance(model_instance2) + assert len(await model_registry.get_all_instances(model_instance.model_name)) == 2 + + +@pytest.mark.asyncio +async def test_same_model_name_different_ip_port(model_registry): + """ + Test if instances with the same model name but different IP and port are handled correctly + """ + instance1 = ModelInstance(model_name="test_model", ip="192.168.1.1", port=5000) + instance2 = ModelInstance(model_name="test_model", ip="192.168.1.2", port=6000) + await model_registry.register_instance(instance1) + await model_registry.register_instance(instance2) + instances = await model_registry.get_all_instances("test_model") + assert len(instances) == 2 + assert instances[0].ip != instances[1].ip + assert instances[0].port != instances[1].port diff --git a/pilot/model/loader.py b/pilot/model/loader.py index 170f7c460..e4478450a 100644 --- a/pilot/model/loader.py +++ b/pilot/model/loader.py @@ -60,7 +60,7 @@ def _get_model_real_path(model_name, default_model_path) -> str: return _genenv_ignoring_key_case("model_path", default_value=default_model_path) -class ModelLoader(metaclass=Singleton): +class ModelLoader: """Model loader is a class for model load Args: model_path @@ -68,17 +68,11 @@ class ModelLoader(metaclass=Singleton): TODO: multi model support. """ - kwargs = {} - def __init__(self, model_path: str, model_name: str = None) -> None: self.device = DEVICE self.model_path = model_path self.model_name = model_name self.prompt_template: str = None - self.kwargs = { - "torch_dtype": torch.float16, - "device_map": "auto", - } # TODO multi gpu support def loader( @@ -121,6 +115,18 @@ class ModelLoader(metaclass=Singleton): else: raise Exception(f"Unkown model type {model_type}") + def loader_with_params(self, model_params: ModelParameters): + llm_adapter = get_llm_model_adapter(self.model_name, self.model_path) + model_type = llm_adapter.model_type() + self.prompt_template = model_params.prompt_template + logger.info(f"model_params:\n{model_params}") + if model_type == ModelType.HF: + return huggingface_loader(llm_adapter, model_params) + elif model_type == ModelType.LLAMA_CPP: + return llamacpp_loader(llm_adapter, model_params) + else: + raise Exception(f"Unkown model type {model_type}") + def huggingface_loader(llm_adapter: BaseLLMAdaper, model_params: ModelParameters): device = model_params.device diff --git a/pilot/model/parameter.py b/pilot/model/parameter.py index 6ede45d9b..baa646711 100644 --- a/pilot/model/parameter.py +++ b/pilot/model/parameter.py @@ -1,9 +1,10 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +import argparse import os - -from typing import Any, Optional, Type from dataclasses import dataclass, field, fields +from enum import Enum +from typing import Any, Dict, List, Optional, Type, Union from pilot.model.conversation import conv_templates @@ -20,11 +21,28 @@ def _genenv_ignoring_key_case(env_key: str, env_prefix: str = None, default_valu class EnvArgumentParser: + @staticmethod + def get_env_prefix(env_key: str) -> str: + if not env_key: + return None + env_key = env_key.replace("-", "_") + return env_key + "_" + def parse_args_into_dataclass( - self, dataclass_type: Type, env_prefix: str = None, **kwargs + self, + dataclass_type: Type, + env_prefix: str = None, + command_args: List[str] = None, + **kwargs, ) -> Any: + """Parse parameters from environment variables and command lines and populate them into data class""" + parser = argparse.ArgumentParser() for field in fields(dataclass_type): env_var_value = _genenv_ignoring_key_case(field.name, env_prefix) + if not env_var_value: + # Read without env prefix + env_var_value = _genenv_ignoring_key_case(field.name) + if env_var_value: env_var_value = env_var_value.strip() if field.type is int or field.type == Optional[int]: @@ -37,17 +55,241 @@ class EnvArgumentParser: pass else: raise ValueError(f"Unsupported parameter type {field.type}") - kwargs[field.name] = env_var_value + if not env_var_value: + env_var_value = kwargs.get(field.name) + if not env_var_value: + env_var_value = field.default + + # Add a command-line argument for this field + help_text = field.metadata.get("help", "") + valid_values = field.metadata.get("valid_values", None) + parser.add_argument( + f"--{field.name}", + type=self._get_argparse_type(field.type), + help=help_text, + choices=valid_values, + default=env_var_value, + ) + + # Parse the command-line arguments + cmd_args, cmd_argv = parser.parse_known_args(args=command_args) + print(f"cmd_args: {cmd_args}") + for field in fields(dataclass_type): + # cmd_line_value = getattr(cmd_args, field.name) + if field.name in cmd_args: + cmd_line_value = getattr(cmd_args, field.name) + if cmd_line_value is not None: + kwargs[field.name] = cmd_line_value + return dataclass_type(**kwargs) + @staticmethod + def _get_argparse_type(field_type: Type) -> Type: + # Return the appropriate type for argparse to use based on the field type + if field_type is int or field_type == Optional[int]: + return int + elif field_type is float or field_type == Optional[float]: + return float + elif field_type is bool or field_type == Optional[bool]: + return bool + elif field_type is str or field_type == Optional[str]: + return str + else: + raise ValueError(f"Unsupported parameter type {field_type}") + + @staticmethod + def _get_argparse_type_str(field_type: Type) -> str: + argparse_type = EnvArgumentParser._get_argparse_type(field_type) + if argparse_type is int: + return "int" + elif argparse_type is float: + return "float" + elif argparse_type is bool: + return "bool" + else: + return "str" + @dataclass -class ModelParameters: - device: str = field(metadata={"help": "Device to run model"}) - model_name: str = field(metadata={"help": "Model name"}) - model_path: str = field(metadata={"help": "Model path"}) +class ParameterDescription: + param_name: str + param_type: str + description: str + default_value: Optional[Any] + valid_values: Optional[List[Any]] + + +def _get_parameter_descriptions(dataclass_type: Type) -> List[ParameterDescription]: + descriptions = [] + for field in fields(dataclass_type): + descriptions.append( + ParameterDescription( + param_name=field.name, + param_type=EnvArgumentParser._get_argparse_type_str(field.type), + description=field.metadata.get("help", None), + default_value=field.default, # TODO handle dataclasses._MISSING_TYPE + valid_values=field.metadata.get("valid_values", None), + ) + ) + return descriptions + + +class WorkerType(str, Enum): + LLM = "llm" + TEXT2VEC = "text2vec" + + @staticmethod + def values(): + return [item.value for item in WorkerType] + + +@dataclass +class BaseParameters: + def update_from(self, source: Union["BaseParameters", dict]) -> bool: + """ + Update the attributes of this object using the values from another object (of the same or parent type) or a dictionary. + Only update if the new value is different from the current value and the field is not marked as "fixed" in metadata. + + Args: + source (Union[BaseParameters, dict]): The source to update from. Can be another object of the same type or a dictionary. + + Returns: + bool: True if at least one field was updated, otherwise False. + """ + updated = False # Flag to indicate whether any field was updated + if isinstance(source, (BaseParameters, dict)): + for field_info in fields(self): + # Check if the field has a "fixed" tag in metadata + tags = field_info.metadata.get("tags") + tags = [] if not tags else tags.split(",") + if tags and "fixed" in tags: + continue # skip this field + # Get the new value from source (either another BaseParameters object or a dict) + new_value = ( + getattr(source, field_info.name) + if isinstance(source, BaseParameters) + else source.get(field_info.name, None) + ) + + # If the new value is not None and different from the current value, update the field and set the flag + if new_value is not None and new_value != getattr( + self, field_info.name + ): + setattr(self, field_info.name, new_value) + updated = True + else: + raise ValueError( + "Source must be an instance of BaseParameters (or its derived class) or a dictionary." + ) + + return updated + + def __str__(self) -> str: + class_name = self.__class__.__name__ + parameters = [ + f"\n\n=========================== {class_name} ===========================\n" + ] + for field_info in fields(self): + value = getattr(self, field_info.name) + parameters.append(f"{field_info.name}: {value}") + parameters.append( + "\n======================================================================\n\n" + ) + return "\n".join(parameters) + + +@dataclass +class ModelWorkerParameters(BaseParameters): + model_name: str = field(metadata={"help": "Model name", "tags": "fixed"}) + model_path: str = field(metadata={"help": "Model path", "tags": "fixed"}) + worker_type: Optional[str] = field( + default=None, + metadata={"valid_values": WorkerType.values(), "help": "Worker type"}, + ) + worker_class: Optional[str] = field( + default=None, + metadata={ + "help": "Model worker deploy host, pilot.model.worker.default_worker.DefaultModelWorker" + }, + ) + host: Optional[str] = field( + default="0.0.0.0", metadata={"help": "Model worker deploy host"} + ) + + port: Optional[int] = field( + default=8000, metadata={"help": "Model worker deploy port"} + ) + limit_model_concurrency: Optional[int] = field( + default=5, metadata={"help": "Model concurrency limit"} + ) + standalone: Optional[bool] = field( + default=False, + metadata={"help": "Standalone mode. If True, embedded Run ModelController"}, + ) + register: Optional[bool] = field( + default=True, metadata={"help": "Register current worker to model controller"} + ) + worker_register_host: Optional[str] = field( + default=None, + metadata={ + "help": "The ip address of current worker to register to ModelController. If None, the address is automatically determined" + }, + ) + controller_addr: Optional[str] = field( + default=None, metadata={"help": "The Model controller address to register"} + ) + send_heartbeat: Optional[bool] = field( + default=True, metadata={"help": "Send heartbeat to model controller"} + ) + heartbeat_interval: Optional[int] = field( + default=20, metadata={"help": "The interval for sending heartbeats (seconds)"} + ) + + +@dataclass +class EmbeddingModelParameters(BaseParameters): + model_name: str = field(metadata={"help": "Model name", "tags": "fixed"}) + model_path: str = field(metadata={"help": "Model path", "tags": "fixed"}) + device: Optional[str] = field( + default=None, + metadata={ + "help": "Device to run model. If None, the device is automatically determined" + }, + ) + + normalize_embeddings: Optional[bool] = field( + default=None, + metadata={ + "help": "Determines whether the model's embeddings should be normalized." + }, + ) + + def build_kwargs(self, **kwargs) -> Dict: + model_kwargs, encode_kwargs = None, None + if self.device: + model_kwargs = {"device": self.device} + if self.normalize_embeddings: + encode_kwargs = {"normalize_embeddings": self.normalize_embeddings} + if model_kwargs: + kwargs["model_kwargs"] = model_kwargs + if encode_kwargs: + kwargs["encode_kwargs"] = encode_kwargs + return kwargs + + +@dataclass +class ModelParameters(BaseParameters): + model_name: str = field(metadata={"help": "Model name", "tags": "fixed"}) + model_path: str = field(metadata={"help": "Model path", "tags": "fixed"}) + device: Optional[str] = field( + default=None, + metadata={ + "help": "Device to run model. If None, the device is automatically determined" + }, + ) model_type: Optional[str] = field( - default="huggingface", metadata={"help": "Model type, huggingface or llama.cpp"} + default="huggingface", + metadata={"help": "Model type, huggingface or llama.cpp", "tags": "fixed"}, ) prompt_template: Optional[str] = field( default=None, @@ -91,7 +333,6 @@ class ModelParameters: default=True, metadata={"help": "Nested quantization, only valid when load_4bit=True"}, ) - # "bfloat16", "float16", "float32" compute_dtype: Optional[str] = field( default=None, metadata={ diff --git a/pilot/model/worker/__init__.py b/pilot/model/worker/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/model/worker/base.py b/pilot/model/worker/base.py new file mode 100644 index 000000000..dfb0186fb --- /dev/null +++ b/pilot/model/worker/base.py @@ -0,0 +1,114 @@ +from abc import ABC, abstractmethod +from typing import Dict, Iterator, List, Type + +from pilot.model.base import ModelOutput +from pilot.model.parameter import ( + ModelParameters, + ParameterDescription, + WorkerType, + _get_parameter_descriptions, +) + + +class ModelWorker(ABC): + """ + Abstract representation of a Model Worker responsible for model interaction, startup, and shutdown. Supports 'llm' and 'text2vec' models. + """ + + def worker_type(self) -> WorkerType: + """Return the type of worker as LLM.""" + return WorkerType.LLM + + def model_param_class(self) -> Type: + """Return the class representing model parameters.""" + return ModelParameters + + def support_async(self) -> bool: + """Whether support async, if True, invoke async_generate_stream, async_generate and async_embeddings instead of generate_stream, generate and embeddings""" + return False + + @abstractmethod + def parse_parameters(self, command_args: List[str] = None) -> ModelParameters: + """Parse the parameters using the provided command arguments. + + Args: + command_args (List[str]): The command-line arguments. Default is sys.argv[1:]. + """ + + @abstractmethod + def load_worker(self, model_name: str, model_path: str, **kwargs) -> None: + """Load the worker with the specified model name and path.""" + + @abstractmethod + def start( + self, model_params: ModelParameters = None, command_args: List[str] = None + ) -> None: + """Start the model worker""" + + @abstractmethod + def stop(self) -> None: + """Stop the model worker and clean up all the resources used.""" + + def restart( + self, model_params: ModelParameters = None, command_args: List[str] = None + ) -> None: + """Restart the model worker.""" + self.stop() + self.start(model_params, command_args) + + def parameter_descriptions(self) -> List[ParameterDescription]: + """Fetch the parameter configuration information for the current model.""" + param_cls = self.model_param_class() + return _get_parameter_descriptions(param_cls) + + @abstractmethod + def generate_stream(self, params: Dict) -> Iterator[ModelOutput]: + """Generate a stream based on provided parameters. + + Args: + params (Dict): Parameters matching the PromptRequest data class format. Example: + { + "messages": [{"role": "user", "content": "Hello world"}], # List of ModelMessage objects + "model": "vicuna-13b-v1.5", + "prompt": "Hello world", + "temperature": 0.7, # Optional; float value between 0 and 1 + "max_new_tokens": 2048, # Optional; max number of new tokens for the output + "stop": "#", # Optional; stopping condition for the output + "echo": True # Optional; whether to echo the input in the output + } + + Returns: + Iterator[ModelOutput]: Stream of model outputs. + """ + + async def async_generate_stream(self, params: Dict) -> Iterator[ModelOutput]: + """Asynchronously generate a stream based on provided parameters.""" + raise NotImplementedError + + @abstractmethod + def generate(self, params: Dict) -> ModelOutput: + """Generate output (non-stream) based on provided parameters.""" + + async def async_generate(self, params: Dict) -> ModelOutput: + """Asynchronously generate output (non-stream) based on provided parameters.""" + raise NotImplementedError + + @abstractmethod + def embeddings(self, params: Dict) -> List[List[float]]: + """ + Return embeddings for the given input parameters. + + Args: + params (Dict): Parameters matching the EmbeddingsRequest data class format. Example: + { + "model": "text2vec-large-chinese", + "input": ["Hello world", "DB-GPT is amazing"] + } + + Returns: + List[List[float]]: List of embeddings corresponding to each input string. + """ + + async def async_embeddings(self, params: Dict) -> List[List[float]]: + """Return embeddings asynchronously for the given input parameters.""" + raise NotImplementedError diff --git a/pilot/model/worker/default_worker.py b/pilot/model/worker/default_worker.py new file mode 100644 index 000000000..deea90191 --- /dev/null +++ b/pilot/model/worker/default_worker.py @@ -0,0 +1,133 @@ +import logging +import platform +from typing import Dict, Iterator, List + +import torch +from pilot.configs.model_config import DEVICE +from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper +from pilot.model.base import ModelOutput +from pilot.model.loader import ModelLoader, _get_model_real_path +from pilot.model.parameter import EnvArgumentParser, ModelParameters +from pilot.model.worker.base import ModelWorker +from pilot.server.chat_adapter import get_llm_chat_adapter, BaseChatAdpter +from pilot.utils.model_utils import _clear_torch_cache + +logger = logging.getLogger("model_worker") + + +class DefaultModelWorker(ModelWorker): + def __init__(self) -> None: + self.model = None + self.tokenizer = None + self._model_params = None + self.llm_adapter: BaseLLMAdaper = None + self.llm_chat_adapter: BaseChatAdpter = None + + def load_worker(self, model_name: str, model_path: str, **kwargs) -> None: + if model_path.endswith("/"): + model_path = model_path[:-1] + model_path = _get_model_real_path(model_name, model_path) + self.model_name = model_name + self.model_path = model_path + + self.llm_adapter = get_llm_model_adapter(self.model_name, self.model_path) + model_type = self.llm_adapter.model_type() + self.param_cls = self.llm_adapter.model_param_class(model_type) + + self.llm_chat_adapter = get_llm_chat_adapter(self.model_name, self.model_path) + self.generate_stream_func = self.llm_chat_adapter.get_generate_stream_func( + self.model_path + ) + + self.ml: ModelLoader = ModelLoader( + model_path=self.model_path, model_name=self.model_name + ) + # TODO read context len from model config + self.context_len = 2048 + + def model_param_class(self) -> ModelParameters: + return self.param_cls + + def parse_parameters(self, command_args: List[str] = None) -> ModelParameters: + param_cls = self.model_param_class() + model_args = EnvArgumentParser() + env_prefix = EnvArgumentParser.get_env_prefix(self.model_name) + model_type = self.llm_adapter.model_type() + model_params: ModelParameters = model_args.parse_args_into_dataclass( + param_cls, + env_prefix=env_prefix, + command_args=command_args, + model_name=self.model_name, + model_path=self.model_path, + model_type=model_type, + ) + if not model_params.device: + model_params.device = DEVICE + logger.info( + f"[DefaultModelWorker] Parameters of device is None, use {model_params.device}" + ) + return model_params + + def start( + self, model_params: ModelParameters = None, command_args: List[str] = None + ) -> None: + if not model_params: + model_params = self.parse_parameters(command_args) + self._model_params = model_params + logger.info(f"Begin load model, model params: {model_params}") + self.model, self.tokenizer = self.ml.loader_with_params(model_params) + + def stop(self) -> None: + if not self.model: + return + del self.model + del self.tokenizer + self.model = None + self.tokenizer = None + _clear_torch_cache(self._model_params.device) + + def generate_stream(self, params: Dict) -> Iterator[ModelOutput]: + try: + # params adaptation + params, model_context = self.llm_chat_adapter.model_adaptation( + params, self.ml.model_path, prompt_template=self.ml.prompt_template + ) + + for output in self.generate_stream_func( + self.model, self.tokenizer, params, DEVICE, self.context_len + ): + # Please do not open the output in production! + # The gpt4all thread shares stdout with the parent process, + # and opening it may affect the frontend output. + if "windows" in platform.platform().lower(): + # Do not print the model output, because it may contain Emoji, there is a problem with the GBK encoding + pass + else: + print("output: ", output) + # return some model context to dgt-server + model_output = ModelOutput( + text=output, error_code=0, model_context=model_context + ) + yield model_output + + except torch.cuda.CudaError: + model_output = ModelOutput( + text="**GPU OutOfMemory, Please Refresh.**", error_code=0 + ) + yield model_output + except Exception as e: + model_output = ModelOutput( + text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}", + error_code=0, + ) + yield model_output + + def generate(self, params: Dict) -> ModelOutput: + """Generate non stream result""" + output = None + for out in self.generate_stream(params): + output = out + return output + + def embeddings(self, params: Dict) -> List[List[float]]: + raise NotImplementedError diff --git a/pilot/model/worker/embedding_worker.py b/pilot/model/worker/embedding_worker.py new file mode 100644 index 000000000..0f011dc6b --- /dev/null +++ b/pilot/model/worker/embedding_worker.py @@ -0,0 +1,100 @@ +import logging +from typing import Dict, List, Type + +from pilot.configs.model_config import DEVICE +from pilot.model.loader import _get_model_real_path +from pilot.model.parameter import ( + EmbeddingModelParameters, + EnvArgumentParser, + WorkerType, +) +from pilot.model.worker.base import ModelWorker +from pilot.utils.model_utils import _clear_torch_cache + +logger = logging.getLogger("model_worker") + + +class EmbeddingsModelWorker(ModelWorker): + def __init__(self) -> None: + try: + from langchain.embeddings import HuggingFaceEmbeddings + from langchain.embeddings.base import Embeddings + except ImportError as exc: + raise ImportError( + "Could not import langchain.embeddings.HuggingFaceEmbeddings python package. " + "Please install it with `pip install langchain`." + ) from exc + self.embeddings: Embeddings = None + self._model_params = None + + def load_worker(self, model_name: str, model_path: str, **kwargs) -> None: + if model_path.endswith("/"): + model_path = model_path[:-1] + model_path = _get_model_real_path(model_name, model_path) + + self.model_name = model_name + self.model_path = model_path + + def worker_type(self) -> WorkerType: + return WorkerType.TEXT2VEC + + def model_param_class(self) -> Type: + return EmbeddingModelParameters + + def parse_parameters( + self, command_args: List[str] = None + ) -> EmbeddingModelParameters: + param_cls = self.model_param_class() + model_args = EnvArgumentParser() + env_prefix = EnvArgumentParser.get_env_prefix(self.model_name) + model_params: EmbeddingModelParameters = model_args.parse_args_into_dataclass( + param_cls, + env_prefix=env_prefix, + command_args=command_args, + model_name=self.model_name, + model_path=self.model_path, + ) + if not model_params.device: + model_params.device = DEVICE + logger.info( + f"[EmbeddingsModelWorker] Parameters of device is None, use {model_params.device}" + ) + return model_params + + def start( + self, + model_params: EmbeddingModelParameters = None, + command_args: List[str] = None, + ) -> None: + """Start model worker""" + from langchain.embeddings import HuggingFaceEmbeddings + + if not model_params: + model_params = self.parse_parameters(command_args) + self._model_params = model_params + + kwargs = model_params.build_kwargs(model_name=model_params.model_path) + logger.info(f"Start HuggingFaceEmbeddings with kwargs: {kwargs}") + self.embeddings = HuggingFaceEmbeddings(**kwargs) + + def __del__(self): + self.stop() + + def stop(self) -> None: + if not self.embeddings: + return + del self.embeddings + self.embeddings = None + _clear_torch_cache(self._model_params.device) + + def generate_stream(self, params: Dict): + """Generate stream result, chat scene""" + raise NotImplementedError("Not supported generate_stream for embeddings model") + + def generate(self, params: Dict): + """Generate non stream result""" + raise NotImplementedError("Not supported generate for embeddings model") + + def embeddings(self, params: Dict) -> List[List[float]]: + input: List[str] = params["input"] + return self.embeddings.embed_documents(input) diff --git a/pilot/model/worker/manager.py b/pilot/model/worker/manager.py new file mode 100644 index 000000000..3d18088e9 --- /dev/null +++ b/pilot/model/worker/manager.py @@ -0,0 +1,705 @@ +import asyncio +import httpx +import itertools +import json +import os +import random +import time +from abc import ABC, abstractmethod +from concurrent.futures import Future, ThreadPoolExecutor +from dataclasses import asdict, dataclass, field +from datetime import datetime +from typing import Awaitable, Callable, Dict, Iterator, List, Optional + +import uvicorn +from fastapi import APIRouter, FastAPI, Request +from fastapi.responses import StreamingResponse +from pilot.configs.model_config import LOGDIR +from pilot.model.base import ( + ModelInstance, + ModelOutput, + WorkerApplyType, + WorkerApplyOutput, +) +from pilot.model.controller.registry import ModelRegistry +from pilot.model.parameter import ( + EnvArgumentParser, + ModelParameters, + ModelWorkerParameters, + WorkerType, + ParameterDescription, +) +from pilot.model.worker.base import ModelWorker +from pilot.scene.base_message import ModelMessage +from pilot.utils import build_logger +from pydantic import BaseModel + +logger = build_logger("model_worker", LOGDIR + "/model_worker.log") + + +class PromptRequest(BaseModel): + messages: List[ModelMessage] + model: str + prompt: str = None + temperature: float = None + max_new_tokens: int = None + stop: str = None + echo: bool = True + + +class EmbeddingsRequest(BaseModel): + model: str + input: List[str] + + +class WorkerApplyRequest(BaseModel): + model: str + apply_type: WorkerApplyType + worker_type: WorkerType = WorkerType.LLM + params: Dict = None + apply_user: str = None + + +class WorkerParameterRequest(BaseModel): + model: str + worker_type: WorkerType = WorkerType.LLM + + +@dataclass +class WorkerRunData: + worker_key: str + worker: ModelWorker + worker_params: ModelWorkerParameters + model_params: ModelParameters + stop_event: asyncio.Event + semaphore: asyncio.Semaphore = None + command_args: List[str] = None + _heartbeat_future: Optional[Future] = None + _last_heartbeat: Optional[datetime] = None + + +RegisterFunc = Callable[[WorkerRunData], Awaitable[None]] +DeregisterFunc = Callable[[WorkerRunData], Awaitable[None]] +SendHeartbeatFunc = Callable[[WorkerRunData], Awaitable[None]] +ApplyFunction = Callable[[WorkerRunData], Awaitable[None]] + + +class WorkerManager(ABC): + @abstractmethod + async def get_model_instances( + self, worker_type: str, model_name: str, healthy_only: bool = True + ) -> List[WorkerRunData]: + """Get model instances by worker type and model name""" + + @abstractmethod + async def select_one_instanes( + self, worker_type: str, model_name: str, healthy_only: bool = True + ) -> WorkerRunData: + """Select one instances""" + + @abstractmethod + async def generate_stream(self, params: Dict, **kwargs) -> Iterator[ModelOutput]: + """Generate stream result, chat scene""" + + @abstractmethod + async def generate(self, params: Dict) -> ModelOutput: + """Generate non stream result""" + + @abstractmethod + async def embeddings(self, params: Dict) -> List[List[float]]: + """Embed input""" + + @abstractmethod + async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput: + """Worker apply""" + + @abstractmethod + async def parameter_descriptions( + self, worker_type: str, model_name: str + ) -> List[ParameterDescription]: + """Get parameter descriptions of model""" + + +async def _async_heartbeat_sender( + worker_run_data: WorkerRunData, send_heartbeat_func: SendHeartbeatFunc +): + while not worker_run_data.stop_event.is_set(): + try: + await send_heartbeat_func(worker_run_data) + except Exception as e: + logger.warn(f"Send heartbeat func error: {str(e)}") + finally: + await asyncio.sleep(worker_run_data.worker_params.heartbeat_interval) + + +class LocalWorkerManager(WorkerManager): + def __init__( + self, + register_func: RegisterFunc = None, + deregister_func: DeregisterFunc = None, + send_heartbeat_func: SendHeartbeatFunc = None, + model_registry: ModelRegistry = None, + ) -> None: + self.workers: Dict[str, List[WorkerRunData]] = dict() + self.executor = ThreadPoolExecutor(max_workers=os.cpu_count() * 5) + self.register_func = register_func + self.deregister_func = deregister_func + self.send_heartbeat_func = send_heartbeat_func + self.model_registry = model_registry + + def _worker_key(self, worker_type: str, model_name: str) -> str: + return f"{model_name}@{worker_type}" + + def add_worker( + self, + worker: ModelWorker, + worker_params: ModelWorkerParameters, + embedded_mod: bool = True, + command_args: List[str] = None, + ): + if not command_args: + import sys + + command_args = sys.argv[1:] + worker.load_worker(**asdict(worker_params)) + + if not worker_params.worker_type: + worker_params.worker_type = worker.worker_type() + + worker_key = self._worker_key( + worker_params.worker_type, worker_params.model_name + ) + host = worker_params.host + port = worker_params.port + + instances = self.workers.get(worker_key) + if not instances: + instances = [] + self.workers[worker_key] = instances + logger.info(f"Init empty instances list for {worker_key}") + # Load model params from persist storage + model_params = worker.parse_parameters(command_args=command_args) + + worker_run_data = WorkerRunData( + worker_key=worker_key, + worker=worker, + worker_params=worker_params, + model_params=model_params, + stop_event=asyncio.Event(), + semaphore=asyncio.Semaphore(worker_params.limit_model_concurrency), + command_args=command_args, + ) + if not embedded_mod: + exist_instances = [ + (w, p) for w, p in instances if p.host == host and p.port == port + ] + if not exist_instances: + instances.append(worker_run_data) + else: + instances.append(worker_run_data) + + async def get_model_instances( + self, worker_type: str, model_name: str, healthy_only: bool = True + ) -> List[WorkerRunData]: + worker_key = self._worker_key(worker_type, model_name) + return self.workers.get(worker_key) + + async def select_one_instanes( + self, worker_type: str, model_name: str, healthy_only: bool = True + ) -> WorkerRunData: + worker_instances = await self.get_model_instances( + worker_type, model_name, healthy_only + ) + if not worker_instances: + raise Exception( + f"Cound not found worker instances for model name {model_name} and worker type {worker_type}" + ) + worker_run_data = random.choice(worker_instances) + return worker_run_data + + async def _get_model(self, params: Dict, worker_type: str = "llm") -> WorkerRunData: + model = params.get("model") + if not model: + raise Exception("Model name count not be empty") + return await self.select_one_instanes(worker_type, model, healthy_only=True) + + async def generate_stream( + self, params: Dict, async_wrapper=None, **kwargs + ) -> Iterator[ModelOutput]: + """Generate stream result, chat scene""" + worker_run_data = await self._get_model(params) + async with worker_run_data.semaphore: + if worker_run_data.worker.support_async(): + async for outout in worker_run_data.worker.async_generate_stream( + params + ): + yield outout + else: + if not async_wrapper: + from starlette.concurrency import iterate_in_threadpool + + async_wrapper = iterate_in_threadpool + async for output in async_wrapper( + worker_run_data.worker.generate_stream(params) + ): + yield output + + async def generate(self, params: Dict) -> ModelOutput: + """Generate non stream result""" + worker_run_data = await self._get_model(params) + async with worker_run_data.semaphore: + if worker_run_data.worker.support_async(): + return await worker_run_data.worker.async_generate(params) + else: + loop = asyncio.get_event_loop() + return await loop.run_in_executor( + self.executor, worker_run_data.worker.generate, params + ) + + async def embeddings(self, params: Dict) -> List[List[float]]: + """Embed input""" + worker_run_data = await self._get_model(params, worker_type="text2vec") + async with worker_run_data.semaphore: + if worker_run_data.worker.support_async(): + return await worker_run_data.worker.async_embeddings(params) + else: + loop = asyncio.get_event_loop() + return await loop.run_in_executor( + self.executor, worker_run_data.worker.embeddings, params + ) + + async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput: + apply_func: Callable[[WorkerApplyRequest], Awaitable[str]] = None + if apply_req.apply_type == WorkerApplyType.START: + apply_func = self._start_all_worker + elif apply_req.apply_type == WorkerApplyType.STOP: + apply_func = self._stop_all_worker + elif apply_req.apply_type == WorkerApplyType.UPDATE_PARAMS: + apply_func = self._update_all_worker_params + else: + raise ValueError(f"Unsupported apply type {apply_req.apply_type}") + return await apply_func(apply_req) + + async def parameter_descriptions( + self, worker_type: str, model_name: str + ) -> List[ParameterDescription]: + worker_instances = await self.get_model_instances(worker_type, model_name) + if not worker_instances: + raise Exception( + f"Not worker instances for model name {model_name} worker type {worker_type}" + ) + worker_run_data = worker_instances[0] + return worker_run_data.worker.parameter_descriptions() + + async def _apply_worker( + self, apply_req: WorkerApplyRequest, apply_func: ApplyFunction + ) -> None: + """Apply function to worker instances in parallel + + Args: + apply_req (WorkerApplyRequest): Worker apply request + apply_func (ApplyFunction): Function to apply to worker instances, now function is async function + """ + if apply_req: + worker_type = apply_req.worker_type.value + model_name = apply_req.model + worker_instances = await self.get_model_instances(worker_type, model_name) + if not worker_instances: + raise Exception( + f"No worker instance found for the model {model_name} worker type {worker_type}" + ) + else: + # Apply to all workers + worker_instances = list(itertools.chain(*self.workers.values())) + logger.info(f"Apply to all workers: {worker_instances}") + return await asyncio.gather( + *(apply_func(worker) for worker in worker_instances) + ) + + async def _start_all_worker( + self, apply_req: WorkerApplyRequest + ) -> WorkerApplyOutput: + start_time = time.time() + logger.info(f"Begin start all worker, apply_req: {apply_req}") + + async def _start_worker(worker_run_data: WorkerRunData): + worker_run_data.worker.start( + worker_run_data.model_params, worker_run_data.command_args + ) + worker_run_data.stop_event.clear() + if worker_run_data.worker_params.register and self.register_func: + # Register worker to controller + await self.register_func(worker_run_data) + if ( + worker_run_data.worker_params.send_heartbeat + and self.send_heartbeat_func + ): + asyncio.create_task( + _async_heartbeat_sender( + worker_run_data, self.send_heartbeat_func + ) + ) + + await self._apply_worker(apply_req, _start_worker) + timecost = time.time() - start_time + return WorkerApplyOutput( + message=f"Worker started successfully", timecost=timecost + ) + + async def _stop_all_worker( + self, apply_req: WorkerApplyRequest + ) -> WorkerApplyOutput: + start_time = time.time() + + async def _stop_worker(worker_run_data: WorkerRunData): + worker_run_data.worker.stop() + # Set stop event + worker_run_data.stop_event.set() + if worker_run_data._heartbeat_future: + # Wait thread finish + worker_run_data._heartbeat_future.result() + worker_run_data._heartbeat_future = None + if ( + worker_run_data.worker_params.register + and self.register_func + and self.deregister_func + ): + await self.deregister_func(worker_run_data) + + await self._apply_worker(apply_req, _stop_worker) + timecost = time.time() - start_time + return WorkerApplyOutput( + message=f"Worker stopped successfully", timecost=timecost + ) + + async def _update_all_worker_params( + self, apply_req: WorkerApplyRequest + ) -> WorkerApplyOutput: + start_time = time.time() + need_restart = False + + async def update_params(worker_run_data: WorkerRunData): + nonlocal need_restart + new_params = apply_req.params + if not new_params: + return + if worker_run_data.model_params.update_from(new_params): + need_restart = True + + await self._apply_worker(apply_req, update_params) + message = f"Update worker params successfully" + timecost = time.time() - start_time + if need_restart: + logger.info("Model params update successfully, begin restart worker") + await self._stop_all_worker(apply_req) + await self._start_all_worker(apply_req) + timecost = time.time() - start_time + message = f"Update worker params and restart successfully" + return WorkerApplyOutput(message=message, timecost=timecost) + + +class RemoteWorkerManager(LocalWorkerManager): + def __init__(self, model_registry: ModelRegistry = None) -> None: + super().__init__(model_registry=model_registry) + + async def get_model_instances( + self, worker_type: str, model_name: str, healthy_only: bool = True + ) -> List[WorkerRunData]: + from pilot.model.worker.remote_worker import RemoteModelWorker + + worker_key = self._worker_key(worker_type, model_name) + instances: List[ModelInstance] = await self.model_registry.get_all_instances( + worker_key, healthy_only + ) + worker_instances = [] + for ins in instances: + worker = RemoteModelWorker() + worker.load_worker(model_name, model_name, host=ins.host, port=ins.port) + wr = WorkerRunData( + worker_key=ins.model_name, + worker=worker, + worker_params=None, + model_params=None, + stop_event=asyncio.Event(), + semaphore=asyncio.Semaphore(100), # Not limit in client + ) + worker_instances.append(wr) + return worker_instances + + async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput: + async def _remote_apply_func(worker_run_data: WorkerRunData): + worker_addr = worker_run_data.worker.worker_addr + async with httpx.AsyncClient() as client: + response = await client.post( + worker_addr + "/apply", + headers=worker_run_data.worker.headers, + json=apply_req.dict(), + timeout=worker_run_data.worker.timeout, + ) + if response.status_code == 200: + output = WorkerApplyOutput(**response.json()) + logger.info(f"worker_apply success: {output}") + else: + output = WorkerApplyOutput(message=response.text) + logger.warn(f"worker_apply failed: {output}") + return output + + results = await self._apply_worker(apply_req, _remote_apply_func) + if results: + return results[0] + + +class WorkerManagerAdapter(WorkerManager): + def __init__(self, worker_manager: WorkerManager = None) -> None: + self.worker_manager = worker_manager + + async def get_model_instances( + self, worker_type: str, model_name: str, healthy_only: bool = True + ) -> List[WorkerRunData]: + return await self.worker_manager.get_model_instances( + worker_type, model_name, healthy_only + ) + + async def select_one_instanes( + self, worker_type: str, model_name: str, healthy_only: bool = True + ) -> WorkerRunData: + return await self.worker_manager.select_one_instanes( + worker_type, model_name, healthy_only + ) + + async def generate_stream(self, params: Dict, **kwargs) -> Iterator[ModelOutput]: + async for output in self.worker_manager.generate_stream(params, **kwargs): + yield output + + async def generate(self, params: Dict) -> ModelOutput: + return await self.worker_manager.generate(params) + + async def embeddings(self, params: Dict) -> List[List[float]]: + return await self.worker_manager.embeddings(params) + + async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput: + return await self.worker_manager.worker_apply(apply_req) + + async def parameter_descriptions( + self, worker_type: str, model_name: str + ) -> List[ParameterDescription]: + return await self.worker_manager.parameter_descriptions(worker_type, model_name) + + +worker_manager = WorkerManagerAdapter() +router = APIRouter() + + +async def generate_json_stream(params): + from starlette.concurrency import iterate_in_threadpool + + async for output in worker_manager.generate_stream( + params, async_wrapper=iterate_in_threadpool + ): + yield json.dumps(asdict(output), ensure_ascii=False).encode() + b"\0" + + +@router.post("/worker/generate_stream") +async def api_generate_stream(request: Request): + params = await request.json() + generator = generate_json_stream(params) + return StreamingResponse(generator) + + +@router.post("/worker/generate") +async def api_generate(request: PromptRequest): + params = request.dict(exclude_none=True) + output = await worker_manager.generate(params) + return output + + +@router.post("/worker/embeddings") +async def api_embeddings(request: EmbeddingsRequest): + params = request.dict(exclude_none=True) + output = await worker_manager.embeddings(params) + return output + + +@router.post("/worker/apply") +async def api_worker_apply(request: WorkerApplyRequest): + output = await worker_manager.worker_apply(request) + return output + + +@router.get("/worker/parameter/descriptions") +async def api_worker_parameter_descs( + model: str, worker_type: str = WorkerType.LLM.value +): + output = await worker_manager.parameter_descriptions(worker_type, model) + return output + + +def _setup_fastapi(worker_params: ModelWorkerParameters): + app = FastAPI() + if worker_params.standalone: + from pilot.model.controller.controller import router as controller_router + + if not worker_params.controller_addr: + worker_params.controller_addr = f"http://127.0.0.1:{worker_params.port}" + app.include_router(controller_router, prefix="/api") + + @app.on_event("startup") + async def startup_event(): + asyncio.create_task( + worker_manager.worker_manager._start_all_worker(apply_req=None) + ) + + return app + + +def _parse_worker_params( + model_name: str = None, model_path: str = None, **kwargs +) -> ModelWorkerParameters: + worker_args = EnvArgumentParser() + worker_params: ModelWorkerParameters = worker_args.parse_args_into_dataclass( + ModelWorkerParameters, model_name=model_name, model_path=model_path, **kwargs + ) + env_prefix = EnvArgumentParser.get_env_prefix(worker_params.model_name) + # Read parameters agein with prefix of model name. + new_worker_params = worker_args.parse_args_into_dataclass( + ModelWorkerParameters, + env_prefix=env_prefix, + model_name=worker_params.model_name, + model_path=worker_params.model_path, + **kwargs, + ) + worker_params.update_from(new_worker_params) + + logger.info(f"Worker params: {worker_params}") + return worker_params + + +def _create_local_model_manager( + worker_params: ModelWorkerParameters, +) -> LocalWorkerManager: + if not worker_params.register or not worker_params.controller_addr: + logger.info( + f"Not register current to controller, register: {worker_params.register}, controller_addr: {worker_params.controller_addr}" + ) + return LocalWorkerManager() + else: + from pilot.model.controller.registry import ModelRegistryClient + from pilot.utils.net_utils import _get_ip_address + + client = ModelRegistryClient(worker_params.controller_addr) + host = _get_ip_address() + port = worker_params.port + + async def register_func(worker_run_data: WorkerRunData): + instance = ModelInstance( + model_name=worker_run_data.worker_key, host=host, port=port + ) + return await client.register_instance(instance) + + async def send_heartbeat_func(worker_run_data: WorkerRunData): + instance = ModelInstance( + model_name=worker_run_data.worker_key, host=host, port=port + ) + return await client.send_heartbeat(instance) + + return LocalWorkerManager( + register_func=register_func, send_heartbeat_func=send_heartbeat_func + ) + + +def _start_local_worker( + worker_manager: WorkerManagerAdapter, + worker_params: ModelWorkerParameters, + embedded_mod=True, +): + from pilot.utils.module_utils import import_from_checked_string + + if worker_params.worker_class: + worker_cls = import_from_checked_string(worker_params.worker_class, ModelWorker) + logger.info( + f"Import worker class from {worker_params.worker_class} successfully" + ) + worker: ModelWorker = worker_cls() + else: + from pilot.model.worker.default_worker import DefaultModelWorker + + worker = DefaultModelWorker() + + worker_manager.worker_manager = _create_local_model_manager(worker_params) + worker_manager.worker_manager.add_worker( + worker, worker_params, embedded_mod=embedded_mod + ) + + +def initialize_worker_manager_in_client( + app=None, + include_router: bool = True, + model_name: str = None, + model_path: str = None, + run_locally: bool = True, + controller_addr: str = None, +): + global worker_manager + + worker_params: ModelWorkerParameters = _parse_worker_params( + model_name=model_name, model_path=model_path, controller_addr=controller_addr + ) + + logger.info(f"Worker params: {worker_params}") + if run_locally: + worker_params.register = False + _start_local_worker(worker_manager, worker_params, True) + loop = asyncio.get_event_loop() + loop.run_until_complete( + worker_manager.worker_manager._start_all_worker(apply_req=None) + ) + else: + from pilot.model.controller.registry import ModelRegistryClient + + if not worker_params.controller_addr: + raise ValueError("Controller can`t be None") + client = ModelRegistryClient(worker_params.controller_addr) + worker_manager.worker_manager = RemoteWorkerManager(client) + + if include_router and app: + app.include_router(router, prefix="/api") + + +def run_worker_manager( + app=None, + include_router: bool = True, + model_name: str = None, + model_path: str = None, + standalone: bool = False, + port: int = None, +): + global worker_manager + + worker_params: ModelWorkerParameters = _parse_worker_params( + model_name=model_name, model_path=model_path, standalone=standalone, port=port + ) + + embedded_mod = True + if not app: + # Run worker manager independently + embedded_mod = False + app = _setup_fastapi(worker_params) + _start_local_worker(worker_manager, worker_params, embedded_mod=False) + else: + _start_local_worker(worker_manager, worker_params, embedded_mod=False) + loop = asyncio.get_event_loop() + loop.run_until_complete( + worker_manager.worker_manager._start_all_worker(apply_req=None) + ) + + if include_router: + app.include_router(router, prefix="/api") + + if not embedded_mod: + uvicorn.run( + app, host=worker_params.host, port=worker_params.port, log_level="info" + ) + + +if __name__ == "__main__": + run_worker_manager() diff --git a/pilot/model/worker/ray_worker.py b/pilot/model/worker/ray_worker.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/model/worker/remote_worker.py b/pilot/model/worker/remote_worker.py new file mode 100644 index 000000000..8712da0ca --- /dev/null +++ b/pilot/model/worker/remote_worker.py @@ -0,0 +1,98 @@ +import json +from typing import Dict, Iterator, List + +import httpx +from pilot.model.base import ModelOutput +from pilot.model.parameter import ModelParameters +from pilot.model.worker.base import ModelWorker + + +class RemoteModelWorker(ModelWorker): + def __init__(self) -> None: + self.headers = {} + self.timeout = 60 + self.host = None + self.port = None + + @property + def worker_addr(self) -> str: + return f"http://{self.host}:{self.port}/api/worker" + + def support_async(self) -> bool: + return True + + def parse_parameters(self, command_args: List[str] = None) -> ModelParameters: + return None + + def load_worker(self, model_name: str, model_path: str, **kwargs): + self.host = kwargs.get("host") + self.port = kwargs.get("port") + + def start( + self, model_params: ModelParameters = None, command_args: List[str] = None + ) -> None: + """Start model worker""" + pass + # raise NotImplementedError("Remote model worker not support start methods") + + def stop(self) -> None: + raise NotImplementedError("Remote model worker not support stop methods") + + def generate_stream(self, params: Dict) -> Iterator[ModelOutput]: + """Generate stream""" + raise NotImplementedError + + async def async_generate_stream(self, params: Dict) -> Iterator[ModelOutput]: + """Asynchronous generate stream""" + print(f"Send async_generate_stream, params: {params}") + async with httpx.AsyncClient() as client: + delimiter = b"\0" + buffer = b"" + async with client.stream( + "POST", + self.worker_addr + "/generate_stream", + headers=self.headers, + json=params, + timeout=self.timeout, + ) as response: + async for raw_chunk in response.aiter_raw(): + buffer += raw_chunk + while delimiter in buffer: + chunk, buffer = buffer.split(delimiter, 1) + if not chunk: + continue + chunk = chunk.decode() + data = json.loads(chunk) + yield ModelOutput(**data) + + def generate(self, params: Dict) -> ModelOutput: + """Generate non stream""" + raise NotImplementedError + + async def async_generate(self, params: Dict) -> ModelOutput: + """Asynchronous generate non stream""" + print(f"Send async_generate_stream, params: {params}") + + async with httpx.AsyncClient() as client: + response = await client.post( + self.worker_addr + "/generate", + headers=self.headers, + json=params, + timeout=self.timeout, + ) + return ModelOutput(**response.json()) + + def embeddings(self, params: Dict) -> List[List[float]]: + """Get embeddings for input""" + raise NotImplementedError + + async def async_embeddings(self, params: Dict) -> List[List[float]]: + """Asynchronous get embeddings for input""" + async with httpx.AsyncClient() as client: + response = await client.post( + self.worker_addr + "/embeddings", + headers=self.headers, + json=params, + timeout=self.timeout, + ) + return response.json() diff --git a/pilot/openapi/api_v1/api_v1.py b/pilot/openapi/api_v1/api_v1.py index 5930c464d..b8d1dc3b7 100644 --- a/pilot/openapi/api_v1/api_v1.py +++ b/pilot/openapi/api_v1/api_v1.py @@ -311,33 +311,23 @@ async def chat_completions(dialogue: ConversationVo = Body()): async def no_stream_generator(chat): - msg = chat.nostream_call() + msg = await chat.nostream_call() msg = msg.replace("\n", "\\n") yield f"data: {msg}\n\n" async def stream_generator(chat): - model_response = chat.stream_call() msg = "[LLM_ERROR]: llm server has no output, maybe your prompt template is wrong." - if not CFG.NEW_SERVER_MODE: - for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"): - if chunk: - msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex( - chunk, chat.skip_echo_len - ) - msg = msg.replace("\n", "\\n") - yield f"data:{msg}\n\n" - await asyncio.sleep(0.02) - else: - for chunk in model_response: - if chunk: - msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex( - chunk, chat.skip_echo_len - ) - msg = msg.replace("\n", "\\n") - yield f"data:{msg}\n\n" - await asyncio.sleep(0.02) + async for chunk in chat.stream_call(): + if chunk: + msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex( + chunk, chat.skip_echo_len + ) + + msg = msg.replace("\n", "\\n") + yield f"data:{msg}\n\n" + await asyncio.sleep(0.02) chat.current_message.add_ai_message(msg) chat.current_message.add_view_message(msg) diff --git a/pilot/out_parser/base.py b/pilot/out_parser/base.py index 1ae14b506..5cd86573a 100644 --- a/pilot/out_parser/base.py +++ b/pilot/out_parser/base.py @@ -1,26 +1,19 @@ from __future__ import annotations + import json - -from abc import ABC, abstractmethod -from typing import ( - Any, - Dict, - Generic, - List, - NamedTuple, - Optional, - Sequence, - TypeVar, - Union, -) -from pilot.utils import build_logger import re +from abc import ABC, abstractmethod +from dataclasses import asdict +from typing import Any, Dict, TypeVar, Union -from pydantic import BaseModel, Extra, Field, root_validator -from pilot.configs.model_config import LOGDIR from pilot.configs.config import Config +from pilot.configs.model_config import LOGDIR +from pilot.model.base import ModelOutput +from pilot.utils import build_logger T = TypeVar("T") +ResponseTye = Union[str, bytes, ModelOutput] + logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log") CFG = Config() @@ -46,11 +39,8 @@ class BaseOutputParser(ABC): code = sep.join(blocks) return code - def parse_model_stream_resp_ex(self, chunk, skip_echo_len): - if b"\0" in chunk: - chunk = chunk.replace(b"\0", b"") - data = json.loads(chunk.decode()) - + def parse_model_stream_resp_ex(self, chunk: ResponseTye, skip_echo_len): + data = _parse_model_response(chunk) """ TODO Multi mode output handler, rewrite this for multi model, use adapter mode. """ model_context = data.get("model_context") @@ -103,8 +93,8 @@ class BaseOutputParser(ABC): output = data["text"] + f" (error_code: {data['error_code']})" yield output - def parse_model_nostream_resp(self, response, sep: str): - resp_obj_ex = json.loads(response) + def parse_model_nostream_resp(self, response: ResponseTye, sep: str): + resp_obj_ex = _parse_model_response(response) if isinstance(resp_obj_ex, str): resp_obj_ex = json.loads(resp_obj_ex) if resp_obj_ex["error_code"] == 0: @@ -240,3 +230,17 @@ class BaseOutputParser(ABC): output_parser_dict = super().dict() output_parser_dict["_type"] = self._type return output_parser_dict + + +def _parse_model_response(response: ResponseTye): + if isinstance(response, ModelOutput): + resp_obj_ex = asdict(response) + elif isinstance(response, str): + resp_obj_ex = json.loads(response) + elif isinstance(response, bytes): + if b"\0" in response: + response = response.replace(b"\0", b"") + resp_obj_ex = json.loads(response.decode()) + else: + raise ValueError(f"Unsupported response type {type(response)}") + return resp_obj_ex diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index 9bb5ddc9b..88f457935 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -28,10 +28,7 @@ from pilot.memory.chat_history.mem_history import MemHistoryMemory from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory from pilot.configs.model_config import LOGDIR, DATASETS_DIR -from pilot.utils import ( - build_logger, - server_error_msg, -) +from pilot.utils import build_logger, server_error_msg, get_or_create_event_loop from pilot.scene.base_message import ( BaseMessage, SystemMessage, @@ -158,7 +155,7 @@ class BaseChat(ABC): } return payload - def stream_call(self): + async def stream_call(self): # TODO Retry when server connection error payload = self.__call_base() @@ -166,19 +163,10 @@ class BaseChat(ABC): logger.info(f"Requert: \n{payload}") ai_response_text = "" try: - if not CFG.NEW_SERVER_MODE: - response = requests.post( - urljoin(CFG.MODEL_SERVER, "generate_stream"), - headers=headers, - json=payload, - stream=True, - timeout=120, - ) - return response - else: - from pilot.server.llmserver import worker + from pilot.model.worker.manager import worker_manager - return worker.generate_stream_gate(payload) + async for output in worker_manager.generate_stream(payload): + yield output except Exception as e: print(traceback.format_exc()) logger.error("model response parase faild!" + str(e)) @@ -188,34 +176,19 @@ class BaseChat(ABC): ### store current conversation self.memory.append(self.current_message) - def nostream_call(self): + async def nostream_call(self): payload = self.__call_base() logger.info(f"Requert: \n{payload}") ai_response_text = "" try: - rsp_str = "" - if not CFG.NEW_SERVER_MODE: - rsp_obj = requests.post( - urljoin(CFG.MODEL_SERVER, "generate"), - headers=headers, - json=payload, - timeout=120, - ) - rsp_str = rsp_obj.text - else: - ###TODO no stream mode need independent - from pilot.server.llmserver import worker + from pilot.model.worker.manager import worker_manager - output = worker.generate_stream_gate(payload) - for rsp in output: - rsp = rsp.replace(b"\0", b"") - rsp_str = rsp.decode() - print("[TEST: output]:", rsp_str) + model_output = await worker_manager.generate(payload) ### output parse ai_response_text = ( self.prompt_template.output_parser.parse_model_nostream_resp( - rsp_str, self.prompt_template.sep + model_output, self.prompt_template.sep ) ) ### model result deal @@ -245,11 +218,31 @@ class BaseChat(ABC): self.memory.append(self.current_message) return self.current_ai_response() + def _blocking_stream_call(self): + logger.warn( + "_blocking_stream_call is only temporarily used in webserver and will be deleted soon, please use stream_call to replace it for higher performance" + ) + loop = get_or_create_event_loop() + async_gen = self.stream_call() + while True: + try: + value = loop.run_until_complete(async_gen.__anext__()) + yield value + except StopAsyncIteration: + break + + def _blocking_nostream_call(self): + logger.warn( + "_blocking_nostream_call is only temporarily used in webserver and will be deleted soon, please use nostream_call to replace it for higher performance" + ) + loop = get_or_create_event_loop() + return loop.run_until_complete(self.nostream_call()) + def call(self): if self.prompt_template.stream_out: - yield self.stream_call() + yield self._blocking_stream_call() else: - return self.nostream_call() + return self._blocking_nostream_call() def prepare(self): pass diff --git a/pilot/scripts/__init__.py b/pilot/scripts/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/scripts/cli_scripts.py b/pilot/scripts/cli_scripts.py new file mode 100644 index 000000000..537b0ed25 --- /dev/null +++ b/pilot/scripts/cli_scripts.py @@ -0,0 +1,45 @@ +import sys +import click +import os +import copy +import logging + +sys.path.append( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) +) + + +@click.group() +@click.option( + "--log-level", + required=False, + type=str, + default="warn", + help="Log level", +) +@click.version_option() +def cli(log_level: str): + # TODO not working now + logging.basicConfig(level=log_level, encoding="utf-8") + + +def add_command_alias(command, name: str, hidden: bool = False): + new_command = copy.deepcopy(command) + new_command.hidden = hidden + cli.add_command(new_command, name=name) + + +try: + from pilot.model.cli import model_cli_group + + add_command_alias(model_cli_group, name="model") +except ImportError as e: + logging.warning(f"Integrating dbgpt model command line tool failed: {e}") + + +def main(): + return cli() + + +if __name__ == "__main__": + main() diff --git a/pilot/server/dbgpt_server.py b/pilot/server/dbgpt_server.py index 6c0753fa2..86ca49330 100644 --- a/pilot/server/dbgpt_server.py +++ b/pilot/server/dbgpt_server.py @@ -10,13 +10,7 @@ ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__fi sys.path.append(ROOT_PATH) import signal from pilot.configs.config import Config - -# from pilot.configs.model_config import ( -# DATASETS_DIR, -# KNOWLEDGE_UPLOAD_ROOT_PATH, -# LLM_MODEL_CONFIG, -# LOGDIR, -# ) +from pilot.configs.model_config import LLM_MODEL_CONFIG from pilot.utils import build_logger from pilot.server.base import server_init @@ -33,8 +27,9 @@ from pilot.openapi.api_v1.api_v1 import router as api_v1 from pilot.openapi.base import validation_exception_handler from pilot.openapi.api_v1.editor.api_editor_v1 import router as api_editor_route_v1 from pilot.commands.disply_type.show_chart_gen import static_message_img_path +from pilot.model.worker.manager import initialize_worker_manager_in_client -logging.basicConfig(level=logging.INFO) +logging.basicConfig(level=logging.INFO, encoding="utf-8") static_file_path = os.path.join(os.getcwd(), "server/static") @@ -80,12 +75,19 @@ app.include_router(api_editor_route_v1, prefix="/api") app.include_router(knowledge_router) # app.include_router(api_editor_route_v1) -os.makedirs(static_message_img_path, exist_ok=True) -app.mount( - "/images", StaticFiles(directory=static_message_img_path, html=True), name="static2" -) -app.mount("/_next/static", StaticFiles(directory=static_file_path + "/_next/static")) -app.mount("/", StaticFiles(directory=static_file_path, html=True), name="static") + +def mount_static_files(app): + os.makedirs(static_message_img_path, exist_ok=True) + app.mount( + "/images", + StaticFiles(directory=static_message_img_path, html=True), + name="static2", + ) + app.mount( + "/_next/static", StaticFiles(directory=static_file_path + "/_next/static") + ) + app.mount("/", StaticFiles(directory=static_file_path, html=True), name="static") + app.add_exception_handler(RequestValidationError, validation_exception_handler) @@ -100,6 +102,7 @@ if __name__ == "__main__": parser.add_argument("--port", type=int, default=5000) parser.add_argument("--concurrency-count", type=int, default=10) parser.add_argument("--share", default=False, action="store_true") + parser.add_argument("--log-level", type=str, default="info") parser.add_argument( "-light", "--light", @@ -112,17 +115,28 @@ if __name__ == "__main__": args = parser.parse_args() server_init(args) + model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL] if not args.light: print("Model Unified Deployment Mode!") - from pilot.server.llmserver import worker + initialize_worker_manager_in_client( + app=app, model_name=CFG.LLM_MODEL, model_path=model_path + ) - worker.start_check() CFG.NEW_SERVER_MODE = True else: + # MODEL_SERVER is controller address now + initialize_worker_manager_in_client( + app=app, + model_name=CFG.LLM_MODEL, + model_path=model_path, + run_locally=False, + controller_addr=CFG.MODEL_SERVER, + ) CFG.SERVER_LIGHT_MODE = True + mount_static_files(app) import uvicorn - logging.basicConfig(level=logging.INFO) - uvicorn.run(app, host="0.0.0.0", port=args.port, log_level=0) + logging.basicConfig(level=logging.INFO, encoding="utf-8") + uvicorn.run(app, host="0.0.0.0", port=args.port, log_level=args.log_level) signal.signal(signal.SIGINT, signal_handler()) diff --git a/pilot/server/llmserver.py b/pilot/server/llmserver.py index 54e5d0694..7b7a2e90c 100644 --- a/pilot/server/llmserver.py +++ b/pilot/server/llmserver.py @@ -1,19 +1,8 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -import asyncio -import json import os import sys -from typing import List -import platform - -import uvicorn -from fastapi import BackgroundTasks, FastAPI, Request -from fastapi.responses import StreamingResponse - -# from fastapi.middleware.cors import CORSMiddleware -from pydantic import BaseModel global_counter = 0 model_semaphore = None @@ -22,197 +11,26 @@ ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__fi sys.path.append(ROOT_PATH) from pilot.configs.config import Config -from pilot.configs.model_config import * -from pilot.model.llm_out.vicuna_base_llm import get_embeddings -from pilot.model.loader import ModelLoader, _get_model_real_path -from pilot.server.chat_adapter import get_llm_chat_adapter -from pilot.scene.base_message import ModelMessage +from pilot.configs.model_config import LLM_MODEL_CONFIG +from pilot.model.worker.manager import run_worker_manager CFG = Config() - -class ModelWorker: - def __init__(self, model_path, model_name, device): - if model_path.endswith("/"): - model_path = model_path[:-1] - model_path = _get_model_real_path(model_name, model_path) - # self.model_name = model_name or model_path.split("/")[-1] - self.device = device - print( - f"Loading {model_name} LLM ModelServer in {device} from model path {model_path}! Please Wait......" - ) - self.ml: ModelLoader = ModelLoader(model_path=model_path, model_name=model_name) - self.model, self.tokenizer = self.ml.loader( - load_8bit=CFG.IS_LOAD_8BIT, - load_4bit=CFG.IS_LOAD_4BIT, - debug=ISDEBUG, - max_gpu_memory=CFG.MAX_GPU_MEMORY, - ) - - if not isinstance(self.model, str): - if hasattr(self.model, "config") and hasattr( - self.model.config, "max_sequence_length" - ): - self.context_len = self.model.config.max_sequence_length - elif hasattr(self.model, "config") and hasattr( - self.model.config, "max_position_embeddings" - ): - self.context_len = self.model.config.max_position_embeddings - - else: - self.context_len = 2048 - - self.llm_chat_adapter = get_llm_chat_adapter(model_name, model_path) - self.generate_stream_func = self.llm_chat_adapter.get_generate_stream_func( - model_path - ) - - def start_check(self): - print("LLM Model Loading Success!") - - def get_queue_length(self): - if ( - model_semaphore is None - or model_semaphore._value is None - or model_semaphore._waiters is None - ): - return 0 - else: - ( - CFG.LIMIT_MODEL_CONCURRENCY - - model_semaphore._value - + len(model_semaphore._waiters) - ) - - def generate_stream_gate(self, params): - try: - # params adaptation - params, model_context = self.llm_chat_adapter.model_adaptation( - params, self.ml.model_path, prompt_template=self.ml.prompt_template - ) - - for output in self.generate_stream_func( - self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS - ): - # Please do not open the output in production! - # The gpt4all thread shares stdout with the parent process, - # and opening it may affect the frontend output. - if "windows" in platform.platform().lower(): - # Do not print the model output, because it may contain Emoji, there is a problem with the GBK encoding - pass - else: - print("output: ", output) - # return some model context to dgt-server - ret = {"text": output, "error_code": 0, "model_context": model_context} - yield json.dumps(ret).encode() + b"\0" - - except torch.cuda.CudaError: - ret = {"text": "**GPU OutOfMemory, Please Refresh.**", "error_code": 0} - yield json.dumps(ret).encode() + b"\0" - except Exception as e: - ret = { - "text": f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}", - "error_code": 0, - } - yield json.dumps(ret).encode() + b"\0" - - def get_embeddings(self, prompt): - return get_embeddings(self.model, self.tokenizer, prompt) - - model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL] -worker = ModelWorker(model_path=model_path, model_name=CFG.LLM_MODEL, device=DEVICE) +# worker = ModelWorker(model_path=model_path, model_name=CFG.LLM_MODEL, device=DEVICE) -app = FastAPI() -# from pilot.openapi.knowledge.knowledge_controller import router -# -# app.include_router(router) -# -# origins = [ -# "http://localhost", -# "http://localhost:8000", -# "http://localhost:3000", -# ] -# -# app.add_middleware( -# CORSMiddleware, -# allow_origins=origins, -# allow_credentials=True, -# allow_methods=["*"], -# allow_headers=["*"], -# ) - - -class PromptRequest(BaseModel): - messages: List[ModelMessage] - prompt: str - temperature: float - max_new_tokens: int - model: str - stop: str = None - echo: bool = True - - -class StreamRequest(BaseModel): - model: str - prompt: str - temperature: float - max_new_tokens: int - stop: str - - -class EmbeddingRequest(BaseModel): - prompt: str - - -def release_model_semaphore(): - model_semaphore.release() - - -@app.post("/generate_stream") -async def api_generate_stream(request: Request): - global model_semaphore, global_counter - global_counter += 1 - params = await request.json() - - if model_semaphore is None: - model_semaphore = asyncio.Semaphore(CFG.LIMIT_MODEL_CONCURRENCY) - await model_semaphore.acquire() - - generator = worker.generate_stream_gate(params) - background_tasks = BackgroundTasks() - background_tasks.add_task(release_model_semaphore) - return StreamingResponse(generator, background=background_tasks) - - -@app.post("/generate") -def generate(prompt_request: PromptRequest) -> str: - params = { - "messages": prompt_request.messages, - "prompt": prompt_request.prompt, - "temperature": prompt_request.temperature, - "max_new_tokens": prompt_request.max_new_tokens, - "stop": prompt_request.stop, - "echo": prompt_request.echo, - } - - rsp_str = "" - output = worker.generate_stream_gate(params) - for rsp in output: - # rsp = rsp.decode("utf-8") - rsp = rsp.replace(b"\0", b"") - rsp_str = rsp.decode() - - return rsp_str - - -@app.post("/embedding") -def embeddings(prompt_request: EmbeddingRequest): - params = {"prompt": prompt_request.prompt} - print("Received prompt: ", params["prompt"]) - output = worker.get_embeddings(params["prompt"]) - return {"response": [float(x) for x in output]} +# @app.post("/embedding") +# def embeddings(prompt_request: EmbeddingRequest): +# params = {"prompt": prompt_request.prompt} +# print("Received prompt: ", params["prompt"]) +# output = worker.get_embeddings(params["prompt"]) +# return {"response": [float(x) for x in output]} if __name__ == "__main__": - uvicorn.run(app, host="0.0.0.0", port=CFG.MODEL_PORT, log_level="info") + run_worker_manager( + model_name=CFG.LLM_MODEL, + model_path=model_path, + standalone=True, + port=CFG.MODEL_PORT, + ) diff --git a/pilot/speech/brian.py b/pilot/speech/brian.py index 505c9a6f8..e0b59a87c 100644 --- a/pilot/speech/brian.py +++ b/pilot/speech/brian.py @@ -2,7 +2,6 @@ import logging import os import requests -from playsound import playsound from pilot.speech.base import VoiceBase @@ -23,6 +22,8 @@ class BrianSpeech(VoiceBase): Returns: bool: True if the request was successful, False otherwise """ + from playsound import playsound + tts_url = ( f"https://api.streamelements.com/kappa/v2/speech?voice=Brian&text={text}" ) diff --git a/pilot/speech/eleven_labs.py b/pilot/speech/eleven_labs.py index dad841517..671a3d729 100644 --- a/pilot/speech/eleven_labs.py +++ b/pilot/speech/eleven_labs.py @@ -2,7 +2,6 @@ import os import requests -from playsound import playsound from pilot.configs.config import Config from pilot.speech.base import VoiceBase @@ -70,6 +69,7 @@ class ElevenLabsSpeech(VoiceBase): bool: True if the request was successful, False otherwise """ from pilot.logs import logger + from playsound import playsound tts_url = ( f"https://api.elevenlabs.io/v1/text-to-speech/{self._voices[voice_index]}" diff --git a/pilot/speech/gtts.py b/pilot/speech/gtts.py index 7ad164f30..8fc7df19c 100644 --- a/pilot/speech/gtts.py +++ b/pilot/speech/gtts.py @@ -2,7 +2,6 @@ import os import gtts -from playsound import playsound from pilot.speech.base import VoiceBase @@ -15,6 +14,8 @@ class GTTSVoice(VoiceBase): def _speech(self, text: str, _: int = 0) -> bool: """Play the given text.""" + from playsound import playsound + tts = gtts.gTTS(text) tts.save("speech.mp3") playsound("speech.mp3", True) diff --git a/pilot/summary/db_summary_client.py b/pilot/summary/db_summary_client.py index 99b992698..a0b31b6a8 100644 --- a/pilot/summary/db_summary_client.py +++ b/pilot/summary/db_summary_client.py @@ -184,5 +184,5 @@ def _get_llm_response(query, db_input, dbsummary): chat: BaseChat = chat_factory.get_implementation( ChatScene.InnerChatDBSummary.value, **chat_param ) - res = chat.nostream_call() + res = chat._blocking_nostream_call() return json.loads(res)["table"] diff --git a/pilot/utils/__init__.py b/pilot/utils/__init__.py new file mode 100644 index 000000000..8a84bc0ec --- /dev/null +++ b/pilot/utils/__init__.py @@ -0,0 +1,9 @@ +from .utils import ( + get_gpu_memory, + build_logger, + StreamToLogger, + disable_torch_init, + pretty_print_semaphore, + server_error_msg, + get_or_create_event_loop, +) diff --git a/pilot/utils/api_utils.py b/pilot/utils/api_utils.py new file mode 100644 index 000000000..ea2bef6ef --- /dev/null +++ b/pilot/utils/api_utils.py @@ -0,0 +1,101 @@ +import httpx +from inspect import signature +import typing_inspect +import logging +from typing import get_type_hints, List, Type, TypeVar, Union, Optional, Tuple +from dataclasses import is_dataclass, asdict + +T = TypeVar("T") + + +def _extract_dataclass_from_generic(type_hint: Type[T]) -> Union[Type[T], None]: + """Extract actual dataclass from generic type hints like List[dataclass], Optional[dataclass], etc.""" + if typing_inspect.is_generic_type(type_hint) and typing_inspect.get_args(type_hint): + return typing_inspect.get_args(type_hint)[0] + return None + + +def _api_remote(path, method="GET"): + def decorator(func): + return_type = get_type_hints(func).get("return") + if return_type is None: + raise TypeError("Return type must be annotated in the decorated function.") + + actual_dataclass = _extract_dataclass_from_generic(return_type) + logging.debug( + f"return_type: {return_type}, actual_dataclass: {actual_dataclass}" + ) + if not actual_dataclass: + actual_dataclass = return_type + sig = signature(func) + + async def wrapper(self, *args, **kwargs): + base_url = self.base_url # Get base_url from class instance + + bound = sig.bind(self, *args, **kwargs) + bound.apply_defaults() + + formatted_url = base_url + path.format(**bound.arguments) + + # Extract args names from signature, except "self" + arg_names = list(sig.parameters.keys())[1:] + + # Combine args and kwargs into a single dictionary + combined_args = dict(zip(arg_names, args)) + combined_args.update(kwargs) + + request_data = {} + for key, value in combined_args.items(): + if is_dataclass(value): + # Here, instead of adding it as a nested dictionary, + # we set request_data directly to its dictionary representation. + request_data = asdict(value) + else: + request_data[key] = value + + request_params = {"method": method, "url": formatted_url} + + if method in ["POST", "PUT", "PATCH"]: + request_params["json"] = request_data + else: # For GET, DELETE, etc. + request_params["params"] = request_data + + logging.info( + f"request_params: {request_params}, args: {args}, kwargs: {kwargs}" + ) + + async with httpx.AsyncClient() as client: + response = await client.request(**request_params) + + if response.status_code == 200: + return _parse_response( + response.json(), return_type, actual_dataclass + ) + else: + error_msg = f"Remote request error, error code: {response.status_code}, error msg: {response.text}" + raise Exception(error_msg) + + return wrapper + + return decorator + + +def _parse_response(json_response, return_type, actual_dataclass): + # print(f'return_type.__origin__: {return_type.__origin__}, actual_dataclass: {actual_dataclass}, json_response: {json_response}') + if is_dataclass(actual_dataclass): + if return_type.__origin__ is list: # for List[dataclass] + if isinstance(json_response, list): + return [actual_dataclass(**item) for item in json_response] + else: + raise TypeError( + f"Expected list in response but got {type(json_response)}" + ) + else: + if isinstance(json_response, dict): + return actual_dataclass(**json_response) + else: + raise TypeError( + f"Expected dictionary in response but got {type(json_response)}" + ) + else: + return json_response diff --git a/pilot/utils/model_utils.py b/pilot/utils/model_utils.py new file mode 100644 index 000000000..a7a51ad32 --- /dev/null +++ b/pilot/utils/model_utils.py @@ -0,0 +1,27 @@ +import logging + + +def _clear_torch_cache(device="cuda"): + import gc + + import torch + + gc.collect() + if device != "cpu": + if torch.has_mps: + try: + from torch.mps import empty_cache + + empty_cache() + except Exception as e: + logging.warn(f"Clear mps torch cache error, {str(e)}") + elif torch.has_cuda: + device_count = torch.cuda.device_count() + for device_id in range(device_count): + cuda_device = f"cuda:{device_id}" + logging.info(f"Clear torch cache of device: {cuda_device}") + with torch.cuda.device(cuda_device): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + else: + logging.info("No cuda or mps, not support clear torch cache yet") diff --git a/pilot/utils/module_utils.py b/pilot/utils/module_utils.py new file mode 100644 index 000000000..cbc1db149 --- /dev/null +++ b/pilot/utils/module_utils.py @@ -0,0 +1,26 @@ +from typing import Type +from importlib import import_module + + +def import_from_string(module_path: str): + try: + module_path, class_name = module_path.rsplit(".", 1) + except ValueError: + raise ImportError(f"{module_path} doesn't look like a module path") + module = import_module(module_path) + + try: + return getattr(module, class_name) + except AttributeError: + raise ImportError( + f'Module "{module_path}" does not define a "{class_name}" attribute/class' + ) + + +def import_from_checked_string(module_path: str, supper_cls: Type): + cls = import_from_string(module_path) + if not issubclass(cls, supper_cls): + raise ImportError( + f'Module "{module_path}" does not the subclass of {str(supper_cls)}' + ) + return cls diff --git a/pilot/utils/net_utils.py b/pilot/utils/net_utils.py new file mode 100644 index 000000000..8fc803e6f --- /dev/null +++ b/pilot/utils/net_utils.py @@ -0,0 +1,24 @@ +import socket +import errno + + +def _get_ip_address(address: str = "10.254.254.254:1") -> str: + ip, port = address.split(":") + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.settimeout(0) + curr_address = "127.0.0.1" + try: + # doesn't even have to be reachable + s.connect((ip, int(port))) + curr_address = s.getsockname()[0] + except OSError as e: + IP = "127.0.0.1" + if e.errno == errno.ENETUNREACH: + try: + hostname = socket.getfqdn(socket.gethostname()) + curr_address = socket.gethostbyname(hostname) + except Exception: + pass + finally: + s.close() + return curr_address diff --git a/pilot/utils.py b/pilot/utils/utils.py similarity index 89% rename from pilot/utils.py rename to pilot/utils/utils.py index d05387056..ca7cf9d3c 100644 --- a/pilot/utils.py +++ b/pilot/utils/utils.py @@ -5,9 +5,7 @@ import logging import logging.handlers import os import sys - -import requests -import torch +import asyncio from pilot.configs.model_config import LOGDIR @@ -19,6 +17,8 @@ handler = None def get_gpu_memory(max_gpus=None): + import torch + gpu_memory = [] num_gpus = ( torch.cuda.device_count() @@ -73,7 +73,7 @@ def build_logger(logger_name, logger_filename): for name, item in logging.root.manager.loggerDict.items(): if isinstance(item, logging.Logger): item.addHandler(handler) - logging.basicConfig(level=logging.INFO) + logging.basicConfig(level=logging.INFO, encoding="utf-8") # Get logger logger = logging.getLogger(logger_name) @@ -132,3 +132,15 @@ def pretty_print_semaphore(semaphore): if semaphore is None: return "None" return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" + + +def get_or_create_event_loop() -> asyncio.BaseEventLoop: + try: + loop = asyncio.get_event_loop() + except Exception as e: + if not "no running event loop" in str(e): + raise e + logging.warning("Cant not get running event loop, create new event loop now") + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + return loop diff --git a/requirements.txt b/requirements.txt index e4dd4def2..daec0dc85 100644 --- a/requirements.txt +++ b/requirements.txt @@ -76,4 +76,7 @@ bardapi==0.1.29 # TODO moved to optional dependencies pymysql duckdb -duckdb-engine \ No newline at end of file +duckdb-engine + +# cli +prettytable \ No newline at end of file diff --git a/setup.py b/setup.py index a9ee31213..d717b0bc6 100644 --- a/setup.py +++ b/setup.py @@ -267,7 +267,7 @@ def llama_cpp_python_cuda_requires(): llama_cpp_version = "0.1.77" py_version = "cp310" os_pkg_name = "linux_x86_64" if os_type == OSType.LINUX else "win_amd64" - extra_index_url = f"{base_url}/llama_cpp_python_cuda-{llama_cpp_version}+{device}{cpu_avx}-{py_version}-{py_version}-{os_pkg_name}.whl" + extra_index_url = f"{base_url}/llama_cpp_python_cuda-{llama_cpp_version}+{device}-{py_version}-{py_version}-{os_pkg_name}.whl" extra_index_url, _ = encode_url(extra_index_url) print(f"Install llama_cpp_python_cuda from {extra_index_url}") @@ -361,7 +361,7 @@ setuptools.setup( extras_require=setup_spec.extras, entry_points={ "console_scripts": [ - "dbgpt_server=pilot.server:webserver", + "dbgpt=pilot.scripts.cli_scripts:main", ], }, ) diff --git a/tools/cli/cli_scripts.py b/tools/cli/cli_scripts.py index 545acb4a5..2176cc272 100644 --- a/tools/cli/cli_scripts.py +++ b/tools/cli/cli_scripts.py @@ -25,7 +25,6 @@ sys.path.append( from pilot.configs.model_config import DATASETS_DIR -from tools.cli.knowledge_client import knowledge_init API_ADDRESS: str = "http://127.0.0.1:5000" @@ -97,6 +96,8 @@ def knowledge( verbose: bool, ): """Knowledge command line tool""" + from tools.cli.knowledge_client import knowledge_init + knowledge_init( API_ADDRESS, vector_name,