mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 10:06:44 +00:00
clean code
This commit is contained in:
@@ -151,19 +151,6 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
eta_min=0.1 * grpo_config.get("lr", 1e-6),
|
eta_min=0.1 * grpo_config.get("lr", 1e-6),
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_device_mesh_mapping(self):
|
|
||||||
return {
|
|
||||||
"rank": self.rank,
|
|
||||||
"tp_rank": self.tp_rank,
|
|
||||||
"tp_size": self.tp_size,
|
|
||||||
"dp_size": self.dp_size,
|
|
||||||
"dp_rank": self.dp_rank,
|
|
||||||
"pp_size": self.booster.plugin.stage_manager.num_stages,
|
|
||||||
"pp_stage": self.booster.plugin.stage_manager.stage,
|
|
||||||
"is_last_stage": self.booster.plugin.stage_manager.is_last_stage(),
|
|
||||||
"world_size": self.world_size,
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_model_state_dict_keys(self):
|
def get_model_state_dict_keys(self):
|
||||||
return self.orig_state_dict_key
|
return self.orig_state_dict_key
|
||||||
|
|
||||||
|
@@ -116,11 +116,10 @@ def launch_distributed(
|
|||||||
consumer_procs.append(consumer)
|
consumer_procs.append(consumer)
|
||||||
# setup the consumer procs first
|
# setup the consumer procs first
|
||||||
ray.get([p.setup.remote() for p in consumer_procs])
|
ray.get([p.setup.remote() for p in consumer_procs])
|
||||||
# get the device mesh mapping from consumer procs
|
# get state dict key for checking syncing integrity
|
||||||
consumer_device_mesh_mapping = ray.get([p.get_device_mesh_mapping.remote() for p in consumer_procs])
|
|
||||||
model_state_dict_keys = ray.get(consumer_procs[0].get_model_state_dict_keys.remote())
|
model_state_dict_keys = ray.get(consumer_procs[0].get_model_state_dict_keys.remote())
|
||||||
# setup the producer procs
|
# setup the producer procs
|
||||||
ray.get([p.setup.remote(consumer_device_mesh_mapping, model_state_dict_keys) for p in producer_procs])
|
ray.get([p.setup.remote(model_state_dict_keys) for p in producer_procs])
|
||||||
# loop
|
# loop
|
||||||
procs = producer_procs + consumer_procs
|
procs = producer_procs + consumer_procs
|
||||||
ray.get([p.loop.remote() for p in procs])
|
ray.get([p.loop.remote() for p in procs])
|
||||||
|
@@ -80,18 +80,12 @@ class BaseProducer:
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unexpected backend {backend}")
|
raise ValueError(f"Unexpected backend {backend}")
|
||||||
|
|
||||||
def setup(self, consumer_device_mesh_mapping: Dict[str, Any] = None, model_state_dict_keys: List = None) -> None:
|
def setup(self, model_state_dict_keys: List = None) -> None:
|
||||||
self.consumer_device_mesh_mapping = consumer_device_mesh_mapping
|
|
||||||
self.model_state_dict_keys = model_state_dict_keys
|
self.model_state_dict_keys = model_state_dict_keys
|
||||||
cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_data_{self.producer_idx}")
|
cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_data_{self.producer_idx}")
|
||||||
# cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name="sync_model_pp_stage_0")
|
for pp_stage in range(self.consumer_plugin_config.get("pp_size", 1)):
|
||||||
for i in range(self.num_consumer_procs):
|
group_name = f"sync_model_pp_stage_{pp_stage}"
|
||||||
device_mesh_mapping = self.consumer_device_mesh_mapping[i]
|
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name=group_name)
|
||||||
device_mesh_mapping["rank"]
|
|
||||||
# TODO: support ep, sp
|
|
||||||
if device_mesh_mapping["dp_rank"] == 0 and device_mesh_mapping["tp_rank"] == 0:
|
|
||||||
group_name = f"sync_model_pp_stage_{device_mesh_mapping['pp_stage']}"
|
|
||||||
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name=group_name)
|
|
||||||
|
|
||||||
def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@@ -141,17 +135,13 @@ class BaseProducer:
|
|||||||
print(
|
print(
|
||||||
f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}"
|
f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}"
|
||||||
)
|
)
|
||||||
for consumer_rank_id in range(self.num_consumer_procs):
|
for pp_stage in range(self.consumer_plugin_config.get("pp_size", 1)):
|
||||||
device_mesh_mapping = self.consumer_device_mesh_mapping[consumer_rank_id]
|
group_name = f"sync_model_pp_stage_{pp_stage}"
|
||||||
device_mesh_mapping["rank"]
|
state_dict.update(
|
||||||
# TODO: support ep, sp
|
ray_broadcast_tensor_dict(
|
||||||
if device_mesh_mapping["dp_rank"] == 0 and device_mesh_mapping["tp_rank"] == 0:
|
None, src=self.num_producers, device=self.device, group_name=group_name
|
||||||
group_name = f"sync_model_pp_stage_{device_mesh_mapping['pp_stage']}"
|
|
||||||
state_dict.update(
|
|
||||||
ray_broadcast_tensor_dict(
|
|
||||||
None, src=self.num_producers, device=self.device, group_name=group_name
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
)
|
||||||
# check model sync integrity
|
# check model sync integrity
|
||||||
assert len(state_dict) == len(
|
assert len(state_dict) == len(
|
||||||
self.model_state_dict_keys
|
self.model_state_dict_keys
|
||||||
|
Reference in New Issue
Block a user