fix logprob, add filtering, temperature annealing, lr descent

This commit is contained in:
YeAnbang
2025-03-21 10:24:24 +08:00
parent 7ee4452f8c
commit 0472f44163
7 changed files with 74 additions and 27 deletions

View File

@@ -53,7 +53,13 @@ class TransformersInferenceBackend(BaseInferenceBackend):
)
FORCE_GENERATE_CONFIG = dict(output_logits=True, return_dict_in_generate=True)
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
def __init__(
self,
model_config: Dict[str, Any],
generate_config: Dict[str, Any],
tokenizer: PreTrainedTokenizer,
num_generations: int = 8,
):
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
model_config.update(self.FORCE_MODEL_CONFIG)
path = model_config.pop("path")
@@ -61,7 +67,7 @@ class TransformersInferenceBackend(BaseInferenceBackend):
self.generate_config = generate_config.copy()
self.generate_config.update(self.FORCE_GENERATE_CONFIG)
self.tokenizer = tokenizer
self.num_generations = 8
self.num_generations = num_generations
@torch.no_grad()
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
@@ -120,7 +126,13 @@ class TransformersInferenceBackend(BaseInferenceBackend):
class SGLangInferenceBackend(BaseInferenceBackend):
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
def __init__(
self,
model_config: Dict[str, Any],
generate_config: Dict[str, Any],
tokenizer: PreTrainedTokenizer,
num_generations: int = 8,
):
if sgl is None:
raise ImportError("sglang is not installed")
path = model_config.pop("path")
@@ -175,10 +187,15 @@ class VLLMInferenceBackend(BaseInferenceBackend):
)
FORCE_GENERATE_CONFIG = dict(
logprobs=0,
n=8,
)
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
def __init__(
self,
model_config: Dict[str, Any],
generate_config: Dict[str, Any],
tokenizer: PreTrainedTokenizer,
num_generations: int = 8,
):
if LLM is None:
raise ImportError("vllm is not installed")
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
@@ -186,9 +203,10 @@ class VLLMInferenceBackend(BaseInferenceBackend):
self.llm = LLM(model=path, **model_config)
generate_config = generate_config.copy()
generate_config.update(self.FORCE_GENERATE_CONFIG)
generate_config.update({"n": num_generations})
self.generate_config = SamplingParams(**generate_config)
self.tokenizer = tokenizer
self.num_generations = self.FORCE_GENERATE_CONFIG["n"]
self.num_generations = num_generations
@torch.no_grad()
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: