mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-12 05:11:56 +00:00
remove redundant code and fix bugs
This commit is contained in:
parent
a528921944
commit
11a5854b50
@ -121,14 +121,14 @@ class BaseConsumer:
|
||||
raw_batch = unbind_batch(
|
||||
ray_broadcast_tensor_dict(None, src=0, device=self.device, group_name=f"sync_data_{r}")
|
||||
)
|
||||
filtered_batch = [
|
||||
t
|
||||
for t in [
|
||||
self.prompt_level_filtering(self.calculate_group_reward(group))
|
||||
for group in raw_batch
|
||||
]
|
||||
if t is not None
|
||||
processed_batch = [
|
||||
self.prompt_level_filtering(self.calculate_group_reward(group)) for group in raw_batch
|
||||
]
|
||||
filtered_batch = [t for t in processed_batch if t is not None]
|
||||
if self.filter_range is not None:
|
||||
print(
|
||||
f"[T{dist.get_rank()}] Filter recv data: {len(processed_batch)} -> {len(filtered_batch)}"
|
||||
)
|
||||
|
||||
self.buffer.extend(filtered_batch)
|
||||
while len(self.buffer) >= self.dp_size * self.minibatch_size:
|
||||
@ -137,13 +137,8 @@ class BaseConsumer:
|
||||
]
|
||||
batch = bind_batch(batches)
|
||||
batch = post_recv(batch)
|
||||
loss, excessive_prompts_idx = self.step(i, pbar, **batch)
|
||||
|
||||
if excessive_prompts_idx is not None:
|
||||
excessive_prompts = [self.buffer[idx] for idx in excessive_prompts_idx]
|
||||
self.buffer = excessive_prompts + self.buffer[self.dp_size * self.minibatch_size :]
|
||||
else:
|
||||
self.buffer = self.buffer[self.dp_size * self.minibatch_size :]
|
||||
loss = self.step(i, pbar, **batch)
|
||||
self.buffer = self.buffer[self.dp_size * self.minibatch_size :]
|
||||
if loss is not None:
|
||||
allow_sync_model = True
|
||||
pbar.set_postfix({"loss": loss})
|
||||
|
@ -9,7 +9,7 @@ from coati.distributed.loss import PolicyLoss
|
||||
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn
|
||||
from coati.distributed.reward.verifiable_reward import VerifiableReward
|
||||
from coati.distributed.utils import calc_action_log_probs
|
||||
from coati.trainer.utils import all_gather_tensors, all_reduce_mean, all_reduce_sum
|
||||
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
@ -201,10 +201,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
|
||||
# [minibatch_size x num_generations]
|
||||
advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1)
|
||||
# filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score),
|
||||
group_ans_acc = (
|
||||
ans_acc.view(-1, self.num_generations).mean(dim=1).repeat_interleave(self.num_generations, dim=0)
|
||||
)
|
||||
|
||||
# [minibatch_size x num_of_generation]
|
||||
loss_mask = torch.ones(action_mask.size(0), device=action_mask.device).bool()
|
||||
|
||||
@ -214,37 +211,14 @@ class GRPOConsumer(BaseConsumer):
|
||||
loss_mask,
|
||||
action_mask[:, -1] == False,
|
||||
)
|
||||
prompt_level_mask = loss_mask.view(self.minibatch_size, self.num_generations)
|
||||
|
||||
# [minibatch_size] -> calculate the number of effective prompts
|
||||
effective_prompts_mask = prompt_level_mask.any(dim=1)
|
||||
effective_prompts = all_reduce_sum(torch.sum(effective_prompts_mask), self.plugin)
|
||||
self.effective_prompt_count += effective_prompts.item()
|
||||
excessive_prompts_idx = None
|
||||
self.effective_prompt_count += group_reward.size(0) * self.dp_size
|
||||
|
||||
mean_kl, mean_loss = [], []
|
||||
|
||||
if self.grpo_config.get("dynamic_batching", True):
|
||||
need_update = self.effective_prompt_count >= self.batch_size * self.dp_size
|
||||
excessive_prompts = self.effective_prompt_count - self.batch_size * self.dp_size
|
||||
|
||||
if excessive_prompts > 0:
|
||||
excessive_prompts_per_rank = excessive_prompts // self.dp_size
|
||||
# Only count excessive prompts if they are greater than 1 per rank.
|
||||
# TODO: customize excessive prompts calculation.
|
||||
if excessive_prompts_per_rank != 0:
|
||||
# Mask excessive prompts to False
|
||||
true_indices = torch.nonzero(effective_prompts_mask).squeeze()
|
||||
if excessive_prompts_per_rank <= len(true_indices):
|
||||
excessive_prompts_idx = true_indices[-excessive_prompts_per_rank:]
|
||||
else:
|
||||
excessive_prompts_idx = true_indices
|
||||
effective_prompts_mask[excessive_prompts_idx] = False
|
||||
|
||||
for mask_idx in range(len(effective_prompts_mask)):
|
||||
if effective_prompts_mask[mask_idx] == False:
|
||||
# Update loss mask.
|
||||
loss_mask[mask_idx] = False
|
||||
assert excessive_prompts <= 0, "Debug: Excessive prompts should always be less than 0. Bug!!!!"
|
||||
else:
|
||||
# If dynamic batching is disabled, we need to use all samples for training.
|
||||
need_update = (step_idx + 1) % self.num_microbatches == 0
|
||||
@ -460,9 +434,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
self.global_step += 1
|
||||
self.total_sample_count = all_reduce_sum(
|
||||
torch.tensor(self.total_sample_count).to(self.accum_loss.device), self.plugin
|
||||
).item()
|
||||
# no need to run all_reduce_sum on total_sample_count, because all dp ranks recieves a complete inference batch from all producers.
|
||||
sample_utilization = self.effective_sample_count / self.total_sample_count
|
||||
self.effective_prompt_count = 0
|
||||
self.effective_sample_count = 0
|
||||
@ -507,14 +479,9 @@ class GRPOConsumer(BaseConsumer):
|
||||
self.accum_advantages.zero_()
|
||||
self.accum_response_length.zero_()
|
||||
self.accum_count = 0
|
||||
|
||||
if excessive_prompts_idx is not None:
|
||||
# All gather excessive prompts index across DP ranks.
|
||||
excessive_prompts_idx = [idx + self.dp_rank * self.minibatch_size for idx in excessive_prompts_idx]
|
||||
excessive_prompts_idx = all_gather_tensors(excessive_prompts_idx, self.plugin)
|
||||
return loss_scalar, excessive_prompts_idx
|
||||
return loss_scalar
|
||||
else:
|
||||
return None, excessive_prompts_idx
|
||||
return None
|
||||
|
||||
def calculate_group_reward(self, rollout_group: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
|
@ -66,7 +66,7 @@ def launch_distributed(
|
||||
|
||||
dataset_path = train_dataset_config["path"]
|
||||
num_samples = get_jsonl_size_fast(dataset_path)
|
||||
global_inference_batch_size = inference_batch_size * num_producers # TODO: this doesn't support TP on producer
|
||||
global_inference_batch_size = inference_batch_size * num_producers
|
||||
num_update_per_episode = num_samples // global_inference_batch_size
|
||||
num_recv_per_update = inference_batch_size // inference_microbatch_size
|
||||
|
||||
|
@ -187,7 +187,7 @@ class BaseProducer:
|
||||
for eval_task_name in self.eval_dataloaders:
|
||||
if self.producer_idx == 0:
|
||||
print(
|
||||
f"[P{self.producer_idx}] Evaluate episode {episode} step {self.consumer_global_step} on task {eval_task_name}"
|
||||
f"[P{self.producer_idx}] Evaluate consumer step {self.consumer_global_step} on task {eval_task_name}"
|
||||
)
|
||||
eval_results = []
|
||||
eval_statistics_tensor = torch.zeros((2,), dtype=torch.float32).to(self.device)
|
||||
|
@ -104,7 +104,13 @@ if __name__ == "__main__":
|
||||
choices=["think_answer_tags", "boxed"],
|
||||
help="Reward type for GRPO.",
|
||||
)
|
||||
parser.add_argument("-ei", "--eval-interval", type=int, default=100, help="Interval for evaluation.")
|
||||
parser.add_argument(
|
||||
"-ei",
|
||||
"--eval-interval",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Interval for evaluation. Evaluate every ei consumer steps.",
|
||||
)
|
||||
|
||||
# Logging/Checkpointing parameters
|
||||
parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.")
|
||||
@ -125,8 +131,8 @@ if __name__ == "__main__":
|
||||
and args.train_microbatch_size > 0
|
||||
), "Train micro batch size must be greater than 0 less than train mini batch size * num generations"
|
||||
assert (
|
||||
args.train_minibatch_size <= args.train_batch_size
|
||||
), "Train mini batch size must be less than or equals to train batch size"
|
||||
args.train_minibatch_size <= args.train_batch_size and args.train_batch_size % args.train_minibatch_size == 0
|
||||
), "Train mini batch size must be less than or equals to train batch size and train batch size must be divisible by train mini batch size"
|
||||
|
||||
if args.master_address is None:
|
||||
# Default settings: Using single machine
|
||||
|
Loading…
Reference in New Issue
Block a user