mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[fix] revert reward update and evaluation (#6295)
* Revert "rewrite reward fn" This reverts commitd06042b434
. * Revert "upgrade reward math verification" This reverts commita6085ff676
. * Revert "fix bug" This reverts commit01640ebd65
. * Revert "reuse comm-group" This reverts commitbd61918dcf
. * Revert "Support evaluation during training" This reverts commit57a88395fe
.
This commit is contained in:
@@ -40,7 +40,6 @@ class GRPOConsumer(BaseConsumer):
|
||||
project_name=None,
|
||||
save_interval: int = 100,
|
||||
save_dir="./model",
|
||||
eval_interval: int = -1,
|
||||
):
|
||||
print(f"Using GRPO config: {grpo_config}")
|
||||
if grpo_config.get("loss_variation", "sample_level") == "token_level":
|
||||
@@ -73,7 +72,6 @@ class GRPOConsumer(BaseConsumer):
|
||||
minibatch_size,
|
||||
save_interval=save_interval,
|
||||
save_dir=save_dir,
|
||||
eval_interval=eval_interval,
|
||||
)
|
||||
path = model_config.pop("path")
|
||||
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||
@@ -530,5 +528,4 @@ class GRPOConsumer(BaseConsumer):
|
||||
self.policy_model._force_wait_all_gather()
|
||||
model = self.policy_model.unwrap()
|
||||
state_dict = model.state_dict()
|
||||
state_dict["consumer_global_step"] = torch.tensor([self.global_step], device=self.device)
|
||||
return state_dict
|
||||
|
Reference in New Issue
Block a user