feat: Support multi model (#501)

- **ModelController** : Model register and manager.
- **ModelRegistry**: Abstract base class for a model registry. It
provides an interface for registering, deregistering, fetching
instances, and sending heartbeats for instances.
- **ModelWorker**: Abstract representation of a Model Worker responsible
for model interaction, startup, and shutdown. Supports 'llm' and
'text2vec' models.
- **WorkerManager**: Manager deployed worker instance in current server.
The `WorkerManager` also is the handle to invoke model service.
- **Model command line tools**: List, stop, start, restart model
instances.
-  Modify `BaseChat`: Asynchronous chat messages for higher performance.
This commit is contained in:
Aries-ckt 2023-08-30 13:15:27 +08:00 committed by GitHub
commit 0983456311
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
38 changed files with 2428 additions and 333 deletions

View File

@ -12,19 +12,13 @@ from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
LlamaTokenizer,
BitsAndBytesConfig,
)
from pilot.model.parameter import ModelParameters, LlamaCppModelParameters
from pilot.configs.model_config import DEVICE
from pilot.configs.config import Config
from pilot.logs import logger
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype="bfloat16",
bnb_4bit_use_double_quant=False,
)
CFG = Config()
@ -203,6 +197,14 @@ class FalconAdapater(BaseLLMAdaper):
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
if CFG.QLoRA:
from transformers import BitsAndBytesConfig
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype="bfloat16",
bnb_4bit_use_double_quant=False,
)
model = AutoModelForCausalLM.from_pretrained(
model_path,
load_in_4bit=True, # quantize

View File

@ -1,7 +1,10 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import TypedDict
from enum import Enum
from typing import TypedDict, Optional, Dict
from dataclasses import dataclass
from datetime import datetime
class Message(TypedDict):
@ -9,3 +12,39 @@ class Message(TypedDict):
role: str
content: str
@dataclass
class ModelInstance:
"""Model instance info"""
model_name: str
host: str
port: int
weight: Optional[float] = 1.0
check_healthy: Optional[bool] = True
healthy: Optional[bool] = False
enabled: Optional[bool] = True
prompt_template: Optional[str] = None
last_heartbeat: Optional[datetime] = None
class WorkerApplyType(str, Enum):
START = "start"
STOP = "stop"
RESTART = "restart"
UPDATE_PARAMS = "update_params"
@dataclass
class ModelOutput:
text: str
error_code: int
model_context: Dict = None
@dataclass
class WorkerApplyOutput:
message: str
# The seconds cost to apply some action to worker instances
timecost: Optional[int] = -1

151
pilot/model/cli.py Normal file
View File

@ -0,0 +1,151 @@
import click
import functools
from pilot.model.controller.registry import ModelRegistryClient
from pilot.model.worker.manager import (
RemoteWorkerManager,
WorkerApplyRequest,
WorkerApplyType,
)
from pilot.utils import get_or_create_event_loop
@click.group("model")
def model_cli_group():
pass
@model_cli_group.command()
@click.option(
"--address",
type=str,
default="http://127.0.0.1:8000",
required=False,
help=(
"Address of the Model Controller to connect to."
"Just support light deploy model"
),
)
@click.option(
"--model-name", type=str, default=None, required=False, help=("The name of model")
)
@click.option(
"--model-type", type=str, default="llm", required=False, help=("The type of model")
)
def list(address: str, model_name: str, model_type: str):
"""List model instances"""
from prettytable import PrettyTable
loop = get_or_create_event_loop()
registry = ModelRegistryClient(address)
if not model_name:
instances = loop.run_until_complete(registry.get_all_model_instances())
else:
if not model_type:
model_type = "llm"
register_model_name = f"{model_name}@{model_type}"
instances = loop.run_until_complete(
registry.get_all_instances(register_model_name)
)
table = PrettyTable()
table.field_names = [
"Model Name",
"Model Type",
"Host",
"Port",
"Healthy",
"Enabled",
"Prompt Template",
"Last Heartbeat",
]
for instance in instances:
model_name, model_type = instance.model_name.split("@")
table.add_row(
[
model_name,
model_type,
instance.host,
instance.port,
instance.healthy,
instance.enabled,
instance.prompt_template,
instance.last_heartbeat,
]
)
print(table)
def add_model_options(func):
@click.option(
"--address",
type=str,
default="http://127.0.0.1:8000",
required=False,
help=(
"Address of the Model Controller to connect to."
"Just support light deploy model"
),
)
@click.option(
"--model-name",
type=str,
default=None,
required=True,
help=("The name of model"),
)
@click.option(
"--model-type",
type=str,
default="llm",
required=False,
help=("The type of model"),
)
@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
@model_cli_group.command()
@add_model_options
def stop(address: str, model_name: str, model_type: str):
"""Stop model instances"""
worker_apply(address, model_name, model_type, WorkerApplyType.STOP)
@model_cli_group.command()
@add_model_options
def start(address: str, model_name: str, model_type: str):
"""Start model instances"""
worker_apply(address, model_name, model_type, WorkerApplyType.START)
@model_cli_group.command()
@add_model_options
def restart(address: str, model_name: str, model_type: str):
"""Restart model instances"""
worker_apply(address, model_name, model_type, WorkerApplyType.RESTART)
# @model_cli_group.command()
# @add_model_options
# def modify(address: str, model_name: str, model_type: str):
# """Restart model instances"""
# worker_apply(address, model_name, model_type, WorkerApplyType.UPDATE_PARAMS)
def worker_apply(
address: str, model_name: str, model_type: str, apply_type: WorkerApplyType
):
loop = get_or_create_event_loop()
registry = ModelRegistryClient(address)
worker_manager = RemoteWorkerManager(registry)
apply_req = WorkerApplyRequest(
model=model_name, worker_type=model_type, apply_type=apply_type
)
res = loop.run_until_complete(worker_manager.worker_apply(apply_req))
print(res)

View File

View File

@ -0,0 +1,65 @@
import logging
from typing import List
from fastapi import APIRouter
from pilot.model.base import ModelInstance
from pilot.model.controller.registry import EmbeddedModelRegistry, ModelRegistry
class ModelController:
def __init__(self, registry: ModelRegistry = None) -> None:
if not registry:
registry = EmbeddedModelRegistry()
self.registry = registry
self.deployment = None
async def register_instance(self, instance: ModelInstance) -> bool:
return await self.registry.register_instance(instance)
async def deregister_instance(self, instance: ModelInstance) -> bool:
return await self.registry.deregister_instance(instance)
async def get_all_instances(
self, model_name: str, healthy_only: bool = False
) -> List[ModelInstance]:
logging.info(
f"Get all instances with {model_name}, healthy_only: {healthy_only}"
)
return await self.registry.get_all_instances(model_name, healthy_only)
async def get_all_model_instances(self) -> List[ModelInstance]:
return await self.registry.get_all_model_instances()
async def send_heartbeat(self, instance: ModelInstance) -> bool:
return await self.registry.send_heartbeat(instance)
async def model_apply(self) -> bool:
# TODO
raise NotImplementedError
router = APIRouter()
controller = ModelController()
@router.post("/controller/models")
async def api_register_instance(request: ModelInstance):
return await controller.register_instance(request)
@router.delete("/controller/models")
async def api_deregister_instance(request: ModelInstance):
return await controller.deregister_instance(request)
@router.get("/controller/models")
async def api_get_all_instances(model_name: str = None, healthy_only: bool = False):
if not model_name:
return await controller.get_all_model_instances()
return await controller.get_all_instances(model_name, healthy_only=healthy_only)
@router.post("/controller/heartbeat")
async def api_model_heartbeat(request: ModelInstance):
return await controller.send_heartbeat(request)

View File

View File

@ -0,0 +1,224 @@
import random
import threading
import time
from abc import ABC, abstractmethod
from collections import defaultdict
from datetime import datetime, timedelta
from typing import Dict, List, Tuple
import itertools
from pilot.model.base import ModelInstance
class ModelRegistry(ABC):
"""
Abstract base class for a model registry. It provides an interface
for registering, deregistering, fetching instances, and sending heartbeats
for instances.
"""
@abstractmethod
async def register_instance(self, instance: ModelInstance) -> bool:
"""
Register a given model instance.
Args:
- instance (ModelInstance): The instance of the model to register.
Returns:
- bool: True if registration is successful, False otherwise.
"""
pass
@abstractmethod
async def deregister_instance(self, instance: ModelInstance) -> bool:
"""
Deregister a given model instance.
Args:
- instance (ModelInstance): The instance of the model to deregister.
Returns:
- bool: True if deregistration is successful, False otherwise.
"""
@abstractmethod
async def get_all_instances(
self, model_name: str, healthy_only: bool = False
) -> List[ModelInstance]:
"""
Fetch all instances of a given model. Optionally, fetch only the healthy instances.
Args:
- model_name (str): Name of the model to fetch instances for.
- healthy_only (bool, optional): If set to True, fetches only the healthy instances.
Defaults to False.
Returns:
- List[ModelInstance]: A list of instances for the given model.
"""
@abstractmethod
async def get_all_model_instances(self) -> List[ModelInstance]:
"""
Fetch all instances of all models
Returns:
- List[ModelInstance]: A list of instances for the all models.
"""
async def select_one_health_instance(self, model_name: str) -> ModelInstance:
"""
Selects one healthy and enabled instance for a given model.
Args:
- model_name (str): Name of the model.
Returns:
- ModelInstance: One randomly selected healthy and enabled instance, or None if no such instance exists.
"""
instances = await self.get_all_instances(model_name, healthy_only=True)
instances = [i for i in instances if i.enabled]
if not instances:
return None
return random.choice(instances)
@abstractmethod
async def send_heartbeat(self, instance: ModelInstance) -> bool:
"""
Send a heartbeat for a given model instance. This can be used to
verify if the instance is still alive and functioning.
Args:
- instance (ModelInstance): The instance of the model to send a heartbeat for.
Returns:
- bool: True if heartbeat is successful, False otherwise.
"""
class EmbeddedModelRegistry(ModelRegistry):
def __init__(
self, heartbeat_interval_secs: int = 60, heartbeat_timeout_secs: int = 120
):
self.registry: Dict[str, List[ModelInstance]] = defaultdict(list)
self.heartbeat_interval_secs = heartbeat_interval_secs
self.heartbeat_timeout_secs = heartbeat_timeout_secs
self.heartbeat_thread = threading.Thread(target=self._heartbeat_checker)
self.heartbeat_thread.daemon = True
self.heartbeat_thread.start()
def _get_instances(
self, model_name: str, host: str, port: int, healthy_only: bool = False
) -> Tuple[List[ModelInstance], List[ModelInstance]]:
instances = self.registry[model_name]
if healthy_only:
instances = [ins for ins in instances if ins.healthy == True]
exist_ins = [ins for ins in instances if ins.host == host and ins.port == port]
return instances, exist_ins
def _heartbeat_checker(self):
while True:
for instances in self.registry.values():
for instance in instances:
if (
instance.check_healthy
and datetime.now() - instance.last_heartbeat
> timedelta(seconds=self.heartbeat_timeout_secs)
):
instance.healthy = False
time.sleep(self.heartbeat_interval_secs)
async def register_instance(self, instance: ModelInstance) -> bool:
model_name = instance.model_name.strip()
host = instance.host.strip()
port = instance.port
instances, exist_ins = self._get_instances(
model_name, host, port, healthy_only=False
)
if exist_ins:
# One exist instance at most
ins = exist_ins[0]
# Update instance
ins.weight = instance.weight
ins.healthy = True
ins.prompt_template = instance.prompt_template
ins.last_heartbeat = datetime.now()
else:
instance.healthy = True
instance.last_heartbeat = datetime.now()
instances.append(instance)
return True
async def deregister_instance(self, instance: ModelInstance) -> bool:
model_name = instance.model_name.strip()
host = instance.host.strip()
port = instance.port
_, exist_ins = self._get_instances(model_name, host, port, healthy_only=False)
if exist_ins:
ins = exist_ins[0]
ins.healthy = False
return True
async def get_all_instances(
self, model_name: str, healthy_only: bool = False
) -> List[ModelInstance]:
instances = self.registry[model_name]
if healthy_only:
instances = [ins for ins in instances if ins.healthy == True]
return instances
async def get_all_model_instances(self) -> List[ModelInstance]:
print(self.registry)
return list(itertools.chain(*self.registry.values()))
async def send_heartbeat(self, instance: ModelInstance) -> bool:
_, exist_ins = self._get_instances(
instance.model_name, instance.host, instance.port, healthy_only=False
)
if not exist_ins:
return False
ins = exist_ins[0]
ins.last_heartbeat = datetime.now()
ins.healthy = True
return True
from pilot.utils.api_utils import _api_remote as api_remote
class ModelRegistryClient(ModelRegistry):
def __init__(self, base_url: str) -> None:
self.base_url = base_url
@api_remote(path="/api/controller/models", method="POST")
async def register_instance(self, instance: ModelInstance) -> bool:
pass
@api_remote(path="/api/controller/models", method="DELETE")
async def deregister_instance(self, instance: ModelInstance) -> bool:
pass
@api_remote(path="/api/controller/models")
async def get_all_instances(
self, model_name: str, healthy_only: bool = False
) -> List[ModelInstance]:
pass
@api_remote(path="/api/controller/models")
async def get_all_model_instances(self) -> List[ModelInstance]:
pass
@api_remote(path="/api/controller/models")
async def select_one_health_instance(self, model_name: str) -> ModelInstance:
instances = await self.get_all_instances(model_name, healthy_only=True)
instances = [i for i in instances if i.enabled]
if not instances:
return None
return random.choice(instances)
@api_remote(path="/api/controller/heartbeat", method="POST")
async def send_heartbeat(self, instance: ModelInstance) -> bool:
pass

View File

View File

@ -0,0 +1,148 @@
import pytest
from datetime import datetime, timedelta
import asyncio
from unittest.mock import patch
from pilot.model.base import ModelInstance
from pilot.model.controller.registry import ModelRegistry, EmbeddedModelRegistry
@pytest.fixture
def model_registry():
return EmbeddedModelRegistry()
@pytest.fixture
def model_instance():
return ModelInstance(
model_name="test_model",
ip="192.168.1.1",
port=5000,
)
# Async function to test the registry
@pytest.mark.asyncio
async def test_register_instance(model_registry, model_instance):
"""
Test if an instance can be registered correctly
"""
assert await model_registry.register_instance(model_instance) == True
assert len(model_registry.registry[model_instance.model_name]) == 1
@pytest.mark.asyncio
async def test_deregister_instance(model_registry, model_instance):
"""
Test if an instance can be deregistered correctly
"""
await model_registry.register_instance(model_instance)
assert await model_registry.deregister_instance(model_instance) == True
assert not model_registry.registry[model_instance.model_name][0].healthy
@pytest.mark.asyncio
async def test_get_all_instances(model_registry, model_instance):
"""
Test if all instances can be retrieved, with and without the healthy_only filter
"""
await model_registry.register_instance(model_instance)
assert len(await model_registry.get_all_instances(model_instance.model_name)) == 1
assert (
len(
await model_registry.get_all_instances(
model_instance.model_name, healthy_only=True
)
)
== 1
)
model_instance.healthy = False
assert (
len(
await model_registry.get_all_instances(
model_instance.model_name, healthy_only=True
)
)
== 0
)
@pytest.mark.asyncio
async def test_select_one_health_instance(model_registry, model_instance):
"""
Test if a single healthy instance can be selected
"""
await model_registry.register_instance(model_instance)
selected_instance = await model_registry.select_one_health_instance(
model_instance.model_name
)
assert selected_instance is not None
assert selected_instance.healthy
assert selected_instance.enabled
@pytest.mark.asyncio
async def test_send_heartbeat(model_registry, model_instance):
"""
Test if a heartbeat can be sent and that it correctly updates the last_heartbeat timestamp
"""
await model_registry.register_instance(model_instance)
last_heartbeat = datetime.now() - timedelta(seconds=10)
model_instance.last_heartbeat = last_heartbeat
assert (
await model_registry.send_heartbeat(
model_instance.model_name, model_instance.ip, model_instance.port
)
== True
)
assert (
model_registry.registry[model_instance.model_name][0].last_heartbeat
> last_heartbeat
)
assert model_registry.registry[model_instance.model_name][0].healthy == True
@pytest.mark.asyncio
async def test_heartbeat_timeout(model_registry, model_instance):
"""
Test if an instance is marked as unhealthy when the heartbeat is not sent within the timeout
"""
model_registry = EmbeddedModelRegistry(1, 1)
await model_registry.register_instance(model_instance)
model_registry.registry[model_instance.model_name][
0
].last_heartbeat = datetime.now() - timedelta(
seconds=model_registry.heartbeat_timeout_secs + 1
)
await asyncio.sleep(model_registry.heartbeat_interval_secs + 1)
assert not model_registry.registry[model_instance.model_name][0].healthy
@pytest.mark.asyncio
async def test_multiple_instances(model_registry, model_instance):
"""
Test if multiple instances of the same model are handled correctly
"""
model_instance2 = ModelInstance(
model_name="test_model",
ip="192.168.1.2",
port=5000,
)
await model_registry.register_instance(model_instance)
await model_registry.register_instance(model_instance2)
assert len(await model_registry.get_all_instances(model_instance.model_name)) == 2
@pytest.mark.asyncio
async def test_same_model_name_different_ip_port(model_registry):
"""
Test if instances with the same model name but different IP and port are handled correctly
"""
instance1 = ModelInstance(model_name="test_model", ip="192.168.1.1", port=5000)
instance2 = ModelInstance(model_name="test_model", ip="192.168.1.2", port=6000)
await model_registry.register_instance(instance1)
await model_registry.register_instance(instance2)
instances = await model_registry.get_all_instances("test_model")
assert len(instances) == 2
assert instances[0].ip != instances[1].ip
assert instances[0].port != instances[1].port

View File

@ -60,7 +60,7 @@ def _get_model_real_path(model_name, default_model_path) -> str:
return _genenv_ignoring_key_case("model_path", default_value=default_model_path)
class ModelLoader(metaclass=Singleton):
class ModelLoader:
"""Model loader is a class for model load
Args: model_path
@ -68,17 +68,11 @@ class ModelLoader(metaclass=Singleton):
TODO: multi model support.
"""
kwargs = {}
def __init__(self, model_path: str, model_name: str = None) -> None:
self.device = DEVICE
self.model_path = model_path
self.model_name = model_name
self.prompt_template: str = None
self.kwargs = {
"torch_dtype": torch.float16,
"device_map": "auto",
}
# TODO multi gpu support
def loader(
@ -121,6 +115,18 @@ class ModelLoader(metaclass=Singleton):
else:
raise Exception(f"Unkown model type {model_type}")
def loader_with_params(self, model_params: ModelParameters):
llm_adapter = get_llm_model_adapter(self.model_name, self.model_path)
model_type = llm_adapter.model_type()
self.prompt_template = model_params.prompt_template
logger.info(f"model_params:\n{model_params}")
if model_type == ModelType.HF:
return huggingface_loader(llm_adapter, model_params)
elif model_type == ModelType.LLAMA_CPP:
return llamacpp_loader(llm_adapter, model_params)
else:
raise Exception(f"Unkown model type {model_type}")
def huggingface_loader(llm_adapter: BaseLLMAdaper, model_params: ModelParameters):
device = model_params.device

View File

@ -1,9 +1,10 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import argparse
import os
from typing import Any, Optional, Type
from dataclasses import dataclass, field, fields
from enum import Enum
from typing import Any, Dict, List, Optional, Type, Union
from pilot.model.conversation import conv_templates
@ -20,11 +21,28 @@ def _genenv_ignoring_key_case(env_key: str, env_prefix: str = None, default_valu
class EnvArgumentParser:
@staticmethod
def get_env_prefix(env_key: str) -> str:
if not env_key:
return None
env_key = env_key.replace("-", "_")
return env_key + "_"
def parse_args_into_dataclass(
self, dataclass_type: Type, env_prefix: str = None, **kwargs
self,
dataclass_type: Type,
env_prefix: str = None,
command_args: List[str] = None,
**kwargs,
) -> Any:
"""Parse parameters from environment variables and command lines and populate them into data class"""
parser = argparse.ArgumentParser()
for field in fields(dataclass_type):
env_var_value = _genenv_ignoring_key_case(field.name, env_prefix)
if not env_var_value:
# Read without env prefix
env_var_value = _genenv_ignoring_key_case(field.name)
if env_var_value:
env_var_value = env_var_value.strip()
if field.type is int or field.type == Optional[int]:
@ -37,17 +55,241 @@ class EnvArgumentParser:
pass
else:
raise ValueError(f"Unsupported parameter type {field.type}")
kwargs[field.name] = env_var_value
if not env_var_value:
env_var_value = kwargs.get(field.name)
if not env_var_value:
env_var_value = field.default
# Add a command-line argument for this field
help_text = field.metadata.get("help", "")
valid_values = field.metadata.get("valid_values", None)
parser.add_argument(
f"--{field.name}",
type=self._get_argparse_type(field.type),
help=help_text,
choices=valid_values,
default=env_var_value,
)
# Parse the command-line arguments
cmd_args, cmd_argv = parser.parse_known_args(args=command_args)
print(f"cmd_args: {cmd_args}")
for field in fields(dataclass_type):
# cmd_line_value = getattr(cmd_args, field.name)
if field.name in cmd_args:
cmd_line_value = getattr(cmd_args, field.name)
if cmd_line_value is not None:
kwargs[field.name] = cmd_line_value
return dataclass_type(**kwargs)
@staticmethod
def _get_argparse_type(field_type: Type) -> Type:
# Return the appropriate type for argparse to use based on the field type
if field_type is int or field_type == Optional[int]:
return int
elif field_type is float or field_type == Optional[float]:
return float
elif field_type is bool or field_type == Optional[bool]:
return bool
elif field_type is str or field_type == Optional[str]:
return str
else:
raise ValueError(f"Unsupported parameter type {field_type}")
@staticmethod
def _get_argparse_type_str(field_type: Type) -> str:
argparse_type = EnvArgumentParser._get_argparse_type(field_type)
if argparse_type is int:
return "int"
elif argparse_type is float:
return "float"
elif argparse_type is bool:
return "bool"
else:
return "str"
@dataclass
class ModelParameters:
device: str = field(metadata={"help": "Device to run model"})
model_name: str = field(metadata={"help": "Model name"})
model_path: str = field(metadata={"help": "Model path"})
class ParameterDescription:
param_name: str
param_type: str
description: str
default_value: Optional[Any]
valid_values: Optional[List[Any]]
def _get_parameter_descriptions(dataclass_type: Type) -> List[ParameterDescription]:
descriptions = []
for field in fields(dataclass_type):
descriptions.append(
ParameterDescription(
param_name=field.name,
param_type=EnvArgumentParser._get_argparse_type_str(field.type),
description=field.metadata.get("help", None),
default_value=field.default, # TODO handle dataclasses._MISSING_TYPE
valid_values=field.metadata.get("valid_values", None),
)
)
return descriptions
class WorkerType(str, Enum):
LLM = "llm"
TEXT2VEC = "text2vec"
@staticmethod
def values():
return [item.value for item in WorkerType]
@dataclass
class BaseParameters:
def update_from(self, source: Union["BaseParameters", dict]) -> bool:
"""
Update the attributes of this object using the values from another object (of the same or parent type) or a dictionary.
Only update if the new value is different from the current value and the field is not marked as "fixed" in metadata.
Args:
source (Union[BaseParameters, dict]): The source to update from. Can be another object of the same type or a dictionary.
Returns:
bool: True if at least one field was updated, otherwise False.
"""
updated = False # Flag to indicate whether any field was updated
if isinstance(source, (BaseParameters, dict)):
for field_info in fields(self):
# Check if the field has a "fixed" tag in metadata
tags = field_info.metadata.get("tags")
tags = [] if not tags else tags.split(",")
if tags and "fixed" in tags:
continue # skip this field
# Get the new value from source (either another BaseParameters object or a dict)
new_value = (
getattr(source, field_info.name)
if isinstance(source, BaseParameters)
else source.get(field_info.name, None)
)
# If the new value is not None and different from the current value, update the field and set the flag
if new_value is not None and new_value != getattr(
self, field_info.name
):
setattr(self, field_info.name, new_value)
updated = True
else:
raise ValueError(
"Source must be an instance of BaseParameters (or its derived class) or a dictionary."
)
return updated
def __str__(self) -> str:
class_name = self.__class__.__name__
parameters = [
f"\n\n=========================== {class_name} ===========================\n"
]
for field_info in fields(self):
value = getattr(self, field_info.name)
parameters.append(f"{field_info.name}: {value}")
parameters.append(
"\n======================================================================\n\n"
)
return "\n".join(parameters)
@dataclass
class ModelWorkerParameters(BaseParameters):
model_name: str = field(metadata={"help": "Model name", "tags": "fixed"})
model_path: str = field(metadata={"help": "Model path", "tags": "fixed"})
worker_type: Optional[str] = field(
default=None,
metadata={"valid_values": WorkerType.values(), "help": "Worker type"},
)
worker_class: Optional[str] = field(
default=None,
metadata={
"help": "Model worker deploy host, pilot.model.worker.default_worker.DefaultModelWorker"
},
)
host: Optional[str] = field(
default="0.0.0.0", metadata={"help": "Model worker deploy host"}
)
port: Optional[int] = field(
default=8000, metadata={"help": "Model worker deploy port"}
)
limit_model_concurrency: Optional[int] = field(
default=5, metadata={"help": "Model concurrency limit"}
)
standalone: Optional[bool] = field(
default=False,
metadata={"help": "Standalone mode. If True, embedded Run ModelController"},
)
register: Optional[bool] = field(
default=True, metadata={"help": "Register current worker to model controller"}
)
worker_register_host: Optional[str] = field(
default=None,
metadata={
"help": "The ip address of current worker to register to ModelController. If None, the address is automatically determined"
},
)
controller_addr: Optional[str] = field(
default=None, metadata={"help": "The Model controller address to register"}
)
send_heartbeat: Optional[bool] = field(
default=True, metadata={"help": "Send heartbeat to model controller"}
)
heartbeat_interval: Optional[int] = field(
default=20, metadata={"help": "The interval for sending heartbeats (seconds)"}
)
@dataclass
class EmbeddingModelParameters(BaseParameters):
model_name: str = field(metadata={"help": "Model name", "tags": "fixed"})
model_path: str = field(metadata={"help": "Model path", "tags": "fixed"})
device: Optional[str] = field(
default=None,
metadata={
"help": "Device to run model. If None, the device is automatically determined"
},
)
normalize_embeddings: Optional[bool] = field(
default=None,
metadata={
"help": "Determines whether the model's embeddings should be normalized."
},
)
def build_kwargs(self, **kwargs) -> Dict:
model_kwargs, encode_kwargs = None, None
if self.device:
model_kwargs = {"device": self.device}
if self.normalize_embeddings:
encode_kwargs = {"normalize_embeddings": self.normalize_embeddings}
if model_kwargs:
kwargs["model_kwargs"] = model_kwargs
if encode_kwargs:
kwargs["encode_kwargs"] = encode_kwargs
return kwargs
@dataclass
class ModelParameters(BaseParameters):
model_name: str = field(metadata={"help": "Model name", "tags": "fixed"})
model_path: str = field(metadata={"help": "Model path", "tags": "fixed"})
device: Optional[str] = field(
default=None,
metadata={
"help": "Device to run model. If None, the device is automatically determined"
},
)
model_type: Optional[str] = field(
default="huggingface", metadata={"help": "Model type, huggingface or llama.cpp"}
default="huggingface",
metadata={"help": "Model type, huggingface or llama.cpp", "tags": "fixed"},
)
prompt_template: Optional[str] = field(
default=None,
@ -91,7 +333,6 @@ class ModelParameters:
default=True,
metadata={"help": "Nested quantization, only valid when load_4bit=True"},
)
# "bfloat16", "float16", "float32"
compute_dtype: Optional[str] = field(
default=None,
metadata={

View File

114
pilot/model/worker/base.py Normal file
View File

@ -0,0 +1,114 @@
from abc import ABC, abstractmethod
from typing import Dict, Iterator, List, Type
from pilot.model.base import ModelOutput
from pilot.model.parameter import (
ModelParameters,
ParameterDescription,
WorkerType,
_get_parameter_descriptions,
)
class ModelWorker(ABC):
"""
Abstract representation of a Model Worker responsible for model interaction, startup, and shutdown. Supports 'llm' and 'text2vec' models.
"""
def worker_type(self) -> WorkerType:
"""Return the type of worker as LLM."""
return WorkerType.LLM
def model_param_class(self) -> Type:
"""Return the class representing model parameters."""
return ModelParameters
def support_async(self) -> bool:
"""Whether support async, if True, invoke async_generate_stream, async_generate and async_embeddings instead of generate_stream, generate and embeddings"""
return False
@abstractmethod
def parse_parameters(self, command_args: List[str] = None) -> ModelParameters:
"""Parse the parameters using the provided command arguments.
Args:
command_args (List[str]): The command-line arguments. Default is sys.argv[1:].
"""
@abstractmethod
def load_worker(self, model_name: str, model_path: str, **kwargs) -> None:
"""Load the worker with the specified model name and path."""
@abstractmethod
def start(
self, model_params: ModelParameters = None, command_args: List[str] = None
) -> None:
"""Start the model worker"""
@abstractmethod
def stop(self) -> None:
"""Stop the model worker and clean up all the resources used."""
def restart(
self, model_params: ModelParameters = None, command_args: List[str] = None
) -> None:
"""Restart the model worker."""
self.stop()
self.start(model_params, command_args)
def parameter_descriptions(self) -> List[ParameterDescription]:
"""Fetch the parameter configuration information for the current model."""
param_cls = self.model_param_class()
return _get_parameter_descriptions(param_cls)
@abstractmethod
def generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
"""Generate a stream based on provided parameters.
Args:
params (Dict): Parameters matching the PromptRequest data class format. Example:
{
"messages": [{"role": "user", "content": "Hello world"}], # List of ModelMessage objects
"model": "vicuna-13b-v1.5",
"prompt": "Hello world",
"temperature": 0.7, # Optional; float value between 0 and 1
"max_new_tokens": 2048, # Optional; max number of new tokens for the output
"stop": "#", # Optional; stopping condition for the output
"echo": True # Optional; whether to echo the input in the output
}
Returns:
Iterator[ModelOutput]: Stream of model outputs.
"""
async def async_generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
"""Asynchronously generate a stream based on provided parameters."""
raise NotImplementedError
@abstractmethod
def generate(self, params: Dict) -> ModelOutput:
"""Generate output (non-stream) based on provided parameters."""
async def async_generate(self, params: Dict) -> ModelOutput:
"""Asynchronously generate output (non-stream) based on provided parameters."""
raise NotImplementedError
@abstractmethod
def embeddings(self, params: Dict) -> List[List[float]]:
"""
Return embeddings for the given input parameters.
Args:
params (Dict): Parameters matching the EmbeddingsRequest data class format. Example:
{
"model": "text2vec-large-chinese",
"input": ["Hello world", "DB-GPT is amazing"]
}
Returns:
List[List[float]]: List of embeddings corresponding to each input string.
"""
async def async_embeddings(self, params: Dict) -> List[List[float]]:
"""Return embeddings asynchronously for the given input parameters."""
raise NotImplementedError

View File

@ -0,0 +1,133 @@
import logging
import platform
from typing import Dict, Iterator, List
import torch
from pilot.configs.model_config import DEVICE
from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper
from pilot.model.base import ModelOutput
from pilot.model.loader import ModelLoader, _get_model_real_path
from pilot.model.parameter import EnvArgumentParser, ModelParameters
from pilot.model.worker.base import ModelWorker
from pilot.server.chat_adapter import get_llm_chat_adapter, BaseChatAdpter
from pilot.utils.model_utils import _clear_torch_cache
logger = logging.getLogger("model_worker")
class DefaultModelWorker(ModelWorker):
def __init__(self) -> None:
self.model = None
self.tokenizer = None
self._model_params = None
self.llm_adapter: BaseLLMAdaper = None
self.llm_chat_adapter: BaseChatAdpter = None
def load_worker(self, model_name: str, model_path: str, **kwargs) -> None:
if model_path.endswith("/"):
model_path = model_path[:-1]
model_path = _get_model_real_path(model_name, model_path)
self.model_name = model_name
self.model_path = model_path
self.llm_adapter = get_llm_model_adapter(self.model_name, self.model_path)
model_type = self.llm_adapter.model_type()
self.param_cls = self.llm_adapter.model_param_class(model_type)
self.llm_chat_adapter = get_llm_chat_adapter(self.model_name, self.model_path)
self.generate_stream_func = self.llm_chat_adapter.get_generate_stream_func(
self.model_path
)
self.ml: ModelLoader = ModelLoader(
model_path=self.model_path, model_name=self.model_name
)
# TODO read context len from model config
self.context_len = 2048
def model_param_class(self) -> ModelParameters:
return self.param_cls
def parse_parameters(self, command_args: List[str] = None) -> ModelParameters:
param_cls = self.model_param_class()
model_args = EnvArgumentParser()
env_prefix = EnvArgumentParser.get_env_prefix(self.model_name)
model_type = self.llm_adapter.model_type()
model_params: ModelParameters = model_args.parse_args_into_dataclass(
param_cls,
env_prefix=env_prefix,
command_args=command_args,
model_name=self.model_name,
model_path=self.model_path,
model_type=model_type,
)
if not model_params.device:
model_params.device = DEVICE
logger.info(
f"[DefaultModelWorker] Parameters of device is None, use {model_params.device}"
)
return model_params
def start(
self, model_params: ModelParameters = None, command_args: List[str] = None
) -> None:
if not model_params:
model_params = self.parse_parameters(command_args)
self._model_params = model_params
logger.info(f"Begin load model, model params: {model_params}")
self.model, self.tokenizer = self.ml.loader_with_params(model_params)
def stop(self) -> None:
if not self.model:
return
del self.model
del self.tokenizer
self.model = None
self.tokenizer = None
_clear_torch_cache(self._model_params.device)
def generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
try:
# params adaptation
params, model_context = self.llm_chat_adapter.model_adaptation(
params, self.ml.model_path, prompt_template=self.ml.prompt_template
)
for output in self.generate_stream_func(
self.model, self.tokenizer, params, DEVICE, self.context_len
):
# Please do not open the output in production!
# The gpt4all thread shares stdout with the parent process,
# and opening it may affect the frontend output.
if "windows" in platform.platform().lower():
# Do not print the model output, because it may contain Emoji, there is a problem with the GBK encoding
pass
else:
print("output: ", output)
# return some model context to dgt-server
model_output = ModelOutput(
text=output, error_code=0, model_context=model_context
)
yield model_output
except torch.cuda.CudaError:
model_output = ModelOutput(
text="**GPU OutOfMemory, Please Refresh.**", error_code=0
)
yield model_output
except Exception as e:
model_output = ModelOutput(
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
error_code=0,
)
yield model_output
def generate(self, params: Dict) -> ModelOutput:
"""Generate non stream result"""
output = None
for out in self.generate_stream(params):
output = out
return output
def embeddings(self, params: Dict) -> List[List[float]]:
raise NotImplementedError

View File

@ -0,0 +1,100 @@
import logging
from typing import Dict, List, Type
from pilot.configs.model_config import DEVICE
from pilot.model.loader import _get_model_real_path
from pilot.model.parameter import (
EmbeddingModelParameters,
EnvArgumentParser,
WorkerType,
)
from pilot.model.worker.base import ModelWorker
from pilot.utils.model_utils import _clear_torch_cache
logger = logging.getLogger("model_worker")
class EmbeddingsModelWorker(ModelWorker):
def __init__(self) -> None:
try:
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.embeddings.base import Embeddings
except ImportError as exc:
raise ImportError(
"Could not import langchain.embeddings.HuggingFaceEmbeddings python package. "
"Please install it with `pip install langchain`."
) from exc
self.embeddings: Embeddings = None
self._model_params = None
def load_worker(self, model_name: str, model_path: str, **kwargs) -> None:
if model_path.endswith("/"):
model_path = model_path[:-1]
model_path = _get_model_real_path(model_name, model_path)
self.model_name = model_name
self.model_path = model_path
def worker_type(self) -> WorkerType:
return WorkerType.TEXT2VEC
def model_param_class(self) -> Type:
return EmbeddingModelParameters
def parse_parameters(
self, command_args: List[str] = None
) -> EmbeddingModelParameters:
param_cls = self.model_param_class()
model_args = EnvArgumentParser()
env_prefix = EnvArgumentParser.get_env_prefix(self.model_name)
model_params: EmbeddingModelParameters = model_args.parse_args_into_dataclass(
param_cls,
env_prefix=env_prefix,
command_args=command_args,
model_name=self.model_name,
model_path=self.model_path,
)
if not model_params.device:
model_params.device = DEVICE
logger.info(
f"[EmbeddingsModelWorker] Parameters of device is None, use {model_params.device}"
)
return model_params
def start(
self,
model_params: EmbeddingModelParameters = None,
command_args: List[str] = None,
) -> None:
"""Start model worker"""
from langchain.embeddings import HuggingFaceEmbeddings
if not model_params:
model_params = self.parse_parameters(command_args)
self._model_params = model_params
kwargs = model_params.build_kwargs(model_name=model_params.model_path)
logger.info(f"Start HuggingFaceEmbeddings with kwargs: {kwargs}")
self.embeddings = HuggingFaceEmbeddings(**kwargs)
def __del__(self):
self.stop()
def stop(self) -> None:
if not self.embeddings:
return
del self.embeddings
self.embeddings = None
_clear_torch_cache(self._model_params.device)
def generate_stream(self, params: Dict):
"""Generate stream result, chat scene"""
raise NotImplementedError("Not supported generate_stream for embeddings model")
def generate(self, params: Dict):
"""Generate non stream result"""
raise NotImplementedError("Not supported generate for embeddings model")
def embeddings(self, params: Dict) -> List[List[float]]:
input: List[str] = params["input"]
return self.embeddings.embed_documents(input)

View File

@ -0,0 +1,705 @@
import asyncio
import httpx
import itertools
import json
import os
import random
import time
from abc import ABC, abstractmethod
from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import asdict, dataclass, field
from datetime import datetime
from typing import Awaitable, Callable, Dict, Iterator, List, Optional
import uvicorn
from fastapi import APIRouter, FastAPI, Request
from fastapi.responses import StreamingResponse
from pilot.configs.model_config import LOGDIR
from pilot.model.base import (
ModelInstance,
ModelOutput,
WorkerApplyType,
WorkerApplyOutput,
)
from pilot.model.controller.registry import ModelRegistry
from pilot.model.parameter import (
EnvArgumentParser,
ModelParameters,
ModelWorkerParameters,
WorkerType,
ParameterDescription,
)
from pilot.model.worker.base import ModelWorker
from pilot.scene.base_message import ModelMessage
from pilot.utils import build_logger
from pydantic import BaseModel
logger = build_logger("model_worker", LOGDIR + "/model_worker.log")
class PromptRequest(BaseModel):
messages: List[ModelMessage]
model: str
prompt: str = None
temperature: float = None
max_new_tokens: int = None
stop: str = None
echo: bool = True
class EmbeddingsRequest(BaseModel):
model: str
input: List[str]
class WorkerApplyRequest(BaseModel):
model: str
apply_type: WorkerApplyType
worker_type: WorkerType = WorkerType.LLM
params: Dict = None
apply_user: str = None
class WorkerParameterRequest(BaseModel):
model: str
worker_type: WorkerType = WorkerType.LLM
@dataclass
class WorkerRunData:
worker_key: str
worker: ModelWorker
worker_params: ModelWorkerParameters
model_params: ModelParameters
stop_event: asyncio.Event
semaphore: asyncio.Semaphore = None
command_args: List[str] = None
_heartbeat_future: Optional[Future] = None
_last_heartbeat: Optional[datetime] = None
RegisterFunc = Callable[[WorkerRunData], Awaitable[None]]
DeregisterFunc = Callable[[WorkerRunData], Awaitable[None]]
SendHeartbeatFunc = Callable[[WorkerRunData], Awaitable[None]]
ApplyFunction = Callable[[WorkerRunData], Awaitable[None]]
class WorkerManager(ABC):
@abstractmethod
async def get_model_instances(
self, worker_type: str, model_name: str, healthy_only: bool = True
) -> List[WorkerRunData]:
"""Get model instances by worker type and model name"""
@abstractmethod
async def select_one_instanes(
self, worker_type: str, model_name: str, healthy_only: bool = True
) -> WorkerRunData:
"""Select one instances"""
@abstractmethod
async def generate_stream(self, params: Dict, **kwargs) -> Iterator[ModelOutput]:
"""Generate stream result, chat scene"""
@abstractmethod
async def generate(self, params: Dict) -> ModelOutput:
"""Generate non stream result"""
@abstractmethod
async def embeddings(self, params: Dict) -> List[List[float]]:
"""Embed input"""
@abstractmethod
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
"""Worker apply"""
@abstractmethod
async def parameter_descriptions(
self, worker_type: str, model_name: str
) -> List[ParameterDescription]:
"""Get parameter descriptions of model"""
async def _async_heartbeat_sender(
worker_run_data: WorkerRunData, send_heartbeat_func: SendHeartbeatFunc
):
while not worker_run_data.stop_event.is_set():
try:
await send_heartbeat_func(worker_run_data)
except Exception as e:
logger.warn(f"Send heartbeat func error: {str(e)}")
finally:
await asyncio.sleep(worker_run_data.worker_params.heartbeat_interval)
class LocalWorkerManager(WorkerManager):
def __init__(
self,
register_func: RegisterFunc = None,
deregister_func: DeregisterFunc = None,
send_heartbeat_func: SendHeartbeatFunc = None,
model_registry: ModelRegistry = None,
) -> None:
self.workers: Dict[str, List[WorkerRunData]] = dict()
self.executor = ThreadPoolExecutor(max_workers=os.cpu_count() * 5)
self.register_func = register_func
self.deregister_func = deregister_func
self.send_heartbeat_func = send_heartbeat_func
self.model_registry = model_registry
def _worker_key(self, worker_type: str, model_name: str) -> str:
return f"{model_name}@{worker_type}"
def add_worker(
self,
worker: ModelWorker,
worker_params: ModelWorkerParameters,
embedded_mod: bool = True,
command_args: List[str] = None,
):
if not command_args:
import sys
command_args = sys.argv[1:]
worker.load_worker(**asdict(worker_params))
if not worker_params.worker_type:
worker_params.worker_type = worker.worker_type()
worker_key = self._worker_key(
worker_params.worker_type, worker_params.model_name
)
host = worker_params.host
port = worker_params.port
instances = self.workers.get(worker_key)
if not instances:
instances = []
self.workers[worker_key] = instances
logger.info(f"Init empty instances list for {worker_key}")
# Load model params from persist storage
model_params = worker.parse_parameters(command_args=command_args)
worker_run_data = WorkerRunData(
worker_key=worker_key,
worker=worker,
worker_params=worker_params,
model_params=model_params,
stop_event=asyncio.Event(),
semaphore=asyncio.Semaphore(worker_params.limit_model_concurrency),
command_args=command_args,
)
if not embedded_mod:
exist_instances = [
(w, p) for w, p in instances if p.host == host and p.port == port
]
if not exist_instances:
instances.append(worker_run_data)
else:
instances.append(worker_run_data)
async def get_model_instances(
self, worker_type: str, model_name: str, healthy_only: bool = True
) -> List[WorkerRunData]:
worker_key = self._worker_key(worker_type, model_name)
return self.workers.get(worker_key)
async def select_one_instanes(
self, worker_type: str, model_name: str, healthy_only: bool = True
) -> WorkerRunData:
worker_instances = await self.get_model_instances(
worker_type, model_name, healthy_only
)
if not worker_instances:
raise Exception(
f"Cound not found worker instances for model name {model_name} and worker type {worker_type}"
)
worker_run_data = random.choice(worker_instances)
return worker_run_data
async def _get_model(self, params: Dict, worker_type: str = "llm") -> WorkerRunData:
model = params.get("model")
if not model:
raise Exception("Model name count not be empty")
return await self.select_one_instanes(worker_type, model, healthy_only=True)
async def generate_stream(
self, params: Dict, async_wrapper=None, **kwargs
) -> Iterator[ModelOutput]:
"""Generate stream result, chat scene"""
worker_run_data = await self._get_model(params)
async with worker_run_data.semaphore:
if worker_run_data.worker.support_async():
async for outout in worker_run_data.worker.async_generate_stream(
params
):
yield outout
else:
if not async_wrapper:
from starlette.concurrency import iterate_in_threadpool
async_wrapper = iterate_in_threadpool
async for output in async_wrapper(
worker_run_data.worker.generate_stream(params)
):
yield output
async def generate(self, params: Dict) -> ModelOutput:
"""Generate non stream result"""
worker_run_data = await self._get_model(params)
async with worker_run_data.semaphore:
if worker_run_data.worker.support_async():
return await worker_run_data.worker.async_generate(params)
else:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor, worker_run_data.worker.generate, params
)
async def embeddings(self, params: Dict) -> List[List[float]]:
"""Embed input"""
worker_run_data = await self._get_model(params, worker_type="text2vec")
async with worker_run_data.semaphore:
if worker_run_data.worker.support_async():
return await worker_run_data.worker.async_embeddings(params)
else:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor, worker_run_data.worker.embeddings, params
)
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
apply_func: Callable[[WorkerApplyRequest], Awaitable[str]] = None
if apply_req.apply_type == WorkerApplyType.START:
apply_func = self._start_all_worker
elif apply_req.apply_type == WorkerApplyType.STOP:
apply_func = self._stop_all_worker
elif apply_req.apply_type == WorkerApplyType.UPDATE_PARAMS:
apply_func = self._update_all_worker_params
else:
raise ValueError(f"Unsupported apply type {apply_req.apply_type}")
return await apply_func(apply_req)
async def parameter_descriptions(
self, worker_type: str, model_name: str
) -> List[ParameterDescription]:
worker_instances = await self.get_model_instances(worker_type, model_name)
if not worker_instances:
raise Exception(
f"Not worker instances for model name {model_name} worker type {worker_type}"
)
worker_run_data = worker_instances[0]
return worker_run_data.worker.parameter_descriptions()
async def _apply_worker(
self, apply_req: WorkerApplyRequest, apply_func: ApplyFunction
) -> None:
"""Apply function to worker instances in parallel
Args:
apply_req (WorkerApplyRequest): Worker apply request
apply_func (ApplyFunction): Function to apply to worker instances, now function is async function
"""
if apply_req:
worker_type = apply_req.worker_type.value
model_name = apply_req.model
worker_instances = await self.get_model_instances(worker_type, model_name)
if not worker_instances:
raise Exception(
f"No worker instance found for the model {model_name} worker type {worker_type}"
)
else:
# Apply to all workers
worker_instances = list(itertools.chain(*self.workers.values()))
logger.info(f"Apply to all workers: {worker_instances}")
return await asyncio.gather(
*(apply_func(worker) for worker in worker_instances)
)
async def _start_all_worker(
self, apply_req: WorkerApplyRequest
) -> WorkerApplyOutput:
start_time = time.time()
logger.info(f"Begin start all worker, apply_req: {apply_req}")
async def _start_worker(worker_run_data: WorkerRunData):
worker_run_data.worker.start(
worker_run_data.model_params, worker_run_data.command_args
)
worker_run_data.stop_event.clear()
if worker_run_data.worker_params.register and self.register_func:
# Register worker to controller
await self.register_func(worker_run_data)
if (
worker_run_data.worker_params.send_heartbeat
and self.send_heartbeat_func
):
asyncio.create_task(
_async_heartbeat_sender(
worker_run_data, self.send_heartbeat_func
)
)
await self._apply_worker(apply_req, _start_worker)
timecost = time.time() - start_time
return WorkerApplyOutput(
message=f"Worker started successfully", timecost=timecost
)
async def _stop_all_worker(
self, apply_req: WorkerApplyRequest
) -> WorkerApplyOutput:
start_time = time.time()
async def _stop_worker(worker_run_data: WorkerRunData):
worker_run_data.worker.stop()
# Set stop event
worker_run_data.stop_event.set()
if worker_run_data._heartbeat_future:
# Wait thread finish
worker_run_data._heartbeat_future.result()
worker_run_data._heartbeat_future = None
if (
worker_run_data.worker_params.register
and self.register_func
and self.deregister_func
):
await self.deregister_func(worker_run_data)
await self._apply_worker(apply_req, _stop_worker)
timecost = time.time() - start_time
return WorkerApplyOutput(
message=f"Worker stopped successfully", timecost=timecost
)
async def _update_all_worker_params(
self, apply_req: WorkerApplyRequest
) -> WorkerApplyOutput:
start_time = time.time()
need_restart = False
async def update_params(worker_run_data: WorkerRunData):
nonlocal need_restart
new_params = apply_req.params
if not new_params:
return
if worker_run_data.model_params.update_from(new_params):
need_restart = True
await self._apply_worker(apply_req, update_params)
message = f"Update worker params successfully"
timecost = time.time() - start_time
if need_restart:
logger.info("Model params update successfully, begin restart worker")
await self._stop_all_worker(apply_req)
await self._start_all_worker(apply_req)
timecost = time.time() - start_time
message = f"Update worker params and restart successfully"
return WorkerApplyOutput(message=message, timecost=timecost)
class RemoteWorkerManager(LocalWorkerManager):
def __init__(self, model_registry: ModelRegistry = None) -> None:
super().__init__(model_registry=model_registry)
async def get_model_instances(
self, worker_type: str, model_name: str, healthy_only: bool = True
) -> List[WorkerRunData]:
from pilot.model.worker.remote_worker import RemoteModelWorker
worker_key = self._worker_key(worker_type, model_name)
instances: List[ModelInstance] = await self.model_registry.get_all_instances(
worker_key, healthy_only
)
worker_instances = []
for ins in instances:
worker = RemoteModelWorker()
worker.load_worker(model_name, model_name, host=ins.host, port=ins.port)
wr = WorkerRunData(
worker_key=ins.model_name,
worker=worker,
worker_params=None,
model_params=None,
stop_event=asyncio.Event(),
semaphore=asyncio.Semaphore(100), # Not limit in client
)
worker_instances.append(wr)
return worker_instances
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
async def _remote_apply_func(worker_run_data: WorkerRunData):
worker_addr = worker_run_data.worker.worker_addr
async with httpx.AsyncClient() as client:
response = await client.post(
worker_addr + "/apply",
headers=worker_run_data.worker.headers,
json=apply_req.dict(),
timeout=worker_run_data.worker.timeout,
)
if response.status_code == 200:
output = WorkerApplyOutput(**response.json())
logger.info(f"worker_apply success: {output}")
else:
output = WorkerApplyOutput(message=response.text)
logger.warn(f"worker_apply failed: {output}")
return output
results = await self._apply_worker(apply_req, _remote_apply_func)
if results:
return results[0]
class WorkerManagerAdapter(WorkerManager):
def __init__(self, worker_manager: WorkerManager = None) -> None:
self.worker_manager = worker_manager
async def get_model_instances(
self, worker_type: str, model_name: str, healthy_only: bool = True
) -> List[WorkerRunData]:
return await self.worker_manager.get_model_instances(
worker_type, model_name, healthy_only
)
async def select_one_instanes(
self, worker_type: str, model_name: str, healthy_only: bool = True
) -> WorkerRunData:
return await self.worker_manager.select_one_instanes(
worker_type, model_name, healthy_only
)
async def generate_stream(self, params: Dict, **kwargs) -> Iterator[ModelOutput]:
async for output in self.worker_manager.generate_stream(params, **kwargs):
yield output
async def generate(self, params: Dict) -> ModelOutput:
return await self.worker_manager.generate(params)
async def embeddings(self, params: Dict) -> List[List[float]]:
return await self.worker_manager.embeddings(params)
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
return await self.worker_manager.worker_apply(apply_req)
async def parameter_descriptions(
self, worker_type: str, model_name: str
) -> List[ParameterDescription]:
return await self.worker_manager.parameter_descriptions(worker_type, model_name)
worker_manager = WorkerManagerAdapter()
router = APIRouter()
async def generate_json_stream(params):
from starlette.concurrency import iterate_in_threadpool
async for output in worker_manager.generate_stream(
params, async_wrapper=iterate_in_threadpool
):
yield json.dumps(asdict(output), ensure_ascii=False).encode() + b"\0"
@router.post("/worker/generate_stream")
async def api_generate_stream(request: Request):
params = await request.json()
generator = generate_json_stream(params)
return StreamingResponse(generator)
@router.post("/worker/generate")
async def api_generate(request: PromptRequest):
params = request.dict(exclude_none=True)
output = await worker_manager.generate(params)
return output
@router.post("/worker/embeddings")
async def api_embeddings(request: EmbeddingsRequest):
params = request.dict(exclude_none=True)
output = await worker_manager.embeddings(params)
return output
@router.post("/worker/apply")
async def api_worker_apply(request: WorkerApplyRequest):
output = await worker_manager.worker_apply(request)
return output
@router.get("/worker/parameter/descriptions")
async def api_worker_parameter_descs(
model: str, worker_type: str = WorkerType.LLM.value
):
output = await worker_manager.parameter_descriptions(worker_type, model)
return output
def _setup_fastapi(worker_params: ModelWorkerParameters):
app = FastAPI()
if worker_params.standalone:
from pilot.model.controller.controller import router as controller_router
if not worker_params.controller_addr:
worker_params.controller_addr = f"http://127.0.0.1:{worker_params.port}"
app.include_router(controller_router, prefix="/api")
@app.on_event("startup")
async def startup_event():
asyncio.create_task(
worker_manager.worker_manager._start_all_worker(apply_req=None)
)
return app
def _parse_worker_params(
model_name: str = None, model_path: str = None, **kwargs
) -> ModelWorkerParameters:
worker_args = EnvArgumentParser()
worker_params: ModelWorkerParameters = worker_args.parse_args_into_dataclass(
ModelWorkerParameters, model_name=model_name, model_path=model_path, **kwargs
)
env_prefix = EnvArgumentParser.get_env_prefix(worker_params.model_name)
# Read parameters agein with prefix of model name.
new_worker_params = worker_args.parse_args_into_dataclass(
ModelWorkerParameters,
env_prefix=env_prefix,
model_name=worker_params.model_name,
model_path=worker_params.model_path,
**kwargs,
)
worker_params.update_from(new_worker_params)
logger.info(f"Worker params: {worker_params}")
return worker_params
def _create_local_model_manager(
worker_params: ModelWorkerParameters,
) -> LocalWorkerManager:
if not worker_params.register or not worker_params.controller_addr:
logger.info(
f"Not register current to controller, register: {worker_params.register}, controller_addr: {worker_params.controller_addr}"
)
return LocalWorkerManager()
else:
from pilot.model.controller.registry import ModelRegistryClient
from pilot.utils.net_utils import _get_ip_address
client = ModelRegistryClient(worker_params.controller_addr)
host = _get_ip_address()
port = worker_params.port
async def register_func(worker_run_data: WorkerRunData):
instance = ModelInstance(
model_name=worker_run_data.worker_key, host=host, port=port
)
return await client.register_instance(instance)
async def send_heartbeat_func(worker_run_data: WorkerRunData):
instance = ModelInstance(
model_name=worker_run_data.worker_key, host=host, port=port
)
return await client.send_heartbeat(instance)
return LocalWorkerManager(
register_func=register_func, send_heartbeat_func=send_heartbeat_func
)
def _start_local_worker(
worker_manager: WorkerManagerAdapter,
worker_params: ModelWorkerParameters,
embedded_mod=True,
):
from pilot.utils.module_utils import import_from_checked_string
if worker_params.worker_class:
worker_cls = import_from_checked_string(worker_params.worker_class, ModelWorker)
logger.info(
f"Import worker class from {worker_params.worker_class} successfully"
)
worker: ModelWorker = worker_cls()
else:
from pilot.model.worker.default_worker import DefaultModelWorker
worker = DefaultModelWorker()
worker_manager.worker_manager = _create_local_model_manager(worker_params)
worker_manager.worker_manager.add_worker(
worker, worker_params, embedded_mod=embedded_mod
)
def initialize_worker_manager_in_client(
app=None,
include_router: bool = True,
model_name: str = None,
model_path: str = None,
run_locally: bool = True,
controller_addr: str = None,
):
global worker_manager
worker_params: ModelWorkerParameters = _parse_worker_params(
model_name=model_name, model_path=model_path, controller_addr=controller_addr
)
logger.info(f"Worker params: {worker_params}")
if run_locally:
worker_params.register = False
_start_local_worker(worker_manager, worker_params, True)
loop = asyncio.get_event_loop()
loop.run_until_complete(
worker_manager.worker_manager._start_all_worker(apply_req=None)
)
else:
from pilot.model.controller.registry import ModelRegistryClient
if not worker_params.controller_addr:
raise ValueError("Controller can`t be None")
client = ModelRegistryClient(worker_params.controller_addr)
worker_manager.worker_manager = RemoteWorkerManager(client)
if include_router and app:
app.include_router(router, prefix="/api")
def run_worker_manager(
app=None,
include_router: bool = True,
model_name: str = None,
model_path: str = None,
standalone: bool = False,
port: int = None,
):
global worker_manager
worker_params: ModelWorkerParameters = _parse_worker_params(
model_name=model_name, model_path=model_path, standalone=standalone, port=port
)
embedded_mod = True
if not app:
# Run worker manager independently
embedded_mod = False
app = _setup_fastapi(worker_params)
_start_local_worker(worker_manager, worker_params, embedded_mod=False)
else:
_start_local_worker(worker_manager, worker_params, embedded_mod=False)
loop = asyncio.get_event_loop()
loop.run_until_complete(
worker_manager.worker_manager._start_all_worker(apply_req=None)
)
if include_router:
app.include_router(router, prefix="/api")
if not embedded_mod:
uvicorn.run(
app, host=worker_params.host, port=worker_params.port, log_level="info"
)
if __name__ == "__main__":
run_worker_manager()

View File

View File

@ -0,0 +1,98 @@
import json
from typing import Dict, Iterator, List
import httpx
from pilot.model.base import ModelOutput
from pilot.model.parameter import ModelParameters
from pilot.model.worker.base import ModelWorker
class RemoteModelWorker(ModelWorker):
def __init__(self) -> None:
self.headers = {}
self.timeout = 60
self.host = None
self.port = None
@property
def worker_addr(self) -> str:
return f"http://{self.host}:{self.port}/api/worker"
def support_async(self) -> bool:
return True
def parse_parameters(self, command_args: List[str] = None) -> ModelParameters:
return None
def load_worker(self, model_name: str, model_path: str, **kwargs):
self.host = kwargs.get("host")
self.port = kwargs.get("port")
def start(
self, model_params: ModelParameters = None, command_args: List[str] = None
) -> None:
"""Start model worker"""
pass
# raise NotImplementedError("Remote model worker not support start methods")
def stop(self) -> None:
raise NotImplementedError("Remote model worker not support stop methods")
def generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
"""Generate stream"""
raise NotImplementedError
async def async_generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
"""Asynchronous generate stream"""
print(f"Send async_generate_stream, params: {params}")
async with httpx.AsyncClient() as client:
delimiter = b"\0"
buffer = b""
async with client.stream(
"POST",
self.worker_addr + "/generate_stream",
headers=self.headers,
json=params,
timeout=self.timeout,
) as response:
async for raw_chunk in response.aiter_raw():
buffer += raw_chunk
while delimiter in buffer:
chunk, buffer = buffer.split(delimiter, 1)
if not chunk:
continue
chunk = chunk.decode()
data = json.loads(chunk)
yield ModelOutput(**data)
def generate(self, params: Dict) -> ModelOutput:
"""Generate non stream"""
raise NotImplementedError
async def async_generate(self, params: Dict) -> ModelOutput:
"""Asynchronous generate non stream"""
print(f"Send async_generate_stream, params: {params}")
async with httpx.AsyncClient() as client:
response = await client.post(
self.worker_addr + "/generate",
headers=self.headers,
json=params,
timeout=self.timeout,
)
return ModelOutput(**response.json())
def embeddings(self, params: Dict) -> List[List[float]]:
"""Get embeddings for input"""
raise NotImplementedError
async def async_embeddings(self, params: Dict) -> List[List[float]]:
"""Asynchronous get embeddings for input"""
async with httpx.AsyncClient() as client:
response = await client.post(
self.worker_addr + "/embeddings",
headers=self.headers,
json=params,
timeout=self.timeout,
)
return response.json()

View File

@ -311,33 +311,23 @@ async def chat_completions(dialogue: ConversationVo = Body()):
async def no_stream_generator(chat):
msg = chat.nostream_call()
msg = await chat.nostream_call()
msg = msg.replace("\n", "\\n")
yield f"data: {msg}\n\n"
async def stream_generator(chat):
model_response = chat.stream_call()
msg = "[LLM_ERROR]: llm server has no output, maybe your prompt template is wrong."
if not CFG.NEW_SERVER_MODE:
for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
chunk, chat.skip_echo_len
)
msg = msg.replace("\n", "\\n")
yield f"data:{msg}\n\n"
await asyncio.sleep(0.02)
else:
for chunk in model_response:
if chunk:
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
chunk, chat.skip_echo_len
)
msg = msg.replace("\n", "\\n")
yield f"data:{msg}\n\n"
await asyncio.sleep(0.02)
async for chunk in chat.stream_call():
if chunk:
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
chunk, chat.skip_echo_len
)
msg = msg.replace("\n", "\\n")
yield f"data:{msg}\n\n"
await asyncio.sleep(0.02)
chat.current_message.add_ai_message(msg)
chat.current_message.add_view_message(msg)

