mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-27 07:47:05 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
c0c0da2f26
commit
e397327de0
@ -1,7 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from typing import Optional, Any
|
from typing import Any, Optional
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
import torch
|
import torch
|
||||||
@ -212,7 +212,12 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
# update gradient only if at least 0.7*batch_size*num_generation valid samples are collected in case a lot of samples are invalid and got filtered out.
|
# update gradient only if at least 0.7*batch_size*num_generation valid samples are collected in case a lot of samples are invalid and got filtered out.
|
||||||
# balance between efficiency and accuracy
|
# balance between efficiency and accuracy
|
||||||
need_update = self.effective_sample_count >= self.batch_size * self.dp_size * self.num_generations * 0.75
|
need_update = self.effective_sample_count >= self.batch_size * self.dp_size * self.num_generations * 0.75
|
||||||
pbar.set_postfix({"Step": self.global_step + 1, "Status": f"Collecting: {self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations * 0.75}"})
|
pbar.set_postfix(
|
||||||
|
{
|
||||||
|
"Step": self.global_step + 1,
|
||||||
|
"Status": f"Collecting: {self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations * 0.75}",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# Gradient must be synchronized if zero2 is enabled. https://github.com/hpcaitech/ColossalAI/blob/44d4053fec005fe0b06b6bc755fdc962463145df/colossalai/booster/plugin/hybrid_parallel_plugin.py#L1500
|
# Gradient must be synchronized if zero2 is enabled. https://github.com/hpcaitech/ColossalAI/blob/44d4053fec005fe0b06b6bc755fdc962463145df/colossalai/booster/plugin/hybrid_parallel_plugin.py#L1500
|
||||||
ctx = (
|
ctx = (
|
||||||
|
Loading…
Reference in New Issue
Block a user