mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
merge
This commit is contained in:
@@ -106,9 +106,14 @@ class BaseConsumer:
|
||||
f"Consumer{self.rank} num_update: {self.num_update_per_episode}, num_recv: {self.num_recv_per_update}, nmb: {self.num_microbatches}"
|
||||
)
|
||||
for episode in range(self.num_episodes):
|
||||
with tqdm(range(self.num_update_per_episode), desc=f"Episode {episode}", disable=self.rank != 0) as pbar:
|
||||
with tqdm(
|
||||
range(self.num_update_per_episode),
|
||||
desc=f"Episode {episode} with rollout step(s)",
|
||||
disable=self.rank != 0,
|
||||
) as pbar:
|
||||
for step in pbar:
|
||||
i = 0
|
||||
allow_sync_model = False
|
||||
for _ in range(self.num_recv_per_update):
|
||||
# receive data from producers
|
||||
for r in range(self.num_producers):
|
||||
@@ -126,15 +131,15 @@ class BaseConsumer:
|
||||
]
|
||||
batch = bind_batch(batches)
|
||||
batch = post_recv(batch)
|
||||
loss, num_excessive_prompts = self.step(i, pbar, **batch)
|
||||
self.buffer = (
|
||||
self.buffer[
|
||||
(self.dp_rank + 1) * self.minibatch_size
|
||||
- num_excessive_prompts : (self.dp_rank + 1) * self.minibatch_size
|
||||
]
|
||||
+ self.buffer[self.dp_size * self.minibatch_size :]
|
||||
)
|
||||
loss, excessive_prompts_idx = self.step(i, pbar, **batch)
|
||||
|
||||
if excessive_prompts_idx is not None:
|
||||
excessive_prompts = [self.buffer[idx] for idx in excessive_prompts_idx]
|
||||
self.buffer = excessive_prompts + self.buffer[self.dp_size * self.minibatch_size :]
|
||||
else:
|
||||
self.buffer = self.buffer[self.dp_size * self.minibatch_size :]
|
||||
if loss is not None:
|
||||
allow_sync_model = True
|
||||
pbar.set_postfix({"loss": loss})
|
||||
i += 1
|
||||
if self.lr_scheduler is not None:
|
||||
@@ -148,29 +153,31 @@ class BaseConsumer:
|
||||
print(f"Saved model checkpoint at step {step + 1} in folder {save_path}")
|
||||
|
||||
if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1:
|
||||
if self.pp_size > 1:
|
||||
print(
|
||||
f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {step}"
|
||||
)
|
||||
else:
|
||||
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
|
||||
torch.cuda.empty_cache()
|
||||
state_dict = self.state_dict()
|
||||
if self.pp_size > 1:
|
||||
if self.tp_rank == 0 and self.dp_rank == 0:
|
||||
ray_broadcast_tensor_dict(
|
||||
state_dict,
|
||||
src=self.num_producers,
|
||||
device=self.device,
|
||||
group_name=f"sync_model_{self.pp_rank}",
|
||||
if allow_sync_model:
|
||||
if self.pp_size > 1:
|
||||
print(
|
||||
f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {step}"
|
||||
)
|
||||
else:
|
||||
if self.rank == 0:
|
||||
ray_broadcast_tensor_dict(
|
||||
state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
|
||||
)
|
||||
del state_dict
|
||||
torch.cuda.empty_cache()
|
||||
else:
|
||||
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
|
||||
torch.cuda.empty_cache()
|
||||
state_dict = self.state_dict()
|
||||
if self.pp_size > 1:
|
||||
if self.tp_rank == 0 and self.dp_rank == 0:
|
||||
ray_broadcast_tensor_dict(
|
||||
state_dict,
|
||||
src=self.num_producers,
|
||||
device=self.device,
|
||||
group_name=f"sync_model_{self.pp_rank}",
|
||||
)
|
||||
else:
|
||||
if self.rank == 0:
|
||||
ray_broadcast_tensor_dict(
|
||||
state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
|
||||
)
|
||||
del state_dict
|
||||
torch.cuda.empty_cache()
|
||||
allow_sync_model = False
|
||||
|
||||
|
||||
@ray.remote
|
||||
|
Reference in New Issue
Block a user