refactor: The first refactored version for sdk release (#907)

Co-authored-by: chengfangyin2 <chengfangyin3@jd.com>
This commit is contained in:
FangYin Cheng
2023-12-08 14:45:59 +08:00
committed by GitHub
parent e7e4aff667
commit cd725db1fb
573 changed files with 2094 additions and 3571 deletions

View File

View File

@@ -0,0 +1,103 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
from typing import List
import re
import torch
from dbgpt.app.scene import ModelMessage, _parse_model_messages
# TODO move sep to scene prompt of model
_CHATGLM_SEP = "\n"
_CHATGLM2_SEP = "\n\n"
@torch.inference_mode()
def chatglm_generate_stream(
model, tokenizer, params, device, context_len=2048, stream_interval=2
):
"""Generate text using chatglm model's chat api_v1"""
prompt = params["prompt"]
temperature = float(params.get("temperature", 1.0))
top_p = float(params.get("top_p", 1.0))
stop = params.get("stop", "###")
echo = params.get("echo", False)
generate_kwargs = {
"do_sample": True if temperature > 1e-5 else False,
"top_p": top_p,
"repetition_penalty": 1.0,
"logits_processor": None,
}
if temperature > 1e-5:
generate_kwargs["temperature"] = temperature
# TODO, Fix this
# print(prompt)
# messages = prompt.split(stop)
messages: List[ModelMessage] = params["messages"]
query, system_messages, hist = _parse_model_messages(messages)
system_messages_str = "".join(system_messages)
if not hist:
# No history conversation, but has system messages, merge to user`s query
query = prompt_adaptation(system_messages_str, query)
else:
# history exist, add system message to head of history
hist[0][0] = system_messages_str + _CHATGLM2_SEP + hist[0][0]
print("Query Message: ", query)
print("hist: ", hist)
for i, (response, new_hist) in enumerate(
model.stream_chat(tokenizer, query, hist, **generate_kwargs)
):
if echo:
output = query + " " + response
else:
output = response
yield output
yield output
class HistoryEntry:
def __init__(self, question: str = "", answer: str = ""):
self.question = question
self.answer = answer
def add_question(self, question: str):
self.question += question
def add_answer(self, answer: str):
self.answer += answer
def to_list(self):
if self.question == "" or self.answer == "":
return None
return [self.question, self.answer]
def build_history(hist: List[HistoryEntry]) -> List[List[str]]:
return list(filter(lambda hl: hl is not None, map(lambda h: h.to_list(), hist)))
def prompt_adaptation(system_messages_str: str, human_message: str) -> str:
if not system_messages_str or system_messages_str == "":
return human_message
# TODO Multi-model prompt adaptation
adaptation_rules = [
r"Question:\s*{}\s*", # chat_db scene
r"Goals:\s*{}\s*", # chat_execution
r"问题:\s*{}\s*", # chat_knowledge zh
r"question:\s*{}\s*", # chat_knowledge en
]
# system message has include human question
for rule in adaptation_rules:
pattern = re.compile(rule.format(re.escape(human_message)))
if re.search(pattern, system_messages_str):
return system_messages_str
# https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
return system_messages_str + _CHATGLM2_SEP + human_message

View File

@@ -0,0 +1,53 @@
import torch
from threading import Thread
from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria
def falcon_generate_output(model, tokenizer, params, device, context_len=2048):
"""Fork from: https://github.com/KohakuBlueleaf/guanaco-lora/blob/main/generate.py"""
tokenizer.bos_token_id = 1
print(params)
stop = params.get("stop", "###")
prompt = params["prompt"]
query = prompt
print("Query Message: ", query)
input_ids = tokenizer(query, return_tensors="pt").input_ids
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
)
tokenizer.bos_token_id = 1
stop_token_ids = [0]
class StopOnTokens(StoppingCriteria):
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
) -> bool:
for stop_id in stop_token_ids:
if input_ids[0][-1] == stop_id:
return True
return False
stop = StopOnTokens()
generate_kwargs = dict(
input_ids=input_ids,
max_new_tokens=512,
temperature=1.0,
do_sample=True,
top_k=1,
streamer=streamer,
repetition_penalty=1.7,
stopping_criteria=StoppingCriteriaList([stop]),
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
out = ""
for new_text in streamer:
out += new_text
yield out

View File

@@ -0,0 +1,62 @@
import torch
@torch.inference_mode()
def generate_stream(
model, tokenizer, params, device, context_len=42048, stream_interval=2
):
"""Fork from https://github.com/ShishirPatil/gorilla/blob/main/inference/serve/gorilla_cli.py"""
prompt = params["prompt"]
l_prompt = len(prompt)
max_new_tokens = int(params.get("max_new_tokens", 1024))
stop_str = params.get("stop", None)
input_ids = tokenizer(prompt).input_ids
output_ids = list(input_ids)
input_echo_len = len(input_ids)
max_src_len = context_len - max_new_tokens - 8
input_ids = input_ids[-max_src_len:]
past_key_values = out = None
for i in range(max_new_tokens):
if i == 0:
out = model(torch.as_tensor([input_ids], device=device), use_cache=True)
logits = out.logits
past_key_values = out.past_key_values
else:
out = model(
input_ids=torch.as_tensor([[token]], device=device),
use_cache=True,
past_key_values=past_key_values,
)
logits = out.logits
past_key_values = out.past_key_values
last_token_logits = logits[0][-1]
probs = torch.softmax(last_token_logits, dim=-1)
token = int(torch.multinomial(probs, num_samples=1))
output_ids.append(token)
if token == tokenizer.eos_token_id:
stopped = True
else:
stopped = False
if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped:
tmp_output_ids = output_ids[input_echo_len:]
output = tokenizer.decode(
tmp_output_ids,
skip_special_tokens=True,
spaces_between_special_tokens=False,
)
pos = output.rfind(stop_str, l_prompt)
if pos != -1:
output = output[:pos]
stopped = True
yield output
if stopped:
break
del past_key_values

View File

@@ -0,0 +1,10 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
def gpt4all_generate_stream(model, tokenizer, params, device, max_position_embeddings):
stop = params.get("stop", "###")
prompt = params["prompt"]
role, query = prompt.split(stop)[0].split(":")
print(f"gpt4all, role: {role}, query: {query}")
yield model.generate(prompt=query, streaming=True)

View File

@@ -0,0 +1,109 @@
import torch
from threading import Thread
from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria
def guanaco_generate_output(model, tokenizer, params, device, context_len=2048):
"""Fork from: https://github.com/KohakuBlueleaf/guanaco-lora/blob/main/generate.py"""
print(params)
stop = params.get("stop", "###")
prompt = params["prompt"]
query = prompt
print("Query Message: ", query)
input_ids = tokenizer(query, return_tensors="pt").input_ids
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
)
stop_token_ids = [0]
class StopOnTokens(StoppingCriteria):
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
) -> bool:
for stop_id in stop_token_ids:
if input_ids[0][-1] == stop_id:
return True
return False
stop = StopOnTokens()
generate_kwargs = dict(
input_ids=input_ids,
max_new_tokens=512,
temperature=1.0,
do_sample=True,
top_k=1,
streamer=streamer,
repetition_penalty=1.7,
stopping_criteria=StoppingCriteriaList([stop]),
)
t1 = Thread(target=model.generate, kwargs=generate_kwargs)
t1.start()
generator = model.generate(**generate_kwargs)
for output in generator:
# new_tokens = len(output) - len(input_ids[0])
decoded_output = tokenizer.decode(output)
if output[-1] in [tokenizer.eos_token_id]:
break
out = decoded_output.split("### Response:")[-1].strip()
yield out
def guanaco_generate_stream(model, tokenizer, params, device, context_len=2048):
"""Fork from: https://github.com/KohakuBlueleaf/guanaco-lora/blob/main/generate.py"""
tokenizer.bos_token_id = 1
print(params)
stop = params.get("stop", "###")
prompt = params["prompt"]
max_new_tokens = params.get("max_new_tokens", 512)
temerature = params.get("temperature", 1.0)
query = prompt
print("Query Message: ", query)
input_ids = tokenizer(query, return_tensors="pt").input_ids
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
)
tokenizer.bos_token_id = 1
stop_token_ids = [0]
class StopOnTokens(StoppingCriteria):
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
) -> bool:
for stop_id in stop_token_ids:
if input_ids[-1][-1] == stop_id:
return True
return False
stop = StopOnTokens()
generate_kwargs = dict(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
temperature=temerature,
do_sample=True,
top_k=1,
streamer=streamer,
repetition_penalty=1.7,
stopping_criteria=StoppingCriteriaList([stop]),
)
model.generate(**generate_kwargs)
out = ""
for new_text in streamer:
out += new_text
yield out

