mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-20 08:53:22 +00:00
merge grpo-latest'
This commit is contained in:
commit
f067e778e9
@ -365,7 +365,7 @@ class SimpleConsumer(BaseConsumer):
|
|||||||
self.model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
self.model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||||
self.model.train()
|
self.model.train()
|
||||||
self.model.gradient_checkpointing_enable()
|
self.model.gradient_checkpointing_enable()
|
||||||
self.optimizer = HybridAdam(self.model.parameters(), lr=1e-3)
|
self.optimizer = HybridAdam(self.model.parameters(), lr=1e-3, weight_decay=0.01)
|
||||||
self.accum_loss = torch.zeros(1, device=self.device)
|
self.accum_loss = torch.zeros(1, device=self.device)
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
|
@ -72,7 +72,11 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||||
self.policy_model.train()
|
self.policy_model.train()
|
||||||
self.policy_model.gradient_checkpointing_enable()
|
self.policy_model.gradient_checkpointing_enable()
|
||||||
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6))
|
self.optimizer = HybridAdam(
|
||||||
|
self.policy_model.parameters(),
|
||||||
|
lr=grpo_config.get("lr", 1e-6),
|
||||||
|
weight_decay=grpo_config.get("weight_decay", 0.01),
|
||||||
|
)
|
||||||
self.accum_loss = torch.zeros(1, device=self.device)
|
self.accum_loss = torch.zeros(1, device=self.device)
|
||||||
self.accum_kl = torch.zeros(1, device=self.device)
|
self.accum_kl = torch.zeros(1, device=self.device)
|
||||||
self.accum_entropy = torch.zeros(1, device=self.device)
|
self.accum_entropy = torch.zeros(1, device=self.device)
|
||||||
|
@ -180,8 +180,6 @@ def run_test(in_outs, test=None, debug=False, timeout=15, run_all_tests=False):
|
|||||||
tmp_test = new_test
|
tmp_test = new_test
|
||||||
|
|
||||||
sol += tmp_test
|
sol += tmp_test
|
||||||
# if debug:
|
|
||||||
# print(f"sol = {sol}")
|
|
||||||
method_name = "code"
|
method_name = "code"
|
||||||
signal.alarm(timeout)
|
signal.alarm(timeout)
|
||||||
try:
|
try:
|
||||||
|
Loading…
Reference in New Issue
Block a user