mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-18 00:07:45 +00:00
feat(web): Add incremental response to streaming response for /v1/chat/completion
request (#611)
Close #610
This commit is contained in:
commit
3590d7bab4
@ -8,7 +8,7 @@ from pilot.model.loader import ModelLoader, _get_model_real_path
|
|||||||
from pilot.model.parameter import ModelParameters
|
from pilot.model.parameter import ModelParameters
|
||||||
from pilot.model.cluster.worker_base import ModelWorker
|
from pilot.model.cluster.worker_base import ModelWorker
|
||||||
from pilot.server.chat_adapter import get_llm_chat_adapter, BaseChatAdpter
|
from pilot.server.chat_adapter import get_llm_chat_adapter, BaseChatAdpter
|
||||||
from pilot.utils.model_utils import _clear_torch_cache
|
from pilot.utils.model_utils import _clear_model_cache
|
||||||
from pilot.utils.parameter_utils import EnvArgumentParser
|
from pilot.utils.parameter_utils import EnvArgumentParser
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -87,7 +87,7 @@ class DefaultModelWorker(ModelWorker):
|
|||||||
del self.tokenizer
|
del self.tokenizer
|
||||||
self.model = None
|
self.model = None
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
_clear_torch_cache(self._model_params.device)
|
_clear_model_cache(self._model_params.device)
|
||||||
|
|
||||||
def generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
|
def generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
|
||||||
torch_imported = False
|
torch_imported = False
|
||||||
|
@ -11,7 +11,7 @@ from pilot.model.parameter import (
|
|||||||
)
|
)
|
||||||
from pilot.model.cluster.worker_base import ModelWorker
|
from pilot.model.cluster.worker_base import ModelWorker
|
||||||
from pilot.model.cluster.embedding.loader import EmbeddingLoader
|
from pilot.model.cluster.embedding.loader import EmbeddingLoader
|
||||||
from pilot.utils.model_utils import _clear_torch_cache
|
from pilot.utils.model_utils import _clear_model_cache
|
||||||
from pilot.utils.parameter_utils import EnvArgumentParser
|
from pilot.utils.parameter_utils import EnvArgumentParser
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -79,7 +79,7 @@ class EmbeddingsModelWorker(ModelWorker):
|
|||||||
return
|
return
|
||||||
del self._embeddings_impl
|
del self._embeddings_impl
|
||||||
self._embeddings_impl = None
|
self._embeddings_impl = None
|
||||||
_clear_torch_cache(self._model_params.device)
|
_clear_model_cache(self._model_params.device)
|
||||||
|
|
||||||
def generate_stream(self, params: Dict):
|
def generate_stream(self, params: Dict):
|
||||||
"""Generate stream result, chat scene"""
|
"""Generate stream result, chat scene"""
|
||||||
|
@ -18,6 +18,7 @@ from pilot.logs import logger
|
|||||||
def _check_multi_gpu_or_4bit_quantization(model_params: ModelParameters):
|
def _check_multi_gpu_or_4bit_quantization(model_params: ModelParameters):
|
||||||
# TODO: vicuna-v1.5 8-bit quantization info is slow
|
# TODO: vicuna-v1.5 8-bit quantization info is slow
|
||||||
# TODO: support wizardlm quantization, see: https://huggingface.co/WizardLM/WizardLM-13B-V1.2/discussions/5
|
# TODO: support wizardlm quantization, see: https://huggingface.co/WizardLM/WizardLM-13B-V1.2/discussions/5
|
||||||
|
# TODO: support internlm quantization
|
||||||
model_name = model_params.model_name.lower()
|
model_name = model_params.model_name.lower()
|
||||||
supported_models = ["llama", "baichuan", "vicuna"]
|
supported_models = ["llama", "baichuan", "vicuna"]
|
||||||
return any(m in model_name for m in supported_models)
|
return any(m in model_name for m in supported_models)
|
||||||
|
@ -26,6 +26,9 @@ from pilot.openapi.api_view_model import (
|
|||||||
ConversationVo,
|
ConversationVo,
|
||||||
MessageVo,
|
MessageVo,
|
||||||
ChatSceneVo,
|
ChatSceneVo,
|
||||||
|
ChatCompletionResponseStreamChoice,
|
||||||
|
DeltaMessage,
|
||||||
|
ChatCompletionStreamResponse,
|
||||||
)
|
)
|
||||||
from pilot.connections.db_conn_info import DBConfig, DbTypeInfo
|
from pilot.connections.db_conn_info import DBConfig, DbTypeInfo
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
@ -383,7 +386,7 @@ async def chat_completions(dialogue: ConversationVo = Body()):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
stream_generator(chat),
|
stream_generator(chat, dialogue.incremental, dialogue.model_name),
|
||||||
headers=headers,
|
headers=headers,
|
||||||
media_type="text/plain",
|
media_type="text/plain",
|
||||||
)
|
)
|
||||||
@ -421,19 +424,48 @@ async def no_stream_generator(chat):
|
|||||||
yield f"data: {msg}\n\n"
|
yield f"data: {msg}\n\n"
|
||||||
|
|
||||||
|
|
||||||
async def stream_generator(chat):
|
async def stream_generator(chat, incremental: bool, model_name: str):
|
||||||
|
"""Generate streaming responses
|
||||||
|
|
||||||
|
Our goal is to generate an openai-compatible streaming responses.
|
||||||
|
Currently, the incremental response is compatible, and the full response will be transformed in the future.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat (BaseChat): Chat instance.
|
||||||
|
incremental (bool): Used to control whether the content is returned incrementally or in full each time.
|
||||||
|
model_name (str): The model name
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
_type_: streaming responses
|
||||||
|
"""
|
||||||
msg = "[LLM_ERROR]: llm server has no output, maybe your prompt template is wrong."
|
msg = "[LLM_ERROR]: llm server has no output, maybe your prompt template is wrong."
|
||||||
|
|
||||||
|
stream_id = f"chatcmpl-{str(uuid.uuid1())}"
|
||||||
|
previous_response = ""
|
||||||
async for chunk in chat.stream_call():
|
async for chunk in chat.stream_call():
|
||||||
if chunk:
|
if chunk:
|
||||||
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
|
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
|
||||||
chunk, chat.skip_echo_len
|
chunk, chat.skip_echo_len
|
||||||
)
|
)
|
||||||
|
msg = msg.replace("\ufffd", "")
|
||||||
msg = msg.replace("\n", "\\n")
|
if incremental:
|
||||||
yield f"data:{msg}\n\n"
|
incremental_output = msg[len(previous_response) :]
|
||||||
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
|
index=0,
|
||||||
|
delta=DeltaMessage(role="assistant", content=incremental_output),
|
||||||
|
)
|
||||||
|
chunk = ChatCompletionStreamResponse(
|
||||||
|
id=stream_id, choices=[choice_data], model=model_name
|
||||||
|
)
|
||||||
|
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
|
||||||
|
else:
|
||||||
|
# TODO generate an openai-compatible streaming responses
|
||||||
|
msg = msg.replace("\n", "\\n")
|
||||||
|
yield f"data:{msg}\n\n"
|
||||||
|
previous_response = msg
|
||||||
await asyncio.sleep(0.02)
|
await asyncio.sleep(0.02)
|
||||||
|
if incremental:
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
chat.current_message.add_ai_message(msg)
|
chat.current_message.add_ai_message(msg)
|
||||||
chat.current_message.add_view_message(msg)
|
chat.current_message.add_view_message(msg)
|
||||||
chat.memory.append(chat.current_message)
|
chat.memory.append(chat.current_message)
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing import TypeVar, Generic, Any
|
from typing import TypeVar, Generic, Any, Optional, Literal, List
|
||||||
|
import uuid
|
||||||
|
import time
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
@ -59,6 +61,11 @@ class ConversationVo(BaseModel):
|
|||||||
"""
|
"""
|
||||||
model_name: str = None
|
model_name: str = None
|
||||||
|
|
||||||
|
"""Used to control whether the content is returned incrementally or in full each time.
|
||||||
|
If this parameter is not provided, the default is full return.
|
||||||
|
"""
|
||||||
|
incremental: bool = False
|
||||||
|
|
||||||
|
|
||||||
class MessageVo(BaseModel):
|
class MessageVo(BaseModel):
|
||||||
"""
|
"""
|
||||||
@ -83,3 +90,21 @@ class MessageVo(BaseModel):
|
|||||||
model_name
|
model_name
|
||||||
"""
|
"""
|
||||||
model_name: str
|
model_name: str
|
||||||
|
|
||||||
|
|
||||||
|
class DeltaMessage(BaseModel):
|
||||||
|
role: Optional[str] = None
|
||||||
|
content: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||||
|
index: int
|
||||||
|
delta: DeltaMessage
|
||||||
|
finish_reason: Optional[Literal["stop", "length"]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionStreamResponse(BaseModel):
|
||||||
|
id: str = Field(default_factory=lambda: f"chatcmpl-{str(uuid.uuid1())}")
|
||||||
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
|
model: str
|
||||||
|
choices: List[ChatCompletionResponseStreamChoice]
|
||||||
|
@ -1,10 +1,22 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _clear_model_cache(device="cuda"):
|
||||||
|
try:
|
||||||
|
# clear torch cache
|
||||||
|
import torch
|
||||||
|
|
||||||
|
_clear_torch_cache(device)
|
||||||
|
except ImportError:
|
||||||
|
logger.warn("Torch not installed, skip clear torch cache")
|
||||||
|
# TODO clear other cache
|
||||||
|
|
||||||
|
|
||||||
def _clear_torch_cache(device="cuda"):
|
def _clear_torch_cache(device="cuda"):
|
||||||
import gc
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import gc
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
if device != "cpu":
|
if device != "cpu":
|
||||||
@ -14,14 +26,14 @@ def _clear_torch_cache(device="cuda"):
|
|||||||
|
|
||||||
empty_cache()
|
empty_cache()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warn(f"Clear mps torch cache error, {str(e)}")
|
logger.warn(f"Clear mps torch cache error, {str(e)}")
|
||||||
elif torch.has_cuda:
|
elif torch.has_cuda:
|
||||||
device_count = torch.cuda.device_count()
|
device_count = torch.cuda.device_count()
|
||||||
for device_id in range(device_count):
|
for device_id in range(device_count):
|
||||||
cuda_device = f"cuda:{device_id}"
|
cuda_device = f"cuda:{device_id}"
|
||||||
logging.info(f"Clear torch cache of device: {cuda_device}")
|
logger.info(f"Clear torch cache of device: {cuda_device}")
|
||||||
with torch.cuda.device(cuda_device):
|
with torch.cuda.device(cuda_device):
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
torch.cuda.ipc_collect()
|
torch.cuda.ipc_collect()
|
||||||
else:
|
else:
|
||||||
logging.info("No cuda or mps, not support clear torch cache yet")
|
logger.info("No cuda or mps, not support clear torch cache yet")
|
||||||
|
Loading…
Reference in New Issue
Block a user