View File

@@ -0,0 +1,53 @@
import logging
import torch
from threading import Thread
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
logger = logging.getLogger(__name__)
@torch.inference_mode()
def huggingface_chat_generate_stream(
model: AutoModelForCausalLM,
tokenizer: AutoTokenizer,
params,
device,
context_len=4096,
):
prompt = params["prompt"]
temperature = float(params.get("temperature", 0.7))
top_p = float(params.get("top_p", 1.0))
echo = params.get("echo", False)
max_new_tokens = int(params.get("max_new_tokens", 2048))
input_ids = tokenizer(prompt).input_ids
# input_ids = input_ids.to(device)
if model.config.is_encoder_decoder:
max_src_len = context_len
else: # truncate
max_src_len = context_len - max_new_tokens - 1
input_ids = input_ids[-max_src_len:]
input_echo_len = len(input_ids)
input_ids = torch.as_tensor([input_ids], device=device)
# messages = params["messages"]
# messages = ModelMessage.to_openai_messages(messages)
# input_ids = tokenizer.apply_chat_template(conversation=messages, tokenize=True, add_generation_prompt=True, return_tensors='pt')
# input_ids = input_ids.to(device)
streamer = TextIteratorStreamer(
tokenizer, skip_prompt=not echo, skip_special_tokens=True
)
generate_kwargs = {
"input_ids": input_ids,
"max_length": context_len,
"temperature": temperature,
"streamer": streamer,
}
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
out = ""
for new_text in streamer:
out += new_text
yield out

