[feat] Sync shard model (#6289)

* [feat] support hybrid parallel model sync

* update consumer and producer

* update files

* update producer

* remove print

* update

---------

Co-authored-by: duanjunwen <935724073@qq.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li35271158@gmail.com>
This commit is contained in:
Tong Li 2025-04-30 14:47:01 +08:00 committed by GitHub
parent 14f237ce7e
commit 5fd4bcb9d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 66 additions and 20 deletions

View File

@ -59,10 +59,6 @@ class BaseConsumer:
self.lr_scheduler = None self.lr_scheduler = None
def setup(self) -> None: def setup(self) -> None:
for i in range(self.num_producers):
cc.init_collective_group(self.world_size + 1, self.rank + 1, group_name=f"sync_data_{i}")
if self.rank == 0:
cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model")
launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0) launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0)
plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2) plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2)
@ -77,8 +73,24 @@ class BaseConsumer:
self.booster = Booster(plugin=self.plugin) self.booster = Booster(plugin=self.plugin)
self.dp_rank = dist.get_rank(self.plugin.dp_group) self.dp_rank = dist.get_rank(self.plugin.dp_group)
self.tp_rank = dist.get_rank(self.plugin.tp_group) self.tp_rank = dist.get_rank(self.plugin.tp_group)
self.pp_rank = dist.get_rank(self.plugin.pp_group)
self.dp_size = dist.get_world_size(self.plugin.dp_group) self.dp_size = dist.get_world_size(self.plugin.dp_group)
self.tp_size = dist.get_world_size(self.plugin.tp_group)
self.pp_size = dist.get_world_size(self.plugin.pp_group)
# Init Hybrid ray process group
for i in range(self.num_producers):
cc.init_collective_group(self.world_size + 1, self.rank + 1, group_name=f"sync_data_{i}")
if self.pp_size > 1:
# use hybrid tp + pp
if self.tp_rank == 0 and self.dp_rank == 0:
cc.init_collective_group(
self.num_producers + 1, self.num_producers, group_name=f"sync_model_{self.pp_rank}"
)
else:
if self.rank == 0:
cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model")
self.buffer = [] self.buffer = []
@ -140,13 +152,27 @@ class BaseConsumer:
print(f"Saved model checkpoint at step {step + 1} in folder {save_path}") print(f"Saved model checkpoint at step {step + 1} in folder {save_path}")
if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1: if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1:
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}") if self.pp_size > 1:
print(
f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {step}"
)
else:
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
torch.cuda.empty_cache() torch.cuda.empty_cache()
state_dict = self.state_dict() state_dict = self.state_dict()
if self.rank == 0: if self.pp_size > 1:
ray_broadcast_tensor_dict( if self.tp_rank == 0 and self.dp_rank == 0:
state_dict, src=self.num_producers, device=self.device, group_name="sync_model" ray_broadcast_tensor_dict(
) state_dict,
src=self.num_producers,
device=self.device,
group_name=f"sync_model_{self.pp_rank}",
)
else:
if self.rank == 0:
ray_broadcast_tensor_dict(
state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
)
del state_dict del state_dict
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@ -57,7 +57,7 @@ def launch_distributed(
else: else:
core_consumer = ALGO_MAP.get(core_algo, SimpleConsumer) core_consumer = ALGO_MAP.get(core_algo, SimpleConsumer)
train_dp_size = get_dp_size_fast(num_producers, plugin_config) train_dp_size = get_dp_size_fast(num_consumer_procs, plugin_config)
assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0 assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0
dataset_path = dataset_config["path"] dataset_path = dataset_config["path"]
@ -82,6 +82,7 @@ def launch_distributed(
microbatch_size=inference_microbatch_size, microbatch_size=inference_microbatch_size,
backend=inference_backend, backend=inference_backend,
num_generations=num_generations, num_generations=num_generations,
consumer_plugin_config=plugin_config,
) )
procs.append(producer) procs.append(producer)
generate_config_consumer = copy.deepcopy(generate_config) generate_config_consumer = copy.deepcopy(generate_config)

