This commit is contained in:
YeAnbang
2025-06-20 15:44:13 +08:00
parent ff6696a9bb
commit c2561f826a
6 changed files with 90 additions and 38 deletions

View File

@@ -107,6 +107,37 @@ class BaseConsumer:
def step(self, step_idx: int, **kwargs) -> Optional[float]:
raise NotImplementedError
def prepare_mini_batch(self, effective_group_to_raw_group_mapping: Dict[int, int]) -> Dict[str, torch.Tensor]:
"""
Prepare a mini-batch from the effective group to raw group mapping.
This method is used to create a mini-batch for training.
"""
batches = [
self.buffer[effective_group_to_raw_group_mapping[i]]
for i in range(self.dp_rank * self.minibatch_size, (self.dp_rank + 1) * self.minibatch_size)
]
# every dp_rank will receive a complete mini-batch, no need to sync within step() later
# each mini-batch use the first self.dp_size * minibatch_size effective samples
raw_mini_batches = self.buffer[
: effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1
] # include the last effective sample
raw_mini_batches_metric_dict = {
"raw_train_mini_batch_reward": [t[1] for t in raw_mini_batches],
"raw_train_mini_batch_format_acc": [t[2] for t in raw_mini_batches],
"raw_train_mini_batch_ans_acc": [t[3] for t in raw_mini_batches],
"raw_train_mini_batch_response_len": [t[4] for t in raw_mini_batches],
}
batch = bind_batch([t[0] for t in batches])
batch = post_recv(batch)
return batch, raw_mini_batches_metric_dict
def calculate_effective_group_to_raw_group_mapping(self):
effective_group_to_raw_group_mapping = {}
for buffer_idx in range(len(self.buffer)):
if self.buffer[buffer_idx][0] is not None:
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = buffer_idx
return effective_group_to_raw_group_mapping
def loop(self) -> None:
print(
f"Consumer{self.rank} num_update: {self.num_update_per_episode}, num_recv: {self.num_recv_per_update}, nmb: {self.num_microbatches}"
@@ -121,6 +152,38 @@ class BaseConsumer:
torch.cuda.reset_peak_memory_stats()
i = 0
for _ in range(self.num_recv_per_update):
# after sync model, do not wait for more data to arrive as rollout takes time, use buffered data
effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping()
while len(effective_group_to_raw_group_mapping) > max(
self.dp_size * self.batch_size
- self.dp_size
* self.minibatch_size
* self.grpo_config.get("num_minibatch_during_rollout", 1),
self.dp_size * self.minibatch_size,
):
self.profiler.log(
f"Still have {len(effective_group_to_raw_group_mapping)} effective groups, greater than {self.dp_size * self.minibatch_size}, start training"
)
batch, raw_mini_batches_metric_dict = self.prepare_mini_batch(
effective_group_to_raw_group_mapping
)
self.profiler.enter("step")
loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict)
self.profiler.exit("step")
self.buffer = self.buffer[
effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 :
]
# recalculate the effective group to raw group mapping
effective_group_to_raw_group_mapping_size_before = len(effective_group_to_raw_group_mapping)
effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping()
assert (
len(effective_group_to_raw_group_mapping)
== effective_group_to_raw_group_mapping_size_before - self.dp_size * self.minibatch_size
)
if loss is not None:
pbar.set_postfix({"loss": loss})
i += 1
# receive data from producers
for r in range(self.num_producers):
print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}")
@@ -170,37 +233,20 @@ class BaseConsumer:
f"[T{dist.get_rank()}] Filter recv data: {len(raw_batch)} -> {torch.sum(effective_group_mask).cpu().item()} effective groups"
)
# mapping the effective group to the raw group for indexing
effective_group_to_raw_group_mapping = {}
for buffer_idx in range(len(self.buffer)):
if self.buffer[buffer_idx][0] is not None:
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = (
buffer_idx
)
effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping()
print(
f"[T{dist.get_rank()}] Collect Effective Prompt: {len(effective_group_to_raw_group_mapping)}/{self.dp_size * self.minibatch_size}"
)
while len(effective_group_to_raw_group_mapping) >= self.dp_size * self.minibatch_size:
while len(effective_group_to_raw_group_mapping) > self.dp_size * self.batch_size:
self.profiler.log(
f"Received {len(effective_group_to_raw_group_mapping)} effective groups, greater than {self.dp_size * self.batch_size}, start training after recv"
)
# always keep at least dp_size * batch_size effective samples in the buffer for training during the rollout times after each sync model
# on each dp_rank, we use minibatch_size effective samples to form a batch
batches = [
self.buffer[effective_group_to_raw_group_mapping[i]]
for i in range(
self.dp_rank * self.minibatch_size, (self.dp_rank + 1) * self.minibatch_size
)
]
# every dp_rank will receive a complete mini-batch, no need to sync within step() later
# each mini-batch use the first self.dp_size * minibatch_size effective samples
raw_mini_batches = self.buffer[
: effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1
] # include the last effective sample
raw_mini_batches_metric_dict = {
"raw_train_mini_batch_reward": [t[1] for t in raw_mini_batches],
"raw_train_mini_batch_format_acc": [t[2] for t in raw_mini_batches],
"raw_train_mini_batch_ans_acc": [t[3] for t in raw_mini_batches],
"raw_train_mini_batch_response_len": [t[4] for t in raw_mini_batches],
}
batch = bind_batch([t[0] for t in batches])
batch = post_recv(batch)
batch, raw_mini_batches_metric_dict = self.prepare_mini_batch(
effective_group_to_raw_group_mapping
)
self.profiler.enter("step")
loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict)
self.profiler.exit("step")
@@ -209,12 +255,7 @@ class BaseConsumer:
]
# recalculate the effective group to raw group mapping
effective_group_to_raw_group_mapping_size_before = len(effective_group_to_raw_group_mapping)
effective_group_to_raw_group_mapping = {}
for buffer_idx in range(len(self.buffer)):
if self.buffer[buffer_idx][0] is not None:
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = (
buffer_idx
)
effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping()
assert (
len(effective_group_to_raw_group_mapping)
== effective_group_to_raw_group_mapping_size_before - self.dp_size * self.minibatch_size

View File

@@ -379,7 +379,7 @@ class GRPOConsumer(BaseConsumer):
reference_model_logits / self.generate_config["temperature"],
input_ids_forward_micro_batch,
num_action,
self.plugin.shard_config,
shard_config=self.plugin.shard_config,
)
per_token_kl = (
torch.exp(reference_action_log_probs - action_log_probs)

View File

@@ -1,3 +1,7 @@
import os
import time
class CustomProfiler:
def __init__(self, name, disabled=True):
self.disabled = disabled