[Fix] Add L2 Regularization (#6372)

* fix no L2 regularization error

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
YeAnbang 2025-07-29 16:56:52 +08:00 committed by GitHub
parent 57e92104a2
commit cd32236e53
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 6 additions and 2 deletions

View File

@ -365,7 +365,7 @@ class SimpleConsumer(BaseConsumer):
self.model = AutoModelForCausalLM.from_pretrained(path, **model_config)
self.model.train()
self.model.gradient_checkpointing_enable()
self.optimizer = HybridAdam(self.model.parameters(), lr=1e-3)
self.optimizer = HybridAdam(self.model.parameters(), lr=1e-3, weight_decay=0.01)
self.accum_loss = torch.zeros(1, device=self.device)
def setup(self):

View File

@ -72,7 +72,11 @@ class GRPOConsumer(BaseConsumer):
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
self.policy_model.train()
self.policy_model.gradient_checkpointing_enable()
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6))
self.optimizer = HybridAdam(
self.policy_model.parameters(),
lr=grpo_config.get("lr", 1e-6),
weight_decay=grpo_config.get("weight_decay", 0.01),
)
self.accum_loss = torch.zeros(1, device=self.device)
self.accum_kl = torch.zeros(1, device=self.device)
self.accum_entropy = torch.zeros(1, device=self.device)