mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-25 06:52:46 +00:00
Merge pull request #6288 from duanjunwen/support_hybrid_model_sync
[Feat] support hybrid parallel model sync
This commit is contained in:
commit
87bac841ea
@ -59,10 +59,6 @@ class BaseConsumer:
|
|||||||
self.lr_scheduler = None
|
self.lr_scheduler = None
|
||||||
|
|
||||||
def setup(self) -> None:
|
def setup(self) -> None:
|
||||||
for i in range(self.num_producers):
|
|
||||||
cc.init_collective_group(self.world_size + 1, self.rank + 1, group_name=f"sync_data_{i}")
|
|
||||||
if self.rank == 0:
|
|
||||||
cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model")
|
|
||||||
launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0)
|
launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0)
|
||||||
|
|
||||||
plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2)
|
plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2)
|
||||||
@ -77,8 +73,26 @@ class BaseConsumer:
|
|||||||
self.booster = Booster(plugin=self.plugin)
|
self.booster = Booster(plugin=self.plugin)
|
||||||
self.dp_rank = dist.get_rank(self.plugin.dp_group)
|
self.dp_rank = dist.get_rank(self.plugin.dp_group)
|
||||||
self.tp_rank = dist.get_rank(self.plugin.tp_group)
|
self.tp_rank = dist.get_rank(self.plugin.tp_group)
|
||||||
|
self.pp_rank = dist.get_rank(self.plugin.pp_group)
|
||||||
|
|
||||||
self.dp_size = dist.get_world_size(self.plugin.dp_group)
|
self.dp_size = dist.get_world_size(self.plugin.dp_group)
|
||||||
|
self.tp_size = dist.get_world_size(self.plugin.tp_group)
|
||||||
|
self.pp_size = dist.get_world_size(self.plugin.pp_group)
|
||||||
|
|
||||||
|
# Init Hybrid ray process group
|
||||||
|
for i in range(self.num_producers):
|
||||||
|
cc.init_collective_group(self.world_size + 1, self.rank + 1, backend="hccl", group_name=f"sync_data_{i}")
|
||||||
|
if self.pp_size > 1:
|
||||||
|
# use hybrid tp + pp
|
||||||
|
if self.tp_rank == 0 and self.dp_rank == 0:
|
||||||
|
cc.init_collective_group(
|
||||||
|
self.num_producers + 1, self.num_producers, backend="hccl", group_name=f"sync_model_{self.pp_rank}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if self.rank == 0:
|
||||||
|
cc.init_collective_group(
|
||||||
|
self.num_producers + 1, self.num_producers, backend="hccl", group_name="sync_model"
|
||||||
|
)
|
||||||
|
|
||||||
self.buffer = []
|
self.buffer = []
|
||||||
|
|
||||||
@ -143,10 +157,19 @@ class BaseConsumer:
|
|||||||
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
|
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
if self.rank == 0:
|
if self.pp_size > 1:
|
||||||
ray_broadcast_tensor_dict(
|
if self.tp_rank == 0 and self.dp_rank == 0:
|
||||||
state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
|
ray_broadcast_tensor_dict(
|
||||||
)
|
state_dict,
|
||||||
|
src=self.num_producers,
|
||||||
|
device=self.device,
|
||||||
|
group_name=f"sync_model_{self.pp_rank}",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if self.rank == 0:
|
||||||
|
ray_broadcast_tensor_dict(
|
||||||
|
state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
|
||||||
|
)
|
||||||
del state_dict
|
del state_dict
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
@ -57,7 +57,9 @@ def launch_distributed(
|
|||||||
else:
|
else:
|
||||||
core_consumer = ALGO_MAP.get(core_algo, SimpleConsumer)
|
core_consumer = ALGO_MAP.get(core_algo, SimpleConsumer)
|
||||||
|
|
||||||
train_dp_size = get_dp_size_fast(num_producers, plugin_config)
|
train_dp_size = (
|
||||||
|
get_dp_size_fast(num_producers, plugin_config) if get_dp_size_fast(num_producers, plugin_config) else 1
|
||||||
|
)
|
||||||
assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0
|
assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0
|
||||||
|
|
||||||
dataset_path = dataset_config["path"]
|
dataset_path = dataset_config["path"]
|
||||||
@ -82,6 +84,7 @@ def launch_distributed(
|
|||||||
microbatch_size=inference_microbatch_size,
|
microbatch_size=inference_microbatch_size,
|
||||||
backend=inference_backend,
|
backend=inference_backend,
|
||||||
num_generations=num_generations,
|
num_generations=num_generations,
|
||||||
|
consumer_plugin_config=plugin_config,
|
||||||
)
|
)
|
||||||
procs.append(producer)
|
procs.append(producer)
|
||||||
generate_config_consumer = copy.deepcopy(generate_config)
|
generate_config_consumer = copy.deepcopy(generate_config)
|
||||||
|
@ -29,6 +29,7 @@ class BaseProducer:
|
|||||||
tokenizer_config: Optional[Dict[str, Any]] = None,
|
tokenizer_config: Optional[Dict[str, Any]] = None,
|
||||||
microbatch_size: int = 1,
|
microbatch_size: int = 1,
|
||||||
backend: str = "transformers",
|
backend: str = "transformers",
|
||||||
|
consumer_plugin_config: Dict[str, Any] = None,
|
||||||
):
|
):
|
||||||
self.producer_idx = producer_idx
|
self.producer_idx = producer_idx
|
||||||
self.num_producers = num_producers
|
self.num_producers = num_producers
|
||||||
@ -78,9 +79,19 @@ class BaseProducer:
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unexpected backend {backend}")
|
raise ValueError(f"Unexpected backend {backend}")
|
||||||
|
|
||||||
|
self.consumer_pp_size = consumer_plugin_config["pp_size"] # consumer pp size
|
||||||
|
|
||||||
def setup(self) -> None:
|
def setup(self) -> None:
|
||||||
cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_data_{self.producer_idx}")
|
cc.init_collective_group(
|
||||||
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name="sync_model")
|
1 + self.num_consumer_procs, 0, backend="hccl", group_name=f"sync_data_{self.producer_idx}"
|
||||||
|
)
|
||||||
|
if self.consumer_pp_size > 1:
|
||||||
|
for i in range(self.consumer_pp_size):
|
||||||
|
cc.init_collective_group(
|
||||||
|
self.num_producers + 1, self.producer_idx, backend="hccl", group_name=f"sync_model_{i}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cc.init_collective_group(self.num_producers + 1, self.producer_idx, backend="hccl", group_name="sync_model")
|
||||||
|
|
||||||
def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@ -130,10 +141,18 @@ class BaseProducer:
|
|||||||
)
|
)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
state_dict = ray_broadcast_tensor_dict(
|
if self.consumer_pp_size > 1:
|
||||||
None, self.num_producers, device=self.device, group_name="sync_model"
|
# TODO: loop load
|
||||||
)
|
for i in range(self.consumer_pp_size):
|
||||||
self.load_state_dict(state_dict)
|
state_dict = ray_broadcast_tensor_dict(
|
||||||
|
None, self.num_producers, device=self.device, group_name=f"sync_model_{i}"
|
||||||
|
)
|
||||||
|
self.load_state_dict(state_dict)
|
||||||
|
else:
|
||||||
|
state_dict = ray_broadcast_tensor_dict(
|
||||||
|
None, self.num_producers, device=self.device, group_name="sync_model"
|
||||||
|
)
|
||||||
|
self.load_state_dict(state_dict)
|
||||||
del state_dict
|
del state_dict
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
|
if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
|
||||||
@ -170,6 +189,7 @@ class SimpleProducer(BaseProducer):
|
|||||||
microbatch_size=1,
|
microbatch_size=1,
|
||||||
backend="transformers",
|
backend="transformers",
|
||||||
num_generations: int = 8,
|
num_generations: int = 8,
|
||||||
|
consumer_plugin_config=None,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
producer_idx,
|
producer_idx,
|
||||||
@ -184,6 +204,7 @@ class SimpleProducer(BaseProducer):
|
|||||||
tokenizer_config,
|
tokenizer_config,
|
||||||
microbatch_size,
|
microbatch_size,
|
||||||
backend,
|
backend,
|
||||||
|
consumer_plugin_config,
|
||||||
)
|
)
|
||||||
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
|
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
|
||||||
|
|
||||||
|
@ -129,7 +129,7 @@ if __name__ == "__main__":
|
|||||||
args.top_k = -1
|
args.top_k = -1
|
||||||
|
|
||||||
inference_model_config = dict(path=args.model)
|
inference_model_config = dict(path=args.model)
|
||||||
train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False)
|
train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False, attn_implementation="eager")
|
||||||
generate_config = dict(top_k=args.top_k, top_p=args.top_p, temperature=args.temperature)
|
generate_config = dict(top_k=args.top_k, top_p=args.top_p, temperature=args.temperature)
|
||||||
|
|
||||||
if args.backend == "transformers":
|
if args.backend == "transformers":
|
||||||
@ -155,7 +155,7 @@ if __name__ == "__main__":
|
|||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
enable_chunked_prefill=True,
|
enable_chunked_prefill=True,
|
||||||
max_model_len=args.max_new_tokens + args.max_prompt_tokens,
|
max_model_len=args.max_new_tokens + args.max_prompt_tokens,
|
||||||
tensor_parallel_size=1,
|
tensor_parallel_size=2,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
generate_config.update(
|
generate_config.update(
|
||||||
|
Loading…
Reference in New Issue
Block a user