mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-23 02:27:55 +00:00
feat: Support multi model (#501)
- **ModelController** : Model register and manager. - **ModelRegistry**: Abstract base class for a model registry. It provides an interface for registering, deregistering, fetching instances, and sending heartbeats for instances. - **ModelWorker**: Abstract representation of a Model Worker responsible for model interaction, startup, and shutdown. Supports 'llm' and 'text2vec' models. - **WorkerManager**: Manager deployed worker instance in current server. The `WorkerManager` also is the handle to invoke model service. - **Model command line tools**: List, stop, start, restart model instances. - Modify `BaseChat`: Asynchronous chat messages for higher performance.
This commit is contained in:
commit
0983456311
@ -12,19 +12,13 @@ from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
LlamaTokenizer,
|
||||
BitsAndBytesConfig,
|
||||
)
|
||||
from pilot.model.parameter import ModelParameters, LlamaCppModelParameters
|
||||
from pilot.configs.model_config import DEVICE
|
||||
from pilot.configs.config import Config
|
||||
from pilot.logs import logger
|
||||
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype="bfloat16",
|
||||
bnb_4bit_use_double_quant=False,
|
||||
)
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
@ -203,6 +197,14 @@ class FalconAdapater(BaseLLMAdaper):
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
||||
|
||||
if CFG.QLoRA:
|
||||
from transformers import BitsAndBytesConfig
|
||||
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype="bfloat16",
|
||||
bnb_4bit_use_double_quant=False,
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
load_in_4bit=True, # quantize
|
||||
|
@ -1,7 +1,10 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from typing import TypedDict
|
||||
from enum import Enum
|
||||
from typing import TypedDict, Optional, Dict
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class Message(TypedDict):
|
||||
@ -9,3 +12,39 @@ class Message(TypedDict):
|
||||
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelInstance:
|
||||
"""Model instance info"""
|
||||
|
||||
model_name: str
|
||||
host: str
|
||||
port: int
|
||||
weight: Optional[float] = 1.0
|
||||
check_healthy: Optional[bool] = True
|
||||
healthy: Optional[bool] = False
|
||||
enabled: Optional[bool] = True
|
||||
prompt_template: Optional[str] = None
|
||||
last_heartbeat: Optional[datetime] = None
|
||||
|
||||
|
||||
class WorkerApplyType(str, Enum):
|
||||
START = "start"
|
||||
STOP = "stop"
|
||||
RESTART = "restart"
|
||||
UPDATE_PARAMS = "update_params"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelOutput:
|
||||
text: str
|
||||
error_code: int
|
||||
model_context: Dict = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkerApplyOutput:
|
||||
message: str
|
||||
# The seconds cost to apply some action to worker instances
|
||||
timecost: Optional[int] = -1
|
||||
|
151
pilot/model/cli.py
Normal file
151
pilot/model/cli.py
Normal file
@ -0,0 +1,151 @@
|
||||
import click
|
||||
import functools
|
||||
|
||||
from pilot.model.controller.registry import ModelRegistryClient
|
||||
from pilot.model.worker.manager import (
|
||||
RemoteWorkerManager,
|
||||
WorkerApplyRequest,
|
||||
WorkerApplyType,
|
||||
)
|
||||
from pilot.utils import get_or_create_event_loop
|
||||
|
||||
|
||||
@click.group("model")
|
||||
def model_cli_group():
|
||||
pass
|
||||
|
||||
|
||||
@model_cli_group.command()
|
||||
@click.option(
|
||||
"--address",
|
||||
type=str,
|
||||
default="http://127.0.0.1:8000",
|
||||
required=False,
|
||||
help=(
|
||||
"Address of the Model Controller to connect to."
|
||||
"Just support light deploy model"
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--model-name", type=str, default=None, required=False, help=("The name of model")
|
||||
)
|
||||
@click.option(
|
||||
"--model-type", type=str, default="llm", required=False, help=("The type of model")
|
||||
)
|
||||
def list(address: str, model_name: str, model_type: str):
|
||||
"""List model instances"""
|
||||
from prettytable import PrettyTable
|
||||
|
||||
loop = get_or_create_event_loop()
|
||||
registry = ModelRegistryClient(address)
|
||||
|
||||
if not model_name:
|
||||
instances = loop.run_until_complete(registry.get_all_model_instances())
|
||||
else:
|
||||
if not model_type:
|
||||
model_type = "llm"
|
||||
register_model_name = f"{model_name}@{model_type}"
|
||||
instances = loop.run_until_complete(
|
||||
registry.get_all_instances(register_model_name)
|
||||
)
|
||||
table = PrettyTable()
|
||||
|
||||
table.field_names = [
|
||||
"Model Name",
|
||||
"Model Type",
|
||||
"Host",
|
||||
"Port",
|
||||
"Healthy",
|
||||
"Enabled",
|
||||
"Prompt Template",
|
||||
"Last Heartbeat",
|
||||
]
|
||||
for instance in instances:
|
||||
model_name, model_type = instance.model_name.split("@")
|
||||
table.add_row(
|
||||
[
|
||||
model_name,
|
||||
model_type,
|
||||
instance.host,
|
||||
instance.port,
|
||||
instance.healthy,
|
||||
instance.enabled,
|
||||
instance.prompt_template,
|
||||
instance.last_heartbeat,
|
||||
]
|
||||
)
|
||||
|
||||
print(table)
|
||||
|
||||
|
||||
def add_model_options(func):
|
||||
@click.option(
|
||||
"--address",
|
||||
type=str,
|
||||
default="http://127.0.0.1:8000",
|
||||
required=False,
|
||||
help=(
|
||||
"Address of the Model Controller to connect to."
|
||||
"Just support light deploy model"
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--model-name",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help=("The name of model"),
|
||||
)
|
||||
@click.option(
|
||||
"--model-type",
|
||||
type=str,
|
||||
default="llm",
|
||||
required=False,
|
||||
help=("The type of model"),
|
||||
)
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@model_cli_group.command()
|
||||
@add_model_options
|
||||
def stop(address: str, model_name: str, model_type: str):
|
||||
"""Stop model instances"""
|
||||
worker_apply(address, model_name, model_type, WorkerApplyType.STOP)
|
||||
|
||||
|
||||
@model_cli_group.command()
|
||||
@add_model_options
|
||||
def start(address: str, model_name: str, model_type: str):
|
||||
"""Start model instances"""
|
||||
worker_apply(address, model_name, model_type, WorkerApplyType.START)
|
||||
|
||||
|
||||
@model_cli_group.command()
|
||||
@add_model_options
|
||||
def restart(address: str, model_name: str, model_type: str):
|
||||
"""Restart model instances"""
|
||||
worker_apply(address, model_name, model_type, WorkerApplyType.RESTART)
|
||||
|
||||
|
||||
# @model_cli_group.command()
|
||||
# @add_model_options
|
||||
# def modify(address: str, model_name: str, model_type: str):
|
||||
# """Restart model instances"""
|
||||
# worker_apply(address, model_name, model_type, WorkerApplyType.UPDATE_PARAMS)
|
||||
|
||||
|
||||
def worker_apply(
|
||||
address: str, model_name: str, model_type: str, apply_type: WorkerApplyType
|
||||
):
|
||||
loop = get_or_create_event_loop()
|
||||
registry = ModelRegistryClient(address)
|
||||
worker_manager = RemoteWorkerManager(registry)
|
||||
apply_req = WorkerApplyRequest(
|
||||
model=model_name, worker_type=model_type, apply_type=apply_type
|
||||
)
|
||||
res = loop.run_until_complete(worker_manager.worker_apply(apply_req))
|
||||
print(res)
|
0
pilot/model/controller/__init__.py
Normal file
0
pilot/model/controller/__init__.py
Normal file
65
pilot/model/controller/controller.py
Normal file
65
pilot/model/controller/controller.py
Normal file
@ -0,0 +1,65 @@
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter
|
||||
from pilot.model.base import ModelInstance
|
||||
from pilot.model.controller.registry import EmbeddedModelRegistry, ModelRegistry
|
||||
|
||||
|
||||
class ModelController:
|
||||
def __init__(self, registry: ModelRegistry = None) -> None:
|
||||
if not registry:
|
||||
registry = EmbeddedModelRegistry()
|
||||
self.registry = registry
|
||||
self.deployment = None
|
||||
|
||||
async def register_instance(self, instance: ModelInstance) -> bool:
|
||||
return await self.registry.register_instance(instance)
|
||||
|
||||
async def deregister_instance(self, instance: ModelInstance) -> bool:
|
||||
return await self.registry.deregister_instance(instance)
|
||||
|
||||
async def get_all_instances(
|
||||
self, model_name: str, healthy_only: bool = False
|
||||
) -> List[ModelInstance]:
|
||||
logging.info(
|
||||
f"Get all instances with {model_name}, healthy_only: {healthy_only}"
|
||||
)
|
||||
return await self.registry.get_all_instances(model_name, healthy_only)
|
||||
|
||||
async def get_all_model_instances(self) -> List[ModelInstance]:
|
||||
return await self.registry.get_all_model_instances()
|
||||
|
||||
async def send_heartbeat(self, instance: ModelInstance) -> bool:
|
||||
return await self.registry.send_heartbeat(instance)
|
||||
|
||||
async def model_apply(self) -> bool:
|
||||
# TODO
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
controller = ModelController()
|
||||
|
||||
|
||||
@router.post("/controller/models")
|
||||
async def api_register_instance(request: ModelInstance):
|
||||
return await controller.register_instance(request)
|
||||
|
||||
|
||||
@router.delete("/controller/models")
|
||||
async def api_deregister_instance(request: ModelInstance):
|
||||
return await controller.deregister_instance(request)
|
||||
|
||||
|
||||
@router.get("/controller/models")
|
||||
async def api_get_all_instances(model_name: str = None, healthy_only: bool = False):
|
||||
if not model_name:
|
||||
return await controller.get_all_model_instances()
|
||||
return await controller.get_all_instances(model_name, healthy_only=healthy_only)
|
||||
|
||||
|
||||
@router.post("/controller/heartbeat")
|
||||
async def api_model_heartbeat(request: ModelInstance):
|
||||
return await controller.send_heartbeat(request)
|
0
pilot/model/controller/ray_controller.py
Normal file
0
pilot/model/controller/ray_controller.py
Normal file
224
pilot/model/controller/registry.py
Normal file
224
pilot/model/controller/registry.py
Normal file
@ -0,0 +1,224 @@
|
||||
import random
|
||||
import threading
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Tuple
|
||||
import itertools
|
||||
|
||||
from pilot.model.base import ModelInstance
|
||||
|
||||
|
||||
class ModelRegistry(ABC):
|
||||
"""
|
||||
Abstract base class for a model registry. It provides an interface
|
||||
for registering, deregistering, fetching instances, and sending heartbeats
|
||||
for instances.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def register_instance(self, instance: ModelInstance) -> bool:
|
||||
"""
|
||||
Register a given model instance.
|
||||
|
||||
Args:
|
||||
- instance (ModelInstance): The instance of the model to register.
|
||||
|
||||
Returns:
|
||||
- bool: True if registration is successful, False otherwise.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def deregister_instance(self, instance: ModelInstance) -> bool:
|
||||
"""
|
||||
Deregister a given model instance.
|
||||
|
||||
Args:
|
||||
- instance (ModelInstance): The instance of the model to deregister.
|
||||
|
||||
Returns:
|
||||
- bool: True if deregistration is successful, False otherwise.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_all_instances(
|
||||
self, model_name: str, healthy_only: bool = False
|
||||
) -> List[ModelInstance]:
|
||||
"""
|
||||
Fetch all instances of a given model. Optionally, fetch only the healthy instances.
|
||||
|
||||
Args:
|
||||
- model_name (str): Name of the model to fetch instances for.
|
||||
- healthy_only (bool, optional): If set to True, fetches only the healthy instances.
|
||||
Defaults to False.
|
||||
|
||||
Returns:
|
||||
- List[ModelInstance]: A list of instances for the given model.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_all_model_instances(self) -> List[ModelInstance]:
|
||||
"""
|
||||
Fetch all instances of all models
|
||||
|
||||
Returns:
|
||||
- List[ModelInstance]: A list of instances for the all models.
|
||||
"""
|
||||
|
||||
async def select_one_health_instance(self, model_name: str) -> ModelInstance:
|
||||
"""
|
||||
Selects one healthy and enabled instance for a given model.
|
||||
|
||||
Args:
|
||||
- model_name (str): Name of the model.
|
||||
|
||||
Returns:
|
||||
- ModelInstance: One randomly selected healthy and enabled instance, or None if no such instance exists.
|
||||
"""
|
||||
instances = await self.get_all_instances(model_name, healthy_only=True)
|
||||
instances = [i for i in instances if i.enabled]
|
||||
if not instances:
|
||||
return None
|
||||
return random.choice(instances)
|
||||
|
||||
@abstractmethod
|
||||
async def send_heartbeat(self, instance: ModelInstance) -> bool:
|
||||
"""
|
||||
Send a heartbeat for a given model instance. This can be used to
|
||||
verify if the instance is still alive and functioning.
|
||||
|
||||
Args:
|
||||
- instance (ModelInstance): The instance of the model to send a heartbeat for.
|
||||
|
||||
Returns:
|
||||
- bool: True if heartbeat is successful, False otherwise.
|
||||
"""
|
||||
|
||||
|
||||
class EmbeddedModelRegistry(ModelRegistry):
|
||||
def __init__(
|
||||
self, heartbeat_interval_secs: int = 60, heartbeat_timeout_secs: int = 120
|
||||
):
|
||||
self.registry: Dict[str, List[ModelInstance]] = defaultdict(list)
|
||||
self.heartbeat_interval_secs = heartbeat_interval_secs
|
||||
self.heartbeat_timeout_secs = heartbeat_timeout_secs
|
||||
self.heartbeat_thread = threading.Thread(target=self._heartbeat_checker)
|
||||
self.heartbeat_thread.daemon = True
|
||||
self.heartbeat_thread.start()
|
||||
|
||||
def _get_instances(
|
||||
self, model_name: str, host: str, port: int, healthy_only: bool = False
|
||||
) -> Tuple[List[ModelInstance], List[ModelInstance]]:
|
||||
instances = self.registry[model_name]
|
||||
if healthy_only:
|
||||
instances = [ins for ins in instances if ins.healthy == True]
|
||||
exist_ins = [ins for ins in instances if ins.host == host and ins.port == port]
|
||||
return instances, exist_ins
|
||||
|
||||
def _heartbeat_checker(self):
|
||||
while True:
|
||||
for instances in self.registry.values():
|
||||
for instance in instances:
|
||||
if (
|
||||
instance.check_healthy
|
||||
and datetime.now() - instance.last_heartbeat
|
||||
> timedelta(seconds=self.heartbeat_timeout_secs)
|
||||
):
|
||||
instance.healthy = False
|
||||
time.sleep(self.heartbeat_interval_secs)
|
||||
|
||||
async def register_instance(self, instance: ModelInstance) -> bool:
|
||||
model_name = instance.model_name.strip()
|
||||
host = instance.host.strip()
|
||||
port = instance.port
|
||||
|
||||
instances, exist_ins = self._get_instances(
|
||||
model_name, host, port, healthy_only=False
|
||||
)
|
||||
if exist_ins:
|
||||
# One exist instance at most
|
||||
ins = exist_ins[0]
|
||||
# Update instance
|
||||
ins.weight = instance.weight
|
||||
ins.healthy = True
|
||||
ins.prompt_template = instance.prompt_template
|
||||
ins.last_heartbeat = datetime.now()
|
||||
else:
|
||||
instance.healthy = True
|
||||
instance.last_heartbeat = datetime.now()
|
||||
instances.append(instance)
|
||||
return True
|
||||
|
||||
async def deregister_instance(self, instance: ModelInstance) -> bool:
|
||||
model_name = instance.model_name.strip()
|
||||
host = instance.host.strip()
|
||||
port = instance.port
|
||||
_, exist_ins = self._get_instances(model_name, host, port, healthy_only=False)
|
||||
if exist_ins:
|
||||
ins = exist_ins[0]
|
||||
ins.healthy = False
|
||||
return True
|
||||
|
||||
async def get_all_instances(
|
||||
self, model_name: str, healthy_only: bool = False
|
||||
) -> List[ModelInstance]:
|
||||
instances = self.registry[model_name]
|
||||
if healthy_only:
|
||||
instances = [ins for ins in instances if ins.healthy == True]
|
||||
return instances
|
||||
|
||||
async def get_all_model_instances(self) -> List[ModelInstance]:
|
||||
print(self.registry)
|
||||
return list(itertools.chain(*self.registry.values()))
|
||||
|
||||
async def send_heartbeat(self, instance: ModelInstance) -> bool:
|
||||
_, exist_ins = self._get_instances(
|
||||
instance.model_name, instance.host, instance.port, healthy_only=False
|
||||
)
|
||||
if not exist_ins:
|
||||
return False
|
||||
|
||||
ins = exist_ins[0]
|
||||
ins.last_heartbeat = datetime.now()
|
||||
ins.healthy = True
|
||||
return True
|
||||
|
||||
|
||||
from pilot.utils.api_utils import _api_remote as api_remote
|
||||
|
||||
|
||||
class ModelRegistryClient(ModelRegistry):
|
||||
def __init__(self, base_url: str) -> None:
|
||||
self.base_url = base_url
|
||||
|
||||
@api_remote(path="/api/controller/models", method="POST")
|
||||
async def register_instance(self, instance: ModelInstance) -> bool:
|
||||
pass
|
||||
|
||||
@api_remote(path="/api/controller/models", method="DELETE")
|
||||
async def deregister_instance(self, instance: ModelInstance) -> bool:
|
||||
pass
|
||||
|
||||
@api_remote(path="/api/controller/models")
|
||||
async def get_all_instances(
|
||||
self, model_name: str, healthy_only: bool = False
|
||||
) -> List[ModelInstance]:
|
||||
pass
|
||||
|
||||
@api_remote(path="/api/controller/models")
|
||||
async def get_all_model_instances(self) -> List[ModelInstance]:
|
||||
pass
|
||||
|
||||
@api_remote(path="/api/controller/models")
|
||||
async def select_one_health_instance(self, model_name: str) -> ModelInstance:
|
||||
instances = await self.get_all_instances(model_name, healthy_only=True)
|
||||
instances = [i for i in instances if i.enabled]
|
||||
if not instances:
|
||||
return None
|
||||
return random.choice(instances)
|
||||
|
||||
@api_remote(path="/api/controller/heartbeat", method="POST")
|
||||
async def send_heartbeat(self, instance: ModelInstance) -> bool:
|
||||
pass
|
0
pilot/model/controller/tests/__init__.py
Normal file
0
pilot/model/controller/tests/__init__.py
Normal file
148
pilot/model/controller/tests/test_registry.py
Normal file
148
pilot/model/controller/tests/test_registry.py
Normal file
@ -0,0 +1,148 @@
|
||||
import pytest
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import patch
|
||||
from pilot.model.base import ModelInstance
|
||||
from pilot.model.controller.registry import ModelRegistry, EmbeddedModelRegistry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_registry():
|
||||
return EmbeddedModelRegistry()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_instance():
|
||||
return ModelInstance(
|
||||
model_name="test_model",
|
||||
ip="192.168.1.1",
|
||||
port=5000,
|
||||
)
|
||||
|
||||
|
||||
# Async function to test the registry
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_instance(model_registry, model_instance):
|
||||
"""
|
||||
Test if an instance can be registered correctly
|
||||
"""
|
||||
assert await model_registry.register_instance(model_instance) == True
|
||||
assert len(model_registry.registry[model_instance.model_name]) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deregister_instance(model_registry, model_instance):
|
||||
"""
|
||||
Test if an instance can be deregistered correctly
|
||||
"""
|
||||
await model_registry.register_instance(model_instance)
|
||||
assert await model_registry.deregister_instance(model_instance) == True
|
||||
assert not model_registry.registry[model_instance.model_name][0].healthy
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_all_instances(model_registry, model_instance):
|
||||
"""
|
||||
Test if all instances can be retrieved, with and without the healthy_only filter
|
||||
"""
|
||||
await model_registry.register_instance(model_instance)
|
||||
assert len(await model_registry.get_all_instances(model_instance.model_name)) == 1
|
||||
assert (
|
||||
len(
|
||||
await model_registry.get_all_instances(
|
||||
model_instance.model_name, healthy_only=True
|
||||
)
|
||||
)
|
||||
== 1
|
||||
)
|
||||
model_instance.healthy = False
|
||||
assert (
|
||||
len(
|
||||
await model_registry.get_all_instances(
|
||||
model_instance.model_name, healthy_only=True
|
||||
)
|
||||
)
|
||||
== 0
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_select_one_health_instance(model_registry, model_instance):
|
||||
"""
|
||||
Test if a single healthy instance can be selected
|
||||
"""
|
||||
await model_registry.register_instance(model_instance)
|
||||
selected_instance = await model_registry.select_one_health_instance(
|
||||
model_instance.model_name
|
||||
)
|
||||
assert selected_instance is not None
|
||||
assert selected_instance.healthy
|
||||
assert selected_instance.enabled
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_heartbeat(model_registry, model_instance):
|
||||
"""
|
||||
Test if a heartbeat can be sent and that it correctly updates the last_heartbeat timestamp
|
||||
"""
|
||||
await model_registry.register_instance(model_instance)
|
||||
last_heartbeat = datetime.now() - timedelta(seconds=10)
|
||||
model_instance.last_heartbeat = last_heartbeat
|
||||
assert (
|
||||
await model_registry.send_heartbeat(
|
||||
model_instance.model_name, model_instance.ip, model_instance.port
|
||||
)
|
||||
== True
|
||||
)
|
||||
assert (
|
||||
model_registry.registry[model_instance.model_name][0].last_heartbeat
|
||||
> last_heartbeat
|
||||
)
|
||||
assert model_registry.registry[model_instance.model_name][0].healthy == True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_heartbeat_timeout(model_registry, model_instance):
|
||||
"""
|
||||
Test if an instance is marked as unhealthy when the heartbeat is not sent within the timeout
|
||||
"""
|
||||
model_registry = EmbeddedModelRegistry(1, 1)
|
||||
await model_registry.register_instance(model_instance)
|
||||
model_registry.registry[model_instance.model_name][
|
||||
0
|
||||
].last_heartbeat = datetime.now() - timedelta(
|
||||
seconds=model_registry.heartbeat_timeout_secs + 1
|
||||
)
|
||||
await asyncio.sleep(model_registry.heartbeat_interval_secs + 1)
|
||||
assert not model_registry.registry[model_instance.model_name][0].healthy
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_instances(model_registry, model_instance):
|
||||
"""
|
||||
Test if multiple instances of the same model are handled correctly
|
||||
"""
|
||||
model_instance2 = ModelInstance(
|
||||
model_name="test_model",
|
||||
ip="192.168.1.2",
|
||||
port=5000,
|
||||
)
|
||||
await model_registry.register_instance(model_instance)
|
||||
await model_registry.register_instance(model_instance2)
|
||||
assert len(await model_registry.get_all_instances(model_instance.model_name)) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_same_model_name_different_ip_port(model_registry):
|
||||
"""
|
||||
Test if instances with the same model name but different IP and port are handled correctly
|
||||
"""
|
||||
instance1 = ModelInstance(model_name="test_model", ip="192.168.1.1", port=5000)
|
||||
instance2 = ModelInstance(model_name="test_model", ip="192.168.1.2", port=6000)
|
||||
await model_registry.register_instance(instance1)
|
||||
await model_registry.register_instance(instance2)
|
||||
instances = await model_registry.get_all_instances("test_model")
|
||||
assert len(instances) == 2
|
||||
assert instances[0].ip != instances[1].ip
|
||||
assert instances[0].port != instances[1].port
|
@ -60,7 +60,7 @@ def _get_model_real_path(model_name, default_model_path) -> str:
|
||||
return _genenv_ignoring_key_case("model_path", default_value=default_model_path)
|
||||
|
||||
|
||||
class ModelLoader(metaclass=Singleton):
|
||||
class ModelLoader:
|
||||
"""Model loader is a class for model load
|
||||
|
||||
Args: model_path
|
||||
@ -68,17 +68,11 @@ class ModelLoader(metaclass=Singleton):
|
||||
TODO: multi model support.
|
||||
"""
|
||||
|
||||
kwargs = {}
|
||||
|
||||
def __init__(self, model_path: str, model_name: str = None) -> None:
|
||||
self.device = DEVICE
|
||||
self.model_path = model_path
|
||||
self.model_name = model_name
|
||||
self.prompt_template: str = None
|
||||
self.kwargs = {
|
||||
"torch_dtype": torch.float16,
|
||||
"device_map": "auto",
|
||||
}
|
||||
|
||||
# TODO multi gpu support
|
||||
def loader(
|
||||
@ -121,6 +115,18 @@ class ModelLoader(metaclass=Singleton):
|
||||
else:
|
||||
raise Exception(f"Unkown model type {model_type}")
|
||||
|
||||
def loader_with_params(self, model_params: ModelParameters):
|
||||
llm_adapter = get_llm_model_adapter(self.model_name, self.model_path)
|
||||
model_type = llm_adapter.model_type()
|
||||
self.prompt_template = model_params.prompt_template
|
||||
logger.info(f"model_params:\n{model_params}")
|
||||
if model_type == ModelType.HF:
|
||||
return huggingface_loader(llm_adapter, model_params)
|
||||
elif model_type == ModelType.LLAMA_CPP:
|
||||
return llamacpp_loader(llm_adapter, model_params)
|
||||
else:
|
||||
raise Exception(f"Unkown model type {model_type}")
|
||||
|
||||
|
||||
def huggingface_loader(llm_adapter: BaseLLMAdaper, model_params: ModelParameters):
|
||||
device = model_params.device
|
||||
|
@ -1,9 +1,10 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
import argparse
|
||||
import os
|
||||
|
||||
from typing import Any, Optional, Type
|
||||
from dataclasses import dataclass, field, fields
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
|
||||
from pilot.model.conversation import conv_templates
|
||||
|
||||
@ -20,11 +21,28 @@ def _genenv_ignoring_key_case(env_key: str, env_prefix: str = None, default_valu
|
||||
|
||||
|
||||
class EnvArgumentParser:
|
||||
@staticmethod
|
||||
def get_env_prefix(env_key: str) -> str:
|
||||
if not env_key:
|
||||
return None
|
||||
env_key = env_key.replace("-", "_")
|
||||
return env_key + "_"
|
||||
|
||||
def parse_args_into_dataclass(
|
||||
self, dataclass_type: Type, env_prefix: str = None, **kwargs
|
||||
self,
|
||||
dataclass_type: Type,
|
||||
env_prefix: str = None,
|
||||
command_args: List[str] = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Parse parameters from environment variables and command lines and populate them into data class"""
|
||||
parser = argparse.ArgumentParser()
|
||||
for field in fields(dataclass_type):
|
||||
env_var_value = _genenv_ignoring_key_case(field.name, env_prefix)
|
||||
if not env_var_value:
|
||||
# Read without env prefix
|
||||
env_var_value = _genenv_ignoring_key_case(field.name)
|
||||
|
||||
if env_var_value:
|
||||
env_var_value = env_var_value.strip()
|
||||
if field.type is int or field.type == Optional[int]:
|
||||
@ -37,17 +55,241 @@ class EnvArgumentParser:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"Unsupported parameter type {field.type}")
|
||||
kwargs[field.name] = env_var_value
|
||||
if not env_var_value:
|
||||
env_var_value = kwargs.get(field.name)
|
||||
if not env_var_value:
|
||||
env_var_value = field.default
|
||||
|
||||
# Add a command-line argument for this field
|
||||
help_text = field.metadata.get("help", "")
|
||||
valid_values = field.metadata.get("valid_values", None)
|
||||
parser.add_argument(
|
||||
f"--{field.name}",
|
||||
type=self._get_argparse_type(field.type),
|
||||
help=help_text,
|
||||
choices=valid_values,
|
||||
default=env_var_value,
|
||||
)
|
||||
|
||||
# Parse the command-line arguments
|
||||
cmd_args, cmd_argv = parser.parse_known_args(args=command_args)
|
||||
print(f"cmd_args: {cmd_args}")
|
||||
for field in fields(dataclass_type):
|
||||
# cmd_line_value = getattr(cmd_args, field.name)
|
||||
if field.name in cmd_args:
|
||||
cmd_line_value = getattr(cmd_args, field.name)
|
||||
if cmd_line_value is not None:
|
||||
kwargs[field.name] = cmd_line_value
|
||||
|
||||
return dataclass_type(**kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _get_argparse_type(field_type: Type) -> Type:
|
||||
# Return the appropriate type for argparse to use based on the field type
|
||||
if field_type is int or field_type == Optional[int]:
|
||||
return int
|
||||
elif field_type is float or field_type == Optional[float]:
|
||||
return float
|
||||
elif field_type is bool or field_type == Optional[bool]:
|
||||
return bool
|
||||
elif field_type is str or field_type == Optional[str]:
|
||||
return str
|
||||
else:
|
||||
raise ValueError(f"Unsupported parameter type {field_type}")
|
||||
|
||||
@staticmethod
|
||||
def _get_argparse_type_str(field_type: Type) -> str:
|
||||
argparse_type = EnvArgumentParser._get_argparse_type(field_type)
|
||||
if argparse_type is int:
|
||||
return "int"
|
||||
elif argparse_type is float:
|
||||
return "float"
|
||||
elif argparse_type is bool:
|
||||
return "bool"
|
||||
else:
|
||||
return "str"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelParameters:
|
||||
device: str = field(metadata={"help": "Device to run model"})
|
||||
model_name: str = field(metadata={"help": "Model name"})
|
||||
model_path: str = field(metadata={"help": "Model path"})
|
||||
class ParameterDescription:
|
||||
param_name: str
|
||||
param_type: str
|
||||
description: str
|
||||
default_value: Optional[Any]
|
||||
valid_values: Optional[List[Any]]
|
||||
|
||||
|
||||
def _get_parameter_descriptions(dataclass_type: Type) -> List[ParameterDescription]:
|
||||
descriptions = []
|
||||
for field in fields(dataclass_type):
|
||||
descriptions.append(
|
||||
ParameterDescription(
|
||||
param_name=field.name,
|
||||
param_type=EnvArgumentParser._get_argparse_type_str(field.type),
|
||||
description=field.metadata.get("help", None),
|
||||
default_value=field.default, # TODO handle dataclasses._MISSING_TYPE
|
||||
valid_values=field.metadata.get("valid_values", None),
|
||||
)
|
||||
)
|
||||
return descriptions
|
||||
|
||||
|
||||
class WorkerType(str, Enum):
|
||||
LLM = "llm"
|
||||
TEXT2VEC = "text2vec"
|
||||
|
||||
@staticmethod
|
||||
def values():
|
||||
return [item.value for item in WorkerType]
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseParameters:
|
||||
def update_from(self, source: Union["BaseParameters", dict]) -> bool:
|
||||
"""
|
||||
Update the attributes of this object using the values from another object (of the same or parent type) or a dictionary.
|
||||
Only update if the new value is different from the current value and the field is not marked as "fixed" in metadata.
|
||||
|
||||
Args:
|
||||
source (Union[BaseParameters, dict]): The source to update from. Can be another object of the same type or a dictionary.
|
||||
|
||||
Returns:
|
||||
bool: True if at least one field was updated, otherwise False.
|
||||
"""
|
||||
updated = False # Flag to indicate whether any field was updated
|
||||
if isinstance(source, (BaseParameters, dict)):
|
||||
for field_info in fields(self):
|
||||
# Check if the field has a "fixed" tag in metadata
|
||||
tags = field_info.metadata.get("tags")
|
||||
tags = [] if not tags else tags.split(",")
|
||||
if tags and "fixed" in tags:
|
||||
continue # skip this field
|
||||
# Get the new value from source (either another BaseParameters object or a dict)
|
||||
new_value = (
|
||||
getattr(source, field_info.name)
|
||||
if isinstance(source, BaseParameters)
|
||||
else source.get(field_info.name, None)
|
||||
)
|
||||
|
||||
# If the new value is not None and different from the current value, update the field and set the flag
|
||||
if new_value is not None and new_value != getattr(
|
||||
self, field_info.name
|
||||
):
|
||||
setattr(self, field_info.name, new_value)
|
||||
updated = True
|
||||
else:
|
||||
raise ValueError(
|
||||
"Source must be an instance of BaseParameters (or its derived class) or a dictionary."
|
||||
)
|
||||
|
||||
return updated
|
||||
|
||||
def __str__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
parameters = [
|
||||
f"\n\n=========================== {class_name} ===========================\n"
|
||||
]
|
||||
for field_info in fields(self):
|
||||
value = getattr(self, field_info.name)
|
||||
parameters.append(f"{field_info.name}: {value}")
|
||||
parameters.append(
|
||||
"\n======================================================================\n\n"
|
||||
)
|
||||
return "\n".join(parameters)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelWorkerParameters(BaseParameters):
|
||||
model_name: str = field(metadata={"help": "Model name", "tags": "fixed"})
|
||||
model_path: str = field(metadata={"help": "Model path", "tags": "fixed"})
|
||||
worker_type: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"valid_values": WorkerType.values(), "help": "Worker type"},
|
||||
)
|
||||
worker_class: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Model worker deploy host, pilot.model.worker.default_worker.DefaultModelWorker"
|
||||
},
|
||||
)
|
||||
host: Optional[str] = field(
|
||||
default="0.0.0.0", metadata={"help": "Model worker deploy host"}
|
||||
)
|
||||
|
||||
port: Optional[int] = field(
|
||||
default=8000, metadata={"help": "Model worker deploy port"}
|
||||
)
|
||||
limit_model_concurrency: Optional[int] = field(
|
||||
default=5, metadata={"help": "Model concurrency limit"}
|
||||
)
|
||||
standalone: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Standalone mode. If True, embedded Run ModelController"},
|
||||
)
|
||||
register: Optional[bool] = field(
|
||||
default=True, metadata={"help": "Register current worker to model controller"}
|
||||
)
|
||||
worker_register_host: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The ip address of current worker to register to ModelController. If None, the address is automatically determined"
|
||||
},
|
||||
)
|
||||
controller_addr: Optional[str] = field(
|
||||
default=None, metadata={"help": "The Model controller address to register"}
|
||||
)
|
||||
send_heartbeat: Optional[bool] = field(
|
||||
default=True, metadata={"help": "Send heartbeat to model controller"}
|
||||
)
|
||||
heartbeat_interval: Optional[int] = field(
|
||||
default=20, metadata={"help": "The interval for sending heartbeats (seconds)"}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingModelParameters(BaseParameters):
|
||||
model_name: str = field(metadata={"help": "Model name", "tags": "fixed"})
|
||||
model_path: str = field(metadata={"help": "Model path", "tags": "fixed"})
|
||||
device: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Device to run model. If None, the device is automatically determined"
|
||||
},
|
||||
)
|
||||
|
||||
normalize_embeddings: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Determines whether the model's embeddings should be normalized."
|
||||
},
|
||||
)
|
||||
|
||||
def build_kwargs(self, **kwargs) -> Dict:
|
||||
model_kwargs, encode_kwargs = None, None
|
||||
if self.device:
|
||||
model_kwargs = {"device": self.device}
|
||||
if self.normalize_embeddings:
|
||||
encode_kwargs = {"normalize_embeddings": self.normalize_embeddings}
|
||||
if model_kwargs:
|
||||
kwargs["model_kwargs"] = model_kwargs
|
||||
if encode_kwargs:
|
||||
kwargs["encode_kwargs"] = encode_kwargs
|
||||
return kwargs
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelParameters(BaseParameters):
|
||||
model_name: str = field(metadata={"help": "Model name", "tags": "fixed"})
|
||||
model_path: str = field(metadata={"help": "Model path", "tags": "fixed"})
|
||||
device: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Device to run model. If None, the device is automatically determined"
|
||||
},
|
||||
)
|
||||
model_type: Optional[str] = field(
|
||||
default="huggingface", metadata={"help": "Model type, huggingface or llama.cpp"}
|
||||
default="huggingface",
|
||||
metadata={"help": "Model type, huggingface or llama.cpp", "tags": "fixed"},
|
||||
)
|
||||
prompt_template: Optional[str] = field(
|
||||
default=None,
|
||||
@ -91,7 +333,6 @@ class ModelParameters:
|
||||
default=True,
|
||||
metadata={"help": "Nested quantization, only valid when load_4bit=True"},
|
||||
)
|
||||
# "bfloat16", "float16", "float32"
|
||||
compute_dtype: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
|
0
pilot/model/worker/__init__.py
Normal file
0
pilot/model/worker/__init__.py
Normal file
114
pilot/model/worker/base.py
Normal file
114
pilot/model/worker/base.py
Normal file
@ -0,0 +1,114 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Iterator, List, Type
|
||||
|
||||
from pilot.model.base import ModelOutput
|
||||
from pilot.model.parameter import (
|
||||
ModelParameters,
|
||||
ParameterDescription,
|
||||
WorkerType,
|
||||
_get_parameter_descriptions,
|
||||
)
|
||||
|
||||
|
||||
class ModelWorker(ABC):
|
||||
"""
|
||||
Abstract representation of a Model Worker responsible for model interaction, startup, and shutdown. Supports 'llm' and 'text2vec' models.
|
||||
"""
|
||||
|
||||
def worker_type(self) -> WorkerType:
|
||||
"""Return the type of worker as LLM."""
|
||||
return WorkerType.LLM
|
||||
|
||||
def model_param_class(self) -> Type:
|
||||
"""Return the class representing model parameters."""
|
||||
return ModelParameters
|
||||
|
||||
def support_async(self) -> bool:
|
||||
"""Whether support async, if True, invoke async_generate_stream, async_generate and async_embeddings instead of generate_stream, generate and embeddings"""
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def parse_parameters(self, command_args: List[str] = None) -> ModelParameters:
|
||||
"""Parse the parameters using the provided command arguments.
|
||||
|
||||
Args:
|
||||
command_args (List[str]): The command-line arguments. Default is sys.argv[1:].
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def load_worker(self, model_name: str, model_path: str, **kwargs) -> None:
|
||||
"""Load the worker with the specified model name and path."""
|
||||
|
||||
@abstractmethod
|
||||
def start(
|
||||
self, model_params: ModelParameters = None, command_args: List[str] = None
|
||||
) -> None:
|
||||
"""Start the model worker"""
|
||||
|
||||
@abstractmethod
|
||||
def stop(self) -> None:
|
||||
"""Stop the model worker and clean up all the resources used."""
|
||||
|
||||
def restart(
|
||||
self, model_params: ModelParameters = None, command_args: List[str] = None
|
||||
) -> None:
|
||||
"""Restart the model worker."""
|
||||
self.stop()
|
||||
self.start(model_params, command_args)
|
||||
|
||||
def parameter_descriptions(self) -> List[ParameterDescription]:
|
||||
"""Fetch the parameter configuration information for the current model."""
|
||||
param_cls = self.model_param_class()
|
||||
return _get_parameter_descriptions(param_cls)
|
||||
|
||||
@abstractmethod
|
||||
def generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
|
||||
"""Generate a stream based on provided parameters.
|
||||
|
||||
Args:
|
||||
params (Dict): Parameters matching the PromptRequest data class format. Example:
|
||||
{
|
||||
"messages": [{"role": "user", "content": "Hello world"}], # List of ModelMessage objects
|
||||
"model": "vicuna-13b-v1.5",
|
||||
"prompt": "Hello world",
|
||||
"temperature": 0.7, # Optional; float value between 0 and 1
|
||||
"max_new_tokens": 2048, # Optional; max number of new tokens for the output
|
||||
"stop": "#", # Optional; stopping condition for the output
|
||||
"echo": True # Optional; whether to echo the input in the output
|
||||
}
|
||||
|
||||
Returns:
|
||||
Iterator[ModelOutput]: Stream of model outputs.
|
||||
"""
|
||||
|
||||
async def async_generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
|
||||
"""Asynchronously generate a stream based on provided parameters."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def generate(self, params: Dict) -> ModelOutput:
|
||||
"""Generate output (non-stream) based on provided parameters."""
|
||||
|
||||
async def async_generate(self, params: Dict) -> ModelOutput:
|
||||
"""Asynchronously generate output (non-stream) based on provided parameters."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def embeddings(self, params: Dict) -> List[List[float]]:
|
||||
"""
|
||||
Return embeddings for the given input parameters.
|
||||
|
||||
Args:
|
||||
params (Dict): Parameters matching the EmbeddingsRequest data class format. Example:
|
||||
{
|
||||
"model": "text2vec-large-chinese",
|
||||
"input": ["Hello world", "DB-GPT is amazing"]
|
||||
}
|
||||
|
||||
Returns:
|
||||
List[List[float]]: List of embeddings corresponding to each input string.
|
||||
"""
|
||||
|
||||
async def async_embeddings(self, params: Dict) -> List[List[float]]:
|
||||
"""Return embeddings asynchronously for the given input parameters."""
|
||||
raise NotImplementedError
|
133
pilot/model/worker/default_worker.py
Normal file
133
pilot/model/worker/default_worker.py
Normal file
@ -0,0 +1,133 @@
|
||||
import logging
|
||||
import platform
|
||||
from typing import Dict, Iterator, List
|
||||
|
||||
import torch
|
||||
from pilot.configs.model_config import DEVICE
|
||||
from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper
|
||||
from pilot.model.base import ModelOutput
|
||||
from pilot.model.loader import ModelLoader, _get_model_real_path
|
||||
from pilot.model.parameter import EnvArgumentParser, ModelParameters
|
||||
from pilot.model.worker.base import ModelWorker
|
||||
from pilot.server.chat_adapter import get_llm_chat_adapter, BaseChatAdpter
|
||||
from pilot.utils.model_utils import _clear_torch_cache
|
||||
|
||||
logger = logging.getLogger("model_worker")
|
||||
|
||||
|
||||
class DefaultModelWorker(ModelWorker):
|
||||
def __init__(self) -> None:
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
self._model_params = None
|
||||
self.llm_adapter: BaseLLMAdaper = None
|
||||
self.llm_chat_adapter: BaseChatAdpter = None
|
||||
|
||||
def load_worker(self, model_name: str, model_path: str, **kwargs) -> None:
|
||||
if model_path.endswith("/"):
|
||||
model_path = model_path[:-1]
|
||||
model_path = _get_model_real_path(model_name, model_path)
|
||||
self.model_name = model_name
|
||||
self.model_path = model_path
|
||||
|
||||
self.llm_adapter = get_llm_model_adapter(self.model_name, self.model_path)
|
||||
model_type = self.llm_adapter.model_type()
|
||||
self.param_cls = self.llm_adapter.model_param_class(model_type)
|
||||
|
||||
self.llm_chat_adapter = get_llm_chat_adapter(self.model_name, self.model_path)
|
||||
self.generate_stream_func = self.llm_chat_adapter.get_generate_stream_func(
|
||||
self.model_path
|
||||
)
|
||||
|
||||
self.ml: ModelLoader = ModelLoader(
|
||||
model_path=self.model_path, model_name=self.model_name
|
||||
)
|
||||
# TODO read context len from model config
|
||||
self.context_len = 2048
|
||||
|
||||
def model_param_class(self) -> ModelParameters:
|
||||
return self.param_cls
|
||||
|
||||
def parse_parameters(self, command_args: List[str] = None) -> ModelParameters:
|
||||
param_cls = self.model_param_class()
|
||||
model_args = EnvArgumentParser()
|
||||
env_prefix = EnvArgumentParser.get_env_prefix(self.model_name)
|
||||
model_type = self.llm_adapter.model_type()
|
||||
model_params: ModelParameters = model_args.parse_args_into_dataclass(
|
||||
param_cls,
|
||||
env_prefix=env_prefix,
|
||||
command_args=command_args,
|
||||
model_name=self.model_name,
|
||||
model_path=self.model_path,
|
||||
model_type=model_type,
|
||||
)
|
||||
if not model_params.device:
|
||||
model_params.device = DEVICE
|
||||
logger.info(
|
||||
f"[DefaultModelWorker] Parameters of device is None, use {model_params.device}"
|
||||
)
|
||||
return model_params
|
||||
|
||||
def start(
|
||||
self, model_params: ModelParameters = None, command_args: List[str] = None
|
||||
) -> None:
|
||||
if not model_params:
|
||||
model_params = self.parse_parameters(command_args)
|
||||
self._model_params = model_params
|
||||
logger.info(f"Begin load model, model params: {model_params}")
|
||||
self.model, self.tokenizer = self.ml.loader_with_params(model_params)
|
||||
|
||||
def stop(self) -> None:
|
||||
if not self.model:
|
||||
return
|
||||
del self.model
|
||||
del self.tokenizer
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
_clear_torch_cache(self._model_params.device)
|
||||
|
||||
def generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
|
||||
try:
|
||||
# params adaptation
|
||||
params, model_context = self.llm_chat_adapter.model_adaptation(
|
||||
params, self.ml.model_path, prompt_template=self.ml.prompt_template
|
||||
)
|
||||
|
||||
for output in self.generate_stream_func(
|
||||
self.model, self.tokenizer, params, DEVICE, self.context_len
|
||||
):
|
||||
# Please do not open the output in production!
|
||||
# The gpt4all thread shares stdout with the parent process,
|
||||
# and opening it may affect the frontend output.
|
||||
if "windows" in platform.platform().lower():
|
||||
# Do not print the model output, because it may contain Emoji, there is a problem with the GBK encoding
|
||||
pass
|
||||
else:
|
||||
print("output: ", output)
|
||||
# return some model context to dgt-server
|
||||
model_output = ModelOutput(
|
||||
text=output, error_code=0, model_context=model_context
|
||||
)
|
||||
yield model_output
|
||||
|
||||
except torch.cuda.CudaError:
|
||||
model_output = ModelOutput(
|
||||
text="**GPU OutOfMemory, Please Refresh.**", error_code=0
|
||||
)
|
||||
yield model_output
|
||||
except Exception as e:
|
||||
model_output = ModelOutput(
|
||||
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
|
||||
error_code=0,
|
||||
)
|
||||
yield model_output
|
||||
|
||||
def generate(self, params: Dict) -> ModelOutput:
|
||||
"""Generate non stream result"""
|
||||
output = None
|
||||
for out in self.generate_stream(params):
|
||||
output = out
|
||||
return output
|
||||
|
||||
def embeddings(self, params: Dict) -> List[List[float]]:
|
||||
raise NotImplementedError
|
100
pilot/model/worker/embedding_worker.py
Normal file
100
pilot/model/worker/embedding_worker.py
Normal file
@ -0,0 +1,100 @@
|
||||
import logging
|
||||
from typing import Dict, List, Type
|
||||
|
||||
from pilot.configs.model_config import DEVICE
|
||||
from pilot.model.loader import _get_model_real_path
|
||||
from pilot.model.parameter import (
|
||||
EmbeddingModelParameters,
|
||||
EnvArgumentParser,
|
||||
WorkerType,
|
||||
)
|
||||
from pilot.model.worker.base import ModelWorker
|
||||
from pilot.utils.model_utils import _clear_torch_cache
|
||||
|
||||
logger = logging.getLogger("model_worker")
|
||||
|
||||
|
||||
class EmbeddingsModelWorker(ModelWorker):
|
||||
def __init__(self) -> None:
|
||||
try:
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
from langchain.embeddings.base import Embeddings
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Could not import langchain.embeddings.HuggingFaceEmbeddings python package. "
|
||||
"Please install it with `pip install langchain`."
|
||||
) from exc
|
||||
self.embeddings: Embeddings = None
|
||||
self._model_params = None
|
||||
|
||||
def load_worker(self, model_name: str, model_path: str, **kwargs) -> None:
|
||||
if model_path.endswith("/"):
|
||||
model_path = model_path[:-1]
|
||||
model_path = _get_model_real_path(model_name, model_path)
|
||||
|
||||
self.model_name = model_name
|
||||
self.model_path = model_path
|
||||
|
||||
def worker_type(self) -> WorkerType:
|
||||
return WorkerType.TEXT2VEC
|
||||
|
||||
def model_param_class(self) -> Type:
|
||||
return EmbeddingModelParameters
|
||||
|
||||
def parse_parameters(
|
||||
self, command_args: List[str] = None
|
||||
) -> EmbeddingModelParameters:
|
||||
param_cls = self.model_param_class()
|
||||
model_args = EnvArgumentParser()
|
||||
env_prefix = EnvArgumentParser.get_env_prefix(self.model_name)
|
||||
model_params: EmbeddingModelParameters = model_args.parse_args_into_dataclass(
|
||||
param_cls,
|
||||
env_prefix=env_prefix,
|
||||
command_args=command_args,
|
||||
model_name=self.model_name,
|
||||
model_path=self.model_path,
|
||||
)
|
||||
if not model_params.device:
|
||||
model_params.device = DEVICE
|
||||
logger.info(
|
||||
f"[EmbeddingsModelWorker] Parameters of device is None, use {model_params.device}"
|
||||
)
|
||||
return model_params
|
||||
|
||||
def start(
|
||||
self,
|
||||
model_params: EmbeddingModelParameters = None,
|
||||
command_args: List[str] = None,
|
||||
) -> None:
|
||||
"""Start model worker"""
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
|
||||
if not model_params:
|
||||
model_params = self.parse_parameters(command_args)
|
||||
self._model_params = model_params
|
||||
|
||||
kwargs = model_params.build_kwargs(model_name=model_params.model_path)
|
||||
logger.info(f"Start HuggingFaceEmbeddings with kwargs: {kwargs}")
|
||||
self.embeddings = HuggingFaceEmbeddings(**kwargs)
|
||||
|
||||
def __del__(self):
|
||||
self.stop()
|
||||
|
||||
def stop(self) -> None:
|
||||
if not self.embeddings:
|
||||
return
|
||||
del self.embeddings
|
||||
self.embeddings = None
|
||||
_clear_torch_cache(self._model_params.device)
|
||||
|
||||
def generate_stream(self, params: Dict):
|
||||
"""Generate stream result, chat scene"""
|
||||
raise NotImplementedError("Not supported generate_stream for embeddings model")
|
||||
|
||||
def generate(self, params: Dict):
|
||||
"""Generate non stream result"""
|
||||
raise NotImplementedError("Not supported generate for embeddings model")
|
||||
|
||||
def embeddings(self, params: Dict) -> List[List[float]]:
|
||||
input: List[str] = params["input"]
|
||||
return self.embeddings.embed_documents(input)
|
705
pilot/model/worker/manager.py
Normal file
705
pilot/model/worker/manager.py
Normal file
@ -0,0 +1,705 @@
|
||||
import asyncio
|
||||
import httpx
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Awaitable, Callable, Dict, Iterator, List, Optional
|
||||
|
||||
import uvicorn
|
||||
from fastapi import APIRouter, FastAPI, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
from pilot.model.base import (
|
||||
ModelInstance,
|
||||
ModelOutput,
|
||||
WorkerApplyType,
|
||||
WorkerApplyOutput,
|
||||
)
|
||||
from pilot.model.controller.registry import ModelRegistry
|
||||
from pilot.model.parameter import (
|
||||
EnvArgumentParser,
|
||||
ModelParameters,
|
||||
ModelWorkerParameters,
|
||||
WorkerType,
|
||||
ParameterDescription,
|
||||
)
|
||||
from pilot.model.worker.base import ModelWorker
|
||||
from pilot.scene.base_message import ModelMessage
|
||||
from pilot.utils import build_logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = build_logger("model_worker", LOGDIR + "/model_worker.log")
|
||||
|
||||
|
||||
class PromptRequest(BaseModel):
|
||||
messages: List[ModelMessage]
|
||||
model: str
|
||||
prompt: str = None
|
||||
temperature: float = None
|
||||
max_new_tokens: int = None
|
||||
stop: str = None
|
||||
echo: bool = True
|
||||
|
||||
|
||||
class EmbeddingsRequest(BaseModel):
|
||||
model: str
|
||||
input: List[str]
|
||||
|
||||
|
||||
class WorkerApplyRequest(BaseModel):
|
||||
model: str
|
||||
apply_type: WorkerApplyType
|
||||
worker_type: WorkerType = WorkerType.LLM
|
||||
params: Dict = None
|
||||
apply_user: str = None
|
||||
|
||||
|
||||
class WorkerParameterRequest(BaseModel):
|
||||
model: str
|
||||
worker_type: WorkerType = WorkerType.LLM
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkerRunData:
|
||||
worker_key: str
|
||||
worker: ModelWorker
|
||||
worker_params: ModelWorkerParameters
|
||||
model_params: ModelParameters
|
||||
stop_event: asyncio.Event
|
||||
semaphore: asyncio.Semaphore = None
|
||||
command_args: List[str] = None
|
||||
_heartbeat_future: Optional[Future] = None
|
||||
_last_heartbeat: Optional[datetime] = None
|
||||
|
||||
|
||||
RegisterFunc = Callable[[WorkerRunData], Awaitable[None]]
|
||||
DeregisterFunc = Callable[[WorkerRunData], Awaitable[None]]
|
||||
SendHeartbeatFunc = Callable[[WorkerRunData], Awaitable[None]]
|
||||
ApplyFunction = Callable[[WorkerRunData], Awaitable[None]]
|
||||
|
||||
|
||||
class WorkerManager(ABC):
|
||||
@abstractmethod
|
||||
async def get_model_instances(
|
||||
self, worker_type: str, model_name: str, healthy_only: bool = True
|
||||
) -> List[WorkerRunData]:
|
||||
"""Get model instances by worker type and model name"""
|
||||
|
||||
@abstractmethod
|
||||
async def select_one_instanes(
|
||||
self, worker_type: str, model_name: str, healthy_only: bool = True
|
||||
) -> WorkerRunData:
|
||||
"""Select one instances"""
|
||||
|
||||
@abstractmethod
|
||||
async def generate_stream(self, params: Dict, **kwargs) -> Iterator[ModelOutput]:
|
||||
"""Generate stream result, chat scene"""
|
||||
|
||||
@abstractmethod
|
||||
async def generate(self, params: Dict) -> ModelOutput:
|
||||
"""Generate non stream result"""
|
||||
|
||||
@abstractmethod
|
||||
async def embeddings(self, params: Dict) -> List[List[float]]:
|
||||
"""Embed input"""
|
||||
|
||||
@abstractmethod
|
||||
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
|
||||
"""Worker apply"""
|
||||
|
||||
@abstractmethod
|
||||
async def parameter_descriptions(
|
||||
self, worker_type: str, model_name: str
|
||||
) -> List[ParameterDescription]:
|
||||
"""Get parameter descriptions of model"""
|
||||
|
||||
|
||||
async def _async_heartbeat_sender(
|
||||
worker_run_data: WorkerRunData, send_heartbeat_func: SendHeartbeatFunc
|
||||
):
|
||||
while not worker_run_data.stop_event.is_set():
|
||||
try:
|
||||
await send_heartbeat_func(worker_run_data)
|
||||
except Exception as e:
|
||||
logger.warn(f"Send heartbeat func error: {str(e)}")
|
||||
finally:
|
||||
await asyncio.sleep(worker_run_data.worker_params.heartbeat_interval)
|
||||
|
||||
|
||||
class LocalWorkerManager(WorkerManager):
|
||||
def __init__(
|
||||
self,
|
||||
register_func: RegisterFunc = None,
|
||||
deregister_func: DeregisterFunc = None,
|
||||
send_heartbeat_func: SendHeartbeatFunc = None,
|
||||
model_registry: ModelRegistry = None,
|
||||
) -> None:
|
||||
self.workers: Dict[str, List[WorkerRunData]] = dict()
|
||||
self.executor = ThreadPoolExecutor(max_workers=os.cpu_count() * 5)
|
||||
self.register_func = register_func
|
||||
self.deregister_func = deregister_func
|
||||
self.send_heartbeat_func = send_heartbeat_func
|
||||
self.model_registry = model_registry
|
||||
|
||||
def _worker_key(self, worker_type: str, model_name: str) -> str:
|
||||
return f"{model_name}@{worker_type}"
|
||||
|
||||
def add_worker(
|
||||
self,
|
||||
worker: ModelWorker,
|
||||
worker_params: ModelWorkerParameters,
|
||||
embedded_mod: bool = True,
|
||||
command_args: List[str] = None,
|
||||
):
|
||||
if not command_args:
|
||||
import sys
|
||||
|
||||
command_args = sys.argv[1:]
|
||||
worker.load_worker(**asdict(worker_params))
|
||||
|
||||
if not worker_params.worker_type:
|
||||
worker_params.worker_type = worker.worker_type()
|
||||
|
||||
worker_key = self._worker_key(
|
||||
worker_params.worker_type, worker_params.model_name
|
||||
)
|
||||
host = worker_params.host
|
||||
port = worker_params.port
|
||||
|
||||
instances = self.workers.get(worker_key)
|
||||
if not instances:
|
||||
instances = []
|
||||
self.workers[worker_key] = instances
|
||||
logger.info(f"Init empty instances list for {worker_key}")
|
||||
# Load model params from persist storage
|
||||
model_params = worker.parse_parameters(command_args=command_args)
|
||||
|
||||
worker_run_data = WorkerRunData(
|
||||
worker_key=worker_key,
|
||||
worker=worker,
|
||||
worker_params=worker_params,
|
||||
model_params=model_params,
|
||||
stop_event=asyncio.Event(),
|
||||
semaphore=asyncio.Semaphore(worker_params.limit_model_concurrency),
|
||||
command_args=command_args,
|
||||
)
|
||||
if not embedded_mod:
|
||||
exist_instances = [
|
||||
(w, p) for w, p in instances if p.host == host and p.port == port
|
||||
]
|
||||
if not exist_instances:
|
||||
instances.append(worker_run_data)
|
||||
else:
|
||||
instances.append(worker_run_data)
|
||||
|
||||
async def get_model_instances(
|
||||
self, worker_type: str, model_name: str, healthy_only: bool = True
|
||||
) -> List[WorkerRunData]:
|
||||
worker_key = self._worker_key(worker_type, model_name)
|
||||
return self.workers.get(worker_key)
|
||||
|
||||
async def select_one_instanes(
|
||||
self, worker_type: str, model_name: str, healthy_only: bool = True
|
||||
) -> WorkerRunData:
|
||||
worker_instances = await self.get_model_instances(
|
||||
worker_type, model_name, healthy_only
|
||||
)
|
||||
if not worker_instances:
|
||||
raise Exception(
|
||||
f"Cound not found worker instances for model name {model_name} and worker type {worker_type}"
|
||||
)
|
||||
worker_run_data = random.choice(worker_instances)
|
||||
return worker_run_data
|
||||
|
||||
async def _get_model(self, params: Dict, worker_type: str = "llm") -> WorkerRunData:
|
||||
model = params.get("model")
|
||||
if not model:
|
||||
raise Exception("Model name count not be empty")
|
||||
return await self.select_one_instanes(worker_type, model, healthy_only=True)
|
||||
|
||||
async def generate_stream(
|
||||
self, params: Dict, async_wrapper=None, **kwargs
|
||||
) -> Iterator[ModelOutput]:
|
||||
"""Generate stream result, chat scene"""
|
||||
worker_run_data = await self._get_model(params)
|
||||
async with worker_run_data.semaphore:
|
||||
if worker_run_data.worker.support_async():
|
||||
async for outout in worker_run_data.worker.async_generate_stream(
|
||||
params
|
||||
):
|
||||
yield outout
|
||||
else:
|
||||
if not async_wrapper:
|
||||
from starlette.concurrency import iterate_in_threadpool
|
||||
|
||||
async_wrapper = iterate_in_threadpool
|
||||
async for output in async_wrapper(
|
||||
worker_run_data.worker.generate_stream(params)
|
||||
):
|
||||
yield output
|
||||
|
||||
async def generate(self, params: Dict) -> ModelOutput:
|
||||
"""Generate non stream result"""
|
||||
worker_run_data = await self._get_model(params)
|
||||
async with worker_run_data.semaphore:
|
||||
if worker_run_data.worker.support_async():
|
||||
return await worker_run_data.worker.async_generate(params)
|
||||
else:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
self.executor, worker_run_data.worker.generate, params
|
||||
)
|
||||
|
||||
async def embeddings(self, params: Dict) -> List[List[float]]:
|
||||
"""Embed input"""
|
||||
worker_run_data = await self._get_model(params, worker_type="text2vec")
|
||||
async with worker_run_data.semaphore:
|
||||
if worker_run_data.worker.support_async():
|
||||
return await worker_run_data.worker.async_embeddings(params)
|
||||
else:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
self.executor, worker_run_data.worker.embeddings, params
|
||||
)
|
||||
|
||||
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
|
||||
apply_func: Callable[[WorkerApplyRequest], Awaitable[str]] = None
|
||||
if apply_req.apply_type == WorkerApplyType.START:
|
||||
apply_func = self._start_all_worker
|
||||
elif apply_req.apply_type == WorkerApplyType.STOP:
|
||||
apply_func = self._stop_all_worker
|
||||
elif apply_req.apply_type == WorkerApplyType.UPDATE_PARAMS:
|
||||
apply_func = self._update_all_worker_params
|
||||
else:
|
||||
raise ValueError(f"Unsupported apply type {apply_req.apply_type}")
|
||||
return await apply_func(apply_req)
|
||||
|
||||
async def parameter_descriptions(
|
||||
self, worker_type: str, model_name: str
|
||||
) -> List[ParameterDescription]:
|
||||
worker_instances = await self.get_model_instances(worker_type, model_name)
|
||||
if not worker_instances:
|
||||
raise Exception(
|
||||
f"Not worker instances for model name {model_name} worker type {worker_type}"
|
||||
)
|
||||
worker_run_data = worker_instances[0]
|
||||
return worker_run_data.worker.parameter_descriptions()
|
||||
|
||||
async def _apply_worker(
|
||||
self, apply_req: WorkerApplyRequest, apply_func: ApplyFunction
|
||||
) -> None:
|
||||
"""Apply function to worker instances in parallel
|
||||
|
||||
Args:
|
||||
apply_req (WorkerApplyRequest): Worker apply request
|
||||
apply_func (ApplyFunction): Function to apply to worker instances, now function is async function
|
||||
"""
|
||||
if apply_req:
|
||||
worker_type = apply_req.worker_type.value
|
||||
model_name = apply_req.model
|
||||
worker_instances = await self.get_model_instances(worker_type, model_name)
|
||||
if not worker_instances:
|
||||
raise Exception(
|
||||
f"No worker instance found for the model {model_name} worker type {worker_type}"
|
||||
)
|
||||
else:
|
||||
# Apply to all workers
|
||||
worker_instances = list(itertools.chain(*self.workers.values()))
|
||||
logger.info(f"Apply to all workers: {worker_instances}")
|
||||
return await asyncio.gather(
|
||||
*(apply_func(worker) for worker in worker_instances)
|
||||
)
|
||||
|
||||
async def _start_all_worker(
|
||||
self, apply_req: WorkerApplyRequest
|
||||
) -> WorkerApplyOutput:
|
||||
start_time = time.time()
|
||||
logger.info(f"Begin start all worker, apply_req: {apply_req}")
|
||||
|
||||
async def _start_worker(worker_run_data: WorkerRunData):
|
||||
worker_run_data.worker.start(
|
||||
worker_run_data.model_params, worker_run_data.command_args
|
||||
)
|
||||
worker_run_data.stop_event.clear()
|
||||
if worker_run_data.worker_params.register and self.register_func:
|
||||
# Register worker to controller
|
||||
await self.register_func(worker_run_data)
|
||||
if (
|
||||
worker_run_data.worker_params.send_heartbeat
|
||||
and self.send_heartbeat_func
|
||||
):
|
||||
asyncio.create_task(
|
||||
_async_heartbeat_sender(
|
||||
worker_run_data, self.send_heartbeat_func
|
||||
)
|
||||
)
|
||||
|
||||
await self._apply_worker(apply_req, _start_worker)
|
||||
timecost = time.time() - start_time
|
||||
return WorkerApplyOutput(
|
||||
message=f"Worker started successfully", timecost=timecost
|
||||
)
|
||||
|
||||
async def _stop_all_worker(
|
||||
self, apply_req: WorkerApplyRequest
|
||||
) -> WorkerApplyOutput:
|
||||
start_time = time.time()
|
||||
|
||||
async def _stop_worker(worker_run_data: WorkerRunData):
|
||||
worker_run_data.worker.stop()
|
||||
# Set stop event
|
||||
worker_run_data.stop_event.set()
|
||||
if worker_run_data._heartbeat_future:
|
||||
# Wait thread finish
|
||||
worker_run_data._heartbeat_future.result()
|
||||
worker_run_data._heartbeat_future = None
|
||||
if (
|
||||
worker_run_data.worker_params.register
|
||||
and self.register_func
|
||||
and self.deregister_func
|
||||
):
|
||||
await self.deregister_func(worker_run_data)
|
||||
|
||||
await self._apply_worker(apply_req, _stop_worker)
|
||||
timecost = time.time() - start_time
|
||||
return WorkerApplyOutput(
|
||||
message=f"Worker stopped successfully", timecost=timecost
|
||||
)
|
||||
|
||||
async def _update_all_worker_params(
|
||||
self, apply_req: WorkerApplyRequest
|
||||
) -> WorkerApplyOutput:
|
||||
start_time = time.time()
|
||||
need_restart = False
|
||||
|
||||
async def update_params(worker_run_data: WorkerRunData):
|
||||
nonlocal need_restart
|
||||
new_params = apply_req.params
|
||||
if not new_params:
|
||||
return
|
||||
if worker_run_data.model_params.update_from(new_params):
|
||||
need_restart = True
|
||||
|
||||
await self._apply_worker(apply_req, update_params)
|
||||
message = f"Update worker params successfully"
|
||||
timecost = time.time() - start_time
|
||||
if need_restart:
|
||||
logger.info("Model params update successfully, begin restart worker")
|
||||
await self._stop_all_worker(apply_req)
|
||||
await self._start_all_worker(apply_req)
|
||||
timecost = time.time() - start_time
|
||||
message = f"Update worker params and restart successfully"
|
||||
return WorkerApplyOutput(message=message, timecost=timecost)
|
||||
|
||||
|
||||
class RemoteWorkerManager(LocalWorkerManager):
|
||||
def __init__(self, model_registry: ModelRegistry = None) -> None:
|
||||
super().__init__(model_registry=model_registry)
|
||||
|
||||
async def get_model_instances(
|
||||
self, worker_type: str, model_name: str, healthy_only: bool = True
|
||||
) -> List[WorkerRunData]:
|
||||
from pilot.model.worker.remote_worker import RemoteModelWorker
|
||||
|
||||
worker_key = self._worker_key(worker_type, model_name)
|
||||
instances: List[ModelInstance] = await self.model_registry.get_all_instances(
|
||||
worker_key, healthy_only
|
||||
)
|
||||
worker_instances = []
|
||||
for ins in instances:
|
||||
worker = RemoteModelWorker()
|
||||
worker.load_worker(model_name, model_name, host=ins.host, port=ins.port)
|
||||
wr = WorkerRunData(
|
||||
worker_key=ins.model_name,
|
||||
worker=worker,
|
||||
worker_params=None,
|
||||
model_params=None,
|
||||
stop_event=asyncio.Event(),
|
||||
semaphore=asyncio.Semaphore(100), # Not limit in client
|
||||
)
|
||||
worker_instances.append(wr)
|
||||
return worker_instances
|
||||
|
||||
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
|
||||
async def _remote_apply_func(worker_run_data: WorkerRunData):
|
||||
worker_addr = worker_run_data.worker.worker_addr
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
worker_addr + "/apply",
|
||||
headers=worker_run_data.worker.headers,
|
||||
json=apply_req.dict(),
|
||||
timeout=worker_run_data.worker.timeout,
|
||||
)
|
||||
if response.status_code == 200:
|
||||
output = WorkerApplyOutput(**response.json())
|
||||
logger.info(f"worker_apply success: {output}")
|
||||
else:
|
||||
output = WorkerApplyOutput(message=response.text)
|
||||
logger.warn(f"worker_apply failed: {output}")
|
||||
return output
|
||||
|
||||
results = await self._apply_worker(apply_req, _remote_apply_func)
|
||||
if results:
|
||||
return results[0]
|
||||
|
||||
|
||||
class WorkerManagerAdapter(WorkerManager):
|
||||
def __init__(self, worker_manager: WorkerManager = None) -> None:
|
||||
self.worker_manager = worker_manager
|
||||
|
||||
async def get_model_instances(
|
||||
self, worker_type: str, model_name: str, healthy_only: bool = True
|
||||
) -> List[WorkerRunData]:
|
||||
return await self.worker_manager.get_model_instances(
|
||||
worker_type, model_name, healthy_only
|
||||
)
|
||||
|
||||
async def select_one_instanes(
|
||||
self, worker_type: str, model_name: str, healthy_only: bool = True
|
||||
) -> WorkerRunData:
|
||||
return await self.worker_manager.select_one_instanes(
|
||||
worker_type, model_name, healthy_only
|
||||
)
|
||||
|
||||
async def generate_stream(self, params: Dict, **kwargs) -> Iterator[ModelOutput]:
|
||||
async for output in self.worker_manager.generate_stream(params, **kwargs):
|
||||
yield output
|
||||
|
||||
async def generate(self, params: Dict) -> ModelOutput:
|
||||
return await self.worker_manager.generate(params)
|
||||
|
||||
async def embeddings(self, params: Dict) -> List[List[float]]:
|
||||
return await self.worker_manager.embeddings(params)
|
||||
|
||||
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
|
||||
return await self.worker_manager.worker_apply(apply_req)
|
||||
|
||||
async def parameter_descriptions(
|
||||
self, worker_type: str, model_name: str
|
||||
) -> List[ParameterDescription]:
|
||||
return await self.worker_manager.parameter_descriptions(worker_type, model_name)
|
||||
|
||||
|
||||
worker_manager = WorkerManagerAdapter()
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
async def generate_json_stream(params):
|
||||
from starlette.concurrency import iterate_in_threadpool
|
||||
|
||||
async for output in worker_manager.generate_stream(
|
||||
params, async_wrapper=iterate_in_threadpool
|
||||
):
|
||||
yield json.dumps(asdict(output), ensure_ascii=False).encode() + b"\0"
|
||||
|
||||
|
||||
@router.post("/worker/generate_stream")
|
||||
async def api_generate_stream(request: Request):
|
||||
params = await request.json()
|
||||
generator = generate_json_stream(params)
|
||||
return StreamingResponse(generator)
|
||||
|
||||
|
||||
@router.post("/worker/generate")
|
||||
async def api_generate(request: PromptRequest):
|
||||
params = request.dict(exclude_none=True)
|
||||
output = await worker_manager.generate(params)
|
||||
return output
|
||||
|
||||
|
||||
@router.post("/worker/embeddings")
|
||||
async def api_embeddings(request: EmbeddingsRequest):
|
||||
params = request.dict(exclude_none=True)
|
||||
output = await worker_manager.embeddings(params)
|
||||
return output
|
||||
|
||||
|
||||
@router.post("/worker/apply")
|
||||
async def api_worker_apply(request: WorkerApplyRequest):
|
||||
output = await worker_manager.worker_apply(request)
|
||||
return output
|
||||
|
||||
|
||||
@router.get("/worker/parameter/descriptions")
|
||||
async def api_worker_parameter_descs(
|
||||
model: str, worker_type: str = WorkerType.LLM.value
|
||||
):
|
||||
output = await worker_manager.parameter_descriptions(worker_type, model)
|
||||
return output
|
||||
|
||||
|
||||
def _setup_fastapi(worker_params: ModelWorkerParameters):
|
||||
app = FastAPI()
|
||||
if worker_params.standalone:
|
||||
from pilot.model.controller.controller import router as controller_router
|
||||
|
||||
if not worker_params.controller_addr:
|
||||
worker_params.controller_addr = f"http://127.0.0.1:{worker_params.port}"
|
||||
app.include_router(controller_router, prefix="/api")
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
asyncio.create_task(
|
||||
worker_manager.worker_manager._start_all_worker(apply_req=None)
|
||||
)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def _parse_worker_params(
|
||||
model_name: str = None, model_path: str = None, **kwargs
|
||||
) -> ModelWorkerParameters:
|
||||
worker_args = EnvArgumentParser()
|
||||
worker_params: ModelWorkerParameters = worker_args.parse_args_into_dataclass(
|
||||
ModelWorkerParameters, model_name=model_name, model_path=model_path, **kwargs
|
||||
)
|
||||
env_prefix = EnvArgumentParser.get_env_prefix(worker_params.model_name)
|
||||
# Read parameters agein with prefix of model name.
|
||||
new_worker_params = worker_args.parse_args_into_dataclass(
|
||||
ModelWorkerParameters,
|
||||
env_prefix=env_prefix,
|
||||
model_name=worker_params.model_name,
|
||||
model_path=worker_params.model_path,
|
||||
**kwargs,
|
||||
)
|
||||
worker_params.update_from(new_worker_params)
|
||||
|
||||
logger.info(f"Worker params: {worker_params}")
|
||||
return worker_params
|
||||
|
||||
|
||||
def _create_local_model_manager(
|
||||
worker_params: ModelWorkerParameters,
|
||||
) -> LocalWorkerManager:
|
||||
if not worker_params.register or not worker_params.controller_addr:
|
||||
logger.info(
|
||||
f"Not register current to controller, register: {worker_params.register}, controller_addr: {worker_params.controller_addr}"
|
||||
)
|
||||
return LocalWorkerManager()
|
||||
else:
|
||||
from pilot.model.controller.registry import ModelRegistryClient
|
||||
from pilot.utils.net_utils import _get_ip_address
|
||||
|
||||
client = ModelRegistryClient(worker_params.controller_addr)
|
||||
host = _get_ip_address()
|
||||
port = worker_params.port
|
||||
|
||||
async def register_func(worker_run_data: WorkerRunData):
|
||||
instance = ModelInstance(
|
||||
model_name=worker_run_data.worker_key, host=host, port=port
|
||||
)
|
||||
return await client.register_instance(instance)
|
||||
|
||||
async def send_heartbeat_func(worker_run_data: WorkerRunData):
|
||||
instance = ModelInstance(
|
||||
model_name=worker_run_data.worker_key, host=host, port=port
|
||||
)
|
||||
return await client.send_heartbeat(instance)
|
||||
|
||||
return LocalWorkerManager(
|
||||
register_func=register_func, send_heartbeat_func=send_heartbeat_func
|
||||
)
|
||||
|
||||
|
||||
def _start_local_worker(
|
||||
worker_manager: WorkerManagerAdapter,
|
||||
worker_params: ModelWorkerParameters,
|
||||
embedded_mod=True,
|
||||
):
|
||||
from pilot.utils.module_utils import import_from_checked_string
|
||||
|
||||
if worker_params.worker_class:
|
||||
worker_cls = import_from_checked_string(worker_params.worker_class, ModelWorker)
|
||||
logger.info(
|
||||
f"Import worker class from {worker_params.worker_class} successfully"
|
||||
)
|
||||
worker: ModelWorker = worker_cls()
|
||||
else:
|
||||
from pilot.model.worker.default_worker import DefaultModelWorker
|
||||
|
||||
worker = DefaultModelWorker()
|
||||
|
||||
worker_manager.worker_manager = _create_local_model_manager(worker_params)
|
||||
worker_manager.worker_manager.add_worker(
|
||||
worker, worker_params, embedded_mod=embedded_mod
|
||||
)
|
||||
|
||||
|
||||
def initialize_worker_manager_in_client(
|
||||
app=None,
|
||||
include_router: bool = True,
|
||||
model_name: str = None,
|
||||
model_path: str = None,
|
||||
run_locally: bool = True,
|
||||
controller_addr: str = None,
|
||||
):
|
||||
global worker_manager
|
||||
|
||||
worker_params: ModelWorkerParameters = _parse_worker_params(
|
||||
model_name=model_name, model_path=model_path, controller_addr=controller_addr
|
||||
)
|
||||
|
||||
logger.info(f"Worker params: {worker_params}")
|
||||
if run_locally:
|
||||
worker_params.register = False
|
||||
_start_local_worker(worker_manager, worker_params, True)
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(
|
||||
worker_manager.worker_manager._start_all_worker(apply_req=None)
|
||||
)
|
||||
else:
|
||||
from pilot.model.controller.registry import ModelRegistryClient
|
||||
|
||||
if not worker_params.controller_addr:
|
||||
raise ValueError("Controller can`t be None")
|
||||
client = ModelRegistryClient(worker_params.controller_addr)
|
||||
worker_manager.worker_manager = RemoteWorkerManager(client)
|
||||
|
||||
if include_router and app:
|
||||
app.include_router(router, prefix="/api")
|
||||
|
||||
|
||||
def run_worker_manager(
|
||||
app=None,
|
||||
include_router: bool = True,
|
||||
model_name: str = None,
|
||||
model_path: str = None,
|
||||
standalone: bool = False,
|
||||
port: int = None,
|
||||
):
|
||||
global worker_manager
|
||||
|
||||
worker_params: ModelWorkerParameters = _parse_worker_params(
|
||||
model_name=model_name, model_path=model_path, standalone=standalone, port=port
|
||||
)
|
||||
|
||||
embedded_mod = True
|
||||
if not app:
|
||||
# Run worker manager independently
|
||||
embedded_mod = False
|
||||
app = _setup_fastapi(worker_params)
|
||||
_start_local_worker(worker_manager, worker_params, embedded_mod=False)
|
||||
else:
|
||||
_start_local_worker(worker_manager, worker_params, embedded_mod=False)
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(
|
||||
worker_manager.worker_manager._start_all_worker(apply_req=None)
|
||||
)
|
||||
|
||||
if include_router:
|
||||
app.include_router(router, prefix="/api")
|
||||
|
||||
if not embedded_mod:
|
||||
uvicorn.run(
|
||||
app, host=worker_params.host, port=worker_params.port, log_level="info"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_worker_manager()
|
0
pilot/model/worker/ray_worker.py
Normal file
0
pilot/model/worker/ray_worker.py
Normal file
98
pilot/model/worker/remote_worker.py
Normal file
98
pilot/model/worker/remote_worker.py
Normal file
@ -0,0 +1,98 @@
|
||||
import json
|
||||
from typing import Dict, Iterator, List
|
||||
|
||||
import httpx
|
||||
from pilot.model.base import ModelOutput
|
||||
from pilot.model.parameter import ModelParameters
|
||||
from pilot.model.worker.base import ModelWorker
|
||||
|
||||
|
||||
class RemoteModelWorker(ModelWorker):
|
||||
def __init__(self) -> None:
|
||||
self.headers = {}
|
||||
self.timeout = 60
|
||||
self.host = None
|
||||
self.port = None
|
||||
|
||||
@property
|
||||
def worker_addr(self) -> str:
|
||||
return f"http://{self.host}:{self.port}/api/worker"
|
||||
|
||||
def support_async(self) -> bool:
|
||||
return True
|
||||
|
||||
def parse_parameters(self, command_args: List[str] = None) -> ModelParameters:
|
||||
return None
|
||||
|
||||
def load_worker(self, model_name: str, model_path: str, **kwargs):
|
||||
self.host = kwargs.get("host")
|
||||
self.port = kwargs.get("port")
|
||||
|
||||
def start(
|
||||
self, model_params: ModelParameters = None, command_args: List[str] = None
|
||||
) -> None:
|
||||
"""Start model worker"""
|
||||
pass
|
||||
# raise NotImplementedError("Remote model worker not support start methods")
|
||||
|
||||
def stop(self) -> None:
|
||||
raise NotImplementedError("Remote model worker not support stop methods")
|
||||
|
||||
def generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
|
||||
"""Generate stream"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def async_generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
|
||||
"""Asynchronous generate stream"""
|
||||
print(f"Send async_generate_stream, params: {params}")
|
||||
async with httpx.AsyncClient() as client:
|
||||
delimiter = b"\0"
|
||||
buffer = b""
|
||||
async with client.stream(
|
||||
"POST",
|
||||
self.worker_addr + "/generate_stream",
|
||||
headers=self.headers,
|
||||
json=params,
|
||||
timeout=self.timeout,
|
||||
) as response:
|
||||
async for raw_chunk in response.aiter_raw():
|
||||
buffer += raw_chunk
|
||||
while delimiter in buffer:
|
||||
chunk, buffer = buffer.split(delimiter, 1)
|
||||
if not chunk:
|
||||
continue
|
||||
chunk = chunk.decode()
|
||||
data = json.loads(chunk)
|
||||
yield ModelOutput(**data)
|
||||
|
||||
def generate(self, params: Dict) -> ModelOutput:
|
||||
"""Generate non stream"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def async_generate(self, params: Dict) -> ModelOutput:
|
||||
"""Asynchronous generate non stream"""
|
||||
print(f"Send async_generate_stream, params: {params}")
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
self.worker_addr + "/generate",
|
||||
headers=self.headers,
|
||||
json=params,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
return ModelOutput(**response.json())
|
||||
|
||||
def embeddings(self, params: Dict) -> List[List[float]]:
|
||||
"""Get embeddings for input"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def async_embeddings(self, params: Dict) -> List[List[float]]:
|
||||
"""Asynchronous get embeddings for input"""
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
self.worker_addr + "/embeddings",
|
||||
headers=self.headers,
|
||||
json=params,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
return response.json()
|
@ -311,33 +311,23 @@ async def chat_completions(dialogue: ConversationVo = Body()):
|
||||
|
||||
|
||||
async def no_stream_generator(chat):
|
||||
msg = chat.nostream_call()
|
||||
msg = await chat.nostream_call()
|
||||
msg = msg.replace("\n", "\\n")
|
||||
yield f"data: {msg}\n\n"
|
||||
|
||||
|
||||
async def stream_generator(chat):
|
||||
model_response = chat.stream_call()
|
||||
msg = "[LLM_ERROR]: llm server has no output, maybe your prompt template is wrong."
|
||||
if not CFG.NEW_SERVER_MODE:
|
||||
for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||
if chunk:
|
||||
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
|
||||
chunk, chat.skip_echo_len
|
||||
)
|
||||
msg = msg.replace("\n", "\\n")
|
||||
yield f"data:{msg}\n\n"
|
||||
await asyncio.sleep(0.02)
|
||||
else:
|
||||
for chunk in model_response:
|
||||
if chunk:
|
||||
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
|
||||
chunk, chat.skip_echo_len
|
||||
)
|
||||
|
||||
msg = msg.replace("\n", "\\n")
|
||||
yield f"data:{msg}\n\n"
|
||||
await asyncio.sleep(0.02)
|
||||
async for chunk in chat.stream_call():
|
||||
if chunk:
|
||||
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
|
||||
chunk, chat.skip_echo_len
|
||||
)
|
||||
|
||||
msg = msg.replace("\n", "\\n")
|
||||
yield f"data:{msg}\n\n"
|
||||
await asyncio.sleep(0.02)
|
||||
|
||||
chat.current_message.add_ai_message(msg)
|
||||
chat.current_message.add_view_message(msg)
|
||||
|
@ -1,26 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Sequence,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from pilot.utils import build_logger
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import asdict
|
||||
from typing import Any, Dict, TypeVar, Union
|
||||
|
||||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
from pilot.model.base import ModelOutput
|
||||
from pilot.utils import build_logger
|
||||
|
||||
T = TypeVar("T")
|
||||
ResponseTye = Union[str, bytes, ModelOutput]
|
||||
|
||||
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
||||
|
||||
CFG = Config()
|
||||
@ -46,11 +39,8 @@ class BaseOutputParser(ABC):
|
||||
code = sep.join(blocks)
|
||||
return code
|
||||
|
||||
def parse_model_stream_resp_ex(self, chunk, skip_echo_len):
|
||||
if b"\0" in chunk:
|
||||
chunk = chunk.replace(b"\0", b"")
|
||||
data = json.loads(chunk.decode())
|
||||
|
||||
def parse_model_stream_resp_ex(self, chunk: ResponseTye, skip_echo_len):
|
||||
data = _parse_model_response(chunk)
|
||||
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
|
||||
"""
|
||||
model_context = data.get("model_context")
|
||||
@ -103,8 +93,8 @@ class BaseOutputParser(ABC):
|
||||
output = data["text"] + f" (error_code: {data['error_code']})"
|
||||
yield output
|
||||
|
||||
def parse_model_nostream_resp(self, response, sep: str):
|
||||
resp_obj_ex = json.loads(response)
|
||||
def parse_model_nostream_resp(self, response: ResponseTye, sep: str):
|
||||
resp_obj_ex = _parse_model_response(response)
|
||||
if isinstance(resp_obj_ex, str):
|
||||
resp_obj_ex = json.loads(resp_obj_ex)
|
||||
if resp_obj_ex["error_code"] == 0:
|
||||
@ -240,3 +230,17 @@ class BaseOutputParser(ABC):
|
||||
output_parser_dict = super().dict()
|
||||
output_parser_dict["_type"] = self._type
|
||||
return output_parser_dict
|
||||
|
||||
|
||||
def _parse_model_response(response: ResponseTye):
|
||||
if isinstance(response, ModelOutput):
|
||||
resp_obj_ex = asdict(response)
|
||||
elif isinstance(response, str):
|
||||
resp_obj_ex = json.loads(response)
|
||||
elif isinstance(response, bytes):
|
||||
if b"\0" in response:
|
||||
response = response.replace(b"\0", b"")
|
||||
resp_obj_ex = json.loads(response.decode())
|
||||
else:
|
||||
raise ValueError(f"Unsupported response type {type(response)}")
|
||||
return resp_obj_ex
|
||||
|
@ -28,10 +28,7 @@ from pilot.memory.chat_history.mem_history import MemHistoryMemory
|
||||
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
|
||||
|
||||
from pilot.configs.model_config import LOGDIR, DATASETS_DIR
|
||||
from pilot.utils import (
|
||||
build_logger,
|
||||
server_error_msg,
|
||||
)
|
||||
from pilot.utils import build_logger, server_error_msg, get_or_create_event_loop
|
||||
from pilot.scene.base_message import (
|
||||
BaseMessage,
|
||||
SystemMessage,
|
||||
@ -158,7 +155,7 @@ class BaseChat(ABC):
|
||||
}
|
||||
return payload
|
||||
|
||||
def stream_call(self):
|
||||
async def stream_call(self):
|
||||
# TODO Retry when server connection error
|
||||
payload = self.__call_base()
|
||||
|
||||
@ -166,19 +163,10 @@ class BaseChat(ABC):
|
||||
logger.info(f"Requert: \n{payload}")
|
||||
ai_response_text = ""
|
||||
try:
|
||||
if not CFG.NEW_SERVER_MODE:
|
||||
response = requests.post(
|
||||
urljoin(CFG.MODEL_SERVER, "generate_stream"),
|
||||
headers=headers,
|
||||
json=payload,
|
||||
stream=True,
|
||||
timeout=120,
|
||||
)
|
||||
return response
|
||||
else:
|
||||
from pilot.server.llmserver import worker
|
||||
from pilot.model.worker.manager import worker_manager
|
||||
|
||||
return worker.generate_stream_gate(payload)
|
||||
async for output in worker_manager.generate_stream(payload):
|
||||
yield output
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
logger.error("model response parase faild!" + str(e))
|
||||
@ -188,34 +176,19 @@ class BaseChat(ABC):
|
||||
### store current conversation
|
||||
self.memory.append(self.current_message)
|
||||
|
||||
def nostream_call(self):
|
||||
async def nostream_call(self):
|
||||
payload = self.__call_base()
|
||||
logger.info(f"Requert: \n{payload}")
|
||||
ai_response_text = ""
|
||||
try:
|
||||
rsp_str = ""
|
||||
if not CFG.NEW_SERVER_MODE:
|
||||
rsp_obj = requests.post(
|
||||
urljoin(CFG.MODEL_SERVER, "generate"),
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=120,
|
||||
)
|
||||
rsp_str = rsp_obj.text
|
||||
else:
|
||||
###TODO no stream mode need independent
|
||||
from pilot.server.llmserver import worker
|
||||
from pilot.model.worker.manager import worker_manager
|
||||
|
||||
output = worker.generate_stream_gate(payload)
|
||||
for rsp in output:
|
||||
rsp = rsp.replace(b"\0", b"")
|
||||
rsp_str = rsp.decode()
|
||||
print("[TEST: output]:", rsp_str)
|
||||
model_output = await worker_manager.generate(payload)
|
||||
|
||||
### output parse
|
||||
ai_response_text = (
|
||||
self.prompt_template.output_parser.parse_model_nostream_resp(
|
||||
rsp_str, self.prompt_template.sep
|
||||
model_output, self.prompt_template.sep
|
||||
)
|
||||
)
|
||||
### model result deal
|
||||
@ -245,11 +218,31 @@ class BaseChat(ABC):
|
||||
self.memory.append(self.current_message)
|
||||
return self.current_ai_response()
|
||||
|
||||
def _blocking_stream_call(self):
|
||||
logger.warn(
|
||||
"_blocking_stream_call is only temporarily used in webserver and will be deleted soon, please use stream_call to replace it for higher performance"
|
||||
)
|
||||
loop = get_or_create_event_loop()
|
||||
async_gen = self.stream_call()
|
||||
while True:
|
||||
try:
|
||||
value = loop.run_until_complete(async_gen.__anext__())
|
||||
yield value
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
def _blocking_nostream_call(self):
|
||||
logger.warn(
|
||||
"_blocking_nostream_call is only temporarily used in webserver and will be deleted soon, please use nostream_call to replace it for higher performance"
|
||||
)
|
||||
loop = get_or_create_event_loop()
|
||||
return loop.run_until_complete(self.nostream_call())
|
||||
|
||||
def call(self):
|
||||
if self.prompt_template.stream_out:
|
||||
yield self.stream_call()
|
||||
yield self._blocking_stream_call()
|
||||
else:
|
||||
return self.nostream_call()
|
||||
return self._blocking_nostream_call()
|
||||
|
||||
def prepare(self):
|
||||
pass
|
||||
|
0
pilot/scripts/__init__.py
Normal file
0
pilot/scripts/__init__.py
Normal file
45
pilot/scripts/cli_scripts.py
Normal file
45
pilot/scripts/cli_scripts.py
Normal file
@ -0,0 +1,45 @@
|
||||
import sys
|
||||
import click
|
||||
import os
|
||||
import copy
|
||||
import logging
|
||||
|
||||
sys.path.append(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
|
||||
)
|
||||
|
||||
|
||||
@click.group()
|
||||
@click.option(
|
||||
"--log-level",
|
||||
required=False,
|
||||
type=str,
|
||||
default="warn",
|
||||
help="Log level",
|
||||
)
|
||||
@click.version_option()
|
||||
def cli(log_level: str):
|
||||
# TODO not working now
|
||||
logging.basicConfig(level=log_level, encoding="utf-8")
|
||||
|
||||
|
||||
def add_command_alias(command, name: str, hidden: bool = False):
|
||||
new_command = copy.deepcopy(command)
|
||||
new_command.hidden = hidden
|
||||
cli.add_command(new_command, name=name)
|
||||
|
||||
|
||||
try:
|
||||
from pilot.model.cli import model_cli_group
|
||||
|
||||
add_command_alias(model_cli_group, name="model")
|
||||
except ImportError as e:
|
||||
logging.warning(f"Integrating dbgpt model command line tool failed: {e}")
|
||||
|
||||
|
||||
def main():
|
||||
return cli()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -10,13 +10,7 @@ ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__fi
|
||||
sys.path.append(ROOT_PATH)
|
||||
import signal
|
||||
from pilot.configs.config import Config
|
||||
|
||||
# from pilot.configs.model_config import (
|
||||
# DATASETS_DIR,
|
||||
# KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
# LLM_MODEL_CONFIG,
|
||||
# LOGDIR,
|
||||
# )
|
||||
from pilot.configs.model_config import LLM_MODEL_CONFIG
|
||||
from pilot.utils import build_logger
|
||||
|
||||
from pilot.server.base import server_init
|
||||
@ -33,8 +27,9 @@ from pilot.openapi.api_v1.api_v1 import router as api_v1
|
||||
from pilot.openapi.base import validation_exception_handler
|
||||
from pilot.openapi.api_v1.editor.api_editor_v1 import router as api_editor_route_v1
|
||||
from pilot.commands.disply_type.show_chart_gen import static_message_img_path
|
||||
from pilot.model.worker.manager import initialize_worker_manager_in_client
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logging.basicConfig(level=logging.INFO, encoding="utf-8")
|
||||
|
||||
static_file_path = os.path.join(os.getcwd(), "server/static")
|
||||
|
||||
@ -80,12 +75,19 @@ app.include_router(api_editor_route_v1, prefix="/api")
|
||||
app.include_router(knowledge_router)
|
||||
# app.include_router(api_editor_route_v1)
|
||||
|
||||
os.makedirs(static_message_img_path, exist_ok=True)
|
||||
app.mount(
|
||||
"/images", StaticFiles(directory=static_message_img_path, html=True), name="static2"
|
||||
)
|
||||
app.mount("/_next/static", StaticFiles(directory=static_file_path + "/_next/static"))
|
||||
app.mount("/", StaticFiles(directory=static_file_path, html=True), name="static")
|
||||
|
||||
def mount_static_files(app):
|
||||
os.makedirs(static_message_img_path, exist_ok=True)
|
||||
app.mount(
|
||||
"/images",
|
||||
StaticFiles(directory=static_message_img_path, html=True),
|
||||
name="static2",
|
||||
)
|
||||
app.mount(
|
||||
"/_next/static", StaticFiles(directory=static_file_path + "/_next/static")
|
||||
)
|
||||
app.mount("/", StaticFiles(directory=static_file_path, html=True), name="static")
|
||||
|
||||
|
||||
app.add_exception_handler(RequestValidationError, validation_exception_handler)
|
||||
|
||||
@ -100,6 +102,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--port", type=int, default=5000)
|
||||
parser.add_argument("--concurrency-count", type=int, default=10)
|
||||
parser.add_argument("--share", default=False, action="store_true")
|
||||
parser.add_argument("--log-level", type=str, default="info")
|
||||
parser.add_argument(
|
||||
"-light",
|
||||
"--light",
|
||||
@ -112,17 +115,28 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
server_init(args)
|
||||
|
||||
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
|
||||
if not args.light:
|
||||
print("Model Unified Deployment Mode!")
|
||||
from pilot.server.llmserver import worker
|
||||
initialize_worker_manager_in_client(
|
||||
app=app, model_name=CFG.LLM_MODEL, model_path=model_path
|
||||
)
|
||||
|
||||
worker.start_check()
|
||||
CFG.NEW_SERVER_MODE = True
|
||||
else:
|
||||
# MODEL_SERVER is controller address now
|
||||
initialize_worker_manager_in_client(
|
||||
app=app,
|
||||
model_name=CFG.LLM_MODEL,
|
||||
model_path=model_path,
|
||||
run_locally=False,
|
||||
controller_addr=CFG.MODEL_SERVER,
|
||||
)
|
||||
CFG.SERVER_LIGHT_MODE = True
|
||||
|
||||
mount_static_files(app)
|
||||
import uvicorn
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
uvicorn.run(app, host="0.0.0.0", port=args.port, log_level=0)
|
||||
logging.basicConfig(level=logging.INFO, encoding="utf-8")
|
||||
uvicorn.run(app, host="0.0.0.0", port=args.port, log_level=args.log_level)
|
||||
signal.signal(signal.SIGINT, signal_handler())
|
||||
|
@ -1,19 +1,8 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from typing import List
|
||||
import platform
|
||||
|
||||
import uvicorn
|
||||
from fastapi import BackgroundTasks, FastAPI, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
# from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
|
||||
global_counter = 0
|
||||
model_semaphore = None
|
||||
@ -22,197 +11,26 @@ ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__fi
|
||||
sys.path.append(ROOT_PATH)
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import *
|
||||
from pilot.model.llm_out.vicuna_base_llm import get_embeddings
|
||||
from pilot.model.loader import ModelLoader, _get_model_real_path
|
||||
from pilot.server.chat_adapter import get_llm_chat_adapter
|
||||
from pilot.scene.base_message import ModelMessage
|
||||
from pilot.configs.model_config import LLM_MODEL_CONFIG
|
||||
from pilot.model.worker.manager import run_worker_manager
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class ModelWorker:
|
||||
def __init__(self, model_path, model_name, device):
|
||||
if model_path.endswith("/"):
|
||||
model_path = model_path[:-1]
|
||||
model_path = _get_model_real_path(model_name, model_path)
|
||||
# self.model_name = model_name or model_path.split("/")[-1]
|
||||
self.device = device
|
||||
print(
|
||||
f"Loading {model_name} LLM ModelServer in {device} from model path {model_path}! Please Wait......"
|
||||
)
|
||||
self.ml: ModelLoader = ModelLoader(model_path=model_path, model_name=model_name)
|
||||
self.model, self.tokenizer = self.ml.loader(
|
||||
load_8bit=CFG.IS_LOAD_8BIT,
|
||||
load_4bit=CFG.IS_LOAD_4BIT,
|
||||
debug=ISDEBUG,
|
||||
max_gpu_memory=CFG.MAX_GPU_MEMORY,
|
||||
)
|
||||
|
||||
if not isinstance(self.model, str):
|
||||
if hasattr(self.model, "config") and hasattr(
|
||||
self.model.config, "max_sequence_length"
|
||||
):
|
||||
self.context_len = self.model.config.max_sequence_length
|
||||
elif hasattr(self.model, "config") and hasattr(
|
||||
self.model.config, "max_position_embeddings"
|
||||
):
|
||||
self.context_len = self.model.config.max_position_embeddings
|
||||
|
||||
else:
|
||||
self.context_len = 2048
|
||||
|
||||
self.llm_chat_adapter = get_llm_chat_adapter(model_name, model_path)
|
||||
self.generate_stream_func = self.llm_chat_adapter.get_generate_stream_func(
|
||||
model_path
|
||||
)
|
||||
|
||||
def start_check(self):
|
||||
print("LLM Model Loading Success!")
|
||||
|
||||
def get_queue_length(self):
|
||||
if (
|
||||
model_semaphore is None
|
||||
or model_semaphore._value is None
|
||||
or model_semaphore._waiters is None
|
||||
):
|
||||
return 0
|
||||
else:
|
||||
(
|
||||
CFG.LIMIT_MODEL_CONCURRENCY
|
||||
- model_semaphore._value
|
||||
+ len(model_semaphore._waiters)
|
||||
)
|
||||
|
||||
def generate_stream_gate(self, params):
|
||||
try:
|
||||
# params adaptation
|
||||
params, model_context = self.llm_chat_adapter.model_adaptation(
|
||||
params, self.ml.model_path, prompt_template=self.ml.prompt_template
|
||||
)
|
||||
|
||||
for output in self.generate_stream_func(
|
||||
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
|
||||
):
|
||||
# Please do not open the output in production!
|
||||
# The gpt4all thread shares stdout with the parent process,
|
||||
# and opening it may affect the frontend output.
|
||||
if "windows" in platform.platform().lower():
|
||||
# Do not print the model output, because it may contain Emoji, there is a problem with the GBK encoding
|
||||
pass
|
||||
else:
|
||||
print("output: ", output)
|
||||
# return some model context to dgt-server
|
||||
ret = {"text": output, "error_code": 0, "model_context": model_context}
|
||||
yield json.dumps(ret).encode() + b"\0"
|
||||
|
||||
except torch.cuda.CudaError:
|
||||
ret = {"text": "**GPU OutOfMemory, Please Refresh.**", "error_code": 0}
|
||||
yield json.dumps(ret).encode() + b"\0"
|
||||
except Exception as e:
|
||||
ret = {
|
||||
"text": f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
|
||||
"error_code": 0,
|
||||
}
|
||||
yield json.dumps(ret).encode() + b"\0"
|
||||
|
||||
def get_embeddings(self, prompt):
|
||||
return get_embeddings(self.model, self.tokenizer, prompt)
|
||||
|
||||
|
||||
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
|
||||
worker = ModelWorker(model_path=model_path, model_name=CFG.LLM_MODEL, device=DEVICE)
|
||||
# worker = ModelWorker(model_path=model_path, model_name=CFG.LLM_MODEL, device=DEVICE)
|
||||
|
||||
app = FastAPI()
|
||||
# from pilot.openapi.knowledge.knowledge_controller import router
|
||||
#
|
||||
# app.include_router(router)
|
||||
#
|
||||
# origins = [
|
||||
# "http://localhost",
|
||||
# "http://localhost:8000",
|
||||
# "http://localhost:3000",
|
||||
# ]
|
||||
#
|
||||
# app.add_middleware(
|
||||
# CORSMiddleware,
|
||||
# allow_origins=origins,
|
||||
# allow_credentials=True,
|
||||
# allow_methods=["*"],
|
||||
# allow_headers=["*"],
|
||||
# )
|
||||
|
||||
|
||||
class PromptRequest(BaseModel):
|
||||
messages: List[ModelMessage]
|
||||
prompt: str
|
||||
temperature: float
|
||||
max_new_tokens: int
|
||||
model: str
|
||||
stop: str = None
|
||||
echo: bool = True
|
||||
|
||||
|
||||
class StreamRequest(BaseModel):
|
||||
model: str
|
||||
prompt: str
|
||||
temperature: float
|
||||
max_new_tokens: int
|
||||
stop: str
|
||||
|
||||
|
||||
class EmbeddingRequest(BaseModel):
|
||||
prompt: str
|
||||
|
||||
|
||||
def release_model_semaphore():
|
||||
model_semaphore.release()
|
||||
|
||||
|
||||
@app.post("/generate_stream")
|
||||
async def api_generate_stream(request: Request):
|
||||
global model_semaphore, global_counter
|
||||
global_counter += 1
|
||||
params = await request.json()
|
||||
|
||||
if model_semaphore is None:
|
||||
model_semaphore = asyncio.Semaphore(CFG.LIMIT_MODEL_CONCURRENCY)
|
||||
await model_semaphore.acquire()
|
||||
|
||||
generator = worker.generate_stream_gate(params)
|
||||
background_tasks = BackgroundTasks()
|
||||
background_tasks.add_task(release_model_semaphore)
|
||||
return StreamingResponse(generator, background=background_tasks)
|
||||
|
||||
|
||||
@app.post("/generate")
|
||||
def generate(prompt_request: PromptRequest) -> str:
|
||||
params = {
|
||||
"messages": prompt_request.messages,
|
||||
"prompt": prompt_request.prompt,
|
||||
"temperature": prompt_request.temperature,
|
||||
"max_new_tokens": prompt_request.max_new_tokens,
|
||||
"stop": prompt_request.stop,
|
||||
"echo": prompt_request.echo,
|
||||
}
|
||||
|
||||
rsp_str = ""
|
||||
output = worker.generate_stream_gate(params)
|
||||
for rsp in output:
|
||||
# rsp = rsp.decode("utf-8")
|
||||
rsp = rsp.replace(b"\0", b"")
|
||||
rsp_str = rsp.decode()
|
||||
|
||||
return rsp_str
|
||||
|
||||
|
||||
@app.post("/embedding")
|
||||
def embeddings(prompt_request: EmbeddingRequest):
|
||||
params = {"prompt": prompt_request.prompt}
|
||||
print("Received prompt: ", params["prompt"])
|
||||
output = worker.get_embeddings(params["prompt"])
|
||||
return {"response": [float(x) for x in output]}
|
||||
# @app.post("/embedding")
|
||||
# def embeddings(prompt_request: EmbeddingRequest):
|
||||
# params = {"prompt": prompt_request.prompt}
|
||||
# print("Received prompt: ", params["prompt"])
|
||||
# output = worker.get_embeddings(params["prompt"])
|
||||
# return {"response": [float(x) for x in output]}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(app, host="0.0.0.0", port=CFG.MODEL_PORT, log_level="info")
|
||||
run_worker_manager(
|
||||
model_name=CFG.LLM_MODEL,
|
||||
model_path=model_path,
|
||||
standalone=True,
|
||||
port=CFG.MODEL_PORT,
|
||||
)
|
||||
|
@ -2,7 +2,6 @@ import logging
|
||||
import os
|
||||
|
||||
import requests
|
||||
from playsound import playsound
|
||||
|
||||
from pilot.speech.base import VoiceBase
|
||||
|
||||
@ -23,6 +22,8 @@ class BrianSpeech(VoiceBase):
|
||||
Returns:
|
||||
bool: True if the request was successful, False otherwise
|
||||
"""
|
||||
from playsound import playsound
|
||||
|
||||
tts_url = (
|
||||
f"https://api.streamelements.com/kappa/v2/speech?voice=Brian&text={text}"
|
||||
)
|
||||
|
@ -2,7 +2,6 @@
|
||||
import os
|
||||
|
||||
import requests
|
||||
from playsound import playsound
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.speech.base import VoiceBase
|
||||
@ -70,6 +69,7 @@ class ElevenLabsSpeech(VoiceBase):
|
||||
bool: True if the request was successful, False otherwise
|
||||
"""
|
||||
from pilot.logs import logger
|
||||
from playsound import playsound
|
||||
|
||||
tts_url = (
|
||||
f"https://api.elevenlabs.io/v1/text-to-speech/{self._voices[voice_index]}"
|
||||
|
@ -2,7 +2,6 @@
|
||||
import os
|
||||
|
||||
import gtts
|
||||
from playsound import playsound
|
||||
|
||||
from pilot.speech.base import VoiceBase
|
||||
|
||||
@ -15,6 +14,8 @@ class GTTSVoice(VoiceBase):
|
||||
|
||||
def _speech(self, text: str, _: int = 0) -> bool:
|
||||
"""Play the given text."""
|
||||
from playsound import playsound
|
||||
|
||||
tts = gtts.gTTS(text)
|
||||
tts.save("speech.mp3")
|
||||
playsound("speech.mp3", True)
|
||||
|
@ -184,5 +184,5 @@ def _get_llm_response(query, db_input, dbsummary):
|
||||
chat: BaseChat = chat_factory.get_implementation(
|
||||
ChatScene.InnerChatDBSummary.value, **chat_param
|
||||
)
|
||||
res = chat.nostream_call()
|
||||
res = chat._blocking_nostream_call()
|
||||
return json.loads(res)["table"]
|
||||
|
9
pilot/utils/__init__.py
Normal file
9
pilot/utils/__init__.py
Normal file
@ -0,0 +1,9 @@
|
||||
from .utils import (
|
||||
get_gpu_memory,
|
||||
build_logger,
|
||||
StreamToLogger,
|
||||
disable_torch_init,
|
||||
pretty_print_semaphore,
|
||||
server_error_msg,
|
||||
get_or_create_event_loop,
|
||||
)
|
101
pilot/utils/api_utils.py
Normal file
101
pilot/utils/api_utils.py
Normal file
@ -0,0 +1,101 @@
|
||||
import httpx
|
||||
from inspect import signature
|
||||
import typing_inspect
|
||||
import logging
|
||||
from typing import get_type_hints, List, Type, TypeVar, Union, Optional, Tuple
|
||||
from dataclasses import is_dataclass, asdict
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def _extract_dataclass_from_generic(type_hint: Type[T]) -> Union[Type[T], None]:
|
||||
"""Extract actual dataclass from generic type hints like List[dataclass], Optional[dataclass], etc."""
|
||||
if typing_inspect.is_generic_type(type_hint) and typing_inspect.get_args(type_hint):
|
||||
return typing_inspect.get_args(type_hint)[0]
|
||||
return None
|
||||
|
||||
|
||||
def _api_remote(path, method="GET"):
|
||||
def decorator(func):
|
||||
return_type = get_type_hints(func).get("return")
|
||||
if return_type is None:
|
||||
raise TypeError("Return type must be annotated in the decorated function.")
|
||||
|
||||
actual_dataclass = _extract_dataclass_from_generic(return_type)
|
||||
logging.debug(
|
||||
f"return_type: {return_type}, actual_dataclass: {actual_dataclass}"
|
||||
)
|
||||
if not actual_dataclass:
|
||||
actual_dataclass = return_type
|
||||
sig = signature(func)
|
||||
|
||||
async def wrapper(self, *args, **kwargs):
|
||||
base_url = self.base_url # Get base_url from class instance
|
||||
|
||||
bound = sig.bind(self, *args, **kwargs)
|
||||
bound.apply_defaults()
|
||||
|
||||
formatted_url = base_url + path.format(**bound.arguments)
|
||||
|
||||
# Extract args names from signature, except "self"
|
||||
arg_names = list(sig.parameters.keys())[1:]
|
||||
|
||||
# Combine args and kwargs into a single dictionary
|
||||
combined_args = dict(zip(arg_names, args))
|
||||
combined_args.update(kwargs)
|
||||
|
||||
request_data = {}
|
||||
for key, value in combined_args.items():
|
||||
if is_dataclass(value):
|
||||
# Here, instead of adding it as a nested dictionary,
|
||||
# we set request_data directly to its dictionary representation.
|
||||
request_data = asdict(value)
|
||||
else:
|
||||
request_data[key] = value
|
||||
|
||||
request_params = {"method": method, "url": formatted_url}
|
||||
|
||||
if method in ["POST", "PUT", "PATCH"]:
|
||||
request_params["json"] = request_data
|
||||
else: # For GET, DELETE, etc.
|
||||
request_params["params"] = request_data
|
||||
|
||||
logging.info(
|
||||
f"request_params: {request_params}, args: {args}, kwargs: {kwargs}"
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.request(**request_params)
|
||||
|
||||
if response.status_code == 200:
|
||||
return _parse_response(
|
||||
response.json(), return_type, actual_dataclass
|
||||
)
|
||||
else:
|
||||
error_msg = f"Remote request error, error code: {response.status_code}, error msg: {response.text}"
|
||||
raise Exception(error_msg)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def _parse_response(json_response, return_type, actual_dataclass):
|
||||
# print(f'return_type.__origin__: {return_type.__origin__}, actual_dataclass: {actual_dataclass}, json_response: {json_response}')
|
||||
if is_dataclass(actual_dataclass):
|
||||
if return_type.__origin__ is list: # for List[dataclass]
|
||||
if isinstance(json_response, list):
|
||||
return [actual_dataclass(**item) for item in json_response]
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Expected list in response but got {type(json_response)}"
|
||||
)
|
||||
else:
|
||||
if isinstance(json_response, dict):
|
||||
return actual_dataclass(**json_response)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Expected dictionary in response but got {type(json_response)}"
|
||||
)
|
||||
else:
|
||||
return json_response
|
27
pilot/utils/model_utils.py
Normal file
27
pilot/utils/model_utils.py
Normal file
@ -0,0 +1,27 @@
|
||||
import logging
|
||||
|
||||
|
||||
def _clear_torch_cache(device="cuda"):
|
||||
import gc
|
||||
|
||||
import torch
|
||||
|
||||
gc.collect()
|
||||
if device != "cpu":
|
||||
if torch.has_mps:
|
||||
try:
|
||||
from torch.mps import empty_cache
|
||||
|
||||
empty_cache()
|
||||
except Exception as e:
|
||||
logging.warn(f"Clear mps torch cache error, {str(e)}")
|
||||
elif torch.has_cuda:
|
||||
device_count = torch.cuda.device_count()
|
||||
for device_id in range(device_count):
|
||||
cuda_device = f"cuda:{device_id}"
|
||||
logging.info(f"Clear torch cache of device: {cuda_device}")
|
||||
with torch.cuda.device(cuda_device):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
else:
|
||||
logging.info("No cuda or mps, not support clear torch cache yet")
|
26
pilot/utils/module_utils.py
Normal file
26
pilot/utils/module_utils.py
Normal file
@ -0,0 +1,26 @@
|
||||
from typing import Type
|
||||
from importlib import import_module
|
||||
|
||||
|
||||
def import_from_string(module_path: str):
|
||||
try:
|
||||
module_path, class_name = module_path.rsplit(".", 1)
|
||||
except ValueError:
|
||||
raise ImportError(f"{module_path} doesn't look like a module path")
|
||||
module = import_module(module_path)
|
||||
|
||||
try:
|
||||
return getattr(module, class_name)
|
||||
except AttributeError:
|
||||
raise ImportError(
|
||||
f'Module "{module_path}" does not define a "{class_name}" attribute/class'
|
||||
)
|
||||
|
||||
|
||||
def import_from_checked_string(module_path: str, supper_cls: Type):
|
||||
cls = import_from_string(module_path)
|
||||
if not issubclass(cls, supper_cls):
|
||||
raise ImportError(
|
||||
f'Module "{module_path}" does not the subclass of {str(supper_cls)}'
|
||||
)
|
||||
return cls
|
24
pilot/utils/net_utils.py
Normal file
24
pilot/utils/net_utils.py
Normal file
@ -0,0 +1,24 @@
|
||||
import socket
|
||||
import errno
|
||||
|
||||
|
||||
def _get_ip_address(address: str = "10.254.254.254:1") -> str:
|
||||
ip, port = address.split(":")
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
s.settimeout(0)
|
||||
curr_address = "127.0.0.1"
|
||||
try:
|
||||
# doesn't even have to be reachable
|
||||
s.connect((ip, int(port)))
|
||||
curr_address = s.getsockname()[0]
|
||||
except OSError as e:
|
||||
IP = "127.0.0.1"
|
||||
if e.errno == errno.ENETUNREACH:
|
||||
try:
|
||||
hostname = socket.getfqdn(socket.gethostname())
|
||||
curr_address = socket.gethostbyname(hostname)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
s.close()
|
||||
return curr_address
|
@ -5,9 +5,7 @@ import logging
|
||||
import logging.handlers
|
||||
import os
|
||||
import sys
|
||||
|
||||
import requests
|
||||
import torch
|
||||
import asyncio
|
||||
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
|
||||
@ -19,6 +17,8 @@ handler = None
|
||||
|
||||
|
||||
def get_gpu_memory(max_gpus=None):
|
||||
import torch
|
||||
|
||||
gpu_memory = []
|
||||
num_gpus = (
|
||||
torch.cuda.device_count()
|
||||
@ -73,7 +73,7 @@ def build_logger(logger_name, logger_filename):
|
||||
for name, item in logging.root.manager.loggerDict.items():
|
||||
if isinstance(item, logging.Logger):
|
||||
item.addHandler(handler)
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logging.basicConfig(level=logging.INFO, encoding="utf-8")
|
||||
|
||||
# Get logger
|
||||
logger = logging.getLogger(logger_name)
|
||||
@ -132,3 +132,15 @@ def pretty_print_semaphore(semaphore):
|
||||
if semaphore is None:
|
||||
return "None"
|
||||
return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
|
||||
|
||||
|
||||
def get_or_create_event_loop() -> asyncio.BaseEventLoop:
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except Exception as e:
|
||||
if not "no running event loop" in str(e):
|
||||
raise e
|
||||
logging.warning("Cant not get running event loop, create new event loop now")
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
return loop
|
@ -76,4 +76,7 @@ bardapi==0.1.29
|
||||
# TODO moved to optional dependencies
|
||||
pymysql
|
||||
duckdb
|
||||
duckdb-engine
|
||||
duckdb-engine
|
||||
|
||||
# cli
|
||||
prettytable
|
4
setup.py
4
setup.py
@ -267,7 +267,7 @@ def llama_cpp_python_cuda_requires():
|
||||
llama_cpp_version = "0.1.77"
|
||||
py_version = "cp310"
|
||||
os_pkg_name = "linux_x86_64" if os_type == OSType.LINUX else "win_amd64"
|
||||
extra_index_url = f"{base_url}/llama_cpp_python_cuda-{llama_cpp_version}+{device}{cpu_avx}-{py_version}-{py_version}-{os_pkg_name}.whl"
|
||||
extra_index_url = f"{base_url}/llama_cpp_python_cuda-{llama_cpp_version}+{device}-{py_version}-{py_version}-{os_pkg_name}.whl"
|
||||
extra_index_url, _ = encode_url(extra_index_url)
|
||||
print(f"Install llama_cpp_python_cuda from {extra_index_url}")
|
||||
|
||||
@ -361,7 +361,7 @@ setuptools.setup(
|
||||
extras_require=setup_spec.extras,
|
||||
entry_points={
|
||||
"console_scripts": [
|
||||
"dbgpt_server=pilot.server:webserver",
|
||||
"dbgpt=pilot.scripts.cli_scripts:main",
|
||||
],
|
||||
},
|
||||
)
|
||||
|
@ -25,7 +25,6 @@ sys.path.append(
|
||||
|
||||
from pilot.configs.model_config import DATASETS_DIR
|
||||
|
||||
from tools.cli.knowledge_client import knowledge_init
|
||||
|
||||
API_ADDRESS: str = "http://127.0.0.1:5000"
|
||||
|
||||
@ -97,6 +96,8 @@ def knowledge(
|
||||
verbose: bool,
|
||||
):
|
||||
"""Knowledge command line tool"""
|
||||
from tools.cli.knowledge_client import knowledge_init
|
||||
|
||||
knowledge_init(
|
||||
API_ADDRESS,
|
||||
vector_name,
|
||||
|
Loading…
Reference in New Issue
Block a user