feat: Support llama-2 model

This commit is contained in:
FangYin Cheng 2023-07-20 21:43:20 +08:00
parent 412b104797
commit 168c754a3f
12 changed files with 671 additions and 13 deletions

View File

@ -11,7 +11,7 @@ cp .env.template .env
LLM_MODEL=vicuna-13b
MODEL_SERVER=http://127.0.0.1:8000
```
now we support models vicuna-13b, vicuna-7b, chatglm-6b, flan-t5-base, guanaco-33b-merged, falcon-40b, gorilla-7b.
now we support models vicuna-13b, vicuna-7b, chatglm-6b, flan-t5-base, guanaco-33b-merged, falcon-40b, gorilla-7b, llama-2-7b, llama-2-13b.
if you want use other model, such as chatglm-6b, you just need update .env config file.
```

View File

@ -47,6 +47,9 @@ LLM_MODEL_CONFIG = {
"gorilla-7b": os.path.join(MODEL_PATH, "gorilla-7b"),
"gptj-6b": os.path.join(MODEL_PATH, "ggml-gpt4all-j-v1.3-groovy.bin"),
"proxyllm": "proxyllm",
"llama-2-7b": os.path.join(MODEL_PATH, "Llama-2-7b-chat-hf"),
"llama-2-13b": os.path.join(MODEL_PATH, "Llama-2-13b-chat-hf"),
"llama-2-70b": os.path.join(MODEL_PATH, "Llama-2-70b-chat-hf"),
}
# Load model config

View File

@ -263,12 +263,26 @@ class ProxyllmAdapter(BaseLLMAdaper):
return "proxyllm", None
class Llama2Adapter(BaseLLMAdaper):
"""The model adapter for llama-2"""
def match(self, model_path: str):
return "llama-2" in model_path.lower()
def loader(self, model_path: str, from_pretrained_kwargs: dict):
model, tokenizer = super().loader(model_path, from_pretrained_kwargs)
model.config.eos_token_id = tokenizer.eos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
return model, tokenizer
register_llm_model_adapters(VicunaLLMAdapater)
register_llm_model_adapters(ChatGLMAdapater)
register_llm_model_adapters(GuanacoAdapter)
register_llm_model_adapters(FalconAdapater)
register_llm_model_adapters(GorillaAdapter)
register_llm_model_adapters(GPT4AllAdapter)
register_llm_model_adapters(Llama2Adapter)
# TODO Default support vicuna, other model need to tests and Evaluate
# just for test_py, remove this later

308
pilot/model/conversation.py Normal file
View File

