[feat] support hybrid parallel model sync

This commit is contained in:
duanjunwen 2025-04-29 17:00:31 +08:00
parent 2ca1e3c630
commit 2f293248f7
4 changed files with 70 additions and 27 deletions

View File

@ -59,10 +59,6 @@ class BaseConsumer:
self.lr_scheduler = None self.lr_scheduler = None
def setup(self) -> None: def setup(self) -> None:
for i in range(self.num_producers):
cc.init_collective_group(self.world_size + 1, self.rank + 1, group_name=f"sync_data_{i}")
if self.rank == 0:
cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model")
launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0) launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0)
plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2) plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2)
@ -77,8 +73,26 @@ class BaseConsumer:
self.booster = Booster(plugin=self.plugin) self.booster = Booster(plugin=self.plugin)
self.dp_rank = dist.get_rank(self.plugin.dp_group) self.dp_rank = dist.get_rank(self.plugin.dp_group)
self.tp_rank = dist.get_rank(self.plugin.tp_group) self.tp_rank = dist.get_rank(self.plugin.tp_group)
self.pp_rank = dist.get_rank(self.plugin.pp_group)
self.dp_size = dist.get_world_size(self.plugin.dp_group) self.dp_size = dist.get_world_size(self.plugin.dp_group)
self.tp_size = dist.get_world_size(self.plugin.tp_group)
self.pp_size = dist.get_world_size(self.plugin.pp_group)
# Init Hybrid ray process group
for i in range(self.num_producers):
cc.init_collective_group(self.world_size + 1, self.rank + 1, backend="hccl", group_name=f"sync_data_{i}")
if self.pp_size > 1:
# use hybrid tp + pp
if self.tp_rank == 0 and self.dp_rank == 0:
cc.init_collective_group(
self.num_producers + 1, self.num_producers, backend="hccl", group_name=f"sync_model_{self.pp_rank}"
)
else:
if self.rank == 0:
cc.init_collective_group(
self.num_producers + 1, self.num_producers, backend="hccl", group_name="sync_model"
)
self.buffer = [] self.buffer = []
@ -143,6 +157,15 @@ class BaseConsumer:
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}") print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
torch.cuda.empty_cache() torch.cuda.empty_cache()
state_dict = self.state_dict() state_dict = self.state_dict()
if self.pp_size > 1:
if self.tp_rank == 0 and self.dp_rank == 0:
ray_broadcast_tensor_dict(
state_dict,
src=self.num_producers,
device=self.device,
group_name=f"sync_model_{self.pp_rank}",
)
else:
if self.rank == 0: if self.rank == 0:
ray_broadcast_tensor_dict( ray_broadcast_tensor_dict(
state_dict, src=self.num_producers, device=self.device, group_name="sync_model" state_dict, src=self.num_producers, device=self.device, group_name="sync_model"

View File

@ -57,7 +57,9 @@ def launch_distributed(
else: else:
core_consumer = ALGO_MAP.get(core_algo, SimpleConsumer) core_consumer = ALGO_MAP.get(core_algo, SimpleConsumer)
train_dp_size = get_dp_size_fast(num_producers, plugin_config) train_dp_size = (
get_dp_size_fast(num_producers, plugin_config) if get_dp_size_fast(num_producers, plugin_config) else 1
)
assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0 assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0
dataset_path = dataset_config["path"] dataset_path = dataset_config["path"]
@ -82,6 +84,7 @@ def launch_distributed(
microbatch_size=inference_microbatch_size, microbatch_size=inference_microbatch_size,
backend=inference_backend, backend=inference_backend,
num_generations=num_generations, num_generations=num_generations,
consumer_plugin_config=plugin_config,
) )
procs.append(producer) procs.append(producer)
generate_config_consumer = copy.deepcopy(generate_config) generate_config_consumer = copy.deepcopy(generate_config)

View File

@ -29,6 +29,7 @@ class BaseProducer:
tokenizer_config: Optional[Dict[str, Any]] = None, tokenizer_config: Optional[Dict[str, Any]] = None,
microbatch_size: int = 1, microbatch_size: int = 1,
backend: str = "transformers", backend: str = "transformers",
consumer_plugin_config: Dict[str, Any] = None,
): ):
self.producer_idx = producer_idx self.producer_idx = producer_idx
self.num_producers = num_producers self.num_producers = num_producers
@ -78,9 +79,19 @@ class BaseProducer:
else: else:
raise ValueError(f"Unexpected backend {backend}") raise ValueError(f"Unexpected backend {backend}")
self.consumer_pp_size = consumer_plugin_config["pp_size"] # consumer pp size
def setup(self) -> None: def setup(self) -> None:
cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_data_{self.producer_idx}") cc.init_collective_group(
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name="sync_model") 1 + self.num_consumer_procs, 0, backend="hccl", group_name=f"sync_data_{self.producer_idx}"
)
if self.consumer_pp_size > 1:
for i in range(self.consumer_pp_size):
cc.init_collective_group(
self.num_producers + 1, self.producer_idx, backend="hccl", group_name=f"sync_model_{i}"
)
else:
cc.init_collective_group(self.num_producers + 1, self.producer_idx, backend="hccl", group_name="sync_model")
def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
raise NotImplementedError raise NotImplementedError
@ -130,6 +141,14 @@ class BaseProducer:
) )
torch.cuda.empty_cache() torch.cuda.empty_cache()
if self.consumer_pp_size > 1:
# TODO: loop load
for i in range(self.consumer_pp_size):
state_dict = ray_broadcast_tensor_dict(
None, self.num_producers, device=self.device, group_name=f"sync_model_{i}"
)
self.load_state_dict(state_dict)
else:
state_dict = ray_broadcast_tensor_dict( state_dict = ray_broadcast_tensor_dict(
None, self.num_producers, device=self.device, group_name="sync_model" None, self.num_producers, device=self.device, group_name="sync_model"
) )
@ -170,6 +189,7 @@ class SimpleProducer(BaseProducer):
microbatch_size=1, microbatch_size=1,
backend="transformers", backend="transformers",
num_generations: int = 8, num_generations: int = 8,
consumer_plugin_config=None,
): ):
super().__init__( super().__init__(
producer_idx, producer_idx,
@ -184,6 +204,7 @@ class SimpleProducer(BaseProducer):
tokenizer_config, tokenizer_config,
microbatch_size, microbatch_size,
backend, backend,
consumer_plugin_config,
) )
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations) self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)

