merge grpo-latest'

This commit is contained in:
YeAnbang 2025-08-04 11:38:14 +08:00
commit f067e778e9
3 changed files with 6 additions and 4 deletions

View File

@ -365,7 +365,7 @@ class SimpleConsumer(BaseConsumer):
self.model = AutoModelForCausalLM.from_pretrained(path, **model_config)
self.model.train()
self.model.gradient_checkpointing_enable()
self.optimizer = HybridAdam(self.model.parameters(), lr=1e-3)
self.optimizer = HybridAdam(self.model.parameters(), lr=1e-3, weight_decay=0.01)
self.accum_loss = torch.zeros(1, device=self.device)
def setup(self):

View File

@ -72,7 +72,11 @@ class GRPOConsumer(BaseConsumer):
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
self.policy_model.train()
self.policy_model.gradient_checkpointing_enable()
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6))
self.optimizer = HybridAdam(
self.policy_model.parameters(),
lr=grpo_config.get("lr", 1e-6),
weight_decay=grpo_config.get("weight_decay", 0.01),
)
self.accum_loss = torch.zeros(1, device=self.device)
self.accum_kl = torch.zeros(1, device=self.device)
self.accum_entropy = torch.zeros(1, device=self.device)

View File

@ -180,8 +180,6 @@ def run_test(in_outs, test=None, debug=False, timeout=15, run_all_tests=False):
tmp_test = new_test
sol += tmp_test
# if debug:
# print(f"sol = {sol}")
method_name = "code"
signal.alarm(timeout)
try: