mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-13 21:55:46 +00:00
remove redundant code and fix bugs
This commit is contained in:
parent
a528921944
commit
11a5854b50
@ -121,14 +121,14 @@ class BaseConsumer:
|
|||||||
raw_batch = unbind_batch(
|
raw_batch = unbind_batch(
|
||||||
ray_broadcast_tensor_dict(None, src=0, device=self.device, group_name=f"sync_data_{r}")
|
ray_broadcast_tensor_dict(None, src=0, device=self.device, group_name=f"sync_data_{r}")
|
||||||
)
|
)
|
||||||
filtered_batch = [
|
processed_batch = [
|
||||||
t
|
self.prompt_level_filtering(self.calculate_group_reward(group)) for group in raw_batch
|
||||||
for t in [
|
|
||||||
self.prompt_level_filtering(self.calculate_group_reward(group))
|
|
||||||
for group in raw_batch
|
|
||||||
]
|
|
||||||
if t is not None
|
|
||||||
]
|
]
|
||||||
|
filtered_batch = [t for t in processed_batch if t is not None]
|
||||||
|
if self.filter_range is not None:
|
||||||
|
print(
|
||||||
|
f"[T{dist.get_rank()}] Filter recv data: {len(processed_batch)} -> {len(filtered_batch)}"
|
||||||
|
)
|
||||||
|
|
||||||
self.buffer.extend(filtered_batch)
|
self.buffer.extend(filtered_batch)
|
||||||
while len(self.buffer) >= self.dp_size * self.minibatch_size:
|
while len(self.buffer) >= self.dp_size * self.minibatch_size:
|
||||||
@ -137,13 +137,8 @@ class BaseConsumer:
|
|||||||
]
|
]
|
||||||
batch = bind_batch(batches)
|
batch = bind_batch(batches)
|
||||||
batch = post_recv(batch)
|
batch = post_recv(batch)
|
||||||
loss, excessive_prompts_idx = self.step(i, pbar, **batch)
|
loss = self.step(i, pbar, **batch)
|
||||||
|
self.buffer = self.buffer[self.dp_size * self.minibatch_size :]
|
||||||
if excessive_prompts_idx is not None:
|
|
||||||
excessive_prompts = [self.buffer[idx] for idx in excessive_prompts_idx]
|
|
||||||
self.buffer = excessive_prompts + self.buffer[self.dp_size * self.minibatch_size :]
|
|
||||||
else:
|
|
||||||
self.buffer = self.buffer[self.dp_size * self.minibatch_size :]
|
|
||||||
if loss is not None:
|
if loss is not None:
|
||||||
allow_sync_model = True
|
allow_sync_model = True
|
||||||
pbar.set_postfix({"loss": loss})
|
pbar.set_postfix({"loss": loss})
|
||||||
|
@ -9,7 +9,7 @@ from coati.distributed.loss import PolicyLoss
|
|||||||
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn
|
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn
|
||||||
from coati.distributed.reward.verifiable_reward import VerifiableReward
|
from coati.distributed.reward.verifiable_reward import VerifiableReward
|
||||||
from coati.distributed.utils import calc_action_log_probs
|
from coati.distributed.utils import calc_action_log_probs
|
||||||
from coati.trainer.utils import all_gather_tensors, all_reduce_mean, all_reduce_sum
|
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||||
@ -201,10 +201,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
|
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
|
||||||
# [minibatch_size x num_generations]
|
# [minibatch_size x num_generations]
|
||||||
advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1)
|
advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1)
|
||||||
# filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score),
|
|
||||||
group_ans_acc = (
|
|
||||||
ans_acc.view(-1, self.num_generations).mean(dim=1).repeat_interleave(self.num_generations, dim=0)
|
|
||||||
)
|
|
||||||
# [minibatch_size x num_of_generation]
|
# [minibatch_size x num_of_generation]
|
||||||
loss_mask = torch.ones(action_mask.size(0), device=action_mask.device).bool()
|
loss_mask = torch.ones(action_mask.size(0), device=action_mask.device).bool()
|
||||||
|
|
||||||
@ -214,37 +211,14 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
loss_mask,
|
loss_mask,
|
||||||
action_mask[:, -1] == False,
|
action_mask[:, -1] == False,
|
||||||
)
|
)
|
||||||
prompt_level_mask = loss_mask.view(self.minibatch_size, self.num_generations)
|
self.effective_prompt_count += group_reward.size(0) * self.dp_size
|
||||||
|
|
||||||
# [minibatch_size] -> calculate the number of effective prompts
|
|
||||||
effective_prompts_mask = prompt_level_mask.any(dim=1)
|
|
||||||
effective_prompts = all_reduce_sum(torch.sum(effective_prompts_mask), self.plugin)
|
|
||||||
self.effective_prompt_count += effective_prompts.item()
|
|
||||||
excessive_prompts_idx = None
|
|
||||||
|
|
||||||
mean_kl, mean_loss = [], []
|
mean_kl, mean_loss = [], []
|
||||||
|
|
||||||
if self.grpo_config.get("dynamic_batching", True):
|
if self.grpo_config.get("dynamic_batching", True):
|
||||||
need_update = self.effective_prompt_count >= self.batch_size * self.dp_size
|
need_update = self.effective_prompt_count >= self.batch_size * self.dp_size
|
||||||
excessive_prompts = self.effective_prompt_count - self.batch_size * self.dp_size
|
excessive_prompts = self.effective_prompt_count - self.batch_size * self.dp_size
|
||||||
|
assert excessive_prompts <= 0, "Debug: Excessive prompts should always be less than 0. Bug!!!!"
|
||||||
if excessive_prompts > 0:
|
|
||||||
excessive_prompts_per_rank = excessive_prompts // self.dp_size
|
|
||||||
# Only count excessive prompts if they are greater than 1 per rank.
|
|
||||||
# TODO: customize excessive prompts calculation.
|
|
||||||
if excessive_prompts_per_rank != 0:
|
|
||||||
# Mask excessive prompts to False
|
|
||||||
true_indices = torch.nonzero(effective_prompts_mask).squeeze()
|
|
||||||
if excessive_prompts_per_rank <= len(true_indices):
|
|
||||||
excessive_prompts_idx = true_indices[-excessive_prompts_per_rank:]
|
|
||||||
else:
|
|
||||||
excessive_prompts_idx = true_indices
|
|
||||||
effective_prompts_mask[excessive_prompts_idx] = False
|
|
||||||
|
|
||||||
for mask_idx in range(len(effective_prompts_mask)):
|
|
||||||
if effective_prompts_mask[mask_idx] == False:
|
|
||||||
# Update loss mask.
|
|
||||||
loss_mask[mask_idx] = False
|
|
||||||
else:
|
else:
|
||||||
# If dynamic batching is disabled, we need to use all samples for training.
|
# If dynamic batching is disabled, we need to use all samples for training.
|
||||||
need_update = (step_idx + 1) % self.num_microbatches == 0
|
need_update = (step_idx + 1) % self.num_microbatches == 0
|
||||||
@ -460,9 +434,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
self.global_step += 1
|
self.global_step += 1
|
||||||
self.total_sample_count = all_reduce_sum(
|
# no need to run all_reduce_sum on total_sample_count, because all dp ranks recieves a complete inference batch from all producers.
|
||||||
torch.tensor(self.total_sample_count).to(self.accum_loss.device), self.plugin
|
|
||||||
).item()
|
|
||||||
sample_utilization = self.effective_sample_count / self.total_sample_count
|
sample_utilization = self.effective_sample_count / self.total_sample_count
|
||||||
self.effective_prompt_count = 0
|
self.effective_prompt_count = 0
|
||||||
self.effective_sample_count = 0
|
self.effective_sample_count = 0
|
||||||
@ -507,14 +479,9 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
self.accum_advantages.zero_()
|
self.accum_advantages.zero_()
|
||||||
self.accum_response_length.zero_()
|
self.accum_response_length.zero_()
|
||||||
self.accum_count = 0
|
self.accum_count = 0
|
||||||
|
return loss_scalar
|
||||||
if excessive_prompts_idx is not None:
|
|
||||||
# All gather excessive prompts index across DP ranks.
|
|
||||||
excessive_prompts_idx = [idx + self.dp_rank * self.minibatch_size for idx in excessive_prompts_idx]
|
|
||||||
excessive_prompts_idx = all_gather_tensors(excessive_prompts_idx, self.plugin)
|
|
||||||
return loss_scalar, excessive_prompts_idx
|
|
||||||
else:
|
else:
|
||||||
return None, excessive_prompts_idx
|
return None
|
||||||
|
|
||||||
def calculate_group_reward(self, rollout_group: Dict[str, Any]) -> Dict[str, Any]:
|
def calculate_group_reward(self, rollout_group: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
|
@ -66,7 +66,7 @@ def launch_distributed(
|
|||||||
|
|
||||||
dataset_path = train_dataset_config["path"]
|
dataset_path = train_dataset_config["path"]
|
||||||
num_samples = get_jsonl_size_fast(dataset_path)
|
num_samples = get_jsonl_size_fast(dataset_path)
|
||||||
global_inference_batch_size = inference_batch_size * num_producers # TODO: this doesn't support TP on producer
|
global_inference_batch_size = inference_batch_size * num_producers
|
||||||
num_update_per_episode = num_samples // global_inference_batch_size
|
num_update_per_episode = num_samples // global_inference_batch_size
|
||||||
num_recv_per_update = inference_batch_size // inference_microbatch_size
|
num_recv_per_update = inference_batch_size // inference_microbatch_size
|
||||||
|
|
||||||
|
@ -187,7 +187,7 @@ class BaseProducer:
|
|||||||
for eval_task_name in self.eval_dataloaders:
|
for eval_task_name in self.eval_dataloaders:
|
||||||
if self.producer_idx == 0:
|
if self.producer_idx == 0:
|
||||||
print(
|
print(
|
||||||
f"[P{self.producer_idx}] Evaluate episode {episode} step {self.consumer_global_step} on task {eval_task_name}"
|
f"[P{self.producer_idx}] Evaluate consumer step {self.consumer_global_step} on task {eval_task_name}"
|
||||||
)
|
)
|
||||||
eval_results = []
|
eval_results = []
|
||||||
eval_statistics_tensor = torch.zeros((2,), dtype=torch.float32).to(self.device)
|
eval_statistics_tensor = torch.zeros((2,), dtype=torch.float32).to(self.device)
|
||||||
|
@ -104,7 +104,13 @@ if __name__ == "__main__":
|
|||||||
choices=["think_answer_tags", "boxed"],
|
choices=["think_answer_tags", "boxed"],
|
||||||
help="Reward type for GRPO.",
|
help="Reward type for GRPO.",
|
||||||
)
|
)
|
||||||
parser.add_argument("-ei", "--eval-interval", type=int, default=100, help="Interval for evaluation.")
|
parser.add_argument(
|
||||||
|
"-ei",
|
||||||
|
"--eval-interval",
|
||||||
|
type=int,
|
||||||
|
default=100,
|
||||||
|
help="Interval for evaluation. Evaluate every ei consumer steps.",
|
||||||
|
)
|
||||||
|
|
||||||
# Logging/Checkpointing parameters
|
# Logging/Checkpointing parameters
|
||||||
parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.")
|
parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.")
|
||||||
@ -125,8 +131,8 @@ if __name__ == "__main__":
|
|||||||
and args.train_microbatch_size > 0
|
and args.train_microbatch_size > 0
|
||||||
), "Train micro batch size must be greater than 0 less than train mini batch size * num generations"
|
), "Train micro batch size must be greater than 0 less than train mini batch size * num generations"
|
||||||
assert (
|
assert (
|
||||||
args.train_minibatch_size <= args.train_batch_size
|
args.train_minibatch_size <= args.train_batch_size and args.train_batch_size % args.train_minibatch_size == 0
|
||||||
), "Train mini batch size must be less than or equals to train batch size"
|
), "Train mini batch size must be less than or equals to train batch size and train batch size must be divisible by train mini batch size"
|
||||||
|
|
||||||
if args.master_address is None:
|
if args.master_address is None:
|
||||||
# Default settings: Using single machine
|
# Default settings: Using single machine
|
||||||
|
Loading…
Reference in New Issue
Block a user