mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-07 12:00:46 +00:00
feat(cache): Not cache the failed model output
This commit is contained in:
@@ -17,6 +17,8 @@ class PromptRequest(BaseModel):
|
|||||||
temperature: float = None
|
temperature: float = None
|
||||||
max_new_tokens: int = None
|
max_new_tokens: int = None
|
||||||
stop: str = None
|
stop: str = None
|
||||||
|
stop_token_ids: List[int] = []
|
||||||
|
context_len: int = None
|
||||||
echo: bool = True
|
echo: bool = True
|
||||||
span_id: str = None
|
span_id: str = None
|
||||||
|
|
||||||
|
@@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, Iterator, List
|
from typing import Dict, Iterator, List, Optional
|
||||||
|
|
||||||
from pilot.configs.model_config import get_device
|
from pilot.configs.model_config import get_device
|
||||||
from pilot.model.model_adapter import get_llm_model_adapter, LLMModelAdaper
|
from pilot.model.model_adapter import get_llm_model_adapter, LLMModelAdaper
|
||||||
@@ -60,7 +60,7 @@ class DefaultModelWorker(ModelWorker):
|
|||||||
self.ml: ModelLoader = ModelLoader(
|
self.ml: ModelLoader = ModelLoader(
|
||||||
model_path=self.model_path, model_name=self.model_name
|
model_path=self.model_path, model_name=self.model_name
|
||||||
)
|
)
|
||||||
# TODO read context len from model config
|
# Default model context len
|
||||||
self.context_len = 2048
|
self.context_len = 2048
|
||||||
|
|
||||||
def model_param_class(self) -> ModelParameters:
|
def model_param_class(self) -> ModelParameters:
|
||||||
@@ -111,6 +111,12 @@ class DefaultModelWorker(ModelWorker):
|
|||||||
self.model, self.tokenizer = self.ml.loader_with_params(
|
self.model, self.tokenizer = self.ml.loader_with_params(
|
||||||
model_params, self.llm_adapter
|
model_params, self.llm_adapter
|
||||||
)
|
)
|
||||||
|
model_max_length = _parse_model_max_length(self.model, self.tokenizer)
|
||||||
|
if model_max_length:
|
||||||
|
logger.info(
|
||||||
|
f"Parse model max length {model_max_length} from model {self.model_name}."
|
||||||
|
)
|
||||||
|
self.context_len = model_max_length
|
||||||
|
|
||||||
def stop(self) -> None:
|
def stop(self) -> None:
|
||||||
if not self.model:
|
if not self.model:
|
||||||
@@ -138,9 +144,9 @@ class DefaultModelWorker(ModelWorker):
|
|||||||
)
|
)
|
||||||
|
|
||||||
previous_response = ""
|
previous_response = ""
|
||||||
|
context_len = params.get("context_len") or self.context_len
|
||||||
for output in generate_stream_func(
|
for output in generate_stream_func(
|
||||||
self.model, self.tokenizer, params, get_device(), self.context_len
|
self.model, self.tokenizer, params, get_device(), context_len
|
||||||
):
|
):
|
||||||
model_output, incremental_output, output_str = self._handle_output(
|
model_output, incremental_output, output_str = self._handle_output(
|
||||||
output, previous_response, model_context
|
output, previous_response, model_context
|
||||||
@@ -183,9 +189,10 @@ class DefaultModelWorker(ModelWorker):
|
|||||||
)
|
)
|
||||||
|
|
||||||
previous_response = ""
|
previous_response = ""
|
||||||
|
context_len = params.get("context_len") or self.context_len
|
||||||
|
|
||||||
async for output in generate_stream_func(
|
async for output in generate_stream_func(
|
||||||
self.model, self.tokenizer, params, get_device(), self.context_len
|
self.model, self.tokenizer, params, get_device(), context_len
|
||||||
):
|
):
|
||||||
model_output, incremental_output, output_str = self._handle_output(
|
model_output, incremental_output, output_str = self._handle_output(
|
||||||
output, previous_response, model_context
|
output, previous_response, model_context
|
||||||
@@ -279,11 +286,27 @@ class DefaultModelWorker(ModelWorker):
|
|||||||
# Check if the exception is a torch.cuda.CudaError and if torch was imported.
|
# Check if the exception is a torch.cuda.CudaError and if torch was imported.
|
||||||
if _torch_imported and isinstance(e, torch.cuda.CudaError):
|
if _torch_imported and isinstance(e, torch.cuda.CudaError):
|
||||||
model_output = ModelOutput(
|
model_output = ModelOutput(
|
||||||
text="**GPU OutOfMemory, Please Refresh.**", error_code=0
|
text="**GPU OutOfMemory, Please Refresh.**", error_code=1
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
model_output = ModelOutput(
|
model_output = ModelOutput(
|
||||||
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
|
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
|
||||||
error_code=0,
|
error_code=1,
|
||||||
)
|
)
|
||||||
return model_output
|
return model_output
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_model_max_length(model, tokenizer) -> Optional[int]:
|
||||||
|
if not (tokenizer or model):
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
if tokenizer and hasattr(tokenizer, "model_max_length"):
|
||||||
|
return tokenizer.model_max_length
|
||||||
|
if model and hasattr(model, "config"):
|
||||||
|
model_config = model.config
|
||||||
|
if hasattr(model_config, "max_sequence_length"):
|
||||||
|
return model_config.max_sequence_length
|
||||||
|
if hasattr(model_config, "max_position_embeddings"):
|
||||||
|
return model_config.max_position_embeddings
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
@@ -119,7 +119,10 @@ class LocalWorkerManager(WorkerManager):
|
|||||||
_async_heartbeat_sender(self.run_data, 20, self.send_heartbeat_func)
|
_async_heartbeat_sender(self.run_data, 20, self.send_heartbeat_func)
|
||||||
)
|
)
|
||||||
for listener in self.start_listeners:
|
for listener in self.start_listeners:
|
||||||
listener(self)
|
if asyncio.iscoroutinefunction(listener):
|
||||||
|
await listener(self)
|
||||||
|
else:
|
||||||
|
listener(self)
|
||||||
|
|
||||||
async def stop(self, ignore_exception: bool = False):
|
async def stop(self, ignore_exception: bool = False):
|
||||||
if not self.run_data.stop_event.is_set():
|
if not self.run_data.stop_event.is_set():
|
||||||
@@ -325,7 +328,7 @@ class LocalWorkerManager(WorkerManager):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield ModelOutput(
|
yield ModelOutput(
|
||||||
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
|
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
|
||||||
error_code=0,
|
error_code=1,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
async with worker_run_data.semaphore:
|
async with worker_run_data.semaphore:
|
||||||
@@ -355,7 +358,7 @@ class LocalWorkerManager(WorkerManager):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return ModelOutput(
|
return ModelOutput(
|
||||||
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
|
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
|
||||||
error_code=0,
|
error_code=1,
|
||||||
)
|
)
|
||||||
async with worker_run_data.semaphore:
|
async with worker_run_data.semaphore:
|
||||||
if worker_run_data.worker.support_async():
|
if worker_run_data.worker.support_async():
|
||||||
@@ -996,6 +999,7 @@ def run_worker_manager(
|
|||||||
port: int = None,
|
port: int = None,
|
||||||
embedding_model_name: str = None,
|
embedding_model_name: str = None,
|
||||||
embedding_model_path: str = None,
|
embedding_model_path: str = None,
|
||||||
|
start_listener: Callable[["WorkerManager"], None] = None,
|
||||||
):
|
):
|
||||||
global worker_manager
|
global worker_manager
|
||||||
|
|
||||||
@@ -1029,6 +1033,8 @@ def run_worker_manager(
|
|||||||
worker_manager, embedding_model_name, embedding_model_path
|
worker_manager, embedding_model_name, embedding_model_path
|
||||||
)
|
)
|
||||||
|
|
||||||
|
worker_manager.after_start(start_listener)
|
||||||
|
|
||||||
if include_router:
|
if include_router:
|
||||||
app.include_router(router, prefix="/api")
|
app.include_router(router, prefix="/api")
|
||||||
|
|
||||||
|
@@ -15,7 +15,10 @@ class RemoteWorkerManager(LocalWorkerManager):
|
|||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
for listener in self.start_listeners:
|
for listener in self.start_listeners:
|
||||||
listener(self)
|
if asyncio.iscoroutinefunction(listener):
|
||||||
|
await listener(self)
|
||||||
|
else:
|
||||||
|
listener(self)
|
||||||
|
|
||||||
async def stop(self, ignore_exception: bool = False):
|
async def stop(self, ignore_exception: bool = False):
|
||||||
pass
|
pass
|
||||||
|
@@ -170,9 +170,12 @@ class LLMModelAdaper:
|
|||||||
model_context["has_format_prompt"] = True
|
model_context["has_format_prompt"] = True
|
||||||
params["prompt"] = new_prompt
|
params["prompt"] = new_prompt
|
||||||
|
|
||||||
# Overwrite model params:
|
custom_stop = params.get("stop")
|
||||||
params["stop"] = conv.stop_str
|
custom_stop_token_ids = params.get("stop_token_ids")
|
||||||
params["stop_token_ids"] = conv.stop_token_ids
|
|
||||||
|
# Prefer the value passed in from the input parameter
|
||||||
|
params["stop"] = custom_stop or conv.stop_str
|
||||||
|
params["stop_token_ids"] = custom_stop_token_ids or conv.stop_token_ids
|
||||||
|
|
||||||
return params, model_context
|
return params, model_context
|
||||||
|
|
||||||
|
@@ -1,4 +1,4 @@
|
|||||||
from typing import AsyncIterator, Dict, Union
|
from typing import AsyncIterator, Dict, List, Union
|
||||||
import logging
|
import logging
|
||||||
from pilot.awel import (
|
from pilot.awel import (
|
||||||
BranchFunc,
|
BranchFunc,
|
||||||
@@ -227,7 +227,7 @@ class ModelStreamSaveCacheOperator(
|
|||||||
)
|
)
|
||||||
outputs.append(out)
|
outputs.append(out)
|
||||||
yield out
|
yield out
|
||||||
if llm_cache_key:
|
if llm_cache_key and _is_success_model_output(outputs):
|
||||||
llm_cache_value: LLMCacheValue = self._client.new_value(output=outputs)
|
llm_cache_value: LLMCacheValue = self._client.new_value(output=outputs)
|
||||||
await self._client.set(llm_cache_key, llm_cache_value)
|
await self._client.set(llm_cache_key, llm_cache_value)
|
||||||
|
|
||||||
@@ -258,7 +258,7 @@ class ModelSaveCacheOperator(MapOperator[ModelOutput, ModelOutput]):
|
|||||||
_LLM_MODEL_INPUT_VALUE_KEY
|
_LLM_MODEL_INPUT_VALUE_KEY
|
||||||
)
|
)
|
||||||
llm_cache_value: LLMCacheValue = self._client.new_value(output=input_value)
|
llm_cache_value: LLMCacheValue = self._client.new_value(output=input_value)
|
||||||
if llm_cache_key:
|
if llm_cache_key and _is_success_model_output(input_value):
|
||||||
await self._client.set(llm_cache_key, llm_cache_value)
|
await self._client.set(llm_cache_key, llm_cache_value)
|
||||||
return input_value
|
return input_value
|
||||||
|
|
||||||
@@ -284,3 +284,17 @@ def _parse_cache_key_dict(input_value: Dict) -> Dict:
|
|||||||
# TODO pass model_type
|
# TODO pass model_type
|
||||||
"model_type": input_value.get("model_type", "huggingface"),
|
"model_type": input_value.get("model_type", "huggingface"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _is_success_model_output(out: Union[Dict, ModelOutput, List[ModelOutput]]) -> bool:
|
||||||
|
if not out:
|
||||||
|
return False
|
||||||
|
if isinstance(out, list):
|
||||||
|
# check last model output
|
||||||
|
out = out[-1]
|
||||||
|
error_code = 0
|
||||||
|
if isinstance(out, ModelOutput):
|
||||||
|
error_code = out.error_code
|
||||||
|
else:
|
||||||
|
error_code = int(out.get("error_code", 0))
|
||||||
|
return error_code == 0
|
||||||
|
@@ -173,7 +173,7 @@ class BaseChat(ABC):
|
|||||||
"messages": llm_messages,
|
"messages": llm_messages,
|
||||||
"temperature": float(self.prompt_template.temperature),
|
"temperature": float(self.prompt_template.temperature),
|
||||||
"max_new_tokens": int(self.prompt_template.max_new_tokens),
|
"max_new_tokens": int(self.prompt_template.max_new_tokens),
|
||||||
"stop": self.prompt_template.sep,
|
# "stop": self.prompt_template.sep,
|
||||||
"echo": self.llm_echo,
|
"echo": self.llm_echo,
|
||||||
}
|
}
|
||||||
return payload
|
return payload
|
||||||
|
Reference in New Issue
Block a user