mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-26 07:22:12 +00:00
tested multi-node training; fix bind_batch bug
This commit is contained in:
parent
e397327de0
commit
39a1faac87
@ -16,7 +16,7 @@ from colossalai.nn.optimizer import HybridAdam
|
|||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
from .comm import ray_broadcast_tensor_dict
|
from .comm import ray_broadcast_tensor_dict
|
||||||
from .utils import bind_batch, post_recv, unbind_batch
|
from .utils import bind_batch, pad_batch, post_recv, unbind_batch
|
||||||
|
|
||||||
|
|
||||||
class BaseConsumer:
|
class BaseConsumer:
|
||||||
@ -110,6 +110,9 @@ class BaseConsumer:
|
|||||||
self.dp_rank * self.microbatch_size : (self.dp_rank + 1) * self.microbatch_size
|
self.dp_rank * self.microbatch_size : (self.dp_rank + 1) * self.microbatch_size
|
||||||
]
|
]
|
||||||
self.buffer = self.buffer[self.dp_size * self.microbatch_size :]
|
self.buffer = self.buffer[self.dp_size * self.microbatch_size :]
|
||||||
|
batch = pad_batch(
|
||||||
|
batches
|
||||||
|
) # when `imbs` is smaller than `tMbs`, samples may have differ in size, need to pad before stacking
|
||||||
batch = bind_batch(batches)
|
batch = bind_batch(batches)
|
||||||
batch = post_recv(batch)
|
batch = post_recv(batch)
|
||||||
loss = self.step(i, pbar, **batch)
|
loss = self.step(i, pbar, **batch)
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import warnings
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
@ -41,6 +42,14 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
grpo_config={},
|
grpo_config={},
|
||||||
project_name=None,
|
project_name=None,
|
||||||
):
|
):
|
||||||
|
print(f"Using GRPO config: {grpo_config}")
|
||||||
|
if grpo_config.get("loss_variation", "sample_level") == "token_level":
|
||||||
|
if batch_size != microbatch_size:
|
||||||
|
warnings.warn(
|
||||||
|
f"Applied token_level loss, force overwrite mini-batch-size with batch-size: mini-batch-size: {microbatch_size}->{batch_size}",
|
||||||
|
UserWarning,
|
||||||
|
)
|
||||||
|
microbatch_size = batch_size
|
||||||
super().__init__(
|
super().__init__(
|
||||||
num_producers,
|
num_producers,
|
||||||
num_episodes,
|
num_episodes,
|
||||||
@ -201,8 +210,10 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
loss_mask,
|
loss_mask,
|
||||||
action_mask[:, -1] == False,
|
action_mask[:, -1] == False,
|
||||||
)
|
)
|
||||||
|
effective_tokens_count = torch.sum(action_mask, dim=-1) * loss_mask
|
||||||
|
|
||||||
effective_samples = all_reduce_sum(torch.sum(loss_mask), self.plugin)
|
effective_samples = all_reduce_sum(torch.sum(loss_mask), self.plugin)
|
||||||
|
total_effective_tokens_count = all_reduce_sum(torch.sum(effective_tokens_count), self.plugin)
|
||||||
total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin)
|
total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin)
|
||||||
self.effective_sample_count += effective_samples.item()
|
self.effective_sample_count += effective_samples.item()
|
||||||
self.total_sample_count += total_samples.item()
|
self.total_sample_count += total_samples.item()
|
||||||
@ -211,11 +222,11 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
|
|
||||||
# update gradient only if at least 0.7*batch_size*num_generation valid samples are collected in case a lot of samples are invalid and got filtered out.
|
# update gradient only if at least 0.7*batch_size*num_generation valid samples are collected in case a lot of samples are invalid and got filtered out.
|
||||||
# balance between efficiency and accuracy
|
# balance between efficiency and accuracy
|
||||||
need_update = self.effective_sample_count >= self.batch_size * self.dp_size * self.num_generations * 0.75
|
need_update = self.effective_sample_count >= self.batch_size * self.dp_size * self.num_generations * 0.95
|
||||||
pbar.set_postfix(
|
pbar.set_postfix(
|
||||||
{
|
{
|
||||||
"Step": self.global_step + 1,
|
"Step": self.global_step + 1,
|
||||||
"Status": f"Collecting: {self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations * 0.75}",
|
"Status": f"Collecting: {self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations * 0.95}",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -226,7 +237,6 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
else self.booster.no_sync(self.policy_model, self.optimizer)
|
else self.booster.no_sync(self.policy_model, self.optimizer)
|
||||||
)
|
)
|
||||||
with ctx:
|
with ctx:
|
||||||
|
|
||||||
for forward_micro_batch_start in range(0, data["input_ids"].size(0), forward_batch_size):
|
for forward_micro_batch_start in range(0, data["input_ids"].size(0), forward_batch_size):
|
||||||
input_ids_forward_micro_batch = data["input_ids"][
|
input_ids_forward_micro_batch = data["input_ids"][
|
||||||
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
|
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
|
||||||
@ -324,6 +334,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
per_token_kl,
|
per_token_kl,
|
||||||
inputs["action_mask"],
|
inputs["action_mask"],
|
||||||
loss_mask=inputs["loss_mask"],
|
loss_mask=inputs["loss_mask"],
|
||||||
|
total_effective_tokens_in_batch=total_effective_tokens_count,
|
||||||
)
|
)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
@ -386,6 +397,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
per_token_kl,
|
per_token_kl,
|
||||||
action_mask_forward_micro_batch,
|
action_mask_forward_micro_batch,
|
||||||
loss_mask=loss_mask_forward_micro_batch,
|
loss_mask=loss_mask_forward_micro_batch,
|
||||||
|
total_effective_tokens_in_batch=total_effective_tokens_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.booster.backward(loss, self.optimizer)
|
self.booster.backward(loss, self.optimizer)
|
||||||
|
@ -32,6 +32,7 @@ class PolicyLoss(nn.Module):
|
|||||||
per_token_kl: torch.Tensor,
|
per_token_kl: torch.Tensor,
|
||||||
action_mask: Optional[torch.Tensor] = None,
|
action_mask: Optional[torch.Tensor] = None,
|
||||||
loss_mask: Optional[torch.Tensor] = None,
|
loss_mask: Optional[torch.Tensor] = None,
|
||||||
|
total_effective_tokens_in_batch: torch.Tensor = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if action_mask is None:
|
if action_mask is None:
|
||||||
ratio = (log_probs - log_probs.detach()).exp()
|
ratio = (log_probs - log_probs.detach()).exp()
|
||||||
@ -54,17 +55,13 @@ class PolicyLoss(nn.Module):
|
|||||||
loss = loss * loss_mask
|
loss = loss * loss_mask
|
||||||
loss = loss.mean()
|
loss = loss.mean()
|
||||||
elif self.loss_variation == "token_level":
|
elif self.loss_variation == "token_level":
|
||||||
total_tokens = 0
|
|
||||||
if action_mask is not None:
|
if action_mask is not None:
|
||||||
loss = masked_sum(loss, action_mask)
|
loss = masked_sum(loss, action_mask)
|
||||||
total_tokens = action_mask.sum(dim=1)
|
|
||||||
else:
|
else:
|
||||||
loss = loss.sum(dim=1)
|
loss = loss.sum(dim=1)
|
||||||
total_tokens = torch.ones_like(loss, device=loss.device) * log_probs.size(1)
|
|
||||||
if loss_mask is not None:
|
if loss_mask is not None:
|
||||||
loss = loss * loss_mask
|
loss = loss * loss_mask
|
||||||
total_tokens = total_tokens * loss_mask
|
loss = loss.sum() / (total_effective_tokens_in_batch + 1e-8)
|
||||||
loss = loss.sum() / (total_tokens.sum() + 1e-8)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported loss variation: {self.loss_variation}")
|
raise ValueError(f"Unsupported loss variation: {self.loss_variation}")
|
||||||
|
|
||||||
|
@ -117,6 +117,7 @@ class BaseProducer:
|
|||||||
print(
|
print(
|
||||||
f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}"
|
f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}"
|
||||||
)
|
)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
state_dict = ray_broadcast_tensor_dict(
|
state_dict = ray_broadcast_tensor_dict(
|
||||||
None, self.num_producers, device=self.device, group_name="sync_model"
|
None, self.num_producers, device=self.device, group_name="sync_model"
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from collections import defaultdict
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -26,6 +27,27 @@ def bind_batch(batches: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor
|
|||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
def pad_batch(batches: List[Dict[str, torch.Tensor]], tokenizer: Any = None) -> List[Dict[str, torch.Tensor]]:
|
||||||
|
max_len = defaultdict(int)
|
||||||
|
for sample in batches:
|
||||||
|
for k in sample:
|
||||||
|
if k in ["input_ids", "attention_mask", "action_log_probs", "action_mask"]:
|
||||||
|
max_len[k] = max(max_len[k], sample[k].size(-1))
|
||||||
|
for idx, sample in enumerate(batches):
|
||||||
|
for k in sample:
|
||||||
|
if k in ["input_ids", "attention_mask", "action_log_probs", "action_mask"]:
|
||||||
|
# right pad with 0s
|
||||||
|
if k in ["attention_mask", "action_mask"]:
|
||||||
|
batches[idx][k] = torch.nn.functional.pad(
|
||||||
|
batches[idx][k], (0, max_len[k] - batches[idx][k].size(-1)), "constant", False
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
batches[idx][k] = torch.nn.functional.pad(
|
||||||
|
batches[idx][k], (0, max_len[k] - batches[idx][k].size(-1)), "constant", 0
|
||||||
|
)
|
||||||
|
return batches
|
||||||
|
|
||||||
|
|
||||||
def pre_send(batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
def pre_send(batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||||
# compress mask to save bandwidth
|
# compress mask to save bandwidth
|
||||||
if "attention_mask" in batch:
|
if "attention_mask" in batch:
|
||||||
|
@ -49,6 +49,12 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm"])
|
parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm"])
|
||||||
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"])
|
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"])
|
||||||
|
parser.add_argument(
|
||||||
|
"--ray_dir", type=str, default=None, help="Custom temperary directory for storing ray cluster data, Optional"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--master_address", type=str, default=None, help="Master address for multi-node distributed training, Optional"
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
assert args.train_minibatch_size > 0, "Train mini batch size must be greater than 0"
|
assert args.train_minibatch_size > 0, "Train mini batch size must be greater than 0"
|
||||||
@ -57,7 +63,12 @@ if __name__ == "__main__":
|
|||||||
and args.train_microbatch_size > 0
|
and args.train_microbatch_size > 0
|
||||||
), "Train micro batch size must be greater than 0 less than train mini batch size * num generations"
|
), "Train micro batch size must be greater than 0 less than train mini batch size * num generations"
|
||||||
|
|
||||||
ray.init(address="local", namespace="ray-example")
|
if args.master_address is None:
|
||||||
|
# Default settings: Using single machine
|
||||||
|
ray.init(address="local", namespace="ray-example")
|
||||||
|
else:
|
||||||
|
# For ray distributed multi-machine training, Please change _node_ip_address to your IP address of your master node
|
||||||
|
ray.init(_node_ip_address=args.master_address, namespace="ray-example", _temp_dir=args.ray_dir)
|
||||||
|
|
||||||
inference_model_config = dict(path=args.model)
|
inference_model_config = dict(path=args.model)
|
||||||
train_model_config = dict(path=args.model, use_flash_attention_2=False, use_cache=False)
|
train_model_config = dict(path=args.model, use_flash_attention_2=False, use_cache=False)
|
||||||
@ -80,10 +91,18 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif args.backend == "vllm":
|
elif args.backend == "vllm":
|
||||||
inference_model_config.update(dict(gpu_memory_utilization=0.7, enforce_eager=True, enable_chunked_prefill=True))
|
inference_model_config.update(
|
||||||
|
dict(
|
||||||
|
gpu_memory_utilization=0.7,
|
||||||
|
enforce_eager=True,
|
||||||
|
enable_chunked_prefill=True,
|
||||||
|
max_model_len=1024 * 10 + 510,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
)
|
||||||
|
)
|
||||||
generate_config.update(
|
generate_config.update(
|
||||||
dict(
|
dict(
|
||||||
max_tokens=4096,
|
max_tokens=1024 * 10,
|
||||||
ignore_eos=True,
|
ignore_eos=True,
|
||||||
include_stop_str_in_output=True,
|
include_stop_str_in_output=True,
|
||||||
stop=["</answer>"],
|
stop=["</answer>"],
|
||||||
@ -120,14 +139,14 @@ if __name__ == "__main__":
|
|||||||
"beta": 0.0, # no KL penalty
|
"beta": 0.0, # no KL penalty
|
||||||
"loss_variation": "token_level",
|
"loss_variation": "token_level",
|
||||||
"soft_over_length_punishment": True,
|
"soft_over_length_punishment": True,
|
||||||
"max_length": 4096,
|
"max_length": 1024 * 10,
|
||||||
"cache_length": 512,
|
"cache_length": 512,
|
||||||
"filter_truncated_response": True,
|
"filter_truncated_response": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
launch_distributed(
|
launch_distributed(
|
||||||
num_producers=args.num_inferencer,
|
num_producers=args.num_inferencer,
|
||||||
num_proc_per_producer=1,
|
num_proc_per_producer=inference_model_config.get("tensor_parallel_size", 1),
|
||||||
num_consumer_procs=args.num_trainers,
|
num_consumer_procs=args.num_trainers,
|
||||||
num_episodes=10,
|
num_episodes=10,
|
||||||
inference_batch_size=args.inference_batch_size,
|
inference_batch_size=args.inference_batch_size,
|
||||||
@ -146,7 +165,6 @@ if __name__ == "__main__":
|
|||||||
"zero_stage": 2,
|
"zero_stage": 2,
|
||||||
}, # for zero
|
}, # for zero
|
||||||
# plugin_config={
|
# plugin_config={
|
||||||
# "pp_size": 2,
|
|
||||||
# "tp_size": 2,
|
# "tp_size": 2,
|
||||||
# "microbatch_size": args.train_microbatch_size // 2,
|
# "microbatch_size": args.train_microbatch_size // 2,
|
||||||
# "zero_stage": 0,
|
# "zero_stage": 0,
|
||||||
|
Loading…
Reference in New Issue
Block a user