View File

@@ -0,0 +1,8 @@
from typing import Dict
import torch
@torch.inference_mode()
def generate_stream(model, tokenizer, params: Dict, device: str, context_len: int):
# Just support LlamaCppModel
return model.generate_streaming(params=params, context_len=context_len)

View File

@@ -0,0 +1,38 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import time
from dbgpt.model.proxy.llms.chatgpt import chatgpt_generate_stream
from dbgpt.model.proxy.llms.bard import bard_generate_stream
from dbgpt.model.proxy.llms.claude import claude_generate_stream
from dbgpt.model.proxy.llms.wenxin import wenxin_generate_stream
from dbgpt.model.proxy.llms.tongyi import tongyi_generate_stream
from dbgpt.model.proxy.llms.zhipu import zhipu_generate_stream
from dbgpt.model.proxy.llms.baichuan import baichuan_generate_stream
from dbgpt.model.proxy.llms.spark import spark_generate_stream
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
def proxyllm_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
generator_mapping = {
"proxyllm": chatgpt_generate_stream,
"chatgpt_proxyllm": chatgpt_generate_stream,
"bard_proxyllm": bard_generate_stream,
"claude_proxyllm": claude_generate_stream,
# "gpt4_proxyllm": gpt4_generate_stream, move to chatgpt_generate_stream
"wenxin_proxyllm": wenxin_generate_stream,
"tongyi_proxyllm": tongyi_generate_stream,
"zhipu_proxyllm": zhipu_generate_stream,
"bc_proxyllm": baichuan_generate_stream,
"spark_proxyllm": spark_generate_stream,
}
model_params = model.get_params()
model_name = model_params.model_name
default_error_message = f"{model_name} LLM is not supported"
generator_function = generator_mapping.get(
model_name, lambda: default_error_message
)
yield from generator_function(model, tokenizer, params, device, context_len)