@ -0,0 +1,308 @@
"""
Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
Conversation prompt templates.
"""
import dataclasses
from enum import auto, IntEnum
from typing import List, Any, Dict, Callable
class SeparatorStyle(IntEnum):
"""Separator styles."""
ADD_COLON_SINGLE = auto()
ADD_COLON_TWO = auto()
ADD_COLON_SPACE_SINGLE = auto()
NO_COLON_SINGLE = auto()
NO_COLON_TWO = auto()
ADD_NEW_LINE_SINGLE = auto()
LLAMA2 = auto()
CHATGLM = auto()
CHATML = auto()
CHATINTERN = auto()
DOLLY = auto()
RWKV = auto()
PHOENIX = auto()
ROBIN = auto()
@dataclasses.dataclass
class Conversation:
"""A class that manages prompt templates and keeps all conversation history."""
# The name of this template
name: str
# The system prompt
system: str
# Two roles
roles: List[str]
# All messages. Each item is (role, message).
messages: List[List[str]]
# The number of few shot examples
offset: int
# Separators
sep_style: SeparatorStyle
sep: str
sep2: str = None
# Stop criteria (the default one is EOS token)
stop_str: str = None
# Stops generation if meeting any token in this list
stop_token_ids: List[int] = None
# format system message
system_formatter: Callable = None
def get_prompt(self) -> str:
"""Get the prompt for generation."""
if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
ret = self.system + self.sep
for role, message in self.messages:
if message:
ret += role + ": " + message + self.sep
else:
ret += role + ":"
return ret
elif self.sep_style == SeparatorStyle.ADD_COLON_TWO:
seps = [self.sep, self.sep2]
ret = self.system + seps[0]
for i, (role, message) in enumerate(self.messages):
if message:
ret += role + ": " + message + seps[i % 2]
else:
ret += role + ":"
return ret
elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE:
ret = self.system + self.sep
for role, message in self.messages:
if message:
ret += role + ": " + message + self.sep
else:
ret += role + ": " # must be end with a space
return ret
elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE:
ret = "" if self.system == "" else self.system + self.sep
for role, message in self.messages:
if message:
ret += role + "\n" + message + self.sep
else:
ret += role + "\n"
return ret
elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
ret = self.system
for role, message in self.messages:
if message:
ret += role + message + self.sep
else:
ret += role
return ret
elif self.sep_style == SeparatorStyle.NO_COLON_TWO:
seps = [self.sep, self.sep2]
ret = self.system
for i, (role, message) in enumerate(self.messages):
if message:
ret += role + message + seps[i % 2]
else:
ret += role
return ret
elif self.sep_style == SeparatorStyle.RWKV:
ret = self.system
for i, (role, message) in enumerate(self.messages):
if message:
ret += (
role
+ ": "
+ message.replace("\r\n", "\n").replace("\n\n", "\n")
)
ret += "\n\n"
else:
ret += role + ":"
return ret
elif self.sep_style == SeparatorStyle.LLAMA2:
seps = [self.sep, self.sep2]
ret = ""
for i, (role, message) in enumerate(self.messages):
if message:
if i == 0:
ret += self.system + message
else:
ret += role + " " + message + seps[i % 2]
else:
ret += role
return ret
elif self.sep_style == SeparatorStyle.CHATGLM:
# source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
# source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
round_add_n = 1 if self.name == "chatglm2" else 0
if self.system:
ret = self.system + self.sep
else:
ret = ""
for i, (role, message) in enumerate(self.messages):
if i % 2 == 0:
ret += f"[Round {i//2 + round_add_n}]{self.sep}"
if message:
ret += f"{role}{message}{self.sep}"
else:
ret += f"{role}"
return ret
elif self.sep_style == SeparatorStyle.CHATML:
ret = "" if self.system == "" else self.system + self.sep + "\n"
for role, message in self.messages:
if message:
ret += role + "\n" + message + self.sep + "\n"
else:
ret += role + "\n"
return ret
elif self.sep_style == SeparatorStyle.CHATINTERN:
# source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771
seps = [self.sep, self.sep2]
ret = self.system
for i, (role, message) in enumerate(self.messages):
if i % 2 == 0:
ret += "<s>"
if message:
ret += role + ":" + message + seps[i % 2] + "\n"
else:
ret += role + ":"
return ret
elif self.sep_style == SeparatorStyle.DOLLY:
seps = [self.sep, self.sep2]
ret = self.system
for i, (role, message) in enumerate(self.messages):
if message:
ret += role + ":\n" + message + seps[i % 2]
if i % 2 == 1:
ret += "\n\n"
else:
ret += role + ":\n"
return ret
elif self.sep_style == SeparatorStyle.PHOENIX:
ret = self.system
for role, message in self.messages:
if message:
ret += role + ": " + "<s>" + message + "</s>"
else:
ret += role + ": " + "<s>"
return ret
elif self.sep_style == SeparatorStyle.ROBIN:
ret = self.system + self.sep
for role, message in self.messages:
if message:
ret += role + ":\n" + message + self.sep
else:
ret += role + ":\n"
return ret
else:
raise ValueError(f"Invalid style: {self.sep_style}")
def append_message(self, role: str, message: str):
"""Append a new message."""
self.messages.append([role, message])
def update_last_message(self, message: str):
"""Update the last output.
The last message is typically set to be None when constructing the prompt,
so we need to update it in-place after getting the response from a model.
"""
self.messages[-1][1] = message
def update_system_message(self, system_message: str):
"""Update system message"""
if self.system_formatter:
self.system = self.system_formatter(system_message)
else:
self.system = system_message
def to_gradio_chatbot(self):
"""Convert the conversation to gradio chatbot format."""
ret = []
for i, (role, msg) in enumerate(self.messages[self.offset :]):
if i % 2 == 0:
ret.append([msg, None])
else:
ret[-1][-1] = msg
return ret
def to_openai_api_messages(self):
"""Convert the conversation to OpenAI chat completion format."""
ret = [{"role": "system", "content": self.system}]
for i, (_, msg) in enumerate(self.messages[self.offset :]):
if i % 2 == 0:
ret.append({"role": "user", "content": msg})
else:
if msg is not None:
ret.append({"role": "assistant", "content": msg})
return ret
def copy(self):
return Conversation(
name=self.name,
system=self.system,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep,
sep2=self.sep2,
stop_str=self.stop_str,
stop_token_ids=self.stop_token_ids,
system_formatter=self.system_formatter,
)
def dict(self):
return {
"template_name": self.name,
"system": self.system,
"roles": self.roles,
"messages": self.messages,
"offset": self.offset,
}
# A global registry for all conversation templates
conv_templates: Dict[str, Conversation] = {}
def register_conv_template(template: Conversation, override: bool = False):
"""Register a new conversation template."""
if not override:
assert (
template.name not in conv_templates
), f"{template.name} has been registered."
conv_templates[template.name] = template
def get_conv_template(name: str) -> Conversation:
"""Get a conversation template."""
return conv_templates[name].copy()
# llama2 template
# reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212
register_conv_template(
Conversation(
name="llama-2",
system="<s>[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. "
"Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. "
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
"If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. "
"If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n",
roles=("[INST]", "[/INST]"),
messages=(),
offset=0,
sep_style=SeparatorStyle.LLAMA2,
sep=" ",
sep2=" </s><s>",
stop_token_ids=[2],
system_formatter=lambda msg: f"<s>[INST] <<SYS>>\n{msg}\n<</SYS>>\n\n",
)
)
# TODO Support other model conversation template

