remove redundant code and fix bugs

This commit is contained in:
YeAnbang 2025-05-16 14:08:23 +08:00
parent a528921944
commit 11a5854b50
5 changed files with 27 additions and 59 deletions

View File

@ -121,14 +121,14 @@ class BaseConsumer:
raw_batch = unbind_batch( raw_batch = unbind_batch(
ray_broadcast_tensor_dict(None, src=0, device=self.device, group_name=f"sync_data_{r}") ray_broadcast_tensor_dict(None, src=0, device=self.device, group_name=f"sync_data_{r}")
) )
filtered_batch = [ processed_batch = [
t self.prompt_level_filtering(self.calculate_group_reward(group)) for group in raw_batch
for t in [
self.prompt_level_filtering(self.calculate_group_reward(group))
for group in raw_batch
]
if t is not None
] ]
filtered_batch = [t for t in processed_batch if t is not None]
if self.filter_range is not None:
print(
f"[T{dist.get_rank()}] Filter recv data: {len(processed_batch)} -> {len(filtered_batch)}"
)
self.buffer.extend(filtered_batch) self.buffer.extend(filtered_batch)
while len(self.buffer) >= self.dp_size * self.minibatch_size: while len(self.buffer) >= self.dp_size * self.minibatch_size:
@ -137,12 +137,7 @@ class BaseConsumer:
] ]
batch = bind_batch(batches) batch = bind_batch(batches)
batch = post_recv(batch) batch = post_recv(batch)
loss, excessive_prompts_idx = self.step(i, pbar, **batch) loss = self.step(i, pbar, **batch)
if excessive_prompts_idx is not None:
excessive_prompts = [self.buffer[idx] for idx in excessive_prompts_idx]
self.buffer = excessive_prompts + self.buffer[self.dp_size * self.minibatch_size :]
else:
self.buffer = self.buffer[self.dp_size * self.minibatch_size :] self.buffer = self.buffer[self.dp_size * self.minibatch_size :]
if loss is not None: if loss is not None:
allow_sync_model = True allow_sync_model = True

View File

@ -9,7 +9,7 @@ from coati.distributed.loss import PolicyLoss
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn
from coati.distributed.reward.verifiable_reward import VerifiableReward from coati.distributed.reward.verifiable_reward import VerifiableReward
from coati.distributed.utils import calc_action_log_probs from coati.distributed.utils import calc_action_log_probs
from coati.trainer.utils import all_gather_tensors, all_reduce_mean, all_reduce_sum from coati.trainer.utils import all_reduce_mean, all_reduce_sum
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
@ -201,10 +201,7 @@ class GRPOConsumer(BaseConsumer):
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0) reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
# [minibatch_size x num_generations] # [minibatch_size x num_generations]
advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1) advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1)
# filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score),
group_ans_acc = (
ans_acc.view(-1, self.num_generations).mean(dim=1).repeat_interleave(self.num_generations, dim=0)
)
# [minibatch_size x num_of_generation] # [minibatch_size x num_of_generation]
loss_mask = torch.ones(action_mask.size(0), device=action_mask.device).bool() loss_mask = torch.ones(action_mask.size(0), device=action_mask.device).bool()
@ -214,37 +211,14 @@ class GRPOConsumer(BaseConsumer):
loss_mask, loss_mask,
action_mask[:, -1] == False, action_mask[:, -1] == False,
) )
prompt_level_mask = loss_mask.view(self.minibatch_size, self.num_generations) self.effective_prompt_count += group_reward.size(0) * self.dp_size
# [minibatch_size] -> calculate the number of effective prompts
effective_prompts_mask = prompt_level_mask.any(dim=1)
effective_prompts = all_reduce_sum(torch.sum(effective_prompts_mask), self.plugin)
self.effective_prompt_count += effective_prompts.item()
excessive_prompts_idx = None
mean_kl, mean_loss = [], [] mean_kl, mean_loss = [], []
if self.grpo_config.get("dynamic_batching", True): if self.grpo_config.get("dynamic_batching", True):
need_update = self.effective_prompt_count >= self.batch_size * self.dp_size need_update = self.effective_prompt_count >= self.batch_size * self.dp_size
excessive_prompts = self.effective_prompt_count - self.batch_size * self.dp_size excessive_prompts = self.effective_prompt_count - self.batch_size * self.dp_size
assert excessive_prompts <= 0, "Debug: Excessive prompts should always be less than 0. Bug!!!!"
if excessive_prompts > 0:
excessive_prompts_per_rank = excessive_prompts // self.dp_size
# Only count excessive prompts if they are greater than 1 per rank.
# TODO: customize excessive prompts calculation.
if excessive_prompts_per_rank != 0:
# Mask excessive prompts to False
true_indices = torch.nonzero(effective_prompts_mask).squeeze()
if excessive_prompts_per_rank <= len(true_indices):
excessive_prompts_idx = true_indices[-excessive_prompts_per_rank:]
else:
excessive_prompts_idx = true_indices
effective_prompts_mask[excessive_prompts_idx] = False
for mask_idx in range(len(effective_prompts_mask)):
if effective_prompts_mask[mask_idx] == False:
# Update loss mask.
loss_mask[mask_idx] = False
else: else:
# If dynamic batching is disabled, we need to use all samples for training. # If dynamic batching is disabled, we need to use all samples for training.
need_update = (step_idx + 1) % self.num_microbatches == 0 need_update = (step_idx + 1) % self.num_microbatches == 0
@ -460,9 +434,7 @@ class GRPOConsumer(BaseConsumer):
self.optimizer.step() self.optimizer.step()
self.optimizer.zero_grad() self.optimizer.zero_grad()
self.global_step += 1 self.global_step += 1
self.total_sample_count = all_reduce_sum( # no need to run all_reduce_sum on total_sample_count, because all dp ranks recieves a complete inference batch from all producers.
torch.tensor(self.total_sample_count).to(self.accum_loss.device), self.plugin
).item()
sample_utilization = self.effective_sample_count / self.total_sample_count sample_utilization = self.effective_sample_count / self.total_sample_count
self.effective_prompt_count = 0 self.effective_prompt_count = 0
self.effective_sample_count = 0 self.effective_sample_count = 0
@ -507,14 +479,9 @@ class GRPOConsumer(BaseConsumer):
self.accum_advantages.zero_() self.accum_advantages.zero_()
self.accum_response_length.zero_() self.accum_response_length.zero_()
self.accum_count = 0 self.accum_count = 0
return loss_scalar
if excessive_prompts_idx is not None:
# All gather excessive prompts index across DP ranks.
excessive_prompts_idx = [idx + self.dp_rank * self.minibatch_size for idx in excessive_prompts_idx]
excessive_prompts_idx = all_gather_tensors(excessive_prompts_idx, self.plugin)
return loss_scalar, excessive_prompts_idx
else: else:
return None, excessive_prompts_idx return None
def calculate_group_reward(self, rollout_group: Dict[str, Any]) -> Dict[str, Any]: def calculate_group_reward(self, rollout_group: Dict[str, Any]) -> Dict[str, Any]:
""" """

