tested multi-node training; fix bind_batch bug

This commit is contained in:
YeAnbang 2025-04-21 15:00:28 +08:00
parent e397327de0
commit 39a1faac87
6 changed files with 68 additions and 15 deletions

View File

@ -16,7 +16,7 @@ from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
from .comm import ray_broadcast_tensor_dict
from .utils import bind_batch, post_recv, unbind_batch
from .utils import bind_batch, pad_batch, post_recv, unbind_batch
class BaseConsumer:
@ -110,6 +110,9 @@ class BaseConsumer:
self.dp_rank * self.microbatch_size : (self.dp_rank + 1) * self.microbatch_size
]
self.buffer = self.buffer[self.dp_size * self.microbatch_size :]
batch = pad_batch(
batches
) # when `imbs` is smaller than `tMbs`, samples may have differ in size, need to pad before stacking
batch = bind_batch(batches)
batch = post_recv(batch)
loss = self.step(i, pbar, **batch)

View File

@ -1,5 +1,6 @@
import json
import os
import warnings
from contextlib import nullcontext
from typing import Any, Optional
@ -41,6 +42,14 @@ class GRPOConsumer(BaseConsumer):
grpo_config={},
project_name=None,
):
print(f"Using GRPO config: {grpo_config}")
if grpo_config.get("loss_variation", "sample_level") == "token_level":
if batch_size != microbatch_size:
warnings.warn(
f"Applied token_level loss, force overwrite mini-batch-size with batch-size: mini-batch-size: {microbatch_size}->{batch_size}",
UserWarning,
)
microbatch_size = batch_size
super().__init__(
num_producers,
num_episodes,
@ -201,8 +210,10 @@ class GRPOConsumer(BaseConsumer):
loss_mask,
action_mask[:, -1] == False,
)
effective_tokens_count = torch.sum(action_mask, dim=-1) * loss_mask
effective_samples = all_reduce_sum(torch.sum(loss_mask), self.plugin)
total_effective_tokens_count = all_reduce_sum(torch.sum(effective_tokens_count), self.plugin)
total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin)
self.effective_sample_count += effective_samples.item()
self.total_sample_count += total_samples.item()
@ -211,11 +222,11 @@ class GRPOConsumer(BaseConsumer):
# update gradient only if at least 0.7*batch_size*num_generation valid samples are collected in case a lot of samples are invalid and got filtered out.
# balance between efficiency and accuracy
need_update = self.effective_sample_count >= self.batch_size * self.dp_size * self.num_generations * 0.75
need_update = self.effective_sample_count >= self.batch_size * self.dp_size * self.num_generations * 0.95
pbar.set_postfix(
{
"Step": self.global_step + 1,
"Status": f"Collecting: {self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations * 0.75}",
"Status": f"Collecting: {self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations * 0.95}",
}
)
@ -226,7 +237,6 @@ class GRPOConsumer(BaseConsumer):
else self.booster.no_sync(self.policy_model, self.optimizer)
)
with ctx:
for forward_micro_batch_start in range(0, data["input_ids"].size(0), forward_batch_size):
input_ids_forward_micro_batch = data["input_ids"][
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
@ -324,6 +334,7 @@ class GRPOConsumer(BaseConsumer):
per_token_kl,
inputs["action_mask"],
loss_mask=inputs["loss_mask"],
total_effective_tokens_in_batch=total_effective_tokens_count,
)
return loss
@ -386,6 +397,7 @@ class GRPOConsumer(BaseConsumer):
per_token_kl,
action_mask_forward_micro_batch,
loss_mask=loss_mask_forward_micro_batch,
total_effective_tokens_in_batch=total_effective_tokens_count,
)
self.booster.backward(loss, self.optimizer)

View File

@ -32,6 +32,7 @@ class PolicyLoss(nn.Module):
per_token_kl: torch.Tensor,
action_mask: Optional[torch.Tensor] = None,
loss_mask: Optional[torch.Tensor] = None,
total_effective_tokens_in_batch: torch.Tensor = None,
) -> torch.Tensor:
if action_mask is None:
ratio = (log_probs - log_probs.detach()).exp()
@ -54,17 +55,13 @@ class PolicyLoss(nn.Module):
loss = loss * loss_mask
loss = loss.mean()
elif self.loss_variation == "token_level":
total_tokens = 0
if action_mask is not None:
loss = masked_sum(loss, action_mask)
total_tokens = action_mask.sum(dim=1)
else:
loss = loss.sum(dim=1)
total_tokens = torch.ones_like(loss, device=loss.device) * log_probs.size(1)
if loss_mask is not None:
loss = loss * loss_mask
total_tokens = total_tokens * loss_mask
loss = loss.sum() / (total_tokens.sum() + 1e-8)
loss = loss.sum() / (total_effective_tokens_in_batch + 1e-8)
else:
raise ValueError(f"Unsupported loss variation: {self.loss_variation}")

