mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-01 09:06:55 +00:00
feat(cache): Not cache the failed model output
This commit is contained in:
@@ -17,6 +17,8 @@ class PromptRequest(BaseModel):
|
||||
temperature: float = None
|
||||
max_new_tokens: int = None
|
||||
stop: str = None
|
||||
stop_token_ids: List[int] = []
|
||||
context_len: int = None
|
||||
echo: bool = True
|
||||
span_id: str = None
|
||||
|
||||
|
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict, Iterator, List
|
||||
from typing import Dict, Iterator, List, Optional
|
||||
|
||||
from pilot.configs.model_config import get_device
|
||||
from pilot.model.model_adapter import get_llm_model_adapter, LLMModelAdaper
|
||||
@@ -60,7 +60,7 @@ class DefaultModelWorker(ModelWorker):
|
||||
self.ml: ModelLoader = ModelLoader(
|
||||
model_path=self.model_path, model_name=self.model_name
|
||||
)
|
||||
# TODO read context len from model config
|
||||
# Default model context len
|
||||
self.context_len = 2048
|
||||
|
||||
def model_param_class(self) -> ModelParameters:
|
||||
@@ -111,6 +111,12 @@ class DefaultModelWorker(ModelWorker):
|
||||
self.model, self.tokenizer = self.ml.loader_with_params(
|
||||
model_params, self.llm_adapter
|
||||
)
|
||||
model_max_length = _parse_model_max_length(self.model, self.tokenizer)
|
||||
if model_max_length:
|
||||
logger.info(
|
||||
f"Parse model max length {model_max_length} from model {self.model_name}."
|
||||
)
|
||||
self.context_len = model_max_length
|
||||
|
||||
def stop(self) -> None:
|
||||
if not self.model:
|
||||
@@ -138,9 +144,9 @@ class DefaultModelWorker(ModelWorker):
|
||||
)
|
||||
|
||||
previous_response = ""
|
||||
|
||||
context_len = params.get("context_len") or self.context_len
|
||||
for output in generate_stream_func(
|
||||
self.model, self.tokenizer, params, get_device(), self.context_len
|
||||
self.model, self.tokenizer, params, get_device(), context_len
|
||||
):
|
||||
model_output, incremental_output, output_str = self._handle_output(
|
||||
output, previous_response, model_context
|
||||
@@ -183,9 +189,10 @@ class DefaultModelWorker(ModelWorker):
|
||||
)
|
||||
|
||||
previous_response = ""
|
||||
context_len = params.get("context_len") or self.context_len
|
||||
|
||||
async for output in generate_stream_func(
|
||||
self.model, self.tokenizer, params, get_device(), self.context_len
|
||||
self.model, self.tokenizer, params, get_device(), context_len
|
||||
):
|
||||
model_output, incremental_output, output_str = self._handle_output(
|
||||
output, previous_response, model_context
|
||||
@@ -279,11 +286,27 @@ class DefaultModelWorker(ModelWorker):
|
||||
# Check if the exception is a torch.cuda.CudaError and if torch was imported.
|
||||
if _torch_imported and isinstance(e, torch.cuda.CudaError):
|
||||
model_output = ModelOutput(
|
||||
text="**GPU OutOfMemory, Please Refresh.**", error_code=0
|
||||
text="**GPU OutOfMemory, Please Refresh.**", error_code=1
|
||||
)
|
||||
else:
|
||||
model_output = ModelOutput(
|
||||
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
|
||||
error_code=0,
|
||||
error_code=1,
|
||||
)
|
||||
return model_output
|
||||
|
||||
|
||||
def _parse_model_max_length(model, tokenizer) -> Optional[int]:
|
||||
if not (tokenizer or model):
|
||||
return None
|
||||
try:
|
||||
if tokenizer and hasattr(tokenizer, "model_max_length"):
|
||||
return tokenizer.model_max_length
|
||||
if model and hasattr(model, "config"):
|
||||
model_config = model.config
|
||||
if hasattr(model_config, "max_sequence_length"):
|
||||
return model_config.max_sequence_length
|
||||
if hasattr(model_config, "max_position_embeddings"):
|
||||
return model_config.max_position_embeddings
|
||||
except Exception:
|
||||
return None
|
||||
|
@@ -119,7 +119,10 @@ class LocalWorkerManager(WorkerManager):
|
||||
_async_heartbeat_sender(self.run_data, 20, self.send_heartbeat_func)
|
||||
)
|
||||
for listener in self.start_listeners:
|
||||
listener(self)
|
||||
if asyncio.iscoroutinefunction(listener):
|
||||
await listener(self)
|
||||
else:
|
||||
listener(self)
|
||||
|
||||
async def stop(self, ignore_exception: bool = False):
|
||||
if not self.run_data.stop_event.is_set():
|
||||
@@ -325,7 +328,7 @@ class LocalWorkerManager(WorkerManager):
|
||||
except Exception as e:
|
||||
yield ModelOutput(
|
||||
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
|
||||
error_code=0,
|
||||
error_code=1,
|
||||
)
|
||||
return
|
||||
async with worker_run_data.semaphore:
|
||||
@@ -355,7 +358,7 @@ class LocalWorkerManager(WorkerManager):
|
||||
except Exception as e:
|
||||
return ModelOutput(
|
||||
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
|
||||
error_code=0,
|
||||
error_code=1,
|
||||
)
|
||||
async with worker_run_data.semaphore:
|
||||
if worker_run_data.worker.support_async():
|
||||
@@ -996,6 +999,7 @@ def run_worker_manager(
|
||||
port: int = None,
|
||||
embedding_model_name: str = None,
|
||||
embedding_model_path: str = None,
|
||||
start_listener: Callable[["WorkerManager"], None] = None,
|
||||
):
|
||||
global worker_manager
|
||||
|
||||
@@ -1029,6 +1033,8 @@ def run_worker_manager(
|
||||
worker_manager, embedding_model_name, embedding_model_path
|
||||
)
|
||||
|
||||
worker_manager.after_start(start_listener)
|
||||
|
||||
if include_router:
|
||||
app.include_router(router, prefix="/api")
|
||||
|
||||
|
@@ -15,7 +15,10 @@ class RemoteWorkerManager(LocalWorkerManager):
|
||||
|
||||
async def start(self):
|
||||
for listener in self.start_listeners:
|
||||
listener(self)
|
||||
if asyncio.iscoroutinefunction(listener):
|
||||
await listener(self)
|
||||
else:
|
||||
listener(self)
|
||||
|
||||
async def stop(self, ignore_exception: bool = False):
|
||||
pass
|
||||
|
@@ -170,9 +170,12 @@ class LLMModelAdaper:
|
||||
model_context["has_format_prompt"] = True
|
||||
params["prompt"] = new_prompt
|
||||
|
||||
# Overwrite model params:
|
||||
params["stop"] = conv.stop_str
|
||||
params["stop_token_ids"] = conv.stop_token_ids
|
||||
custom_stop = params.get("stop")
|
||||
custom_stop_token_ids = params.get("stop_token_ids")
|
||||
|
||||
# Prefer the value passed in from the input parameter
|
||||
params["stop"] = custom_stop or conv.stop_str
|
||||
params["stop_token_ids"] = custom_stop_token_ids or conv.stop_token_ids
|
||||
|
||||
return params, model_context
|
||||
|
||||
|
@@ -1,4 +1,4 @@
|
||||
from typing import AsyncIterator, Dict, Union
|
||||
from typing import AsyncIterator, Dict, List, Union
|
||||
import logging
|
||||
from pilot.awel import (
|
||||
BranchFunc,
|
||||
@@ -227,7 +227,7 @@ class ModelStreamSaveCacheOperator(
|
||||
)
|
||||
outputs.append(out)
|
||||
yield out
|
||||
if llm_cache_key:
|
||||
if llm_cache_key and _is_success_model_output(outputs):
|
||||
llm_cache_value: LLMCacheValue = self._client.new_value(output=outputs)
|
||||
await self._client.set(llm_cache_key, llm_cache_value)
|
||||
|
||||
@@ -258,7 +258,7 @@ class ModelSaveCacheOperator(MapOperator[ModelOutput, ModelOutput]):
|
||||
_LLM_MODEL_INPUT_VALUE_KEY
|
||||
)
|
||||
llm_cache_value: LLMCacheValue = self._client.new_value(output=input_value)
|
||||
if llm_cache_key:
|
||||
if llm_cache_key and _is_success_model_output(input_value):
|
||||
await self._client.set(llm_cache_key, llm_cache_value)
|
||||
return input_value
|
||||
|
||||
@@ -284,3 +284,17 @@ def _parse_cache_key_dict(input_value: Dict) -> Dict:
|
||||
# TODO pass model_type
|
||||
"model_type": input_value.get("model_type", "huggingface"),
|
||||
}
|
||||
|
||||
|
||||
def _is_success_model_output(out: Union[Dict, ModelOutput, List[ModelOutput]]) -> bool:
|
||||
if not out:
|
||||
return False
|
||||
if isinstance(out, list):
|
||||
# check last model output
|
||||
out = out[-1]
|
||||
error_code = 0
|
||||
if isinstance(out, ModelOutput):
|
||||
error_code = out.error_code
|
||||
else:
|
||||
error_code = int(out.get("error_code", 0))
|
||||
return error_code == 0
|
||||
|
@@ -173,7 +173,7 @@ class BaseChat(ABC):
|
||||
"messages": llm_messages,
|
||||
"temperature": float(self.prompt_template.temperature),
|
||||
"max_new_tokens": int(self.prompt_template.max_new_tokens),
|
||||
"stop": self.prompt_template.sep,
|
||||
# "stop": self.prompt_template.sep,
|
||||
"echo": self.llm_echo,
|
||||
}
|
||||
return payload
|
||||
|
Reference in New Issue
Block a user