mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-23 19:49:30 +00:00
fix behind
This commit is contained in:
parent
db8baeeaf2
commit
8abf186ce2
@ -131,10 +131,10 @@ class BaseConsumer:
|
|||||||
batch = post_recv(batch)
|
batch = post_recv(batch)
|
||||||
return batch, raw_mini_batches_metric_dict
|
return batch, raw_mini_batches_metric_dict
|
||||||
|
|
||||||
def calculate_effective_group_to_raw_group_mapping(self):
|
def calculate_effective_group_to_raw_group_mapping(self, step):
|
||||||
effective_group_to_raw_group_mapping = {}
|
effective_group_to_raw_group_mapping = {}
|
||||||
for buffer_idx in range(len(self.buffer)):
|
for buffer_idx in range(len(self.buffer)):
|
||||||
if self.buffer[buffer_idx][0] is not None:
|
if self.buffer[buffer_idx][0] is not None and (self.buffer[buffer_idx][-1] <= step - self.n_behind):
|
||||||
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = buffer_idx
|
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = buffer_idx
|
||||||
return effective_group_to_raw_group_mapping
|
return effective_group_to_raw_group_mapping
|
||||||
|
|
||||||
@ -152,37 +152,40 @@ class BaseConsumer:
|
|||||||
torch.cuda.reset_peak_memory_stats()
|
torch.cuda.reset_peak_memory_stats()
|
||||||
i = 0
|
i = 0
|
||||||
for _ in range(self.num_recv_per_update):
|
for _ in range(self.num_recv_per_update):
|
||||||
# after sync model, do not wait for more data to arrive as rollout takes time, use buffered data
|
|
||||||
effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping()
|
if self.n_behind > 0:
|
||||||
while len(effective_group_to_raw_group_mapping) > max(
|
# after sync model, do not wait for more data to arrive as rollout takes time, use buffered data
|
||||||
self.dp_size * self.batch_size
|
effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping(
|
||||||
- self.dp_size
|
step=step
|
||||||
* self.minibatch_size
|
|
||||||
* self.grpo_config.get("num_minibatch_during_rollout", 1),
|
|
||||||
self.dp_size * self.minibatch_size,
|
|
||||||
):
|
|
||||||
self.profiler.log(
|
|
||||||
f"Still have {len(effective_group_to_raw_group_mapping)} effective groups, greater than {self.dp_size * self.minibatch_size}, start training"
|
|
||||||
)
|
)
|
||||||
batch, raw_mini_batches_metric_dict = self.prepare_mini_batch(
|
while len(effective_group_to_raw_group_mapping) >= self.dp_size * self.minibatch_size:
|
||||||
effective_group_to_raw_group_mapping
|
self.profiler.log(
|
||||||
)
|
f"Still have {len(effective_group_to_raw_group_mapping)} effective groups, greater than {self.dp_size * self.minibatch_size}, start training"
|
||||||
self.profiler.enter("step")
|
)
|
||||||
loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict)
|
batch, raw_mini_batches_metric_dict = self.prepare_mini_batch(
|
||||||
self.profiler.exit("step")
|
effective_group_to_raw_group_mapping
|
||||||
self.buffer = self.buffer[
|
)
|
||||||
effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 :
|
self.profiler.enter("step")
|
||||||
]
|
loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict)
|
||||||
# recalculate the effective group to raw group mapping
|
self.profiler.exit("step")
|
||||||
effective_group_to_raw_group_mapping_size_before = len(effective_group_to_raw_group_mapping)
|
self.buffer = self.buffer[
|
||||||
effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping()
|
effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 :
|
||||||
assert (
|
]
|
||||||
len(effective_group_to_raw_group_mapping)
|
# recalculate the effective group to raw group mapping
|
||||||
== effective_group_to_raw_group_mapping_size_before - self.dp_size * self.minibatch_size
|
effective_group_to_raw_group_mapping_size_before = len(
|
||||||
)
|
effective_group_to_raw_group_mapping
|
||||||
if loss is not None:
|
)
|
||||||
pbar.set_postfix({"loss": loss})
|
effective_group_to_raw_group_mapping = (
|
||||||
i += 1
|
self.calculate_effective_group_to_raw_group_mapping(step=step)
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
len(effective_group_to_raw_group_mapping)
|
||||||
|
== effective_group_to_raw_group_mapping_size_before
|
||||||
|
- self.dp_size * self.minibatch_size
|
||||||
|
)
|
||||||
|
if loss is not None:
|
||||||
|
pbar.set_postfix({"loss": loss})
|
||||||
|
i += 1
|
||||||
|
|
||||||
# receive data from producers
|
# receive data from producers
|
||||||
for r in range(self.num_producers):
|
for r in range(self.num_producers):
|
||||||
@ -226,6 +229,7 @@ class BaseConsumer:
|
|||||||
format_acc[group_idx],
|
format_acc[group_idx],
|
||||||
ans_acc[group_idx],
|
ans_acc[group_idx],
|
||||||
response_len[group_idx],
|
response_len[group_idx],
|
||||||
|
step,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
if effective_group_mask is not None:
|
if effective_group_mask is not None:
|
||||||
@ -233,17 +237,22 @@ class BaseConsumer:
|
|||||||
f"[T{dist.get_rank()}] Filter recv data: {len(raw_batch)} -> {torch.sum(effective_group_mask).cpu().item()} effective groups"
|
f"[T{dist.get_rank()}] Filter recv data: {len(raw_batch)} -> {torch.sum(effective_group_mask).cpu().item()} effective groups"
|
||||||
)
|
)
|
||||||
# mapping the effective group to the raw group for indexing
|
# mapping the effective group to the raw group for indexing
|
||||||
effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping()
|
effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping(
|
||||||
|
step=step
|
||||||
|
)
|
||||||
print(
|
print(
|
||||||
f"[T{dist.get_rank()}] Collect Effective Prompt: {len(effective_group_to_raw_group_mapping)}/{self.dp_size * self.minibatch_size}"
|
f"[T{dist.get_rank()}] Collect Effective Prompt: {len(effective_group_to_raw_group_mapping)}/{self.dp_size * self.minibatch_size}"
|
||||||
)
|
)
|
||||||
|
|
||||||
while len(effective_group_to_raw_group_mapping) > self.dp_size * self.batch_size:
|
if self.n_behind == 0:
|
||||||
|
# If n_behind is 0, we start training after receiving data from producers.
|
||||||
|
effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping(
|
||||||
|
step=step
|
||||||
|
)
|
||||||
|
while len(effective_group_to_raw_group_mapping) >= self.dp_size * self.minibatch_size:
|
||||||
self.profiler.log(
|
self.profiler.log(
|
||||||
f"Received {len(effective_group_to_raw_group_mapping)} effective groups, greater than {self.dp_size * self.batch_size}, start training after recv"
|
f"Still have {len(effective_group_to_raw_group_mapping)} effective groups, greater than {self.dp_size * self.minibatch_size}, start training"
|
||||||
)
|
)
|
||||||
# always keep at least dp_size * batch_size effective samples in the buffer for training during the rollout times after each sync model
|
|
||||||
# on each dp_rank, we use minibatch_size effective samples to form a batch
|
|
||||||
batch, raw_mini_batches_metric_dict = self.prepare_mini_batch(
|
batch, raw_mini_batches_metric_dict = self.prepare_mini_batch(
|
||||||
effective_group_to_raw_group_mapping
|
effective_group_to_raw_group_mapping
|
||||||
)
|
)
|
||||||
@ -255,7 +264,9 @@ class BaseConsumer:
|
|||||||
]
|
]
|
||||||
# recalculate the effective group to raw group mapping
|
# recalculate the effective group to raw group mapping
|
||||||
effective_group_to_raw_group_mapping_size_before = len(effective_group_to_raw_group_mapping)
|
effective_group_to_raw_group_mapping_size_before = len(effective_group_to_raw_group_mapping)
|
||||||
effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping()
|
effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping(
|
||||||
|
step=step
|
||||||
|
)
|
||||||
assert (
|
assert (
|
||||||
len(effective_group_to_raw_group_mapping)
|
len(effective_group_to_raw_group_mapping)
|
||||||
== effective_group_to_raw_group_mapping_size_before - self.dp_size * self.minibatch_size
|
== effective_group_to_raw_group_mapping_size_before - self.dp_size * self.minibatch_size
|
||||||
@ -263,6 +274,7 @@ class BaseConsumer:
|
|||||||
if loss is not None:
|
if loss is not None:
|
||||||
pbar.set_postfix({"loss": loss})
|
pbar.set_postfix({"loss": loss})
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
if self.lr_scheduler is not None:
|
if self.lr_scheduler is not None:
|
||||||
self.lr_scheduler.step()
|
self.lr_scheduler.step()
|
||||||
if (step + 1) % self.save_interval == 0 or (step + 1) == self.num_update_per_episode:
|
if (step + 1) % self.save_interval == 0 or (step + 1) == self.num_update_per_episode:
|
||||||
|
@ -337,9 +337,8 @@ class BaseProducer:
|
|||||||
):
|
):
|
||||||
self.model.llm.sleep() # revict KV_cache to avoid OOM
|
self.model.llm.sleep() # revict KV_cache to avoid OOM
|
||||||
# don't sync model for last iteration
|
# don't sync model for last iteration
|
||||||
self.profiler.enter("sync_model")
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
self.profiler.enter("sync_model")
|
||||||
if self.consumer_pp_size > 1:
|
if self.consumer_pp_size > 1:
|
||||||
for pp_idx in range(self.consumer_pp_size):
|
for pp_idx in range(self.consumer_pp_size):
|
||||||
print(
|
print(
|
||||||
@ -361,9 +360,9 @@ class BaseProducer:
|
|||||||
if "consumer_global_step" in state_dict:
|
if "consumer_global_step" in state_dict:
|
||||||
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
|
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
|
||||||
self.load_state_dict(state_dict)
|
self.load_state_dict(state_dict)
|
||||||
|
self.profiler.exit("sync_model")
|
||||||
del state_dict
|
del state_dict
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
self.profiler.exit("sync_model")
|
|
||||||
if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
|
if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
|
||||||
"enable_sleep_mode", False
|
"enable_sleep_mode", False
|
||||||
):
|
):
|
||||||
|
@ -263,7 +263,6 @@ if __name__ == "__main__":
|
|||||||
grpo_config = {
|
grpo_config = {
|
||||||
"lr": args.learning_rate,
|
"lr": args.learning_rate,
|
||||||
"train_microbatch_size": args.train_microbatch_size,
|
"train_microbatch_size": args.train_microbatch_size,
|
||||||
"num_minibatch_during_rollout": 1, # number of mini batches to pop out from buffer and used for training during rollout of the producer after it syncs the model. Hint, set to a proper value close to the number of mini batches for training that takes roughly the same time as the rollout of the producer. A value that is too large or too small will cause bubble time on the trainer or the producer.
|
|
||||||
"beta": args.kl_coeff, # KL penalty coefficient
|
"beta": args.kl_coeff, # KL penalty coefficient
|
||||||
"loss_variation": "sample_level",
|
"loss_variation": "sample_level",
|
||||||
"reward_fn_type": args.reward_type,
|
"reward_fn_type": args.reward_type,
|
||||||
|
Loading…
Reference in New Issue
Block a user