From 3d7481d369b955415d847b7e8efd615e3944571d Mon Sep 17 00:00:00 2001 From: FangYin Cheng Date: Fri, 17 Nov 2023 17:04:37 +0800 Subject: [PATCH] feat(model): Usage infos with vLLM --- pilot/configs/model_config.py | 2 ++ pilot/model/llm_out/vllm_llm.py | 23 ++++++++++++++++++++++- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index cedfa8554..fec343f2a 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -2,6 +2,7 @@ # -*- coding:utf-8 -*- import os +from functools import cache ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) MODEL_PATH = os.path.join(ROOT_PATH, "models") @@ -22,6 +23,7 @@ new_directory = PILOT_PATH os.chdir(new_directory) +@cache def get_device() -> str: try: import torch diff --git a/pilot/model/llm_out/vllm_llm.py b/pilot/model/llm_out/vllm_llm.py index 07d43dc74..375e9589a 100644 --- a/pilot/model/llm_out/vllm_llm.py +++ b/pilot/model/llm_out/vllm_llm.py @@ -53,4 +53,25 @@ async def generate_stream( else: text_outputs = [output.text for output in request_output.outputs] text_outputs = " ".join(text_outputs) - yield {"text": text_outputs, "error_code": 0, "usage": {}} + + # Note: usage is not supported yet + prompt_tokens = len(request_output.prompt_token_ids) + completion_tokens = sum( + len(output.token_ids) for output in request_output.outputs + ) + usage = { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + } + finish_reason = ( + request_output.outputs[0].finish_reason + if len(request_output.outputs) == 1 + else [output.finish_reason for output in request_output.outputs] + ) + yield { + "text": text_outputs, + "error_code": 0, + "usage": usage, + "finish_reason": finish_reason, + }