242
pilot/model/inference.py Normal file
View File

@ -0,0 +1,242 @@
"""
Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py
"""
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import gc
from typing import Iterable, Dict
import torch
import torch
from transformers.generation.logits_process import (
LogitsProcessorList,
RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
)
from pilot.model.llm_utils import is_sentence_complete, is_partial_stop
def prepare_logits_processor(
temperature: float, repetition_penalty: float, top_p: float, top_k: int
) -> LogitsProcessorList:
processor_list = LogitsProcessorList()
# TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op so we skip two cases.
if temperature >= 1e-5 and temperature != 1.0:
processor_list.append(TemperatureLogitsWarper(temperature))
if repetition_penalty > 1.0:
processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty))
if 1e-8 <= top_p < 1.0:
processor_list.append(TopPLogitsWarper(top_p))
if top_k > 0:
processor_list.append(TopKLogitsWarper(top_k))
return processor_list
@torch.inference_mode()
def generate_stream(
model,
tokenizer,
params: Dict,
device: str,
context_len: int,
stream_interval: int = 2,
judge_sent_end: bool = False,
):
# Read parameters
prompt = params["prompt"]
print(f"Prompt of model: \n{prompt}")
len_prompt = len(prompt)
temperature = float(params.get("temperature", 1.0))
repetition_penalty = float(params.get("repetition_penalty", 1.0))
top_p = float(params.get("top_p", 1.0))
top_k = int(params.get("top_k", -1)) # -1 means disable
max_new_tokens = int(params.get("max_new_tokens", 2048))
echo = bool(params.get("echo", True))
stop_str = params.get("stop", None)
stop_token_ids = params.get("stop_token_ids", None) or []
stop_token_ids.append(tokenizer.eos_token_id)
logits_processor = prepare_logits_processor(
temperature, repetition_penalty, top_p, top_k
)
input_ids = tokenizer(prompt).input_ids
if model.config.is_encoder_decoder:
max_src_len = context_len
else: # truncate
max_src_len = context_len - max_new_tokens - 1
input_ids = input_ids[-max_src_len:]
output_ids = list(input_ids)
input_echo_len = len(input_ids)
if model.config.is_encoder_decoder:
encoder_output = model.encoder(
input_ids=torch.as_tensor([input_ids], device=device)
)[0]
start_ids = torch.as_tensor(
[[model.generation_config.decoder_start_token_id]],
dtype=torch.int64,
device=device,
)
past_key_values = out = None
sent_interrupt = False
for i in range(max_new_tokens):
if i == 0: # prefill
if model.config.is_encoder_decoder:
out = model.decoder(
input_ids=start_ids,
encoder_hidden_states=encoder_output,
use_cache=True,
)
logits = model.lm_head(out[0])
else:
out = model(torch.as_tensor([input_ids], device=device), use_cache=True)
logits = out.logits
past_key_values = out.past_key_values
else: # decoding
if model.config.is_encoder_decoder:
out = model.decoder(
input_ids=torch.as_tensor(
[[token] if not sent_interrupt else output_ids], device=device
),
encoder_hidden_states=encoder_output,
use_cache=True,
past_key_values=past_key_values if not sent_interrupt else None,
)
sent_interrupt = False
logits = model.lm_head(out[0])
else:
out = model(
input_ids=torch.as_tensor(
[[token] if not sent_interrupt else output_ids], device=device
),
use_cache=True,
past_key_values=past_key_values if not sent_interrupt else None,
)
sent_interrupt = False
logits = out.logits
past_key_values = out.past_key_values
if logits_processor:
if repetition_penalty > 1.0:
tmp_output_ids = torch.as_tensor([output_ids], device=logits.device)
else:
tmp_output_ids = None
last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0]
else:
last_token_logits = logits[0, -1, :]
if device == "mps":
# Switch to CPU by avoiding some bugs in mps backend.
last_token_logits = last_token_logits.float().to("cpu")
if temperature < 1e-5 or top_p < 1e-8: # greedy
_, indices = torch.topk(last_token_logits, 2)
tokens = [int(index) for index in indices.tolist()]
else:
probs = torch.softmax(last_token_logits, dim=-1)
indices = torch.multinomial(probs, num_samples=2)
tokens = [int(token) for token in indices.tolist()]
token = tokens[0]
output_ids.append(token)
if token in stop_token_ids:
stopped = True
else:
stopped = False
# Yield the output tokens
if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped:
if echo:
tmp_output_ids = output_ids
rfind_start = len_prompt
else:
tmp_output_ids = output_ids[input_echo_len:]
rfind_start = 0
output = tokenizer.decode(
tmp_output_ids,
skip_special_tokens=True,
spaces_between_special_tokens=False,
clean_up_tokenization_spaces=True,
)
# TODO: For the issue of incomplete sentences interrupting output, apply a patch and others can also modify it to a more elegant way
if judge_sent_end and stopped and not is_sentence_complete(output):
if len(tokens) > 1:
token = tokens[1]
output_ids[-1] = token
else:
output_ids.pop()
stopped = False
sent_interrupt = True
partially_stopped = False
if stop_str:
if isinstance(stop_str, str):
pos = output.rfind(stop_str, rfind_start)
if pos != -1:
output = output[:pos]
stopped = True
else:
partially_stopped = is_partial_stop(output, stop_str)
elif isinstance(stop_str, Iterable):
for each_stop in stop_str:
pos = output.rfind(each_stop, rfind_start)
if pos != -1:
output = output[:pos]
stopped = True
break
else:
partially_stopped = is_partial_stop(output, each_stop)
if partially_stopped:
break
else:
raise ValueError("Invalid stop field type.")
# Prevent yielding partial stop sequence
if not partially_stopped:
yield output
# yield {
# "text": output,
# "usage": {
# "prompt_tokens": input_echo_len,
# "completion_tokens": i,
# "total_tokens": input_echo_len + i,
# },
# "finish_reason": None,
# }
if stopped:
break
# Finish stream event, which contains finish reason
if i == max_new_tokens - 1:
finish_reason = "length"
elif stopped:
finish_reason = "stop"
else:
finish_reason = None
yield output
# yield {
# "text": output,
# "usage": {
# "prompt_tokens": input_echo_len,
# "completion_tokens": i,
# "total_tokens": input_echo_len + i,
# },
# "finish_reason": finish_reason,
# }
# Clean
del past_key_values, out
gc.collect()
torch.cuda.empty_cache()

