feat(cache): Not cache the failed model output

This commit is contained in:
FangYin Cheng
2023-11-17 11:06:36 +08:00
parent 995772077c
commit 5109c89daa
7 changed files with 69 additions and 18 deletions

View File

@@ -17,6 +17,8 @@ class PromptRequest(BaseModel):
temperature: float = None
max_new_tokens: int = None
stop: str = None
stop_token_ids: List[int] = []
context_len: int = None
echo: bool = True
span_id: str = None

View File

@@ -1,6 +1,6 @@
import os
import logging
from typing import Dict, Iterator, List
from typing import Dict, Iterator, List, Optional
from pilot.configs.model_config import get_device
from pilot.model.model_adapter import get_llm_model_adapter, LLMModelAdaper
@@ -60,7 +60,7 @@ class DefaultModelWorker(ModelWorker):
self.ml: ModelLoader = ModelLoader(
model_path=self.model_path, model_name=self.model_name
)
# TODO read context len from model config
# Default model context len
self.context_len = 2048
def model_param_class(self) -> ModelParameters:
@@ -111,6 +111,12 @@ class DefaultModelWorker(ModelWorker):
self.model, self.tokenizer = self.ml.loader_with_params(
model_params, self.llm_adapter
)
model_max_length = _parse_model_max_length(self.model, self.tokenizer)
if model_max_length:
logger.info(
f"Parse model max length {model_max_length} from model {self.model_name}."
)
self.context_len = model_max_length
def stop(self) -> None:
if not self.model:
@@ -138,9 +144,9 @@ class DefaultModelWorker(ModelWorker):
)
previous_response = ""
context_len = params.get("context_len") or self.context_len
for output in generate_stream_func(
self.model, self.tokenizer, params, get_device(), self.context_len
self.model, self.tokenizer, params, get_device(), context_len
):
model_output, incremental_output, output_str = self._handle_output(
output, previous_response, model_context
@@ -183,9 +189,10 @@ class DefaultModelWorker(ModelWorker):
)
previous_response = ""
context_len = params.get("context_len") or self.context_len
async for output in generate_stream_func(
self.model, self.tokenizer, params, get_device(), self.context_len
self.model, self.tokenizer, params, get_device(), context_len
):
model_output, incremental_output, output_str = self._handle_output(
output, previous_response, model_context
@@ -279,11 +286,27 @@ class DefaultModelWorker(ModelWorker):
# Check if the exception is a torch.cuda.CudaError and if torch was imported.
if _torch_imported and isinstance(e, torch.cuda.CudaError):
model_output = ModelOutput(
text="**GPU OutOfMemory, Please Refresh.**", error_code=0
text="**GPU OutOfMemory, Please Refresh.**", error_code=1
)
else:
model_output = ModelOutput(
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
error_code=0,
error_code=1,
)
return model_output
def _parse_model_max_length(model, tokenizer) -> Optional[int]:
if not (tokenizer or model):
return None
try:
if tokenizer and hasattr(tokenizer, "model_max_length"):
return tokenizer.model_max_length
if model and hasattr(model, "config"):
model_config = model.config
if hasattr(model_config, "max_sequence_length"):
return model_config.max_sequence_length
if hasattr(model_config, "max_position_embeddings"):
return model_config.max_position_embeddings
except Exception:
return None

View File

@@ -119,7 +119,10 @@ class LocalWorkerManager(WorkerManager):
_async_heartbeat_sender(self.run_data, 20, self.send_heartbeat_func)
)
for listener in self.start_listeners:
listener(self)
if asyncio.iscoroutinefunction(listener):
await listener(self)
else:
listener(self)
async def stop(self, ignore_exception: bool = False):
if not self.run_data.stop_event.is_set():
@@ -325,7 +328,7 @@ class LocalWorkerManager(WorkerManager):
except Exception as e:
yield ModelOutput(
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
error_code=0,
error_code=1,
)
return
async with worker_run_data.semaphore:
@@ -355,7 +358,7 @@ class LocalWorkerManager(WorkerManager):
except Exception as e:
return ModelOutput(
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
error_code=0,
error_code=1,
)
async with worker_run_data.semaphore:
if worker_run_data.worker.support_async():
@@ -996,6 +999,7 @@ def run_worker_manager(
port: int = None,
embedding_model_name: str = None,
embedding_model_path: str = None,
start_listener: Callable[["WorkerManager"], None] = None,
):
global worker_manager
@@ -1029,6 +1033,8 @@ def run_worker_manager(
worker_manager, embedding_model_name, embedding_model_path
)
worker_manager.after_start(start_listener)
if include_router:
app.include_router(router, prefix="/api")

View File

@@ -15,7 +15,10 @@ class RemoteWorkerManager(LocalWorkerManager):
async def start(self):
for listener in self.start_listeners:
listener(self)
if asyncio.iscoroutinefunction(listener):
await listener(self)
else:
listener(self)
async def stop(self, ignore_exception: bool = False):
pass

View File

@@ -170,9 +170,12 @@ class LLMModelAdaper:
model_context["has_format_prompt"] = True
params["prompt"] = new_prompt
# Overwrite model params:
params["stop"] = conv.stop_str
params["stop_token_ids"] = conv.stop_token_ids
custom_stop = params.get("stop")
custom_stop_token_ids = params.get("stop_token_ids")
# Prefer the value passed in from the input parameter
params["stop"] = custom_stop or conv.stop_str
params["stop_token_ids"] = custom_stop_token_ids or conv.stop_token_ids
return params, model_context

View File

@@ -1,4 +1,4 @@
from typing import AsyncIterator, Dict, Union
from typing import AsyncIterator, Dict, List, Union
import logging
from pilot.awel import (
BranchFunc,
@@ -227,7 +227,7 @@ class ModelStreamSaveCacheOperator(
)
outputs.append(out)
yield out
if llm_cache_key:
if llm_cache_key and _is_success_model_output(outputs):
llm_cache_value: LLMCacheValue = self._client.new_value(output=outputs)
await self._client.set(llm_cache_key, llm_cache_value)
@@ -258,7 +258,7 @@ class ModelSaveCacheOperator(MapOperator[ModelOutput, ModelOutput]):
_LLM_MODEL_INPUT_VALUE_KEY
)
llm_cache_value: LLMCacheValue = self._client.new_value(output=input_value)
if llm_cache_key:
if llm_cache_key and _is_success_model_output(input_value):
await self._client.set(llm_cache_key, llm_cache_value)
return input_value
@@ -284,3 +284,17 @@ def _parse_cache_key_dict(input_value: Dict) -> Dict:
# TODO pass model_type
"model_type": input_value.get("model_type", "huggingface"),
}
def _is_success_model_output(out: Union[Dict, ModelOutput, List[ModelOutput]]) -> bool:
if not out:
return False
if isinstance(out, list):
# check last model output
out = out[-1]
error_code = 0
if isinstance(out, ModelOutput):
error_code = out.error_code
else:
error_code = int(out.get("error_code", 0))
return error_code == 0

View File

@@ -173,7 +173,7 @@ class BaseChat(ABC):
"messages": llm_messages,
"temperature": float(self.prompt_template.temperature),
"max_new_tokens": int(self.prompt_template.max_new_tokens),
"stop": self.prompt_template.sep,
# "stop": self.prompt_template.sep,
"echo": self.llm_echo,
}
return payload