View File

@ -1,26 +1,19 @@
from __future__ import annotations
import json
from abc import ABC, abstractmethod
from typing import (
Any,
Dict,
Generic,
List,
NamedTuple,
Optional,
Sequence,
TypeVar,
Union,
)
from pilot.utils import build_logger
import re
from abc import ABC, abstractmethod
from dataclasses import asdict
from typing import Any, Dict, TypeVar, Union
from pydantic import BaseModel, Extra, Field, root_validator
from pilot.configs.model_config import LOGDIR
from pilot.configs.config import Config
from pilot.configs.model_config import LOGDIR
from pilot.model.base import ModelOutput
from pilot.utils import build_logger
T = TypeVar("T")
ResponseTye = Union[str, bytes, ModelOutput]
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
CFG = Config()
@ -46,11 +39,8 @@ class BaseOutputParser(ABC):
code = sep.join(blocks)
return code
def parse_model_stream_resp_ex(self, chunk, skip_echo_len):
if b"\0" in chunk:
chunk = chunk.replace(b"\0", b"")
data = json.loads(chunk.decode())
def parse_model_stream_resp_ex(self, chunk: ResponseTye, skip_echo_len):
data = _parse_model_response(chunk)
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
"""
model_context = data.get("model_context")
@ -103,8 +93,8 @@ class BaseOutputParser(ABC):
output = data["text"] + f" (error_code: {data['error_code']})"
yield output
def parse_model_nostream_resp(self, response, sep: str):
resp_obj_ex = json.loads(response)
def parse_model_nostream_resp(self, response: ResponseTye, sep: str):
resp_obj_ex = _parse_model_response(response)
if isinstance(resp_obj_ex, str):
resp_obj_ex = json.loads(resp_obj_ex)
if resp_obj_ex["error_code"] == 0:
@ -240,3 +230,17 @@ class BaseOutputParser(ABC):
output_parser_dict = super().dict()
output_parser_dict["_type"] = self._type
return output_parser_dict
def _parse_model_response(response: ResponseTye):
if isinstance(response, ModelOutput):
resp_obj_ex = asdict(response)
elif isinstance(response, str):
resp_obj_ex = json.loads(response)
elif isinstance(response, bytes):
if b"\0" in response:
response = response.replace(b"\0", b"")
resp_obj_ex = json.loads(response.decode())
else:
raise ValueError(f"Unsupported response type {type(response)}")
return resp_obj_ex

View File

@ -28,10 +28,7 @@ from pilot.memory.chat_history.mem_history import MemHistoryMemory
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
from pilot.configs.model_config import LOGDIR, DATASETS_DIR
from pilot.utils import (
build_logger,
server_error_msg,
)
from pilot.utils import build_logger, server_error_msg, get_or_create_event_loop
from pilot.scene.base_message import (
BaseMessage,
SystemMessage,
@ -158,7 +155,7 @@ class BaseChat(ABC):
}
return payload
def stream_call(self):
async def stream_call(self):
# TODO Retry when server connection error
payload = self.__call_base()
@ -166,19 +163,10 @@ class BaseChat(ABC):
logger.info(f"Requert: \n{payload}")
ai_response_text = ""
try:
if not CFG.NEW_SERVER_MODE:
response = requests.post(
urljoin(CFG.MODEL_SERVER, "generate_stream"),
headers=headers,
json=payload,
stream=True,
timeout=120,
)
return response
else:
from pilot.server.llmserver import worker
from pilot.model.worker.manager import worker_manager
return worker.generate_stream_gate(payload)
async for output in worker_manager.generate_stream(payload):
yield output
except Exception as e:
print(traceback.format_exc())
logger.error("model response parase faild" + str(e))
@ -188,34 +176,19 @@ class BaseChat(ABC):
### store current conversation
self.memory.append(self.current_message)
def nostream_call(self):
async def nostream_call(self):
payload = self.__call_base()
logger.info(f"Requert: \n{payload}")
ai_response_text = ""
try:
rsp_str = ""
if not CFG.NEW_SERVER_MODE:
rsp_obj = requests.post(
urljoin(CFG.MODEL_SERVER, "generate"),
headers=headers,
json=payload,
timeout=120,
)
rsp_str = rsp_obj.text
else:
###TODO no stream mode need independent
from pilot.server.llmserver import worker
from pilot.model.worker.manager import worker_manager
output = worker.generate_stream_gate(payload)
for rsp in output:
rsp = rsp.replace(b"\0", b"")
rsp_str = rsp.decode()
print("[TEST: output]:", rsp_str)
model_output = await worker_manager.generate(payload)
### output parse
ai_response_text = (
self.prompt_template.output_parser.parse_model_nostream_resp(
rsp_str, self.prompt_template.sep
model_output, self.prompt_template.sep
)
)
### model result deal
@ -245,11 +218,31 @@ class BaseChat(ABC):
self.memory.append(self.current_message)
return self.current_ai_response()
def _blocking_stream_call(self):
logger.warn(
"_blocking_stream_call is only temporarily used in webserver and will be deleted soon, please use stream_call to replace it for higher performance"
)
loop = get_or_create_event_loop()
async_gen = self.stream_call()
while True:
try:
value = loop.run_until_complete(async_gen.__anext__())
yield value
except StopAsyncIteration:
break
def _blocking_nostream_call(self):
logger.warn(
"_blocking_nostream_call is only temporarily used in webserver and will be deleted soon, please use nostream_call to replace it for higher performance"
)
loop = get_or_create_event_loop()
return loop.run_until_complete(self.nostream_call())
def call(self):
if self.prompt_template.stream_out:
yield self.stream_call()
yield self._blocking_stream_call()
else:
return self.nostream_call()
return self._blocking_nostream_call()
def prepare(self):
pass

View File

View File

@ -0,0 +1,45 @@
import sys
import click
import os
import copy
import logging
sys.path.append(
os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
)
@click.group()
@click.option(
"--log-level",
required=False,
type=str,
default="warn",
help="Log level",
)
@click.version_option()
def cli(log_level: str):
# TODO not working now
logging.basicConfig(level=log_level, encoding="utf-8")
def add_command_alias(command, name: str, hidden: bool = False):
new_command = copy.deepcopy(command)
new_command.hidden = hidden
cli.add_command(new_command, name=name)
try:
from pilot.model.cli import model_cli_group
add_command_alias(model_cli_group, name="model")
except ImportError as e:
logging.warning(f"Integrating dbgpt model command line tool failed: {e}")
def main():
return cli()
if __name__ == "__main__":
main()

View File

@ -10,13 +10,7 @@ ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__fi
sys.path.append(ROOT_PATH)
import signal
from pilot.configs.config import Config
# from pilot.configs.model_config import (
# DATASETS_DIR,
# KNOWLEDGE_UPLOAD_ROOT_PATH,
# LLM_MODEL_CONFIG,
# LOGDIR,
# )
from pilot.configs.model_config import LLM_MODEL_CONFIG
from pilot.utils import build_logger
from pilot.server.base import server_init
@ -33,8 +27,9 @@ from pilot.openapi.api_v1.api_v1 import router as api_v1
from pilot.openapi.base import validation_exception_handler
from pilot.openapi.api_v1.editor.api_editor_v1 import router as api_editor_route_v1
from pilot.commands.disply_type.show_chart_gen import static_message_img_path
from pilot.model.worker.manager import initialize_worker_manager_in_client
logging.basicConfig(level=logging.INFO)
logging.basicConfig(level=logging.INFO, encoding="utf-8")
static_file_path = os.path.join(os.getcwd(), "server/static")
@ -80,12 +75,19 @@ app.include_router(api_editor_route_v1, prefix="/api")
app.include_router(knowledge_router)
# app.include_router(api_editor_route_v1)
os.makedirs(static_message_img_path, exist_ok=True)
app.mount(
"/images", StaticFiles(directory=static_message_img_path, html=True), name="static2"
)
app.mount("/_next/static", StaticFiles(directory=static_file_path + "/_next/static"))
app.mount("/", StaticFiles(directory=static_file_path, html=True), name="static")
def mount_static_files(app):
os.makedirs(static_message_img_path, exist_ok=True)
app.mount(
"/images",
StaticFiles(directory=static_message_img_path, html=True),
name="static2",
)
app.mount(
"/_next/static", StaticFiles(directory=static_file_path + "/_next/static")
)
app.mount("/", StaticFiles(directory=static_file_path, html=True), name="static")
app.add_exception_handler(RequestValidationError, validation_exception_handler)
@ -100,6 +102,7 @@ if __name__ == "__main__":
parser.add_argument("--port", type=int, default=5000)
parser.add_argument("--concurrency-count", type=int, default=10)
parser.add_argument("--share", default=False, action="store_true")
parser.add_argument("--log-level", type=str, default="info")
parser.add_argument(
"-light",
"--light",
@ -112,17 +115,28 @@ if __name__ == "__main__":
args = parser.parse_args()
server_init(args)
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
if not args.light:
print("Model Unified Deployment Mode!")
from pilot.server.llmserver import worker
initialize_worker_manager_in_client(
app=app, model_name=CFG.LLM_MODEL, model_path=model_path
)
worker.start_check()
CFG.NEW_SERVER_MODE = True
else:
# MODEL_SERVER is controller address now
initialize_worker_manager_in_client(
app=app,
model_name=CFG.LLM_MODEL,
model_path=model_path,
run_locally=False,
controller_addr=CFG.MODEL_SERVER,
)
CFG.SERVER_LIGHT_MODE = True
mount_static_files(app)
import uvicorn
logging.basicConfig(level=logging.INFO)
uvicorn.run(app, host="0.0.0.0", port=args.port, log_level=0)
logging.basicConfig(level=logging.INFO, encoding="utf-8")
uvicorn.run(app, host="0.0.0.0", port=args.port, log_level=args.log_level)
signal.signal(signal.SIGINT, signal_handler())

View File

@ -1,19 +1,8 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import asyncio
import json
import os
import sys
from typing import List
import platform
import uvicorn
from fastapi import BackgroundTasks, FastAPI, Request
from fastapi.responses import StreamingResponse
# from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
global_counter = 0
model_semaphore = None
@ -22,197 +11,26 @@ ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__fi
sys.path.append(ROOT_PATH)
from pilot.configs.config import Config
from pilot.configs.model_config import *
from pilot.model.llm_out.vicuna_base_llm import get_embeddings
from pilot.model.loader import ModelLoader, _get_model_real_path
from pilot.server.chat_adapter import get_llm_chat_adapter
from pilot.scene.base_message import ModelMessage
from pilot.configs.model_config import LLM_MODEL_CONFIG
from pilot.model.worker.manager import run_worker_manager
CFG = Config()
class ModelWorker:
def __init__(self, model_path, model_name, device):
if model_path.endswith("/"):
model_path = model_path[:-1]
model_path = _get_model_real_path(model_name, model_path)
# self.model_name = model_name or model_path.split("/")[-1]
self.device = device
print(
f"Loading {model_name} LLM ModelServer in {device} from model path {model_path}! Please Wait......"
)
self.ml: ModelLoader = ModelLoader(model_path=model_path, model_name=model_name)
self.model, self.tokenizer = self.ml.loader(
load_8bit=CFG.IS_LOAD_8BIT,
load_4bit=CFG.IS_LOAD_4BIT,
debug=ISDEBUG,
max_gpu_memory=CFG.MAX_GPU_MEMORY,
)
if not isinstance(self.model, str):
if hasattr(self.model, "config") and hasattr(
self.model.config, "max_sequence_length"
):
self.context_len = self.model.config.max_sequence_length
elif hasattr(self.model, "config") and hasattr(
self.model.config, "max_position_embeddings"
):
self.context_len = self.model.config.max_position_embeddings
else:
self.context_len = 2048
self.llm_chat_adapter = get_llm_chat_adapter(model_name, model_path)
self.generate_stream_func = self.llm_chat_adapter.get_generate_stream_func(
model_path
)
def start_check(self):
print("LLM Model Loading Success")
def get_queue_length(self):
if (
model_semaphore is None
or model_semaphore._value is None
or model_semaphore._waiters is None
):
return 0
else:
(
CFG.LIMIT_MODEL_CONCURRENCY
- model_semaphore._value
+ len(model_semaphore._waiters)
)
def generate_stream_gate(self, params):
try:
# params adaptation
params, model_context = self.llm_chat_adapter.model_adaptation(
params, self.ml.model_path, prompt_template=self.ml.prompt_template
)
for output in self.generate_stream_func(
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
):
# Please do not open the output in production!
# The gpt4all thread shares stdout with the parent process,
# and opening it may affect the frontend output.
if "windows" in platform.platform().lower():
# Do not print the model output, because it may contain Emoji, there is a problem with the GBK encoding
pass
else:
print("output: ", output)
# return some model context to dgt-server
ret = {"text": output, "error_code": 0, "model_context": model_context}
yield json.dumps(ret).encode() + b"\0"
except torch.cuda.CudaError:
ret = {"text": "**GPU OutOfMemory, Please Refresh.**", "error_code": 0}
yield json.dumps(ret).encode() + b"\0"
except Exception as e:
ret = {
"text": f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
"error_code": 0,
}
yield json.dumps(ret).encode() + b"\0"
def get_embeddings(self, prompt):
return get_embeddings(self.model, self.tokenizer, prompt)
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
worker = ModelWorker(model_path=model_path, model_name=CFG.LLM_MODEL, device=DEVICE)
# worker = ModelWorker(model_path=model_path, model_name=CFG.LLM_MODEL, device=DEVICE)
app = FastAPI()
# from pilot.openapi.knowledge.knowledge_controller import router
#
# app.include_router(router)
#
# origins = [
# "http://localhost",
# "http://localhost:8000",
# "http://localhost:3000",
# ]
#
# app.add_middleware(
# CORSMiddleware,
# allow_origins=origins,
# allow_credentials=True,
# allow_methods=["*"],
# allow_headers=["*"],
# )
class PromptRequest(BaseModel):
messages: List[ModelMessage]
prompt: str
temperature: float
max_new_tokens: int
model: str
stop: str = None
echo: bool = True
class StreamRequest(BaseModel):
model: str
prompt: str
temperature: float
max_new_tokens: int
stop: str
class EmbeddingRequest(BaseModel):
prompt: str
def release_model_semaphore():
model_semaphore.release()
@app.post("/generate_stream")
async def api_generate_stream(request: Request):
global model_semaphore, global_counter
global_counter += 1
params = await request.json()
if model_semaphore is None:
model_semaphore = asyncio.Semaphore(CFG.LIMIT_MODEL_CONCURRENCY)
await model_semaphore.acquire()
generator = worker.generate_stream_gate(params)
background_tasks = BackgroundTasks()
background_tasks.add_task(release_model_semaphore)
return StreamingResponse(generator, background=background_tasks)
@app.post("/generate")
def generate(prompt_request: PromptRequest) -> str:
params = {
"messages": prompt_request.messages,
"prompt": prompt_request.prompt,
"temperature": prompt_request.temperature,
"max_new_tokens": prompt_request.max_new_tokens,
"stop": prompt_request.stop,
"echo": prompt_request.echo,
}
rsp_str = ""
output = worker.generate_stream_gate(params)
for rsp in output:
# rsp = rsp.decode("utf-8")
rsp = rsp.replace(b"\0", b"")
rsp_str = rsp.decode()
return rsp_str
@app.post("/embedding")
def embeddings(prompt_request: EmbeddingRequest):
params = {"prompt": prompt_request.prompt}
print("Received prompt: ", params["prompt"])
output = worker.get_embeddings(params["prompt"])
return {"response": [float(x) for x in output]}
# @app.post("/embedding")
# def embeddings(prompt_request: EmbeddingRequest):
# params = {"prompt": prompt_request.prompt}
# print("Received prompt: ", params["prompt"])
# output = worker.get_embeddings(params["prompt"])
# return {"response": [float(x) for x in output]}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=CFG.MODEL_PORT, log_level="info")
run_worker_manager(
model_name=CFG.LLM_MODEL,
model_path=model_path,
standalone=True,
port=CFG.MODEL_PORT,
)

View File

@ -2,7 +2,6 @@ import logging
import os
import requests
from playsound import playsound
from pilot.speech.base import VoiceBase
@ -23,6 +22,8 @@ class BrianSpeech(VoiceBase):
Returns:
bool: True if the request was successful, False otherwise
"""
from playsound import playsound
tts_url = (
f"https://api.streamelements.com/kappa/v2/speech?voice=Brian&text={text}"
)