View File

@@ -0,0 +1,233 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import torch
@torch.inference_mode()
def generate_stream(
model, tokenizer, params, device, context_len=4096, stream_interval=2
):
"""Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py"""
prompt = params["prompt"]
l_prompt = len(prompt)
prompt = prompt.replace("ai:", "assistant:").replace("human:", "user:")
temperature = float(params.get("temperature", 1.0))
max_new_tokens = int(params.get("max_new_tokens", 2048))
stop_str = params.get("stop", None)
input_ids = tokenizer(prompt).input_ids
output_ids = list(input_ids)
max_src_len = context_len - max_new_tokens - 8
input_ids = input_ids[-max_src_len:]
for i in range(max_new_tokens):
if i == 0:
out = model(torch.as_tensor([input_ids], device=device), use_cache=True)
logits = out.logits
past_key_values = out.past_key_values
else:
attention_mask = torch.ones(
1, past_key_values[0][0].shape[-2] + 1, device=device
)
out = model(
input_ids=torch.as_tensor([[token]], device=device),
use_cache=True,
attention_mask=attention_mask,
past_key_values=past_key_values,
)
logits = out.logits
past_key_values = out.past_key_values
last_token_logits = logits[0][-1]
if device == "mps":
# Switch to CPU by avoiding some bugs in mps backend.
last_token_logits = last_token_logits.float().to("cpu")
if temperature < 1e-4:
token = int(torch.argmax(last_token_logits))
else:
probs = torch.softmax(last_token_logits / temperature, dim=-1)
token = int(torch.multinomial(probs, num_samples=1))
output_ids.append(token)
if token == tokenizer.eos_token_id:
stopped = True
else:
stopped = False
if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped:
output = tokenizer.decode(output_ids, skip_special_tokens=True)
pos = output.rfind(stop_str, l_prompt)
if pos != -1:
output = output[:pos]
stopped = True
yield output
if stopped:
break
del past_key_values
@torch.inference_mode()
def generate_output(
model, tokenizer, params, device, context_len=4096, stream_interval=2
):
"""Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py"""
prompt = params["prompt"]
l_prompt = len(prompt)
temperature = float(params.get("temperature", 1.0))
max_new_tokens = int(params.get("max_new_tokens", 2048))
stop_str = params.get("stop", None)
input_ids = tokenizer(prompt).input_ids
output_ids = list(input_ids)
max_src_len = context_len - max_new_tokens - 8
input_ids = input_ids[-max_src_len:]
for i in range(max_new_tokens):
if i == 0:
out = model(torch.as_tensor([input_ids], device=device), use_cache=True)
logits = out.logits
past_key_values = out.past_key_values
else:
attention_mask = torch.ones(
1, past_key_values[0][0].shape[-2] + 1, device=device
)
out = model(
input_ids=torch.as_tensor([[token]], device=device),
use_cache=True,
attention_mask=attention_mask,
past_key_values=past_key_values,
)
logits = out.logits
past_key_values = out.past_key_values
last_token_logits = logits[0][-1]
if device == "mps":
# Switch to CPU by avoiding some bugs in mps backend.
last_token_logits = last_token_logits.float().to("cpu")
if temperature < 1e-4:
token = int(torch.argmax(last_token_logits))
else:
probs = torch.softmax(last_token_logits / temperature, dim=-1)
token = int(torch.multinomial(probs, num_samples=1))
output_ids.append(token)
if token == tokenizer.eos_token_id:
stopped = True
else:
stopped = False
if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped:
output = tokenizer.decode(output_ids, skip_special_tokens=True)
pos = output.rfind(stop_str, l_prompt)
if pos != -1:
output = output[:pos]
stopped = True
return output
if stopped:
break
del past_key_values
@torch.inference_mode()
def generate_output_ex(
model, tokenizer, params, device, context_len=2048, stream_interval=2
):
prompt = params["prompt"]
temperature = float(params.get("temperature", 1.0))
max_new_tokens = int(params.get("max_new_tokens", 2048))
stop_parameter = params.get("stop", None)
if stop_parameter == tokenizer.eos_token:
stop_parameter = None
stop_strings = []
if isinstance(stop_parameter, str):
stop_strings.append(stop_parameter)
elif isinstance(stop_parameter, list):
stop_strings = stop_parameter
elif stop_parameter is None:
pass
else:
raise TypeError("Stop parameter must be string or list of strings.")
input_ids = tokenizer(prompt).input_ids
output_ids = []
max_src_len = context_len - max_new_tokens - 8
input_ids = input_ids[-max_src_len:]
stop_word = None
for i in range(max_new_tokens):
if i == 0:
out = model(torch.as_tensor([input_ids], device=device), use_cache=True)
logits = out.logits
past_key_values = out.past_key_values
else:
out = model(
input_ids=torch.as_tensor([[token]], device=device),
use_cache=True,
past_key_values=past_key_values,
)
logits = out.logits
past_key_values = out.past_key_values
last_token_logits = logits[0][-1]
if temperature < 1e-4:
token = int(torch.argmax(last_token_logits))
else:
probs = torch.softmax(last_token_logits / temperature, dim=-1)
token = int(torch.multinomial(probs, num_samples=1))
output_ids.append(token)
if token == tokenizer.eos_token_id:
stopped = True
else:
stopped = False
output = tokenizer.decode(output_ids, skip_special_tokens=True)
# print("Partial output:", output)
for stop_str in stop_strings:
# print(f"Looking for '{stop_str}' in '{output[:l_prompt]}'#END")
pos = output.rfind(stop_str)
if pos != -1:
# print("Found stop str: ", output)
output = output[:pos]
# print("Trimmed output: ", output)
stopped = True
stop_word = stop_str
break
else:
pass
# print("Not found")
if stopped:
break
del past_key_values
if pos != -1:
return output[:pos]
return output
@torch.inference_mode()
def get_embeddings(model, tokenizer, prompt):
input_ids = tokenizer(prompt).input_ids
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
input_embeddings = model.get_input_embeddings().to(device)
embeddings = input_embeddings(torch.LongTensor([input_ids]).to(device))
mean = torch.mean(embeddings[0], 0).cpu().detach()
return mean.to(device)

