mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-11 12:51:55 +00:00
update consumer and producer
This commit is contained in:
parent
87bac841ea
commit
bb8d370b44
@ -81,18 +81,16 @@ class BaseConsumer:
|
|||||||
|
|
||||||
# Init Hybrid ray process group
|
# Init Hybrid ray process group
|
||||||
for i in range(self.num_producers):
|
for i in range(self.num_producers):
|
||||||
cc.init_collective_group(self.world_size + 1, self.rank + 1, backend="hccl", group_name=f"sync_data_{i}")
|
cc.init_collective_group(self.world_size + 1, self.rank + 1, group_name=f"sync_data_{i}")
|
||||||
if self.pp_size > 1:
|
if self.pp_size > 1:
|
||||||
# use hybrid tp + pp
|
# use hybrid tp + pp
|
||||||
if self.tp_rank == 0 and self.dp_rank == 0:
|
if self.tp_rank == 0 and self.dp_rank == 0:
|
||||||
cc.init_collective_group(
|
cc.init_collective_group(
|
||||||
self.num_producers + 1, self.num_producers, backend="hccl", group_name=f"sync_model_{self.pp_rank}"
|
self.num_producers + 1, self.num_producers, group_name=f"sync_model_{self.pp_rank}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
cc.init_collective_group(
|
cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model")
|
||||||
self.num_producers + 1, self.num_producers, backend="hccl", group_name="sync_model"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.buffer = []
|
self.buffer = []
|
||||||
|
|
||||||
@ -154,7 +152,12 @@ class BaseConsumer:
|
|||||||
print(f"Saved model checkpoint at step {step + 1} in folder {save_path}")
|
print(f"Saved model checkpoint at step {step + 1} in folder {save_path}")
|
||||||
|
|
||||||
if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1:
|
if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1:
|
||||||
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
|
if self.pp_size > 1:
|
||||||
|
print(
|
||||||
|
f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {step}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
if self.pp_size > 1:
|
if self.pp_size > 1:
|
||||||
|
@ -82,16 +82,12 @@ class BaseProducer:
|
|||||||
self.consumer_pp_size = consumer_plugin_config["pp_size"] # consumer pp size
|
self.consumer_pp_size = consumer_plugin_config["pp_size"] # consumer pp size
|
||||||
|
|
||||||
def setup(self) -> None:
|
def setup(self) -> None:
|
||||||
cc.init_collective_group(
|
cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_data_{self.producer_idx}")
|
||||||
1 + self.num_consumer_procs, 0, backend="hccl", group_name=f"sync_data_{self.producer_idx}"
|
|
||||||
)
|
|
||||||
if self.consumer_pp_size > 1:
|
if self.consumer_pp_size > 1:
|
||||||
for i in range(self.consumer_pp_size):
|
for i in range(self.consumer_pp_size):
|
||||||
cc.init_collective_group(
|
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name=f"sync_model_{i}")
|
||||||
self.num_producers + 1, self.producer_idx, backend="hccl", group_name=f"sync_model_{i}"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
cc.init_collective_group(self.num_producers + 1, self.producer_idx, backend="hccl", group_name="sync_model")
|
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name="sync_model")
|
||||||
|
|
||||||
def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@ -142,8 +138,8 @@ class BaseProducer:
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
if self.consumer_pp_size > 1:
|
if self.consumer_pp_size > 1:
|
||||||
# TODO: loop load
|
|
||||||
for i in range(self.consumer_pp_size):
|
for i in range(self.consumer_pp_size):
|
||||||
|
print(f"[P{self.producer_idx}] Sync model PP stage {i}")
|
||||||
state_dict = ray_broadcast_tensor_dict(
|
state_dict = ray_broadcast_tensor_dict(
|
||||||
None, self.num_producers, device=self.device, group_name=f"sync_model_{i}"
|
None, self.num_producers, device=self.device, group_name=f"sync_model_{i}"
|
||||||
)
|
)
|
||||||
|
@ -58,7 +58,7 @@ if __name__ == "__main__":
|
|||||||
"--master_address", type=str, default=None, help="Master address for multi-node distributed training, Optional"
|
"--master_address", type=str, default=None, help="Master address for multi-node distributed training, Optional"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--master_port", type=int, default=29506, help="Master port for multi-node distributed training, Optional"
|
"--master_port", type=int, default=29505, help="Master port for multi-node distributed training, Optional"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Sampling parameters
|
# Sampling parameters
|
||||||
@ -129,7 +129,7 @@ if __name__ == "__main__":
|
|||||||
args.top_k = -1
|
args.top_k = -1
|
||||||
|
|
||||||
inference_model_config = dict(path=args.model)
|
inference_model_config = dict(path=args.model)
|
||||||
train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False, attn_implementation="eager")
|
train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False)
|
||||||
generate_config = dict(top_k=args.top_k, top_p=args.top_p, temperature=args.temperature)
|
generate_config = dict(top_k=args.top_k, top_p=args.top_p, temperature=args.temperature)
|
||||||
|
|
||||||
if args.backend == "transformers":
|
if args.backend == "transformers":
|
||||||
@ -155,7 +155,7 @@ if __name__ == "__main__":
|
|||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
enable_chunked_prefill=True,
|
enable_chunked_prefill=True,
|
||||||
max_model_len=args.max_new_tokens + args.max_prompt_tokens,
|
max_model_len=args.max_new_tokens + args.max_prompt_tokens,
|
||||||
tensor_parallel_size=2,
|
tensor_parallel_size=1,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
generate_config.update(
|
generate_config.update(
|
||||||
@ -223,7 +223,7 @@ if __name__ == "__main__":
|
|||||||
"zero_stage": 2,
|
"zero_stage": 2,
|
||||||
}, # for zero
|
}, # for zero
|
||||||
# plugin_config={
|
# plugin_config={
|
||||||
# "tp_size": 2,
|
# "tp_size": 1,
|
||||||
# "pp_size": 2,
|
# "pp_size": 2,
|
||||||
# "microbatch_size": max(
|
# "microbatch_size": max(
|
||||||
# 1, args.train_microbatch_size // 2
|
# 1, args.train_microbatch_size // 2
|
||||||
|
Loading…
Reference in New Issue
Block a user