View File

@ -121,7 +121,7 @@ if __name__ == "__main__":
args.top_k = -1 args.top_k = -1
inference_model_config = dict(path=args.model) inference_model_config = dict(path=args.model)
train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False) train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False, attn_implementation="eager")
generate_config = dict(top_k=args.top_k, top_p=args.top_p, temperature=args.temperature) generate_config = dict(top_k=args.top_k, top_p=args.top_p, temperature=args.temperature)
if args.backend == "transformers": if args.backend == "transformers":
@ -147,7 +147,7 @@ if __name__ == "__main__":
enforce_eager=True, enforce_eager=True,
enable_chunked_prefill=True, enable_chunked_prefill=True,
max_model_len=args.max_new_tokens + args.max_prompt_tokens, max_model_len=args.max_new_tokens + args.max_prompt_tokens,
tensor_parallel_size=1, tensor_parallel_size=2,
) )
) )
generate_config.update( generate_config.update(
@ -210,16 +210,12 @@ if __name__ == "__main__":
train_model_config=train_model_config, train_model_config=train_model_config,
grpo_config=grpo_config, grpo_config=grpo_config,
plugin_config={ plugin_config={
"zero_stage": 2, "pp_size": 2,
}, # for zero "tp_size": 2,
# currently not support tp/pp "microbatch_size": args.train_microbatch_size // 2,
# plugin_config={ "zero_stage": 1,
# "tp_size": 2, "max_norm": 1.0,
# "pp_size": 2, }, # for tp + pp
# "microbatch_size": max(1, args.train_microbatch_size // 2),
# "zero_stage": 0,
# "max_norm": 1.0,
# }, # for pp
inference_backend=args.backend, inference_backend=args.backend,
master_addr="localhost", master_addr="localhost",
master_port=args.master_port, master_port=args.master_port,