fix racing condition

This commit is contained in:
YeAnbang
2025-07-21 17:21:07 +08:00
parent ddda79c36f
commit 2336d7f6d6
10 changed files with 100 additions and 33 deletions

View File

@@ -59,6 +59,7 @@ class TransformersInferenceBackend(BaseInferenceBackend):
generate_config: Dict[str, Any],
tokenizer: PreTrainedTokenizer,
num_generations: int = 8,
tokenizer_config: Dict[str, Any] = None,
):
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
model_config.update(self.FORCE_MODEL_CONFIG)
@@ -132,6 +133,7 @@ class SGLangInferenceBackend(BaseInferenceBackend):
generate_config: Dict[str, Any],
tokenizer: PreTrainedTokenizer,
num_generations: int = 8,
tokenizer_config: Dict[str, Any] = None,
):
if sgl is None:
raise ImportError("sglang is not installed")
@@ -196,12 +198,14 @@ class VLLMInferenceBackend(BaseInferenceBackend):
generate_config: Dict[str, Any],
tokenizer: PreTrainedTokenizer,
num_generations: int = 8,
tokenizer_config: Dict[str, Any] = None,
):
if LLM is None:
raise ImportError("vllm is not installed")
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
path = model_config.pop("path")
self.llm = LLM(model=path, **model_config)
tokenizer_path = tokenizer_config.get("path", None) if tokenizer_config is not None else None
self.llm = LLM(model=path, tokenizer=tokenizer_path, **model_config)
generate_config = generate_config.copy()
generate_config.update(self.FORCE_GENERATE_CONFIG)
generate_config.update({"n": num_generations})