remove redundant code and fix bugs

This commit is contained in:
YeAnbang 2025-05-16 14:08:23 +08:00
parent a528921944
commit 11a5854b50
5 changed files with 27 additions and 59 deletions

View File

@ -121,14 +121,14 @@ class BaseConsumer:
raw_batch = unbind_batch(
ray_broadcast_tensor_dict(None, src=0, device=self.device, group_name=f"sync_data_{r}")
)
filtered_batch = [
t
for t in [
self.prompt_level_filtering(self.calculate_group_reward(group))
for group in raw_batch
]
if t is not None
processed_batch = [
self.prompt_level_filtering(self.calculate_group_reward(group)) for group in raw_batch
]
filtered_batch = [t for t in processed_batch if t is not None]
if self.filter_range is not None:
print(
f"[T{dist.get_rank()}] Filter recv data: {len(processed_batch)} -> {len(filtered_batch)}"
)
self.buffer.extend(filtered_batch)
while len(self.buffer) >= self.dp_size * self.minibatch_size:
@ -137,13 +137,8 @@ class BaseConsumer:
]
batch = bind_batch(batches)
batch = post_recv(batch)
loss, excessive_prompts_idx = self.step(i, pbar, **batch)
if excessive_prompts_idx is not None:
excessive_prompts = [self.buffer[idx] for idx in excessive_prompts_idx]
self.buffer = excessive_prompts + self.buffer[self.dp_size * self.minibatch_size :]
else:
self.buffer = self.buffer[self.dp_size * self.minibatch_size :]
loss = self.step(i, pbar, **batch)
self.buffer = self.buffer[self.dp_size * self.minibatch_size :]
if loss is not None:
allow_sync_model = True
pbar.set_postfix({"loss": loss})

View File