View File

@ -2,7 +2,6 @@
import os
import requests
from playsound import playsound
from pilot.configs.config import Config
from pilot.speech.base import VoiceBase
@ -70,6 +69,7 @@ class ElevenLabsSpeech(VoiceBase):
bool: True if the request was successful, False otherwise
"""
from pilot.logs import logger
from playsound import playsound
tts_url = (
f"https://api.elevenlabs.io/v1/text-to-speech/{self._voices[voice_index]}"

View File

@ -2,7 +2,6 @@
import os
import gtts
from playsound import playsound
from pilot.speech.base import VoiceBase
@ -15,6 +14,8 @@ class GTTSVoice(VoiceBase):
def _speech(self, text: str, _: int = 0) -> bool:
"""Play the given text."""
from playsound import playsound
tts = gtts.gTTS(text)
tts.save("speech.mp3")
playsound("speech.mp3", True)

View File

@ -184,5 +184,5 @@ def _get_llm_response(query, db_input, dbsummary):
chat: BaseChat = chat_factory.get_implementation(
ChatScene.InnerChatDBSummary.value, **chat_param
)
res = chat.nostream_call()
res = chat._blocking_nostream_call()
return json.loads(res)["table"]

9
pilot/utils/__init__.py Normal file
View File

@ -0,0 +1,9 @@
from .utils import (
get_gpu_memory,
build_logger,
StreamToLogger,
disable_torch_init,
pretty_print_semaphore,
server_error_msg,
get_or_create_event_loop,
)

101
pilot/utils/api_utils.py Normal file
View File

@ -0,0 +1,101 @@
import httpx
from inspect import signature
import typing_inspect
import logging
from typing import get_type_hints, List, Type, TypeVar, Union, Optional, Tuple
from dataclasses import is_dataclass, asdict
T = TypeVar("T")
def _extract_dataclass_from_generic(type_hint: Type[T]) -> Union[Type[T], None]:
"""Extract actual dataclass from generic type hints like List[dataclass], Optional[dataclass], etc."""
if typing_inspect.is_generic_type(type_hint) and typing_inspect.get_args(type_hint):
return typing_inspect.get_args(type_hint)[0]
return None
def _api_remote(path, method="GET"):
def decorator(func):
return_type = get_type_hints(func).get("return")
if return_type is None:
raise TypeError("Return type must be annotated in the decorated function.")
actual_dataclass = _extract_dataclass_from_generic(return_type)
logging.debug(
f"return_type: {return_type}, actual_dataclass: {actual_dataclass}"
)
if not actual_dataclass:
actual_dataclass = return_type
sig = signature(func)
async def wrapper(self, *args, **kwargs):
base_url = self.base_url # Get base_url from class instance
bound = sig.bind(self, *args, **kwargs)
bound.apply_defaults()
formatted_url = base_url + path.format(**bound.arguments)
# Extract args names from signature, except "self"
arg_names = list(sig.parameters.keys())[1:]
# Combine args and kwargs into a single dictionary
combined_args = dict(zip(arg_names, args))
combined_args.update(kwargs)
request_data = {}
for key, value in combined_args.items():
if is_dataclass(value):
# Here, instead of adding it as a nested dictionary,
# we set request_data directly to its dictionary representation.
request_data = asdict(value)
else:
request_data[key] = value
request_params = {"method": method, "url": formatted_url}
if method in ["POST", "PUT", "PATCH"]:
request_params["json"] = request_data
else: # For GET, DELETE, etc.
request_params["params"] = request_data
logging.info(
f"request_params: {request_params}, args: {args}, kwargs: {kwargs}"
)
async with httpx.AsyncClient() as client:
response = await client.request(**request_params)
if response.status_code == 200:
return _parse_response(
response.json(), return_type, actual_dataclass
)
else:
error_msg = f"Remote request error, error code: {response.status_code}, error msg: {response.text}"
raise Exception(error_msg)
return wrapper
return decorator
def _parse_response(json_response, return_type, actual_dataclass):
# print(f'return_type.__origin__: {return_type.__origin__}, actual_dataclass: {actual_dataclass}, json_response: {json_response}')
if is_dataclass(actual_dataclass):
if return_type.__origin__ is list: # for List[dataclass]
if isinstance(json_response, list):
return [actual_dataclass(**item) for item in json_response]
else:
raise TypeError(
f"Expected list in response but got {type(json_response)}"
)
else:
if isinstance(json_response, dict):
return actual_dataclass(**json_response)
else:
raise TypeError(
f"Expected dictionary in response but got {type(json_response)}"
)
else:
return json_response

View File

@ -0,0 +1,27 @@
import logging
def _clear_torch_cache(device="cuda"):
import gc
import torch
gc.collect()
if device != "cpu":
if torch.has_mps:
try:
from torch.mps import empty_cache
empty_cache()
except Exception as e:
logging.warn(f"Clear mps torch cache error, {str(e)}")
elif torch.has_cuda:
device_count = torch.cuda.device_count()
for device_id in range(device_count):
cuda_device = f"cuda:{device_id}"
logging.info(f"Clear torch cache of device: {cuda_device}")
with torch.cuda.device(cuda_device):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
else:
logging.info("No cuda or mps, not support clear torch cache yet")

View File

@ -0,0 +1,26 @@
from typing import Type
from importlib import import_module
def import_from_string(module_path: str):
try:
module_path, class_name = module_path.rsplit(".", 1)
except ValueError:
raise ImportError(f"{module_path} doesn't look like a module path")
module = import_module(module_path)
try:
return getattr(module, class_name)
except AttributeError:
raise ImportError(
f'Module "{module_path}" does not define a "{class_name}" attribute/class'
)
def import_from_checked_string(module_path: str, supper_cls: Type):
cls = import_from_string(module_path)
if not issubclass(cls, supper_cls):
raise ImportError(
f'Module "{module_path}" does not the subclass of {str(supper_cls)}'
)
return cls

24
pilot/utils/net_utils.py Normal file
View File

@ -0,0 +1,24 @@
import socket
import errno
def _get_ip_address(address: str = "10.254.254.254:1") -> str:
ip, port = address.split(":")
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.settimeout(0)
curr_address = "127.0.0.1"
try:
# doesn't even have to be reachable
s.connect((ip, int(port)))
curr_address = s.getsockname()[0]
except OSError as e:
IP = "127.0.0.1"
if e.errno == errno.ENETUNREACH:
try:
hostname = socket.getfqdn(socket.gethostname())
curr_address = socket.gethostbyname(hostname)
except Exception:
pass
finally:
s.close()
return curr_address

View File

@ -5,9 +5,7 @@ import logging
import logging.handlers
import os
import sys
import requests
import torch
import asyncio
from pilot.configs.model_config import LOGDIR
@ -19,6 +17,8 @@ handler = None
def get_gpu_memory(max_gpus=None):
import torch
gpu_memory = []
num_gpus = (
torch.cuda.device_count()
@ -73,7 +73,7 @@ def build_logger(logger_name, logger_filename):
for name, item in logging.root.manager.loggerDict.items():
if isinstance(item, logging.Logger):
item.addHandler(handler)
logging.basicConfig(level=logging.INFO)
logging.basicConfig(level=logging.INFO, encoding="utf-8")
# Get logger
logger = logging.getLogger(logger_name)
@ -132,3 +132,15 @@ def pretty_print_semaphore(semaphore):
if semaphore is None:
return "None"
return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
def get_or_create_event_loop() -> asyncio.BaseEventLoop:
try:
loop = asyncio.get_event_loop()
except Exception as e:
if not "no running event loop" in str(e):
raise e
logging.warning("Cant not get running event loop, create new event loop now")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop

View File

@ -76,4 +76,7 @@ bardapi==0.1.29
# TODO moved to optional dependencies
pymysql
duckdb
duckdb-engine
duckdb-engine
# cli
prettytable

View File

@ -267,7 +267,7 @@ def llama_cpp_python_cuda_requires():
llama_cpp_version = "0.1.77"
py_version = "cp310"
os_pkg_name = "linux_x86_64" if os_type == OSType.LINUX else "win_amd64"
extra_index_url = f"{base_url}/llama_cpp_python_cuda-{llama_cpp_version}+{device}{cpu_avx}-{py_version}-{py_version}-{os_pkg_name}.whl"
extra_index_url = f"{base_url}/llama_cpp_python_cuda-{llama_cpp_version}+{device}-{py_version}-{py_version}-{os_pkg_name}.whl"
extra_index_url, _ = encode_url(extra_index_url)
print(f"Install llama_cpp_python_cuda from {extra_index_url}")
@ -361,7 +361,7 @@ setuptools.setup(
extras_require=setup_spec.extras,
entry_points={
"console_scripts": [
"dbgpt_server=pilot.server:webserver",
"dbgpt=pilot.scripts.cli_scripts:main",
],
},
)

View File

@ -25,7 +25,6 @@ sys.path.append(
from pilot.configs.model_config import DATASETS_DIR
from tools.cli.knowledge_client import knowledge_init
API_ADDRESS: str = "http://127.0.0.1:5000"
@ -97,6 +96,8 @@ def knowledge(
verbose: bool,
):
"""Knowledge command line tool"""
from tools.cli.knowledge_client import knowledge_init
knowledge_init(
API_ADDRESS,
vector_name,