mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-05 10:10:32 +00:00
fix bugs
This commit is contained in:
parent
ff6696a9bb
commit
c2561f826a
@ -107,6 +107,37 @@ class BaseConsumer:
|
|||||||
def step(self, step_idx: int, **kwargs) -> Optional[float]:
|
def step(self, step_idx: int, **kwargs) -> Optional[float]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def prepare_mini_batch(self, effective_group_to_raw_group_mapping: Dict[int, int]) -> Dict[str, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Prepare a mini-batch from the effective group to raw group mapping.
|
||||||
|
This method is used to create a mini-batch for training.
|
||||||
|
"""
|
||||||
|
batches = [
|
||||||
|
self.buffer[effective_group_to_raw_group_mapping[i]]
|
||||||
|
for i in range(self.dp_rank * self.minibatch_size, (self.dp_rank + 1) * self.minibatch_size)
|
||||||
|
]
|
||||||
|
# every dp_rank will receive a complete mini-batch, no need to sync within step() later
|
||||||
|
# each mini-batch use the first self.dp_size * minibatch_size effective samples
|
||||||
|
raw_mini_batches = self.buffer[
|
||||||
|
: effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1
|
||||||
|
] # include the last effective sample
|
||||||
|
raw_mini_batches_metric_dict = {
|
||||||
|
"raw_train_mini_batch_reward": [t[1] for t in raw_mini_batches],
|
||||||
|
"raw_train_mini_batch_format_acc": [t[2] for t in raw_mini_batches],
|
||||||
|
"raw_train_mini_batch_ans_acc": [t[3] for t in raw_mini_batches],
|
||||||
|
"raw_train_mini_batch_response_len": [t[4] for t in raw_mini_batches],
|
||||||
|
}
|
||||||
|
batch = bind_batch([t[0] for t in batches])
|
||||||
|
batch = post_recv(batch)
|
||||||
|
return batch, raw_mini_batches_metric_dict
|
||||||
|
|
||||||
|
def calculate_effective_group_to_raw_group_mapping(self):
|
||||||
|
effective_group_to_raw_group_mapping = {}
|
||||||
|
for buffer_idx in range(len(self.buffer)):
|
||||||
|
if self.buffer[buffer_idx][0] is not None:
|
||||||
|
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = buffer_idx
|
||||||
|
return effective_group_to_raw_group_mapping
|
||||||
|
|
||||||
def loop(self) -> None:
|
def loop(self) -> None:
|
||||||
print(
|
print(
|
||||||
f"Consumer{self.rank} num_update: {self.num_update_per_episode}, num_recv: {self.num_recv_per_update}, nmb: {self.num_microbatches}"
|
f"Consumer{self.rank} num_update: {self.num_update_per_episode}, num_recv: {self.num_recv_per_update}, nmb: {self.num_microbatches}"
|
||||||
@ -121,6 +152,38 @@ class BaseConsumer:
|
|||||||
torch.cuda.reset_peak_memory_stats()
|
torch.cuda.reset_peak_memory_stats()
|
||||||
i = 0
|
i = 0
|
||||||
for _ in range(self.num_recv_per_update):
|
for _ in range(self.num_recv_per_update):
|
||||||
|
# after sync model, do not wait for more data to arrive as rollout takes time, use buffered data
|
||||||
|
effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping()
|
||||||
|
while len(effective_group_to_raw_group_mapping) > max(
|
||||||
|
self.dp_size * self.batch_size
|
||||||
|
- self.dp_size
|
||||||
|
* self.minibatch_size
|
||||||
|
* self.grpo_config.get("num_minibatch_during_rollout", 1),
|
||||||
|
self.dp_size * self.minibatch_size,
|
||||||
|
):
|
||||||
|
self.profiler.log(
|
||||||
|
f"Still have {len(effective_group_to_raw_group_mapping)} effective groups, greater than {self.dp_size * self.minibatch_size}, start training"
|
||||||
|
)
|
||||||
|
batch, raw_mini_batches_metric_dict = self.prepare_mini_batch(
|
||||||
|
effective_group_to_raw_group_mapping
|
||||||
|
)
|
||||||
|
self.profiler.enter("step")
|
||||||
|
loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict)
|
||||||
|
self.profiler.exit("step")
|
||||||
|
self.buffer = self.buffer[
|
||||||
|
effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 :
|
||||||
|
]
|
||||||
|
# recalculate the effective group to raw group mapping
|
||||||
|
effective_group_to_raw_group_mapping_size_before = len(effective_group_to_raw_group_mapping)
|
||||||
|
effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping()
|
||||||
|
assert (
|
||||||
|
len(effective_group_to_raw_group_mapping)
|
||||||
|
== effective_group_to_raw_group_mapping_size_before - self.dp_size * self.minibatch_size
|
||||||
|
)
|
||||||
|
if loss is not None:
|
||||||
|
pbar.set_postfix({"loss": loss})
|
||||||
|
i += 1
|
||||||
|
|
||||||
# receive data from producers
|
# receive data from producers
|
||||||
for r in range(self.num_producers):
|
for r in range(self.num_producers):
|
||||||
print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}")
|
print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}")
|
||||||
@ -170,37 +233,20 @@ class BaseConsumer:
|
|||||||
f"[T{dist.get_rank()}] Filter recv data: {len(raw_batch)} -> {torch.sum(effective_group_mask).cpu().item()} effective groups"
|
f"[T{dist.get_rank()}] Filter recv data: {len(raw_batch)} -> {torch.sum(effective_group_mask).cpu().item()} effective groups"
|
||||||
)
|
)
|
||||||
# mapping the effective group to the raw group for indexing
|
# mapping the effective group to the raw group for indexing
|
||||||
effective_group_to_raw_group_mapping = {}
|
effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping()
|
||||||
for buffer_idx in range(len(self.buffer)):
|
|
||||||
if self.buffer[buffer_idx][0] is not None:
|
|
||||||
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = (
|
|
||||||
buffer_idx
|
|
||||||
)
|
|
||||||
print(
|
print(
|
||||||
f"[T{dist.get_rank()}] Collect Effective Prompt: {len(effective_group_to_raw_group_mapping)}/{self.dp_size * self.minibatch_size}"
|
f"[T{dist.get_rank()}] Collect Effective Prompt: {len(effective_group_to_raw_group_mapping)}/{self.dp_size * self.minibatch_size}"
|
||||||
)
|
)
|
||||||
|
|
||||||
while len(effective_group_to_raw_group_mapping) >= self.dp_size * self.minibatch_size:
|
while len(effective_group_to_raw_group_mapping) > self.dp_size * self.batch_size:
|
||||||
# on each dp_rank, we use minibatch_size effective samples to form a batch
|
self.profiler.log(
|
||||||
batches = [
|
f"Received {len(effective_group_to_raw_group_mapping)} effective groups, greater than {self.dp_size * self.batch_size}, start training after recv"
|
||||||
self.buffer[effective_group_to_raw_group_mapping[i]]
|
)
|
||||||
for i in range(
|
# always keep at least dp_size * batch_size effective samples in the buffer for training during the rollout times after each sync model
|
||||||
self.dp_rank * self.minibatch_size, (self.dp_rank + 1) * self.minibatch_size
|
# on each dp_rank, we use minibatch_size effective samples to form a batch
|
||||||
|
batch, raw_mini_batches_metric_dict = self.prepare_mini_batch(
|
||||||
|
effective_group_to_raw_group_mapping
|
||||||
)
|
)
|
||||||
]
|
|
||||||
# every dp_rank will receive a complete mini-batch, no need to sync within step() later
|
|
||||||
# each mini-batch use the first self.dp_size * minibatch_size effective samples
|
|
||||||
raw_mini_batches = self.buffer[
|
|
||||||
: effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1
|
|
||||||
] # include the last effective sample
|
|
||||||
raw_mini_batches_metric_dict = {
|
|
||||||
"raw_train_mini_batch_reward": [t[1] for t in raw_mini_batches],
|
|
||||||
"raw_train_mini_batch_format_acc": [t[2] for t in raw_mini_batches],
|
|
||||||
"raw_train_mini_batch_ans_acc": [t[3] for t in raw_mini_batches],
|
|
||||||
"raw_train_mini_batch_response_len": [t[4] for t in raw_mini_batches],
|
|
||||||
}
|
|
||||||
batch = bind_batch([t[0] for t in batches])
|
|
||||||
batch = post_recv(batch)
|
|
||||||
self.profiler.enter("step")
|
self.profiler.enter("step")
|
||||||
loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict)
|
loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict)
|
||||||
self.profiler.exit("step")
|
self.profiler.exit("step")
|
||||||
@ -209,12 +255,7 @@ class BaseConsumer:
|
|||||||
]
|
]
|
||||||
# recalculate the effective group to raw group mapping
|
# recalculate the effective group to raw group mapping
|
||||||
effective_group_to_raw_group_mapping_size_before = len(effective_group_to_raw_group_mapping)
|
effective_group_to_raw_group_mapping_size_before = len(effective_group_to_raw_group_mapping)
|
||||||
effective_group_to_raw_group_mapping = {}
|
effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping()
|
||||||
for buffer_idx in range(len(self.buffer)):
|
|
||||||
if self.buffer[buffer_idx][0] is not None:
|
|
||||||
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = (
|
|
||||||
buffer_idx
|
|
||||||
)
|
|
||||||
assert (
|
assert (
|
||||||
len(effective_group_to_raw_group_mapping)
|
len(effective_group_to_raw_group_mapping)
|
||||||
== effective_group_to_raw_group_mapping_size_before - self.dp_size * self.minibatch_size
|
== effective_group_to_raw_group_mapping_size_before - self.dp_size * self.minibatch_size
|
||||||
|
@ -379,7 +379,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
reference_model_logits / self.generate_config["temperature"],
|
reference_model_logits / self.generate_config["temperature"],
|
||||||
input_ids_forward_micro_batch,
|
input_ids_forward_micro_batch,
|
||||||
num_action,
|
num_action,
|
||||||
self.plugin.shard_config,
|
shard_config=self.plugin.shard_config,
|
||||||
)
|
)
|
||||||
per_token_kl = (
|
per_token_kl = (
|
||||||
torch.exp(reference_action_log_probs - action_log_probs)
|
torch.exp(reference_action_log_probs - action_log_probs)
|
||||||
|
@ -1,3 +1,7 @@
|
|||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
class CustomProfiler:
|
class CustomProfiler:
|
||||||
def __init__(self, name, disabled=True):
|
def __init__(self, name, disabled=True):
|
||||||
self.disabled = disabled
|
self.disabled = disabled
|
||||||
|
@ -1,7 +1,13 @@
|
|||||||
export VLLM_ALLOW_LONG_MAX_MODEL_LEN=1
|
export VLLM_ALLOW_LONG_MAX_MODEL_LEN=1
|
||||||
|
|
||||||
# 8K context length
|
# 8K context length
|
||||||
|
# rm -rf *.prof
|
||||||
|
# MAX_NEW_TOKENS=$((8192-512))
|
||||||
|
# python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 4 -b vllm -a GRPO -ibs 16 -tbs 16 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tmbs 4 -p GRPO-Math-Profile -ei -5 -zero 1 -pp 2 -mnt $MAX_NEW_TOKENS -nb 1 --enable_profiling 2>&1| tee ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_8192_GRPO_profile.txt
|
||||||
|
# python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_8192_GRPO_profile.png
|
||||||
|
|
||||||
|
# 4K context length
|
||||||
rm -rf *.prof
|
rm -rf *.prof
|
||||||
MAX_NEW_TOKENS=$((8192-512))
|
MAX_NEW_TOKENS=$((4096-512))
|
||||||
python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 4 -b vllm -a GRPO -ibs 16 -tbs 16 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tmbs 4 -p GRPO-Math-Profile -ei -5 -zero 1 -pp 2 -mnt $MAX_NEW_TOKENS -nb 1 --enable_profiling 2>&1| tee ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_8192_GRPO_profile.txt
|
python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 4 -b vllm -a GRPO -ibs 8 -tbs 8 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tMbs 4 -tmbs 4 -p GRPO-Math-Profile -ei -5 -zero 2 -mnt $MAX_NEW_TOKENS -nb 1 --enable_profiling 2>&1| tee ibs_32_tbs_32_tmbs_4_zero_2_4096_GRPO_profile.txt
|
||||||
python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_8192_GRPO_profile.png
|
python visualization.py --visualization actor_timelines_ibs_32_tbs_32_tmbs_4_zero_2_4096_GRPO_profile.png
|
||||||
|
@ -263,6 +263,7 @@ if __name__ == "__main__":
|
|||||||
grpo_config = {
|
grpo_config = {
|
||||||
"lr": args.learning_rate,
|
"lr": args.learning_rate,
|
||||||
"train_microbatch_size": args.train_microbatch_size,
|
"train_microbatch_size": args.train_microbatch_size,
|
||||||
|
"num_minibatch_during_rollout": 1, # number of mini batches to pop out from buffer and used for training during rollout of the producer after it syncs the model. Hint, set to a proper value close to the number of mini batches for training that takes roughly the same time as the rollout of the producer. A value that is too large or too small will cause bubble time on the trainer or the producer.
|
||||||
"beta": args.kl_coeff, # KL penalty coefficient
|
"beta": args.kl_coeff, # KL penalty coefficient
|
||||||
"loss_variation": "sample_level",
|
"loss_variation": "sample_level",
|
||||||
"reward_fn_type": args.reward_type,
|
"reward_fn_type": args.reward_type,
|
||||||
|
@ -74,8 +74,8 @@ for idx, (actor, func_dict) in enumerate(actors.items()):
|
|||||||
yticks.append(y_val)
|
yticks.append(y_val)
|
||||||
yticklabels.append(f"{actor}:{func}")
|
yticklabels.append(f"{actor}:{func}")
|
||||||
for start, end in intervals:
|
for start, end in intervals:
|
||||||
if end - start < 100:
|
if end - start < 6:
|
||||||
end = start + 100 # Ensure minimum length of 100ms
|
end = start + 6 # Ensure minimum length of 100ms
|
||||||
ax.plot(
|
ax.plot(
|
||||||
[start, end],
|
[start, end],
|
||||||
[y_val, y_val],
|
[y_val, y_val],
|
||||||
|
Loading…
Reference in New Issue
Block a user