Support evaluation during training

This commit is contained in:
YeAnbang
2025-04-30 18:13:40 +08:00
parent b920af427b
commit 47a7dc7142
9 changed files with 234 additions and 65 deletions

View File

@@ -40,6 +40,7 @@ class GRPOConsumer(BaseConsumer):
project_name=None,
save_interval: int = 100,
save_dir="./model",
eval_interval: int = -1,
):
print(f"Using GRPO config: {grpo_config}")
if grpo_config.get("loss_variation", "sample_level") == "token_level":
@@ -72,6 +73,7 @@ class GRPOConsumer(BaseConsumer):
minibatch_size,
save_interval=save_interval,
save_dir=save_dir,
eval_interval=eval_interval,
)
path = model_config.pop("path")
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
@@ -528,4 +530,5 @@ class GRPOConsumer(BaseConsumer):
self.policy_model._force_wait_all_gather()
model = self.policy_model.unwrap()
state_dict = model.state_dict()
state_dict["consumer_global_step"] = torch.tensor([self.global_step], device=self.device)
return state_dict