mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-07 02:54:10 +00:00
fix pp state dict incomplete issue
This commit is contained in:
parent
064be50946
commit
6c1b3b694f
@ -59,13 +59,8 @@ class BaseConsumer:
|
|||||||
self.lr_scheduler = None
|
self.lr_scheduler = None
|
||||||
|
|
||||||
def setup(self) -> None:
|
def setup(self) -> None:
|
||||||
for i in range(self.num_producers):
|
|
||||||
cc.init_collective_group(self.world_size + 1, self.rank + 1, group_name=f"sync_data_{i}")
|
|
||||||
if self.rank == 0:
|
|
||||||
cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model")
|
|
||||||
launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0)
|
launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0)
|
||||||
|
plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2) # default config
|
||||||
plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2)
|
|
||||||
if (
|
if (
|
||||||
self.plugin_config.get("pp_size", 1) > 1
|
self.plugin_config.get("pp_size", 1) > 1
|
||||||
and "num_microbatches" not in self.plugin_config
|
and "num_microbatches" not in self.plugin_config
|
||||||
@ -79,10 +74,16 @@ class BaseConsumer:
|
|||||||
self.tp_rank = dist.get_rank(self.plugin.tp_group)
|
self.tp_rank = dist.get_rank(self.plugin.tp_group)
|
||||||
|
|
||||||
self.dp_size = dist.get_world_size(self.plugin.dp_group)
|
self.dp_size = dist.get_world_size(self.plugin.dp_group)
|
||||||
|
self.tp_size = dist.get_world_size(self.plugin.tp_group)
|
||||||
|
|
||||||
self.buffer = []
|
self.buffer = []
|
||||||
|
|
||||||
self.recv_cnt = 0
|
self.recv_cnt = 0
|
||||||
|
for i in range(self.num_producers):
|
||||||
|
cc.init_collective_group(self.world_size + 1, self.rank + 1, group_name=f"sync_data_{i}")
|
||||||
|
if self.dp_rank == 0 and self.tp_rank == 0:
|
||||||
|
group_name = f"sync_model_pp_stage_{self.plugin.stage_manager.stage}"
|
||||||
|
cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name=group_name)
|
||||||
|
|
||||||
def state_dict(self) -> Dict[str, torch.Tensor]:
|
def state_dict(self) -> Dict[str, torch.Tensor]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@ -140,12 +141,13 @@ class BaseConsumer:
|
|||||||
print(f"Saved model checkpoint at step {step + 1} in folder {save_path}")
|
print(f"Saved model checkpoint at step {step + 1} in folder {save_path}")
|
||||||
|
|
||||||
if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1:
|
if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1:
|
||||||
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
if self.rank == 0:
|
if self.dp_rank == 0 and self.tp_rank == 0:
|
||||||
|
group_name = f"sync_model_pp_stage_{self.plugin.stage_manager.stage}"
|
||||||
|
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
|
||||||
ray_broadcast_tensor_dict(
|
ray_broadcast_tensor_dict(
|
||||||
state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
|
state_dict, src=self.num_producers, device=self.device, group_name=group_name
|
||||||
)
|
)
|
||||||
del state_dict
|
del state_dict
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
@ -191,6 +193,9 @@ class SimpleConsumer(BaseConsumer):
|
|||||||
self.optimizer = HybridAdam(self.model.parameters(), lr=1e-3)
|
self.optimizer = HybridAdam(self.model.parameters(), lr=1e-3)
|
||||||
self.accum_loss = torch.zeros(1, device=self.device)
|
self.accum_loss = torch.zeros(1, device=self.device)
|
||||||
|
|
||||||
|
def get_plugin(self):
|
||||||
|
return self.plugin
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
super().setup()
|
super().setup()
|
||||||
self.model, self.optimizer, *_ = self.booster.boost(self.model, self.optimizer)
|
self.model, self.optimizer, *_ = self.booster.boost(self.model, self.optimizer)
|
||||||
|
@ -75,6 +75,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
)
|
)
|
||||||
path = model_config.pop("path")
|
path = model_config.pop("path")
|
||||||
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||||
|
self.orig_state_dict_key = [k for k in self.policy_model.state_dict()]
|
||||||
self.policy_model.train()
|
self.policy_model.train()
|
||||||
self.policy_model.gradient_checkpointing_enable()
|
self.policy_model.gradient_checkpointing_enable()
|
||||||
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6))
|
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6))
|
||||||
@ -150,6 +151,22 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
eta_min=0.1 * grpo_config.get("lr", 1e-6),
|
eta_min=0.1 * grpo_config.get("lr", 1e-6),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_device_mesh_mapping(self):
|
||||||
|
return {
|
||||||
|
"rank": self.rank,
|
||||||
|
"tp_rank": self.tp_rank,
|
||||||
|
"tp_size": self.tp_size,
|
||||||
|
"dp_size": self.dp_size,
|
||||||
|
"dp_rank": self.dp_rank,
|
||||||
|
"pp_size": self.booster.plugin.stage_manager.num_stages,
|
||||||
|
"pp_stage": self.booster.plugin.stage_manager.stage,
|
||||||
|
"is_last_stage": self.booster.plugin.stage_manager.is_last_stage(),
|
||||||
|
"world_size": self.world_size,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_model_state_dict_keys(self):
|
||||||
|
return self.orig_state_dict_key
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
super().setup()
|
super().setup()
|
||||||
if self.use_wandb and (
|
if self.use_wandb and (
|
||||||
|
@ -66,7 +66,7 @@ def launch_distributed(
|
|||||||
num_update_per_episode = num_samples // global_inference_batch_size
|
num_update_per_episode = num_samples // global_inference_batch_size
|
||||||
num_recv_per_update = inference_batch_size // inference_microbatch_size
|
num_recv_per_update = inference_batch_size // inference_microbatch_size
|
||||||
|
|
||||||
procs = []
|
producer_procs = []
|
||||||
for i in range(num_producers):
|
for i in range(num_producers):
|
||||||
producer = SimpleProducer.options(num_gpus=num_proc_per_producer).remote(
|
producer = SimpleProducer.options(num_gpus=num_proc_per_producer).remote(
|
||||||
producer_idx=i,
|
producer_idx=i,
|
||||||
@ -78,18 +78,20 @@ def launch_distributed(
|
|||||||
dataloaders_config=dataloaders_config,
|
dataloaders_config=dataloaders_config,
|
||||||
model_config=inference_model_config,
|
model_config=inference_model_config,
|
||||||
generate_config=generate_config,
|
generate_config=generate_config,
|
||||||
|
consumer_plugin_config=plugin_config,
|
||||||
tokenizer_config=tokenizer_config,
|
tokenizer_config=tokenizer_config,
|
||||||
microbatch_size=inference_microbatch_size,
|
microbatch_size=inference_microbatch_size,
|
||||||
backend=inference_backend,
|
backend=inference_backend,
|
||||||
num_generations=num_generations,
|
num_generations=num_generations,
|
||||||
)
|
)
|
||||||
procs.append(producer)
|
producer_procs.append(producer)
|
||||||
generate_config_consumer = copy.deepcopy(generate_config)
|
generate_config_consumer = copy.deepcopy(generate_config)
|
||||||
generate_config_consumer.update(
|
generate_config_consumer.update(
|
||||||
dict(
|
dict(
|
||||||
backend=inference_backend,
|
backend=inference_backend,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
consumer_procs = []
|
||||||
for i in range(num_consumer_procs):
|
for i in range(num_consumer_procs):
|
||||||
consumer = core_consumer.options(num_gpus=1).remote(
|
consumer = core_consumer.options(num_gpus=1).remote(
|
||||||
num_producers=num_producers,
|
num_producers=num_producers,
|
||||||
@ -111,6 +113,14 @@ def launch_distributed(
|
|||||||
save_interval=save_interval,
|
save_interval=save_interval,
|
||||||
save_dir=save_dir,
|
save_dir=save_dir,
|
||||||
)
|
)
|
||||||
procs.append(consumer)
|
consumer_procs.append(consumer)
|
||||||
ray.get([p.setup.remote() for p in procs])
|
# setup the consumer procs first
|
||||||
|
ray.get([p.setup.remote() for p in consumer_procs])
|
||||||
|
# get the device mesh mapping from consumer procs
|
||||||
|
consumer_device_mesh_mapping = ray.get([p.get_device_mesh_mapping.remote() for p in consumer_procs])
|
||||||
|
model_state_dict_keys = ray.get(consumer_procs[0].get_model_state_dict_keys.remote())
|
||||||
|
# setup the producer procs
|
||||||
|
ray.get([p.setup.remote(consumer_device_mesh_mapping, model_state_dict_keys) for p in producer_procs])
|
||||||
|
# loop
|
||||||
|
procs = producer_procs + consumer_procs
|
||||||
ray.get([p.loop.remote() for p in procs])
|
ray.get([p.loop.remote() for p in procs])
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
import ray.util.collective as cc
|
import ray.util.collective as cc
|
||||||
@ -26,6 +26,7 @@ class BaseProducer:
|
|||||||
dataloaders_config: Dict[str, Any],
|
dataloaders_config: Dict[str, Any],
|
||||||
model_config: Dict[str, Any],
|
model_config: Dict[str, Any],
|
||||||
generate_config: Dict[str, Any],
|
generate_config: Dict[str, Any],
|
||||||
|
consumer_plugin_config: Dict[str, Any] = None,
|
||||||
tokenizer_config: Optional[Dict[str, Any]] = None,
|
tokenizer_config: Optional[Dict[str, Any]] = None,
|
||||||
microbatch_size: int = 1,
|
microbatch_size: int = 1,
|
||||||
backend: str = "transformers",
|
backend: str = "transformers",
|
||||||
@ -43,6 +44,7 @@ class BaseProducer:
|
|||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.generate_config = generate_config
|
self.generate_config = generate_config
|
||||||
self.tokenizer_config = tokenizer_config
|
self.tokenizer_config = tokenizer_config
|
||||||
|
self.consumer_plugin_config = consumer_plugin_config
|
||||||
|
|
||||||
# init tokenizer
|
# init tokenizer
|
||||||
if tokenizer_config is None:
|
if tokenizer_config is None:
|
||||||
@ -78,9 +80,18 @@ class BaseProducer:
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unexpected backend {backend}")
|
raise ValueError(f"Unexpected backend {backend}")
|
||||||
|
|
||||||
def setup(self) -> None:
|
def setup(self, consumer_device_mesh_mapping: Dict[str, Any] = None, model_state_dict_keys: List = None) -> None:
|
||||||
|
self.consumer_device_mesh_mapping = consumer_device_mesh_mapping
|
||||||
|
self.model_state_dict_keys = model_state_dict_keys
|
||||||
cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_data_{self.producer_idx}")
|
cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_data_{self.producer_idx}")
|
||||||
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name="sync_model")
|
# cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name="sync_model_pp_stage_0")
|
||||||
|
for i in range(self.num_consumer_procs):
|
||||||
|
device_mesh_mapping = self.consumer_device_mesh_mapping[i]
|
||||||
|
device_mesh_mapping["rank"]
|
||||||
|
# TODO: support ep, sp
|
||||||
|
if device_mesh_mapping["dp_rank"] == 0 and device_mesh_mapping["tp_rank"] == 0:
|
||||||
|
group_name = f"sync_model_pp_stage_{device_mesh_mapping['pp_stage']}"
|
||||||
|
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name=group_name)
|
||||||
|
|
||||||
def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@ -125,14 +136,27 @@ class BaseProducer:
|
|||||||
):
|
):
|
||||||
self.model.llm.sleep() # revict KV_cache to avoid OOM
|
self.model.llm.sleep() # revict KV_cache to avoid OOM
|
||||||
# don't sync model for last iteration
|
# don't sync model for last iteration
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
state_dict = {}
|
||||||
print(
|
print(
|
||||||
f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}"
|
f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}"
|
||||||
)
|
)
|
||||||
torch.cuda.empty_cache()
|
for consumer_rank_id in range(self.num_consumer_procs):
|
||||||
|
device_mesh_mapping = self.consumer_device_mesh_mapping[consumer_rank_id]
|
||||||
|
device_mesh_mapping["rank"]
|
||||||
|
# TODO: support ep, sp
|
||||||
|
if device_mesh_mapping["dp_rank"] == 0 and device_mesh_mapping["tp_rank"] == 0:
|
||||||
|
group_name = f"sync_model_pp_stage_{device_mesh_mapping['pp_stage']}"
|
||||||
|
state_dict.update(
|
||||||
|
ray_broadcast_tensor_dict(
|
||||||
|
None, src=self.num_producers, device=self.device, group_name=group_name
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# check model sync integrity
|
||||||
|
assert len(state_dict) == len(
|
||||||
|
self.model_state_dict_keys
|
||||||
|
), f"state dict keys has {len(state_dict)} unique keys not equal original model with {len(self.model_state_dict_keys)} keys. Missing keys: {set(self.model_state_dict_keys)-set(state_dict.keys())}. Please kindly inform the developer."
|
||||||
|
|
||||||
state_dict = ray_broadcast_tensor_dict(
|
|
||||||
None, self.num_producers, device=self.device, group_name="sync_model"
|
|
||||||
)
|
|
||||||
self.load_state_dict(state_dict)
|
self.load_state_dict(state_dict)
|
||||||
del state_dict
|
del state_dict
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
@ -166,6 +190,7 @@ class SimpleProducer(BaseProducer):
|
|||||||
dataloaders_config,
|
dataloaders_config,
|
||||||
model_config,
|
model_config,
|
||||||
generate_config,
|
generate_config,
|
||||||
|
consumer_plugin_config=None,
|
||||||
tokenizer_config=None,
|
tokenizer_config=None,
|
||||||
microbatch_size=1,
|
microbatch_size=1,
|
||||||
backend="transformers",
|
backend="transformers",
|
||||||
@ -181,6 +206,7 @@ class SimpleProducer(BaseProducer):
|
|||||||
dataloaders_config,
|
dataloaders_config,
|
||||||
model_config,
|
model_config,
|
||||||
generate_config,
|
generate_config,
|
||||||
|
consumer_plugin_config,
|
||||||
tokenizer_config,
|
tokenizer_config,
|
||||||
microbatch_size,
|
microbatch_size,
|
||||||
backend,
|
backend,
|
||||||
|
Loading…
Reference in New Issue
Block a user