feat(model): Add new LLMClient and new build tools (#967)

This commit is contained in:
Fangyin Cheng
2023-12-23 16:33:01 +08:00
committed by GitHub
parent 12234ae258
commit 0c46c339ca
30 changed files with 1072 additions and 133 deletions

View File

@@ -0,0 +1,4 @@
from dbgpt.model.cluster.client import DefaultLLMClient
from dbgpt.model.utils.chatgpt_utils import OpenAILLMClient
__ALL__ = ["DefaultLLMClient", "OpenAILLMClient"]

View File

@@ -30,6 +30,15 @@ class EmbeddingsRequest(BaseModel):
span_id: str = None
class CountTokenRequest(BaseModel):
model: str
prompt: str
class ModelMetadataRequest(BaseModel):
model: str
class WorkerApplyRequest(BaseModel):
model: str
apply_type: WorkerApplyType

View File

@@ -0,0 +1,40 @@
from typing import AsyncIterator, List
import asyncio
from dbgpt.core.interface.llm import LLMClient, ModelRequest, ModelOutput, ModelMetadata
from dbgpt.model.parameter import WorkerType
from dbgpt.model.cluster.manager_base import WorkerManager
class DefaultLLMClient(LLMClient):
def __init__(self, worker_manager: WorkerManager):
self._worker_manager = worker_manager
async def generate(self, request: ModelRequest) -> ModelOutput:
return await self._worker_manager.generate(request.to_dict())
async def generate_stream(
self, request: ModelRequest
) -> AsyncIterator[ModelOutput]:
async for output in self._worker_manager.generate_stream(request.to_dict()):
yield output
async def models(self) -> List[ModelMetadata]:
instances = await self._worker_manager.get_all_model_instances(
WorkerType.LLM.value, healthy_only=True
)
query_metadata_task = []
for instance in instances:
worker_name, _ = WorkerType.parse_worker_key(instance.worker_key)
query_metadata_task.append(
self._worker_manager.get_model_metadata({"model": worker_name})
)
models: List[ModelMetadata] = await asyncio.gather(*query_metadata_task)
model_map = {}
for single_model in models:
model_map[single_model.model] = single_model
return [model_map[model_name] for model_name in sorted(model_map.keys())]
async def count_token(self, model: str, prompt: str) -> int:
return await self._worker_manager.count_token(
{"model": model, "prompt": prompt}
)

View File

@@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
from datetime import datetime
from concurrent.futures import Future
from dbgpt.component import BaseComponent, ComponentType, SystemApp
from dbgpt.core import ModelOutput
from dbgpt.core import ModelOutput, ModelMetadata
from dbgpt.model.base import WorkerSupportedModel, WorkerApplyOutput
from dbgpt.model.cluster.worker_base import ModelWorker
from dbgpt.model.cluster.base import WorkerStartupRequest, WorkerApplyRequest
@@ -38,6 +38,11 @@ class WorkerRunData:
port = self.port
return f"model {model_name}@{model_type}({host}:{port})"
@property
def stopped(self):
"""Check if the worker is stopped""" ""
return self.stop_event.is_set()
class WorkerManager(ABC):
@abstractmethod
@@ -62,6 +67,20 @@ class WorkerManager(ABC):
) -> List[WorkerRunData]:
"""Asynchronous get model instances by worker type and model name"""
@abstractmethod
async def get_all_model_instances(
self, worker_type: str, healthy_only: bool = True
) -> List[WorkerRunData]:
"""Asynchronous get all model instances
Args:
worker_type (str): worker type
healthy_only (bool, optional): only return healthy instances. Defaults to True.
Returns:
List[WorkerRunData]: worker run data list
"""
@abstractmethod
def sync_get_model_instances(
self, worker_type: str, model_name: str, healthy_only: bool = True
@@ -112,6 +131,25 @@ class WorkerManager(ABC):
We must provide a synchronous version.
"""
@abstractmethod
async def count_token(self, params: Dict) -> int:
"""Count token of prompt
Args:
params (Dict): parameters, eg. {"prompt": "hello", "model": "vicuna-13b-v1.5"}
Returns:
int: token count
"""
@abstractmethod
async def get_model_metadata(self, params: Dict) -> ModelMetadata:
"""Get model metadata
Args:
params (Dict): parameters, eg. {"model": "vicuna-13b-v1.5"}
"""
@abstractmethod
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
"""Worker apply"""

View File

@@ -3,7 +3,7 @@ import pytest_asyncio
from contextlib import contextmanager, asynccontextmanager
from typing import List, Iterator, Dict, Tuple
from dbgpt.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType
from dbgpt.core import ModelOutput
from dbgpt.core import ModelOutput, ModelMetadata
from dbgpt.model.cluster.worker_base import ModelWorker
from dbgpt.model.cluster.worker.manager import (
WorkerManager,
@@ -80,6 +80,14 @@ class MockModelWorker(ModelWorker):
output = out
return output
def count_token(self, prompt: str) -> int:
return len(prompt)
def get_model_metadata(self, params: Dict) -> ModelMetadata:
return ModelMetadata(
model=self.model_parameters.model_name,
)
def embeddings(self, params: Dict) -> List[List[float]]:
return self._embeddings

View File

@@ -8,7 +8,7 @@ import traceback
from dbgpt.configs.model_config import get_device
from dbgpt.model.adapter.base import LLMModelAdapter
from dbgpt.model.adapter.model_adapter import get_llm_model_adapter
from dbgpt.core import ModelOutput, ModelInferenceMetrics
from dbgpt.core import ModelOutput, ModelInferenceMetrics, ModelMetadata
from dbgpt.model.loader import ModelLoader, _get_model_real_path
from dbgpt.model.parameter import ModelParameters
from dbgpt.model.cluster.worker_base import ModelWorker
@@ -118,6 +118,8 @@ class DefaultModelWorker(ModelWorker):
f"Parse model max length {model_max_length} from model {self.model_name}."
)
self.context_len = model_max_length
elif hasattr(model_params, "max_context_size"):
self.context_len = model_params.max_context_size
def stop(self) -> None:
if not self.model:
@@ -186,6 +188,22 @@ class DefaultModelWorker(ModelWorker):
output = out
return output
def count_token(self, prompt: str) -> int:
return _try_to_count_token(prompt, self.tokenizer)
async def async_count_token(self, prompt: str) -> int:
# TODO if we deploy the model by vllm, it can't work, we should run transformer _try_to_count_token to async
raise NotImplementedError
def get_model_metadata(self, params: Dict) -> ModelMetadata:
return ModelMetadata(
model=self.model_name,
context_length=self.context_len,
)
async def async_get_model_metadata(self, params: Dict) -> ModelMetadata:
return self.get_model_metadata(params)
def embeddings(self, params: Dict) -> List[List[float]]:
raise NotImplementedError
@@ -436,6 +454,25 @@ def _new_metrics_from_model_output(
return metrics
def _try_to_count_token(prompt: str, tokenizer) -> int:
"""Try to count token of prompt
Args:
prompt (str): prompt
tokenizer ([type]): tokenizer
Returns:
int: token count, if error return -1
TODO: More implementation
"""
try:
return len(tokenizer(prompt).input_ids[0])
except Exception as e:
logger.warning(f"Count token error, detail: {e}, return -1")
return -1
def _try_import_torch():
global torch
global _torch_imported

View File

@@ -2,6 +2,7 @@ import logging
from typing import Dict, List, Type, Optional
from dbgpt.configs.model_config import get_device
from dbgpt.core import ModelMetadata
from dbgpt.model.loader import _get_model_real_path
from dbgpt.model.parameter import (
EmbeddingModelParameters,
@@ -89,6 +90,14 @@ class EmbeddingsModelWorker(ModelWorker):
"""Generate non stream result"""
raise NotImplementedError("Not supported generate for embeddings model")
def count_token(self, prompt: str) -> int:
raise NotImplementedError("Not supported count_token for embeddings model")
def get_model_metadata(self, params: Dict) -> ModelMetadata:
raise NotImplementedError(
"Not supported get_model_metadata for embeddings model"
)
def embeddings(self, params: Dict) -> List[List[float]]:
model = params.get("model")
logger.info(f"Receive embeddings request, model: {model}")

View File

@@ -15,7 +15,7 @@ from fastapi.responses import StreamingResponse
from dbgpt.component import SystemApp
from dbgpt.configs.model_config import LOGDIR
from dbgpt.core import ModelOutput
from dbgpt.core import ModelOutput, ModelMetadata
from dbgpt.model.base import (
ModelInstance,
WorkerApplyOutput,
@@ -271,6 +271,18 @@ class LocalWorkerManager(WorkerManager):
) -> List[WorkerRunData]:
return self.sync_get_model_instances(worker_type, model_name, healthy_only)
async def get_all_model_instances(
self, worker_type: str, healthy_only: bool = True
) -> List[WorkerRunData]:
instances = list(itertools.chain(*self.workers.values()))
result = []
for instance in instances:
name, wt = WorkerType.parse_worker_key(instance.worker_key)
if wt != worker_type or (healthy_only and instance.stopped):
continue
result.append(instance)
return result
def sync_get_model_instances(
self, worker_type: str, model_name: str, healthy_only: bool = True
) -> List[WorkerRunData]:
@@ -390,6 +402,43 @@ class LocalWorkerManager(WorkerManager):
worker_run_data = self._sync_get_model(params, worker_type="text2vec")
return worker_run_data.worker.embeddings(params)
async def count_token(self, params: Dict) -> int:
"""Count token of prompt"""
with root_tracer.start_span(
"WorkerManager.count_token", params.get("span_id")
) as span:
params["span_id"] = span.span_id
try:
worker_run_data = await self._get_model(params)
except Exception as e:
raise e
prompt = params.get("prompt")
async with worker_run_data.semaphore:
if worker_run_data.worker.support_async():
return await worker_run_data.worker.async_count_token(prompt)
else:
return await self.run_blocking_func(
worker_run_data.worker.count_token, prompt
)
async def get_model_metadata(self, params: Dict) -> ModelMetadata:
"""Get model metadata"""
with root_tracer.start_span(
"WorkerManager.get_model_metadata", params.get("span_id")
) as span:
params["span_id"] = span.span_id
try:
worker_run_data = await self._get_model(params)
except Exception as e:
raise e
async with worker_run_data.semaphore:
if worker_run_data.worker.support_async():
return await worker_run_data.worker.async_get_model_metadata(params)
else:
return await self.run_blocking_func(
worker_run_data.worker.get_model_metadata, params
)
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
apply_func: Callable[[WorkerApplyRequest], Awaitable[str]] = None
if apply_req.apply_type == WorkerApplyType.START:
@@ -601,6 +650,13 @@ class WorkerManagerAdapter(WorkerManager):
worker_type, model_name, healthy_only
)
async def get_all_model_instances(
self, worker_type: str, healthy_only: bool = True
) -> List[WorkerRunData]:
return await self.worker_manager.get_all_model_instances(
worker_type, healthy_only
)
def sync_get_model_instances(
self, worker_type: str, model_name: str, healthy_only: bool = True
) -> List[WorkerRunData]:
@@ -635,6 +691,12 @@ class WorkerManagerAdapter(WorkerManager):
def sync_embeddings(self, params: Dict) -> List[List[float]]:
return self.worker_manager.sync_embeddings(params)
async def count_token(self, params: Dict) -> int:
return await self.worker_manager.count_token(params)
async def get_model_metadata(self, params: Dict) -> ModelMetadata:
return await self.worker_manager.get_model_metadata(params)
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
return await self.worker_manager.worker_apply(apply_req)
@@ -696,6 +758,24 @@ async def api_embeddings(request: EmbeddingsRequest):
return await worker_manager.embeddings(params)
@router.post("/worker/count_token")
async def api_count_token(request: CountTokenRequest):
params = request.dict(exclude_none=True)
span_id = root_tracer.get_current_span_id()
if "span_id" not in params and span_id:
params["span_id"] = span_id
return await worker_manager.count_token(params)
@router.post("/worker/model_metadata")
async def api_get_model_metadata(request: ModelMetadataRequest):
params = request.dict(exclude_none=True)
span_id = root_tracer.get_current_span_id()
if "span_id" not in params and span_id:
params["span_id"] = span_id
return await worker_manager.get_model_metadata(params)
@router.post("/worker/apply")
async def api_worker_apply(request: WorkerApplyRequest):
return await worker_manager.worker_apply(request)

View File

@@ -133,22 +133,29 @@ class RemoteWorkerManager(LocalWorkerManager):
self, model_name: str, instances: List[ModelInstance]
) -> List[WorkerRunData]:
worker_instances = []
for ins in instances:
worker = RemoteModelWorker()
worker.load_worker(model_name, model_name, host=ins.host, port=ins.port)
wr = WorkerRunData(
host=ins.host,
port=ins.port,
worker_key=ins.model_name,
worker=worker,
worker_params=None,
model_params=None,
stop_event=asyncio.Event(),
semaphore=asyncio.Semaphore(100), # Not limit in client
for instance in instances:
worker_instances.append(
self._build_single_worker_instance(model_name, instance)
)
worker_instances.append(wr)
return worker_instances
def _build_single_worker_instance(self, model_name: str, instance: ModelInstance):
worker = RemoteModelWorker()
worker.load_worker(
model_name, model_name, host=instance.host, port=instance.port
)
wr = WorkerRunData(
host=instance.host,
port=instance.port,
worker_key=instance.model_name,
worker=worker,
worker_params=None,
model_params=None,
stop_event=asyncio.Event(),
semaphore=asyncio.Semaphore(100), # Not limit in client
)
return wr
async def get_model_instances(
self, worker_type: str, model_name: str, healthy_only: bool = True
) -> List[WorkerRunData]:
@@ -158,6 +165,20 @@ class RemoteWorkerManager(LocalWorkerManager):
)
return self._build_worker_instances(model_name, instances)
async def get_all_model_instances(
self, worker_type: str, healthy_only: bool = True
) -> List[WorkerRunData]:
instances: List[
ModelInstance
] = await self.model_registry.get_all_model_instances(healthy_only=healthy_only)
result = []
for instance in instances:
name, wt = WorkerType.parse_worker_key(instance.model_name)
if wt != worker_type:
continue
result.append(self._build_single_worker_instance(name, instance))
return result
def sync_get_model_instances(
self, worker_type: str, model_name: str, healthy_only: bool = True
) -> List[WorkerRunData]:

View File

@@ -1,7 +1,7 @@
import json
from typing import Dict, Iterator, List
import logging
from dbgpt.core import ModelOutput
from dbgpt.core import ModelOutput, ModelMetadata
from dbgpt.model.parameter import ModelParameters
from dbgpt.model.cluster.worker_base import ModelWorker
@@ -90,6 +90,44 @@ class RemoteModelWorker(ModelWorker):
)
return ModelOutput(**response.json())
def count_token(self, prompt: str) -> int:
raise NotImplementedError
async def async_count_token(self, prompt: str) -> int:
import httpx
async with httpx.AsyncClient() as client:
url = self.worker_addr + "/count_token"
logger.debug(f"Send async_count_token to url {url}, params: {prompt}")
response = await client.post(
url,
headers=self.headers,
json={"prompt": prompt},
timeout=self.timeout,
)
return response.json()
async def async_get_model_metadata(self, params: Dict) -> ModelMetadata:
"""Asynchronously get model metadata"""
import httpx
async with httpx.AsyncClient() as client:
url = self.worker_addr + "/model_metadata"
logger.debug(
f"Send async_get_model_metadata to url {url}, params: {params}"
)
response = await client.post(
url,
headers=self.headers,
json=params,
timeout=self.timeout,
)
return ModelMetadata(**response.json())
def get_model_metadata(self, params: Dict) -> ModelMetadata:
"""Get model metadata"""
raise NotImplementedError
def embeddings(self, params: Dict) -> List[List[float]]:
"""Get embeddings for input"""
import requests

View File

@@ -1,7 +1,7 @@
from abc import ABC, abstractmethod
from typing import Dict, Iterator, List, Type
from dbgpt.core import ModelOutput
from dbgpt.core import ModelOutput, ModelMetadata
from dbgpt.model.parameter import ModelParameters, WorkerType
from dbgpt.util.parameter_utils import (
ParameterDescription,
@@ -92,6 +92,42 @@ class ModelWorker(ABC):
"""Asynchronously generate output (non-stream) based on provided parameters."""
raise NotImplementedError
@abstractmethod
def count_token(self, prompt: str) -> int:
"""Count token of prompt
Args:
prompt (str): prompt
Returns:
int: token count
"""
async def async_count_token(self, prompt: str) -> int:
"""Asynchronously count token of prompt
Args:
prompt (str): prompt
Returns:
int: token count
"""
raise NotImplementedError
@abstractmethod
def get_model_metadata(self, params: Dict) -> ModelMetadata:
"""Get model metadata
Args:
params (Dict): parameters, eg. {"model": "vicuna-13b-v1.5"}
"""
async def async_get_model_metadata(self, params: Dict) -> ModelMetadata:
"""Asynchronously get model metadata
Args:
params (Dict): parameters, eg. {"model": "vicuna-13b-v1.5"}
"""
raise NotImplementedError
@abstractmethod
def embeddings(self, params: Dict) -> List[List[float]]:
"""

View File

@@ -70,7 +70,7 @@ def _initialize_openai_v1(params: ProxyModelParameters):
api_type = params.proxy_api_type or os.getenv("OPENAI_API_TYPE", "open_ai")
base_url = params.proxy_api_base or os.getenv(
"OPENAI_API_TYPE",
"OPENAI_API_BASE",
os.getenv("AZURE_OPENAI_ENDPOINT") if api_type == "azure" else None,
)
api_key = params.proxy_api_key or os.getenv(

View File

View File

@@ -0,0 +1,282 @@
from __future__ import annotations
import os
import logging
from dataclasses import dataclass
import importlib.metadata as metadata
from typing import List, Dict, Any, Optional, TYPE_CHECKING, Union, AsyncIterator
from dbgpt.core.interface.llm import ModelMetadata, LLMClient
from dbgpt.core.interface.llm import ModelOutput, ModelRequest
if TYPE_CHECKING:
import httpx
from httpx._types import ProxiesTypes
from openai import AsyncAzureOpenAI
from openai import AsyncOpenAI
ClientType = Union[AsyncAzureOpenAI, AsyncOpenAI]
logger = logging.getLogger(__name__)
@dataclass
class OpenAIParameters:
"""A class to represent a LLM model."""
api_type: str = "open_ai"
api_base: Optional[str] = None
api_key: Optional[str] = None
api_version: Optional[str] = None
full_url: Optional[str] = None
proxies: Optional["ProxiesTypes"] = None
def _initialize_openai_v1(init_params: OpenAIParameters):
try:
from openai import OpenAI
except ImportError as exc:
raise ValueError(
"Could not import python package: openai "
"Please install openai by command `pip install openai"
) from exc
if not metadata.version("openai") >= "1.0.0":
raise ImportError("Please upgrade openai package to version 1.0.0 or above")
api_type: Optional[str] = init_params.api_type
api_base: Optional[str] = init_params.api_base
api_key: Optional[str] = init_params.api_key
api_version: Optional[str] = init_params.api_version
full_url: Optional[str] = init_params.full_url
api_type = api_type or os.getenv("OPENAI_API_TYPE", "open_ai")
base_url = api_base or os.getenv(
"OPENAI_API_BASE",
os.getenv("AZURE_OPENAI_ENDPOINT") if api_type == "azure" else None,
)
api_key = api_key or os.getenv(
"OPENAI_API_KEY",
os.getenv("AZURE_OPENAI_KEY") if api_type == "azure" else None,
)
api_version = api_version or os.getenv("OPENAI_API_VERSION")
if not base_url and full_url:
base_url = full_url.split("/chat/completions")[0]
if api_key is None:
raise ValueError("api_key is required, please set OPENAI_API_KEY environment")
if base_url is None:
raise ValueError("base_url is required, please set OPENAI_BASE_URL environment")
if base_url.endswith("/"):
base_url = base_url[:-1]
openai_params = {
"api_key": api_key,
"base_url": base_url,
}
return openai_params, api_type, api_version
def _build_openai_client(init_params: OpenAIParameters):
import httpx
openai_params, api_type, api_version = _initialize_openai_v1(init_params)
if api_type == "azure":
from openai import AsyncAzureOpenAI
return AsyncAzureOpenAI(
api_key=openai_params["api_key"],
api_version=api_version,
azure_endpoint=openai_params["base_url"],
http_client=httpx.AsyncClient(proxies=init_params.proxies),
)
else:
from openai import AsyncOpenAI
return AsyncOpenAI(
**openai_params, http_client=httpx.AsyncClient(proxies=init_params.proxies)
)
class OpenAILLMClient(LLMClient):
"""An implementation of LLMClient using OpenAI API.
In order to have as few dependencies as possible, we directly use the http API.
"""
def __init__(
self,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
api_type: Optional[str] = None,
api_version: Optional[str] = None,
model: Optional[str] = "gpt-3.5-turbo",
proxies: Optional["ProxiesTypes"] = None,
timeout: Optional[int] = 240,
model_alias: Optional[str] = "chatgpt_proxyllm",
context_length: Optional[int] = 8192,
openai_client: Optional["ClientType"] = None,
openai_kwargs: Optional[Dict[str, Any]] = None,
):
self._init_params = OpenAIParameters(
api_type=api_type,
api_base=api_base,
api_key=api_key,
api_version=api_version,
proxies=proxies,
)
self._model = model
self._proxies = proxies
self._timeout = timeout
self._model_alias = model_alias
self._context_length = context_length
self._client = openai_client
self._openai_kwargs = openai_kwargs or {}
@property
def client(self) -> ClientType:
if self._client is None:
self._client = _build_openai_client(init_params=self._init_params)
return self._client
def _build_request(
self, request: ModelRequest, stream: Optional[bool] = False
) -> Dict[str, Any]:
payload = {"model": request.model or self._model, "stream": stream}
# Apply openai kwargs
for k, v in self._openai_kwargs.items():
payload[k] = v
if request.temperature:
payload["temperature"] = request.temperature
if request.max_new_tokens:
payload["max_tokens"] = request.max_new_tokens
return payload
async def generate(self, request: ModelRequest) -> ModelOutput:
messages = request.to_openai_messages()
payload = self._build_request(request)
try:
chat_completion = await self.client.chat.completions.create(
messages=messages, **payload
)
text = chat_completion.choices[0].message.content
usage = chat_completion.usage.dict()
return ModelOutput(text=text, error_code=0, usage=usage)
except Exception as e:
return ModelOutput(
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
error_code=1,
)
async def generate_stream(
self, request: ModelRequest
) -> AsyncIterator[ModelOutput]:
messages = request.to_openai_messages()
payload = self._build_request(request)
try:
chat_completion = await self.client.chat.completions.create(
messages=messages, **payload
)
text = ""
for r in chat_completion:
if len(r.choices) == 0:
continue
if r.choices[0].delta.content is not None:
content = r.choices[0].delta.content
text += content
yield ModelOutput(text=text, error_code=0)
except Exception as e:
yield ModelOutput(
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
error_code=1,
)
async def models(self) -> List[ModelMetadata]:
model_metadata = ModelMetadata(
model=self._model_alias,
context_length=await self.get_context_length(),
)
return [model_metadata]
async def get_context_length(self) -> int:
"""Get the context length of the model.
Returns:
int: The context length.
# TODO: This is a temporary solution. We should have a better way to get the context length.
eg. get real context length from the openai api.
"""
return self._context_length
async def count_token(self, model: str, prompt: str) -> int:
"""Count the number of tokens in a given prompt.
TODO: Get the real number of tokens from the openai api or tiktoken package
"""
raise NotImplementedError()
async def _to_openai_stream(
model: str, output_iter: AsyncIterator[ModelOutput]
) -> AsyncIterator[str]:
"""Convert the output_iter to openai stream format.
Args:
model (str): The model name.
output_iter (AsyncIterator[ModelOutput]): The output iterator.
"""
import json
import shortuuid
from fastchat.protocol.openai_api_protocol import (
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
DeltaMessage,
)
id = f"chatcmpl-{shortuuid.random()}"
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant"),
finish_reason=None,
)
chunk = ChatCompletionStreamResponse(id=id, choices=[choice_data], model=model)
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
previous_text = ""
finish_stream_events = []
async for model_output in output_iter:
model_output: ModelOutput = model_output
if model_output.error_code != 0:
yield f"data: {json.dumps(model_output.to_dict(), ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
return
decoded_unicode = model_output.text.replace("\ufffd", "")
delta_text = decoded_unicode[len(previous_text) :]
previous_text = (
decoded_unicode
if len(decoded_unicode) > len(previous_text)
else previous_text
)
if len(delta_text) == 0:
delta_text = None
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(content=delta_text),
finish_reason=model_output.finish_reason,
)
chunk = ChatCompletionStreamResponse(id=id, choices=[choice_data], model=model)
if delta_text is None:
if model_output.finish_reason is not None:
finish_stream_events.append(chunk)
continue
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
for finish_chunk in finish_stream_events:
yield f"data: {finish_chunk.json(exclude_none=True, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"