View File

@ -29,6 +29,7 @@ class BaseProducer:
tokenizer_config: Optional[Dict[str, Any]] = None, tokenizer_config: Optional[Dict[str, Any]] = None,
microbatch_size: int = 1, microbatch_size: int = 1,
backend: str = "transformers", backend: str = "transformers",
consumer_plugin_config: Dict[str, Any] = None,
): ):
self.producer_idx = producer_idx self.producer_idx = producer_idx
self.num_producers = num_producers self.num_producers = num_producers
@ -78,9 +79,15 @@ class BaseProducer:
else: else:
raise ValueError(f"Unexpected backend {backend}") raise ValueError(f"Unexpected backend {backend}")
self.consumer_pp_size = consumer_plugin_config["pp_size"] # consumer pp size
def setup(self) -> None: def setup(self) -> None:
cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_data_{self.producer_idx}") cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_data_{self.producer_idx}")
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name="sync_model") if self.consumer_pp_size > 1:
for i in range(self.consumer_pp_size):
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name=f"sync_model_{i}")
else:
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name="sync_model")
def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
raise NotImplementedError raise NotImplementedError
@ -125,15 +132,25 @@ class BaseProducer:
): ):
self.model.llm.sleep() # revict KV_cache to avoid OOM self.model.llm.sleep() # revict KV_cache to avoid OOM
# don't sync model for last iteration # don't sync model for last iteration
print(
f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}"
)
torch.cuda.empty_cache() torch.cuda.empty_cache()
state_dict = ray_broadcast_tensor_dict( if self.consumer_pp_size > 1:
None, self.num_producers, device=self.device, group_name="sync_model" for pp_idx in range(self.consumer_pp_size):
) print(
self.load_state_dict(state_dict) f"[P{self.producer_idx}] Sync model PP stage {pp_idx} episode {episode} step {(i + 1) // self.num_microbatches - 1}"
)
state_dict = ray_broadcast_tensor_dict(
None, self.num_producers, device=self.device, group_name=f"sync_model_{pp_idx}"
)
self.load_state_dict(state_dict)
else:
print(
f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}"
)
state_dict = ray_broadcast_tensor_dict(
None, self.num_producers, device=self.device, group_name="sync_model"
)
self.load_state_dict(state_dict)
del state_dict del state_dict
torch.cuda.empty_cache() torch.cuda.empty_cache()
if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get( if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
@ -170,6 +187,7 @@ class SimpleProducer(BaseProducer):
microbatch_size=1, microbatch_size=1,
backend="transformers", backend="transformers",
num_generations: int = 8, num_generations: int = 8,
consumer_plugin_config=None,
): ):
super().__init__( super().__init__(
producer_idx, producer_idx,
@ -184,6 +202,7 @@ class SimpleProducer(BaseProducer):
tokenizer_config, tokenizer_config,
microbatch_size, microbatch_size,
backend, backend,
consumer_plugin_config,
) )
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations) self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)

View File

@ -58,7 +58,7 @@ if __name__ == "__main__":
"--master_address", type=str, default=None, help="Master address for multi-node distributed training, Optional" "--master_address", type=str, default=None, help="Master address for multi-node distributed training, Optional"
) )
parser.add_argument( parser.add_argument(
"--master_port", type=int, default=29506, help="Master port for multi-node distributed training, Optional" "--master_port", type=int, default=29505, help="Master port for multi-node distributed training, Optional"
) )
# Sampling parameters # Sampling parameters
@ -223,7 +223,7 @@ if __name__ == "__main__":
"zero_stage": 2, "zero_stage": 2,
}, # for zero }, # for zero
# plugin_config={ # plugin_config={
# "tp_size": 2, # "tp_size": 1,
# "pp_size": 2, # "pp_size": 2,
# "microbatch_size": max( # "microbatch_size": max(
# 1, args.train_microbatch_size // 2 # 1, args.train_microbatch_size // 2