mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-25 15:01:43 +00:00
tested multi-node training; fix bind_batch bug
This commit is contained in:
parent
e397327de0
commit
39a1faac87
@ -16,7 +16,7 @@ from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .comm import ray_broadcast_tensor_dict
|
||||
from .utils import bind_batch, post_recv, unbind_batch
|
||||
from .utils import bind_batch, pad_batch, post_recv, unbind_batch
|
||||
|
||||
|
||||
class BaseConsumer:
|
||||
@ -110,6 +110,9 @@ class BaseConsumer:
|
||||
self.dp_rank * self.microbatch_size : (self.dp_rank + 1) * self.microbatch_size
|
||||
]
|
||||
self.buffer = self.buffer[self.dp_size * self.microbatch_size :]
|
||||
batch = pad_batch(
|
||||
batches
|
||||
) # when `imbs` is smaller than `tMbs`, samples may have differ in size, need to pad before stacking
|
||||
batch = bind_batch(batches)
|
||||
batch = post_recv(batch)
|
||||
loss = self.step(i, pbar, **batch)
|
||||
|
@ -1,5 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, Optional
|
||||
|
||||
@ -41,6 +42,14 @@ class GRPOConsumer(BaseConsumer):
|
||||
grpo_config={},
|
||||
project_name=None,
|
||||
):
|
||||
print(f"Using GRPO config: {grpo_config}")
|
||||
if grpo_config.get("loss_variation", "sample_level") == "token_level":
|
||||
if batch_size != microbatch_size:
|
||||
warnings.warn(
|
||||
f"Applied token_level loss, force overwrite mini-batch-size with batch-size: mini-batch-size: {microbatch_size}->{batch_size}",
|
||||
UserWarning,
|
||||
)
|
||||
microbatch_size = batch_size
|
||||
super().__init__(
|
||||
num_producers,
|
||||
num_episodes,
|
||||
@ -201,8 +210,10 @@ class GRPOConsumer(BaseConsumer):
|
||||
loss_mask,
|
||||
action_mask[:, -1] == False,
|
||||
)
|
||||
effective_tokens_count = torch.sum(action_mask, dim=-1) * loss_mask
|
||||
|
||||
effective_samples = all_reduce_sum(torch.sum(loss_mask), self.plugin)
|
||||
total_effective_tokens_count = all_reduce_sum(torch.sum(effective_tokens_count), self.plugin)
|
||||
total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin)
|
||||
self.effective_sample_count += effective_samples.item()
|
||||
self.total_sample_count += total_samples.item()
|
||||
@ -211,11 +222,11 @@ class GRPOConsumer(BaseConsumer):
|
||||
|
||||
# update gradient only if at least 0.7*batch_size*num_generation valid samples are collected in case a lot of samples are invalid and got filtered out.
|
||||
# balance between efficiency and accuracy
|
||||
need_update = self.effective_sample_count >= self.batch_size * self.dp_size * self.num_generations * 0.75
|
||||
need_update = self.effective_sample_count >= self.batch_size * self.dp_size * self.num_generations * 0.95
|
||||
pbar.set_postfix(
|
||||
{
|
||||
"Step": self.global_step + 1,
|
||||
"Status": f"Collecting: {self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations * 0.75}",
|
||||
"Status": f"Collecting: {self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations * 0.95}",
|
||||
}
|
||||
)
|
||||
|
||||
@ -226,7 +237,6 @@ class GRPOConsumer(BaseConsumer):
|
||||
else self.booster.no_sync(self.policy_model, self.optimizer)
|
||||
)
|
||||
with ctx:
|
||||
|
||||
for forward_micro_batch_start in range(0, data["input_ids"].size(0), forward_batch_size):
|
||||
input_ids_forward_micro_batch = data["input_ids"][
|
||||
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
|
||||
@ -324,6 +334,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
per_token_kl,
|
||||
inputs["action_mask"],
|
||||
loss_mask=inputs["loss_mask"],
|
||||
total_effective_tokens_in_batch=total_effective_tokens_count,
|
||||
)
|
||||
return loss
|
||||
|
||||
@ -386,6 +397,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
per_token_kl,
|
||||
action_mask_forward_micro_batch,
|
||||
loss_mask=loss_mask_forward_micro_batch,
|
||||
total_effective_tokens_in_batch=total_effective_tokens_count,
|
||||
)
|
||||
|
||||
self.booster.backward(loss, self.optimizer)
|
||||
|
@ -32,6 +32,7 @@ class PolicyLoss(nn.Module):
|
||||
per_token_kl: torch.Tensor,
|
||||
action_mask: Optional[torch.Tensor] = None,
|
||||
loss_mask: Optional[torch.Tensor] = None,
|
||||
total_effective_tokens_in_batch: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
if action_mask is None:
|
||||
ratio = (log_probs - log_probs.detach()).exp()
|
||||
@ -54,17 +55,13 @@ class PolicyLoss(nn.Module):
|
||||
loss = loss * loss_mask
|
||||
loss = loss.mean()
|
||||
elif self.loss_variation == "token_level":
|
||||
total_tokens = 0
|
||||
if action_mask is not None:
|
||||
loss = masked_sum(loss, action_mask)
|
||||
total_tokens = action_mask.sum(dim=1)
|
||||
else:
|
||||
loss = loss.sum(dim=1)
|
||||
total_tokens = torch.ones_like(loss, device=loss.device) * log_probs.size(1)
|
||||
if loss_mask is not None:
|
||||
loss = loss * loss_mask
|
||||
total_tokens = total_tokens * loss_mask
|
||||
loss = loss.sum() / (total_tokens.sum() + 1e-8)
|
||||
loss = loss.sum() / (total_effective_tokens_in_batch + 1e-8)
|
||||
else:
|
||||
raise ValueError(f"Unsupported loss variation: {self.loss_variation}")
|
||||
|
||||
|
@ -117,6 +117,7 @@ class BaseProducer:
|
||||
print(
|
||||
f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}"
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
state_dict = ray_broadcast_tensor_dict(
|
||||
None, self.num_producers, device=self.device, group_name="sync_model"
|
||||
|
@ -1,3 +1,4 @@
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import torch
|
||||
@ -26,6 +27,27 @@ def bind_batch(batches: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor
|
||||
return batch
|
||||
|
||||
|
||||
def pad_batch(batches: List[Dict[str, torch.Tensor]], tokenizer: Any = None) -> List[Dict[str, torch.Tensor]]:
|
||||
max_len = defaultdict(int)
|
||||
for sample in batches:
|
||||
for k in sample:
|
||||
if k in ["input_ids", "attention_mask", "action_log_probs", "action_mask"]:
|
||||
max_len[k] = max(max_len[k], sample[k].size(-1))
|
||||
for idx, sample in enumerate(batches):
|
||||
for k in sample:
|
||||
if k in ["input_ids", "attention_mask", "action_log_probs", "action_mask"]:
|
||||
# right pad with 0s
|
||||
if k in ["attention_mask", "action_mask"]:
|
||||
batches[idx][k] = torch.nn.functional.pad(
|
||||
batches[idx][k], (0, max_len[k] - batches[idx][k].size(-1)), "constant", False
|
||||
)
|
||||
else:
|
||||
batches[idx][k] = torch.nn.functional.pad(
|
||||
batches[idx][k], (0, max_len[k] - batches[idx][k].size(-1)), "constant", 0
|
||||
)
|
||||
return batches
|
||||
|
||||
|
||||
def pre_send(batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
# compress mask to save bandwidth
|
||||
if "attention_mask" in batch:
|
||||
|
@ -49,6 +49,12 @@ if __name__ == "__main__":
|
||||
)
|
||||
parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm"])
|
||||
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"])
|
||||
parser.add_argument(
|
||||
"--ray_dir", type=str, default=None, help="Custom temperary directory for storing ray cluster data, Optional"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--master_address", type=str, default=None, help="Master address for multi-node distributed training, Optional"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
assert args.train_minibatch_size > 0, "Train mini batch size must be greater than 0"
|
||||
@ -57,7 +63,12 @@ if __name__ == "__main__":
|
||||
and args.train_microbatch_size > 0
|
||||
), "Train micro batch size must be greater than 0 less than train mini batch size * num generations"
|
||||
|
||||
ray.init(address="local", namespace="ray-example")
|
||||
if args.master_address is None:
|
||||
# Default settings: Using single machine
|
||||
ray.init(address="local", namespace="ray-example")
|
||||
else:
|
||||
# For ray distributed multi-machine training, Please change _node_ip_address to your IP address of your master node
|
||||
ray.init(_node_ip_address=args.master_address, namespace="ray-example", _temp_dir=args.ray_dir)
|
||||
|
||||
inference_model_config = dict(path=args.model)
|
||||
train_model_config = dict(path=args.model, use_flash_attention_2=False, use_cache=False)
|
||||
@ -80,10 +91,18 @@ if __name__ == "__main__":
|
||||
)
|
||||
)
|
||||
elif args.backend == "vllm":
|
||||
inference_model_config.update(dict(gpu_memory_utilization=0.7, enforce_eager=True, enable_chunked_prefill=True))
|
||||
inference_model_config.update(
|
||||
dict(
|
||||
gpu_memory_utilization=0.7,
|
||||
enforce_eager=True,
|
||||
enable_chunked_prefill=True,
|
||||
max_model_len=1024 * 10 + 510,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
)
|
||||
generate_config.update(
|
||||
dict(
|
||||
max_tokens=4096,
|
||||
max_tokens=1024 * 10,
|
||||
ignore_eos=True,
|
||||
include_stop_str_in_output=True,
|
||||
stop=["</answer>"],
|
||||
@ -120,14 +139,14 @@ if __name__ == "__main__":
|
||||
"beta": 0.0, # no KL penalty
|
||||
"loss_variation": "token_level",
|
||||
"soft_over_length_punishment": True,
|
||||
"max_length": 4096,
|
||||
"max_length": 1024 * 10,
|
||||
"cache_length": 512,
|
||||
"filter_truncated_response": True,
|
||||
}
|
||||
|
||||
launch_distributed(
|
||||
num_producers=args.num_inferencer,
|
||||
num_proc_per_producer=1,
|
||||
num_proc_per_producer=inference_model_config.get("tensor_parallel_size", 1),
|
||||
num_consumer_procs=args.num_trainers,
|
||||
num_episodes=10,
|
||||
inference_batch_size=args.inference_batch_size,
|
||||
@ -146,7 +165,6 @@ if __name__ == "__main__":
|
||||
"zero_stage": 2,
|
||||
}, # for zero
|
||||
# plugin_config={
|
||||
# "pp_size": 2,
|
||||
# "tp_size": 2,
|
||||
# "microbatch_size": args.train_microbatch_size // 2,
|
||||
# "zero_stage": 0,
|
||||
|
Loading…
Reference in New Issue
Block a user