@ -9,7 +9,7 @@ from coati.distributed.loss import PolicyLoss
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn
from coati.distributed.reward.verifiable_reward import VerifiableReward
from coati.distributed.utils import calc_action_log_probs
from coati.trainer.utils import all_gather_tensors, all_reduce_mean, all_reduce_sum
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
from transformers import AutoModelForCausalLM, AutoTokenizer
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
@ -201,10 +201,7 @@ class GRPOConsumer(BaseConsumer):
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
# [minibatch_size x num_generations]
advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1)
# filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score),
group_ans_acc = (
ans_acc.view(-1, self.num_generations).mean(dim=1).repeat_interleave(self.num_generations, dim=0)
)
# [minibatch_size x num_of_generation]
loss_mask = torch.ones(action_mask.size(0), device=action_mask.device).bool()
@ -214,37 +211,14 @@ class GRPOConsumer(BaseConsumer):
loss_mask,
action_mask[:, -1] == False,
)
prompt_level_mask = loss_mask.view(self.minibatch_size, self.num_generations)
# [minibatch_size] -> calculate the number of effective prompts
effective_prompts_mask = prompt_level_mask.any(dim=1)
effective_prompts = all_reduce_sum(torch.sum(effective_prompts_mask), self.plugin)
self.effective_prompt_count += effective_prompts.item()
excessive_prompts_idx = None
self.effective_prompt_count += group_reward.size(0) * self.dp_size
mean_kl, mean_loss = [], []
if self.grpo_config.get("dynamic_batching", True):
need_update = self.effective_prompt_count >= self.batch_size * self.dp_size
excessive_prompts = self.effective_prompt_count - self.batch_size * self.dp_size
if excessive_prompts > 0:
excessive_prompts_per_rank = excessive_prompts // self.dp_size
# Only count excessive prompts if they are greater than 1 per rank.
# TODO: customize excessive prompts calculation.
if excessive_prompts_per_rank != 0:
# Mask excessive prompts to False
true_indices = torch.nonzero(effective_prompts_mask).squeeze()
if excessive_prompts_per_rank <= len(true_indices):
excessive_prompts_idx = true_indices[-excessive_prompts_per_rank:]
else:
excessive_prompts_idx = true_indices
effective_prompts_mask[excessive_prompts_idx] = False
for mask_idx in range(len(effective_prompts_mask)):
if effective_prompts_mask[mask_idx] == False:
# Update loss mask.
loss_mask[mask_idx] = False
assert excessive_prompts <= 0, "Debug: Excessive prompts should always be less than 0. Bug!!!!"
else:
# If dynamic batching is disabled, we need to use all samples for training.
need_update = (step_idx + 1) % self.num_microbatches == 0
@ -460,9 +434,7 @@ class GRPOConsumer(BaseConsumer):
self.optimizer.step()
self.optimizer.zero_grad()
self.global_step += 1
self.total_sample_count = all_reduce_sum(
torch.tensor(self.total_sample_count).to(self.accum_loss.device), self.plugin
).item()
# no need to run all_reduce_sum on total_sample_count, because all dp ranks recieves a complete inference batch from all producers.
sample_utilization = self.effective_sample_count / self.total_sample_count
self.effective_prompt_count = 0
self.effective_sample_count = 0
@ -507,14 +479,9 @@ class GRPOConsumer(BaseConsumer):
self.accum_advantages.zero_()
self.accum_response_length.zero_()
self.accum_count = 0
if excessive_prompts_idx is not None:
# All gather excessive prompts index across DP ranks.
excessive_prompts_idx = [idx + self.dp_rank * self.minibatch_size for idx in excessive_prompts_idx]
excessive_prompts_idx = all_gather_tensors(excessive_prompts_idx, self.plugin)
return loss_scalar, excessive_prompts_idx
return loss_scalar
else:
return None, excessive_prompts_idx
return None
def calculate_group_reward(self, rollout_group: Dict[str, Any]) -> Dict[str, Any]:
"""

View File

@ -66,7 +66,7 @@ def launch_distributed(
dataset_path = train_dataset_config["path"]
num_samples = get_jsonl_size_fast(dataset_path)
global_inference_batch_size = inference_batch_size * num_producers # TODO: this doesn't support TP on producer
global_inference_batch_size = inference_batch_size * num_producers
num_update_per_episode = num_samples // global_inference_batch_size
num_recv_per_update = inference_batch_size // inference_microbatch_size

View File

@ -187,7 +187,7 @@ class BaseProducer:
for eval_task_name in self.eval_dataloaders:
if self.producer_idx == 0:
print(
f"[P{self.producer_idx}] Evaluate episode {episode} step {self.consumer_global_step} on task {eval_task_name}"
f"[P{self.producer_idx}] Evaluate consumer step {self.consumer_global_step} on task {eval_task_name}"
)
eval_results = []
eval_statistics_tensor = torch.zeros((2,), dtype=torch.float32).to(self.device)

View File

@ -104,7 +104,13 @@ if __name__ == "__main__":
choices=["think_answer_tags", "boxed"],
help="Reward type for GRPO.",
)
parser.add_argument("-ei", "--eval-interval", type=int, default=100, help="Interval for evaluation.")
parser.add_argument(
"-ei",
"--eval-interval",
type=int,
default=100,
help="Interval for evaluation. Evaluate every ei consumer steps.",
)
# Logging/Checkpointing parameters
parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.")
@ -125,8 +131,8 @@ if __name__ == "__main__":
and args.train_microbatch_size > 0
), "Train micro batch size must be greater than 0 less than train mini batch size * num generations"
assert (
args.train_minibatch_size <= args.train_batch_size
), "Train mini batch size must be less than or equals to train batch size"
args.train_minibatch_size <= args.train_batch_size and args.train_batch_size % args.train_minibatch_size == 0
), "Train mini batch size must be less than or equals to train batch size and train batch size must be divisible by train mini batch size"
if args.master_address is None:
# Default settings: Using single machine