View File

@ -66,7 +66,7 @@ def launch_distributed(
dataset_path = train_dataset_config["path"] dataset_path = train_dataset_config["path"]
num_samples = get_jsonl_size_fast(dataset_path) num_samples = get_jsonl_size_fast(dataset_path)
global_inference_batch_size = inference_batch_size * num_producers # TODO: this doesn't support TP on producer global_inference_batch_size = inference_batch_size * num_producers
num_update_per_episode = num_samples // global_inference_batch_size num_update_per_episode = num_samples // global_inference_batch_size
num_recv_per_update = inference_batch_size // inference_microbatch_size num_recv_per_update = inference_batch_size // inference_microbatch_size

View File

@ -187,7 +187,7 @@ class BaseProducer:
for eval_task_name in self.eval_dataloaders: for eval_task_name in self.eval_dataloaders:
if self.producer_idx == 0: if self.producer_idx == 0:
print( print(
f"[P{self.producer_idx}] Evaluate episode {episode} step {self.consumer_global_step} on task {eval_task_name}" f"[P{self.producer_idx}] Evaluate consumer step {self.consumer_global_step} on task {eval_task_name}"
) )
eval_results = [] eval_results = []
eval_statistics_tensor = torch.zeros((2,), dtype=torch.float32).to(self.device) eval_statistics_tensor = torch.zeros((2,), dtype=torch.float32).to(self.device)

View File

@ -104,7 +104,13 @@ if __name__ == "__main__":
choices=["think_answer_tags", "boxed"], choices=["think_answer_tags", "boxed"],
help="Reward type for GRPO.", help="Reward type for GRPO.",
) )
parser.add_argument("-ei", "--eval-interval", type=int, default=100, help="Interval for evaluation.") parser.add_argument(
"-ei",
"--eval-interval",
type=int,
default=100,
help="Interval for evaluation. Evaluate every ei consumer steps.",
)
# Logging/Checkpointing parameters # Logging/Checkpointing parameters
parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.") parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.")
@ -125,8 +131,8 @@ if __name__ == "__main__":
and args.train_microbatch_size > 0 and args.train_microbatch_size > 0
), "Train micro batch size must be greater than 0 less than train mini batch size * num generations" ), "Train micro batch size must be greater than 0 less than train mini batch size * num generations"
assert ( assert (
args.train_minibatch_size <= args.train_batch_size args.train_minibatch_size <= args.train_batch_size and args.train_batch_size % args.train_minibatch_size == 0
), "Train mini batch size must be less than or equals to train batch size" ), "Train mini batch size must be less than or equals to train batch size and train batch size must be divisible by train mini batch size"
if args.master_address is None: if args.master_address is None:
# Default settings: Using single machine # Default settings: Using single machine