View File

@ -6,7 +6,7 @@ import requests
from typing import List
from pilot.configs.config import Config
from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
from pilot.scene.base_message import ModelMessage
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
CFG = Config()
@ -25,11 +25,11 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048)
messages: List[ModelMessage] = params["messages"]
# Add history conversation
for message in messages:
if message.role == "human":
if message.role == ModelMessageRoleType.HUMAN:
history.append({"role": "user", "content": message.content})
elif message.role == "system":
elif message.role == ModelMessageRoleType.SYSTEM:
history.append({"role": "system", "content": message.content})
elif message.role == "ai":
elif message.role == ModelMessageRoleType.AI:
history.append({"role": "assistant", "content": message.content})
else:
pass

View File

@ -10,7 +10,6 @@ from typing import List, Optional
from pilot.configs.config import Config
from pilot.model.base import Message
from pilot.server.llmserver import generate_output
def create_chat_completion(
@ -115,3 +114,17 @@ class Iteratorize:
def __exit__(self, exc_type, exc_val, exc_tb):
self.stop_now = True
def is_sentence_complete(output: str):
"""Check whether the output is a complete sentence."""
end_symbols = (".", "?", "!", "...", "", "", "", "", '"', "'", "")
return output.endswith(end_symbols)
def is_partial_stop(output: str, stop_str: str):
"""Check whether the output contains a partial stop str."""
for i in range(0, min(len(output), len(stop_str))):
if stop_str.startswith(output[-i:]):
return True
return False

View File

@ -53,8 +53,15 @@ class BaseOutputParser(ABC):
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
"""
model_context = data.get("model_context")
if model_context and "prompt_echo_len_char" in model_context:
prompt_echo_len_char = int(model_context.get("prompt_echo_len_char", -1))
if prompt_echo_len_char != -1:
skip_echo_len = prompt_echo_len_char
if data.get("error_code", 0) == 0:
if "vicuna" in CFG.LLM_MODEL:
if "vicuna" in CFG.LLM_MODEL or "llama-2" in CFG.LLM_MODEL:
# TODO Judging from model_context
# output = data["text"][skip_echo_len + 11:].strip()
output = data["text"][skip_echo_len:].strip()
elif "guanaco" in CFG.LLM_MODEL:

View File

@ -39,6 +39,7 @@ from pilot.scene.base_message import (
AIMessage,
ViewMessage,
ModelMessage,
ModelMessageRoleType,
)
from pilot.configs.config import Config
@ -258,7 +259,7 @@ class BaseChat(ABC):
if self.prompt_template.template_define:
messages.append(
ModelMessage(
role="system",
role=ModelMessageRoleType.SYSTEM,
content=self.prompt_template.template_define,
)
)

View File

@ -89,6 +89,14 @@ class ModelMessage(BaseModel):
content: str
class ModelMessageRoleType:
""" "Type of ModelMessage role"""
SYSTEM = "system"
HUMAN = "human"
AI = "ai"
class Generation(BaseModel):
"""Output of a single generation."""

View File

@ -2,8 +2,10 @@
# -*- coding: utf-8 -*-
from functools import cache
from typing import List
from typing import List, Dict, Tuple
from pilot.model.llm_out.vicuna_base_llm import generate_stream
from pilot.model.conversation import Conversation, get_conv_template
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
class BaseChatAdpter:
@ -17,6 +19,52 @@ class BaseChatAdpter:
"""Return the generate stream handler func"""
pass
def get_conv_template(self) -> Conversation:
return None
def model_adaptation(self, params: Dict) -> Tuple[Dict, Dict]:
"""Params adaptation"""
conv = self.get_conv_template()
messages = params.get("messages")
# Some model scontext to dbgpt server
model_context = {"prompt_echo_len_char": -1}
if not conv or not messages:
# Nothing to do
return params, model_context
conv = conv.copy()
system_messages = []
for message in messages:
role, content = None, None
if isinstance(message, ModelMessage):
role = message.role
content = message.content
elif isinstance(message, dict):
role = message["role"]
content = message["content"]
else:
raise ValueError(f"Invalid message type: {message}")
if role == ModelMessageRoleType.SYSTEM:
# Support for multiple system messages
system_messages.append(content)
elif role == ModelMessageRoleType.HUMAN:
conv.append_message(conv.roles[0], content)
elif role == ModelMessageRoleType.AI:
conv.append_message(conv.roles[1], content)
else:
raise ValueError(f"Unknown role: {role}")
if system_messages:
conv.update_system_message("".join(system_messages))
# Add a blank message for the assistant.
conv.append_message(conv.roles[1], None)
new_prompt = conv.get_prompt()
# Overwrite the original prompt
# TODO remote bos token and eos token from tokenizer_config.json of model
prompt_echo_len_char = len(new_prompt.replace("</s>", "").replace("<s>", ""))
model_context["prompt_echo_len_char"] = prompt_echo_len_char
params["prompt"] = new_prompt
return params, model_context
llm_model_chat_adapters: List[BaseChatAdpter] = []
@ -134,12 +182,26 @@ class GPT4AllChatAdapter(BaseChatAdpter):
return gpt4all_generate_stream
class Llama2ChatAdapter(BaseChatAdpter):
def match(self, model_path: str):
return "llama-2" in model_path.lower()
def get_conv_template(self) -> Conversation:
return get_conv_template("llama-2")
def get_generate_stream_func(self):
from pilot.model.inference import generate_stream
return generate_stream
register_llm_model_chat_adapter(VicunaChatAdapter)
register_llm_model_chat_adapter(ChatGLMChatAdapter)
register_llm_model_chat_adapter(GuanacoChatAdapter)
register_llm_model_chat_adapter(FalconChatAdapter)
register_llm_model_chat_adapter(GorillaChatAdapter)
register_llm_model_chat_adapter(GPT4AllChatAdapter)
register_llm_model_chat_adapter(Llama2ChatAdapter)
# Proxy model for test and develop, it's cheap for us now.
register_llm_model_chat_adapter(ProxyllmChatAdapter)

View File

@ -77,6 +77,8 @@ class ModelWorker:
def generate_stream_gate(self, params):
try:
# params adaptation
params, model_context = self.llm_chat_adapter.model_adaptation(params)
for output in self.generate_stream_func(
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
):
@ -84,10 +86,8 @@ class ModelWorker:
# The gpt4all thread shares stdout with the parent process,
# and opening it may affect the frontend output.
print("output: ", output)
ret = {
"text": output,
"error_code": 0,
}
# return some model context to dgt-server
ret = {"text": output, "error_code": 0, "model_context": model_context}
yield json.dumps(ret).encode() + b"\0"
except torch.cuda.CudaError: