mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
Support evaluation during training
This commit is contained in:
@@ -40,6 +40,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
project_name=None,
|
||||
save_interval: int = 100,
|
||||
save_dir="./model",
|
||||
eval_interval: int = -1,
|
||||
):
|
||||
print(f"Using GRPO config: {grpo_config}")
|
||||
if grpo_config.get("loss_variation", "sample_level") == "token_level":
|
||||
@@ -72,6 +73,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
minibatch_size,
|
||||
save_interval=save_interval,
|
||||
save_dir=save_dir,
|
||||
eval_interval=eval_interval,
|
||||
)
|
||||
path = model_config.pop("path")
|
||||
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||
@@ -528,4 +530,5 @@ class GRPOConsumer(BaseConsumer):
|
||||
self.policy_model._force_wait_all_gather()
|
||||
model = self.policy_model.unwrap()
|
||||
state_dict = model.state_dict()
|
||||
state_dict["consumer_global_step"] = torch.tensor([self.global_step], device=self.device)
|
||||
return state_dict
|
||||
|
Reference in New Issue
Block a user