feat(web): Add incremental response to streaming response for /v1/chat/completion request (#611)

Close #610
This commit is contained in:
Aries-ckt 2023-09-21 18:50:55 +08:00 committed by GitHub
commit 3590d7bab4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 86 additions and 16 deletions

View File

@ -8,7 +8,7 @@ from pilot.model.loader import ModelLoader, _get_model_real_path
from pilot.model.parameter import ModelParameters
from pilot.model.cluster.worker_base import ModelWorker
from pilot.server.chat_adapter import get_llm_chat_adapter, BaseChatAdpter
from pilot.utils.model_utils import _clear_torch_cache
from pilot.utils.model_utils import _clear_model_cache
from pilot.utils.parameter_utils import EnvArgumentParser
logger = logging.getLogger(__name__)
@ -87,7 +87,7 @@ class DefaultModelWorker(ModelWorker):
del self.tokenizer
self.model = None
self.tokenizer = None
_clear_torch_cache(self._model_params.device)
_clear_model_cache(self._model_params.device)
def generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
torch_imported = False

View File

@ -11,7 +11,7 @@ from pilot.model.parameter import (
)
from pilot.model.cluster.worker_base import ModelWorker
from pilot.model.cluster.embedding.loader import EmbeddingLoader
from pilot.utils.model_utils import _clear_torch_cache
from pilot.utils.model_utils import _clear_model_cache
from pilot.utils.parameter_utils import EnvArgumentParser
logger = logging.getLogger(__name__)
@ -79,7 +79,7 @@ class EmbeddingsModelWorker(ModelWorker):
return
del self._embeddings_impl
self._embeddings_impl = None
_clear_torch_cache(self._model_params.device)
_clear_model_cache(self._model_params.device)
def generate_stream(self, params: Dict):
"""Generate stream result, chat scene"""

View File

@ -18,6 +18,7 @@ from pilot.logs import logger
def _check_multi_gpu_or_4bit_quantization(model_params: ModelParameters):
# TODO: vicuna-v1.5 8-bit quantization info is slow
# TODO: support wizardlm quantization, see: https://huggingface.co/WizardLM/WizardLM-13B-V1.2/discussions/5
# TODO: support internlm quantization
model_name = model_params.model_name.lower()
supported_models = ["llama", "baichuan", "vicuna"]
return any(m in model_name for m in supported_models)

View File

@ -26,6 +26,9 @@ from pilot.openapi.api_view_model import (
ConversationVo,
MessageVo,
ChatSceneVo,
ChatCompletionResponseStreamChoice,
DeltaMessage,
ChatCompletionStreamResponse,
)
from pilot.connections.db_conn_info import DBConfig, DbTypeInfo
from pilot.configs.config import Config
@ -383,7 +386,7 @@ async def chat_completions(dialogue: ConversationVo = Body()):
)
else:
return StreamingResponse(
stream_generator(chat),
stream_generator(chat, dialogue.incremental, dialogue.model_name),
headers=headers,
media_type="text/plain",
)
@ -421,19 +424,48 @@ async def no_stream_generator(chat):
yield f"data: {msg}\n\n"
async def stream_generator(chat):
async def stream_generator(chat, incremental: bool, model_name: str):
"""Generate streaming responses
Our goal is to generate an openai-compatible streaming responses.
Currently, the incremental response is compatible, and the full response will be transformed in the future.
Args:
chat (BaseChat): Chat instance.
incremental (bool): Used to control whether the content is returned incrementally or in full each time.
model_name (str): The model name
Yields:
_type_: streaming responses
"""
msg = "[LLM_ERROR]: llm server has no output, maybe your prompt template is wrong."
stream_id = f"chatcmpl-{str(uuid.uuid1())}"
previous_response = ""
async for chunk in chat.stream_call():
if chunk:
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
chunk, chat.skip_echo_len
)
msg = msg.replace("\n", "\\n")
yield f"data:{msg}\n\n"
msg = msg.replace("\ufffd", "")
if incremental:
incremental_output = msg[len(previous_response) :]
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant", content=incremental_output),
)
chunk = ChatCompletionStreamResponse(
id=stream_id, choices=[choice_data], model=model_name
)
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
else:
# TODO generate an openai-compatible streaming responses
msg = msg.replace("\n", "\\n")
yield f"data:{msg}\n\n"
previous_response = msg
await asyncio.sleep(0.02)
if incremental:
yield "data: [DONE]\n\n"
chat.current_message.add_ai_message(msg)
chat.current_message.add_view_message(msg)
chat.memory.append(chat.current_message)

View File

@ -1,5 +1,7 @@
from pydantic import BaseModel, Field
from typing import TypeVar, Generic, Any
from typing import TypeVar, Generic, Any, Optional, Literal, List
import uuid
import time
T = TypeVar("T")
@ -59,6 +61,11 @@ class ConversationVo(BaseModel):
"""
model_name: str = None
"""Used to control whether the content is returned incrementally or in full each time.
If this parameter is not provided, the default is full return.
"""
incremental: bool = False
class MessageVo(BaseModel):
"""
@ -83,3 +90,21 @@ class MessageVo(BaseModel):
model_name
"""
model_name: str
class DeltaMessage(BaseModel):
role: Optional[str] = None
content: Optional[str] = None
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length"]] = None
class ChatCompletionStreamResponse(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{str(uuid.uuid1())}")
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseStreamChoice]

View File

@ -1,10 +1,22 @@
import logging
logger = logging.getLogger(__name__)
def _clear_model_cache(device="cuda"):
try:
# clear torch cache
import torch
_clear_torch_cache(device)
except ImportError:
logger.warn("Torch not installed, skip clear torch cache")
# TODO clear other cache
def _clear_torch_cache(device="cuda"):
import gc
import torch
import gc
gc.collect()
if device != "cpu":
@ -14,14 +26,14 @@ def _clear_torch_cache(device="cuda"):
empty_cache()
except Exception as e:
logging.warn(f"Clear mps torch cache error, {str(e)}")
logger.warn(f"Clear mps torch cache error, {str(e)}")
elif torch.has_cuda:
device_count = torch.cuda.device_count()
for device_id in range(device_count):
cuda_device = f"cuda:{device_id}"
logging.info(f"Clear torch cache of device: {cuda_device}")
logger.info(f"Clear torch cache of device: {cuda_device}")
with torch.cuda.device(cuda_device):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
else:
logging.info("No cuda or mps, not support clear torch cache yet")
logger.info("No cuda or mps, not support clear torch cache yet")