spot a possible bug

This commit is contained in:
YeAnbang 2025-05-05 18:48:42 +08:00
parent 6fff36dd63
commit 4d18e7d772
3 changed files with 49 additions and 12 deletions

View File

@ -80,8 +80,9 @@ class TransformersInferenceBackend(BaseInferenceBackend):
if self.num_generations > 1:
input_ids = input_ids.repeat_interleave(self.num_generations, dim=0)
attention_mask = attention_mask.repeat_interleave(self.num_generations, dim=0)
generate_config = kwargs.get("generate_config", self.generate_config)
out = self.model.generate(
input_ids, attention_mask=attention_mask, **kwargs, **self.generate_config, tokenizer=self.tokenizer
input_ids, attention_mask=attention_mask, **kwargs, **generate_config, tokenizer=self.tokenizer
)
input_len = input_ids.shape[-1]
new_token_ids = out.sequences[:, input_len:]

View File

@ -93,12 +93,14 @@ class BaseProducer:
)
self.eval_dataset_config = eval_dataset_config
if self.eval_dataset_config is not None and self.eval_interval > 0:
if self.eval_dataset_config is not None:
self.eval_dataloaders = {}
for eval_task_name in self.eval_dataset_config:
eval_dataset_path = eval_dataset_config[eval_task_name].pop("path")
eval_dataset = RawConversationDataset(
self.tokenizer, eval_dataset_path, **eval_dataset_config[eval_task_name]
self.tokenizer,
eval_dataset_config[eval_task_name]["path"],
eval_dataset_config[eval_task_name]["max_length"],
eval_dataset_config[eval_task_name]["system_prompt"],
)
print(f"[P{self.producer_idx}] eval dataset {eval_task_name} size: {len(eval_dataset)}")
self.eval_dataloaders[eval_task_name] = DataLoader(
@ -171,7 +173,14 @@ class BaseProducer:
for eval_batch_id, eval_batch in tqdm.tqdm(
enumerate(self.eval_dataloaders[eval_task_name]), desc=f"Evaluating: {eval_task_name}"
):
eval_outputs = self.rollout(**eval_batch, sample_params=self.eval_sample_params)
if isinstance(self.model, BACKEND_MAP["vllm"]):
eval_outputs = self.rollout(
**eval_batch, sample_params=self.eval_sample_params[eval_task_name]
)
elif isinstance(self.model, BACKEND_MAP["transformers"]):
eval_outputs = self.rollout(
**eval_batch, generate_config=self.eval_generation_config[eval_task_name]
)
eval_results = eval_results + [
self.evaluation_function(
eval_outputs["input_ids"][m][n],
@ -179,6 +188,7 @@ class BaseProducer:
eval_outputs["response_idx"][m][n],
tokenizer=self.tokenizer,
eval_mode=True,
max_new_tokens=self.eval_dataset_config[eval_task_name]["max_new_tokens"],
)
for m in range(eval_outputs["input_ids"].size(0))
for n in range(eval_outputs["input_ids"].size(1))
@ -302,11 +312,29 @@ class SimpleProducer(BaseProducer):
eval_save_dir=eval_save_dir,
)
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
self.eval_generation_config = copy.deepcopy(self.model.generate_config)
self.eval_generation_config.update(
{"n": 1, "temperature": 0.6, "top_p": 0.95}
) # use 1 generation for evaluation
self.eval_sample_params = SamplingParams(**self.eval_generation_config)
self.eval_sample_params = {}
for eval_task_name in eval_dataset_config:
eval_generation_config = copy.deepcopy(self.model.generate_config)
if isinstance(self.model, BACKEND_MAP["vllm"]):
eval_generation_config.update(
{
"n": 1,
"temperature": eval_dataset_config[eval_task_name]["temperature"],
"top_p": eval_dataset_config[eval_task_name]["top_p"],
"top_k": eval_dataset_config[eval_task_name]["top_k"],
}
) # use 1 generation for evaluation
self.eval_sample_params[eval_task_name] = SamplingParams(**eval_generation_config)
elif isinstance(self.model, BACKEND_MAP["transformers"]):
eval_generation_config.update(
{
"num_return_sequences": 1,
"temperature": eval_dataset_config[eval_task_name]["temperature"],
"top_p": eval_dataset_config[eval_task_name]["top_p"],
"top_k": eval_dataset_config[eval_task_name]["top_k"],
}
) # use 1 generation for evaluation
self.eval_generation_config[eval_task_name] = copy.deepcopy(eval_generation_config)
@torch.no_grad()
def rollout(self, input_ids, attention_mask, **kwargs):

View File

@ -246,7 +246,7 @@ if __name__ == "__main__":
# "zero_stage": 2,
# }, # for zero
plugin_config={
"tp_size": 4,
"tp_size": 2,
"pp_size": 2,
"microbatch_size": max(
1, args.train_microbatch_size // 2
@ -261,7 +261,15 @@ if __name__ == "__main__":
save_interval=args.save_interval,
save_dir=os.path.join(args.save_dir, args.project.replace(" ", "_")),
eval_dataset_config={
k: {"path": v, "max_length": args.max_prompt_tokens, "system_prompt": args.system_prompt}
k: {
"path": v["path"],
"max_length": args.max_prompt_tokens + args.max_new_tokens,
"max_new_tokens": args.max_new_tokens,
"system_prompt": args.system_prompt,
"temperature": v.get("temperature", 0.6),
"top_p": v.get("top_p", 0.95),
"top_k": v.get("top_k", 50),
}
for k, v in json.loads(args.eval_dataset).items()
},
eval_interval=args.eval_interval,