mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-17 07:00:37 +00:00
fix bugs
This commit is contained in:
@@ -107,6 +107,37 @@ class BaseConsumer:
|
||||
def step(self, step_idx: int, **kwargs) -> Optional[float]:
|
||||
raise NotImplementedError
|
||||
|
||||
def prepare_mini_batch(self, effective_group_to_raw_group_mapping: Dict[int, int]) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Prepare a mini-batch from the effective group to raw group mapping.
|
||||
This method is used to create a mini-batch for training.
|
||||
"""
|
||||
batches = [
|
||||
self.buffer[effective_group_to_raw_group_mapping[i]]
|
||||
for i in range(self.dp_rank * self.minibatch_size, (self.dp_rank + 1) * self.minibatch_size)
|
||||
]
|
||||
# every dp_rank will receive a complete mini-batch, no need to sync within step() later
|
||||
# each mini-batch use the first self.dp_size * minibatch_size effective samples
|
||||
raw_mini_batches = self.buffer[
|
||||
: effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1
|
||||
] # include the last effective sample
|
||||
raw_mini_batches_metric_dict = {
|
||||
"raw_train_mini_batch_reward": [t[1] for t in raw_mini_batches],
|
||||
"raw_train_mini_batch_format_acc": [t[2] for t in raw_mini_batches],
|
||||
"raw_train_mini_batch_ans_acc": [t[3] for t in raw_mini_batches],
|
||||
"raw_train_mini_batch_response_len": [t[4] for t in raw_mini_batches],
|
||||
}
|
||||
batch = bind_batch([t[0] for t in batches])
|
||||
batch = post_recv(batch)
|
||||
return batch, raw_mini_batches_metric_dict
|
||||
|
||||
def calculate_effective_group_to_raw_group_mapping(self):
|
||||
effective_group_to_raw_group_mapping = {}
|
||||
for buffer_idx in range(len(self.buffer)):
|
||||
if self.buffer[buffer_idx][0] is not None:
|
||||
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = buffer_idx
|
||||
return effective_group_to_raw_group_mapping
|
||||
|
||||
def loop(self) -> None:
|
||||
print(
|
||||
f"Consumer{self.rank} num_update: {self.num_update_per_episode}, num_recv: {self.num_recv_per_update}, nmb: {self.num_microbatches}"
|
||||
@@ -121,6 +152,38 @@ class BaseConsumer:
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
i = 0
|
||||
for _ in range(self.num_recv_per_update):
|
||||
# after sync model, do not wait for more data to arrive as rollout takes time, use buffered data
|
||||
effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping()
|
||||
while len(effective_group_to_raw_group_mapping) > max(
|
||||
self.dp_size * self.batch_size
|
||||
- self.dp_size
|
||||
* self.minibatch_size
|
||||
* self.grpo_config.get("num_minibatch_during_rollout", 1),
|
||||
self.dp_size * self.minibatch_size,
|
||||
):
|
||||
self.profiler.log(
|
||||
f"Still have {len(effective_group_to_raw_group_mapping)} effective groups, greater than {self.dp_size * self.minibatch_size}, start training"
|
||||
)
|
||||
batch, raw_mini_batches_metric_dict = self.prepare_mini_batch(
|
||||
effective_group_to_raw_group_mapping
|
||||
)
|
||||
self.profiler.enter("step")
|
||||
loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict)
|
||||
self.profiler.exit("step")
|
||||
self.buffer = self.buffer[
|
||||
effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 :
|
||||
]
|
||||
# recalculate the effective group to raw group mapping
|
||||
effective_group_to_raw_group_mapping_size_before = len(effective_group_to_raw_group_mapping)
|
||||
effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping()
|
||||
assert (
|
||||
len(effective_group_to_raw_group_mapping)
|
||||
== effective_group_to_raw_group_mapping_size_before - self.dp_size * self.minibatch_size
|
||||
)
|
||||
if loss is not None:
|
||||
pbar.set_postfix({"loss": loss})
|
||||
i += 1
|
||||
|
||||
# receive data from producers
|
||||
for r in range(self.num_producers):
|
||||
print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}")
|
||||
@@ -170,37 +233,20 @@ class BaseConsumer:
|
||||
f"[T{dist.get_rank()}] Filter recv data: {len(raw_batch)} -> {torch.sum(effective_group_mask).cpu().item()} effective groups"
|
||||
)
|
||||
# mapping the effective group to the raw group for indexing
|
||||
effective_group_to_raw_group_mapping = {}
|
||||
for buffer_idx in range(len(self.buffer)):
|
||||
if self.buffer[buffer_idx][0] is not None:
|
||||
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = (
|
||||
buffer_idx
|
||||
)
|
||||
effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping()
|
||||
print(
|
||||
f"[T{dist.get_rank()}] Collect Effective Prompt: {len(effective_group_to_raw_group_mapping)}/{self.dp_size * self.minibatch_size}"
|
||||
)
|
||||
|
||||
while len(effective_group_to_raw_group_mapping) >= self.dp_size * self.minibatch_size:
|
||||
while len(effective_group_to_raw_group_mapping) > self.dp_size * self.batch_size:
|
||||
self.profiler.log(
|
||||
f"Received {len(effective_group_to_raw_group_mapping)} effective groups, greater than {self.dp_size * self.batch_size}, start training after recv"
|
||||
)
|
||||
# always keep at least dp_size * batch_size effective samples in the buffer for training during the rollout times after each sync model
|
||||
# on each dp_rank, we use minibatch_size effective samples to form a batch
|
||||
batches = [
|
||||
self.buffer[effective_group_to_raw_group_mapping[i]]
|
||||
for i in range(
|
||||
self.dp_rank * self.minibatch_size, (self.dp_rank + 1) * self.minibatch_size
|
||||
)
|
||||
]
|
||||
# every dp_rank will receive a complete mini-batch, no need to sync within step() later
|
||||
# each mini-batch use the first self.dp_size * minibatch_size effective samples
|
||||
raw_mini_batches = self.buffer[
|
||||
: effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1
|
||||
] # include the last effective sample
|
||||
raw_mini_batches_metric_dict = {
|
||||
"raw_train_mini_batch_reward": [t[1] for t in raw_mini_batches],
|
||||
"raw_train_mini_batch_format_acc": [t[2] for t in raw_mini_batches],
|
||||
"raw_train_mini_batch_ans_acc": [t[3] for t in raw_mini_batches],
|
||||
"raw_train_mini_batch_response_len": [t[4] for t in raw_mini_batches],
|
||||
}
|
||||
batch = bind_batch([t[0] for t in batches])
|
||||
batch = post_recv(batch)
|
||||
batch, raw_mini_batches_metric_dict = self.prepare_mini_batch(
|
||||
effective_group_to_raw_group_mapping
|
||||
)
|
||||
self.profiler.enter("step")
|
||||
loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict)
|
||||
self.profiler.exit("step")
|
||||
@@ -209,12 +255,7 @@ class BaseConsumer:
|
||||
]
|
||||
# recalculate the effective group to raw group mapping
|
||||
effective_group_to_raw_group_mapping_size_before = len(effective_group_to_raw_group_mapping)
|
||||
effective_group_to_raw_group_mapping = {}
|
||||
for buffer_idx in range(len(self.buffer)):
|
||||
if self.buffer[buffer_idx][0] is not None:
|
||||
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = (
|
||||
buffer_idx
|
||||
)
|
||||
effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping()
|
||||
assert (
|
||||
len(effective_group_to_raw_group_mapping)
|
||||
== effective_group_to_raw_group_mapping_size_before - self.dp_size * self.minibatch_size
|
||||
|
@@ -379,7 +379,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
reference_model_logits / self.generate_config["temperature"],
|
||||
input_ids_forward_micro_batch,
|
||||
num_action,
|
||||
self.plugin.shard_config,
|
||||
shard_config=self.plugin.shard_config,
|
||||
)
|
||||
per_token_kl = (
|
||||
torch.exp(reference_action_log_probs - action_log_probs)
|
||||
|
@@ -1,3 +1,7 @@
|
||||
import os
|
||||
import time
|
||||
|
||||
|
||||
class CustomProfiler:
|
||||
def __init__(self, name, disabled=True):
|
||||
self.disabled = disabled
|
||||
|
Reference in New Issue
Block a user