View File

@@ -0,0 +1,93 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import json
from typing import Any, List, Mapping, Optional
from urllib.parse import urljoin
import requests
from langchain.embeddings.base import Embeddings
from langchain.llms.base import LLM
from dbgpt._private.pydantic import BaseModel
from dbgpt._private.config import Config
CFG = Config()
class VicunaLLM(LLM):
vicuna_generate_path = "generate_stream"
def _call(
self,
prompt: str,
temperature: float,
max_new_tokens: int,
stop: Optional[List[str]] = None,
) -> str:
params = {
"prompt": prompt,
"temperature": temperature,
"max_new_tokens": max_new_tokens,
"stop": stop,
}
response = requests.post(
url=urljoin(CFG.MODEL_SERVER, self.vicuna_generate_path),
data=json.dumps(params),
)
skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("</s>") * 3
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
if data["error_code"] == 0:
output = data["text"][skip_echo_len:].strip()
yield output
@property
def _llm_type(self) -> str:
return "custome"
def _identifying_params(self) -> Mapping[str, Any]:
return {}
class VicunaEmbeddingLLM(BaseModel, Embeddings):
vicuna_embedding_path = "embedding"
def _call(self, prompt: str) -> str:
p = prompt.strip()
print("Sending prompt ", p)
response = requests.post(
url=urljoin(CFG.MODEL_SERVER, self.vicuna_embedding_path),
json={"prompt": p},
)
response.raise_for_status()
return response.json()["response"]
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Call out to Vicuna's server embedding endpoint for embedding search docs.
Args:
texts: The list of text to embed
Returns:
List of embeddings. one for each text.
"""
results = []
for text in texts:
response = self.embed_query(text)
results.append(response)
return results
def embed_query(self, text: str) -> List[float]:
"""Call out to Vicuna's server embedding endpoint for embedding query text.
Args:
text: The text to embed.
Returns:
Embedding for the text
"""
embedding = self._call(text)
return embedding

View File

@@ -0,0 +1,95 @@
from typing import Dict
import os
from vllm import AsyncLLMEngine
from vllm.utils import random_uuid
from vllm.sampling_params import SamplingParams
_IS_BENCHMARK = os.getenv("DB_GPT_MODEL_BENCHMARK", "False").lower() == "true"
async def generate_stream(
model: AsyncLLMEngine, tokenizer, params: Dict, device: str, context_len: int
):
"""
Adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/vllm_worker.py
"""
prompt = params["prompt"]
request_id = params.pop("request_id") if "request_id" in params else random_uuid()
temperature = float(params.get("temperature", 1.0))
top_p = float(params.get("top_p", 1.0))
max_new_tokens = int(params.get("max_new_tokens", 2048))
echo = bool(params.get("echo", True))
stop_str = params.get("stop", None)
stop_token_ids = params.get("stop_token_ids", None) or []
if tokenizer.eos_token_id is not None:
stop_token_ids.append(tokenizer.eos_token_id)
# Handle stop_str
stop = set()
if isinstance(stop_str, str) and stop_str != "":
stop.add(stop_str)
elif isinstance(stop_str, list) and stop_str != []:
stop.update(stop_str)
for tid in stop_token_ids:
if tid is not None:
stop.add(tokenizer.decode(tid))
# make sampling params in vllm
top_p = max(top_p, 1e-5)
if temperature <= 1e-5:
top_p = 1.0
gen_params = {
"stop": list(stop),
"ignore_eos": False,
}
prompt_token_ids = None
if _IS_BENCHMARK:
gen_params["stop"] = []
gen_params["ignore_eos"] = True
prompt_len = context_len - max_new_tokens - 2
prompt_token_ids = tokenizer([prompt]).input_ids[0]
prompt_token_ids = prompt_token_ids[-prompt_len:]
sampling_params = SamplingParams(
n=1,
temperature=temperature,
top_p=top_p,
use_beam_search=False,
max_tokens=max_new_tokens,
**gen_params
)
results_generator = model.generate(
prompt, sampling_params, request_id, prompt_token_ids=prompt_token_ids
)
async for request_output in results_generator:
prompt = request_output.prompt
if echo:
text_outputs = [prompt + output.text for output in request_output.outputs]
else:
text_outputs = [output.text for output in request_output.outputs]
text_outputs = " ".join(text_outputs)
# Note: usage is not supported yet
prompt_tokens = len(request_output.prompt_token_ids)
completion_tokens = sum(
len(output.token_ids) for output in request_output.outputs
)
usage = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
}
finish_reason = (
request_output.outputs[0].finish_reason
if len(request_output.outputs) == 1
else [output.finish_reason for output in request_output.outputs]
)
yield {
"text": text_outputs,
"error_code": 0,
"usage": usage,
"finish_reason": finish_reason,
}