View File

@ -117,6 +117,7 @@ class BaseProducer:
print(
f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}"
)
torch.cuda.empty_cache()
state_dict = ray_broadcast_tensor_dict(
None, self.num_producers, device=self.device, group_name="sync_model"

View File

@ -1,3 +1,4 @@
from collections import defaultdict
from typing import Any, Dict, List
import torch
@ -26,6 +27,27 @@ def bind_batch(batches: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor
return batch
def pad_batch(batches: List[Dict[str, torch.Tensor]], tokenizer: Any = None) -> List[Dict[str, torch.Tensor]]:
max_len = defaultdict(int)
for sample in batches:
for k in sample:
if k in ["input_ids", "attention_mask", "action_log_probs", "action_mask"]:
max_len[k] = max(max_len[k], sample[k].size(-1))
for idx, sample in enumerate(batches):
for k in sample:
if k in ["input_ids", "attention_mask", "action_log_probs", "action_mask"]:
# right pad with 0s
if k in ["attention_mask", "action_mask"]:
batches[idx][k] = torch.nn.functional.pad(
batches[idx][k], (0, max_len[k] - batches[idx][k].size(-1)), "constant", False
)
else:
batches[idx][k] = torch.nn.functional.pad(
batches[idx][k], (0, max_len[k] - batches[idx][k].size(-1)), "constant", 0
)
return batches
def pre_send(batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
# compress mask to save bandwidth
if "attention_mask" in batch:

View File

@ -49,6 +49,12 @@ if __name__ == "__main__":
)
parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm"])
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"])
parser.add_argument(
"--ray_dir", type=str, default=None, help="Custom temperary directory for storing ray cluster data, Optional"
)
parser.add_argument(
"--master_address", type=str, default=None, help="Master address for multi-node distributed training, Optional"
)
args = parser.parse_args()
assert args.train_minibatch_size > 0, "Train mini batch size must be greater than 0"
@ -57,7 +63,12 @@ if __name__ == "__main__":
and args.train_microbatch_size > 0
), "Train micro batch size must be greater than 0 less than train mini batch size * num generations"
ray.init(address="local", namespace="ray-example")
if args.master_address is None:
# Default settings: Using single machine
ray.init(address="local", namespace="ray-example")
else:
# For ray distributed multi-machine training, Please change _node_ip_address to your IP address of your master node
ray.init(_node_ip_address=args.master_address, namespace="ray-example", _temp_dir=args.ray_dir)
inference_model_config = dict(path=args.model)
train_model_config = dict(path=args.model, use_flash_attention_2=False, use_cache=False)
@ -80,10 +91,18 @@ if __name__ == "__main__":
)
)
elif args.backend == "vllm":
inference_model_config.update(dict(gpu_memory_utilization=0.7, enforce_eager=True, enable_chunked_prefill=True))
inference_model_config.update(
dict(
gpu_memory_utilization=0.7,
enforce_eager=True,
enable_chunked_prefill=True,
max_model_len=1024 * 10 + 510,
tensor_parallel_size=1,
)
)
generate_config.update(
dict(
max_tokens=4096,
max_tokens=1024 * 10,
ignore_eos=True,
include_stop_str_in_output=True,
stop=["</answer>"],
@ -120,14 +139,14 @@ if __name__ == "__main__":
"beta": 0.0, # no KL penalty
"loss_variation": "token_level",
"soft_over_length_punishment": True,
"max_length": 4096,
"max_length": 1024 * 10,
"cache_length": 512,
"filter_truncated_response": True,
}
launch_distributed(
num_producers=args.num_inferencer,
num_proc_per_producer=1,
num_proc_per_producer=inference_model_config.get("tensor_parallel_size", 1),
num_consumer_procs=args.num_trainers,
num_episodes=10,
inference_batch_size=args.inference_batch_size,
@ -146,7 +165,6 @@ if __name__ == "__main__":
"zero_stage": 2,
}, # for zero
# plugin_config={
# "pp_size": 2,
# "tp_size": 2,
# "microbatch_size": args.train_microbatch_size // 2,
# "zero_stage": 0,