mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-05 11:37:14 +00:00
[fix] fix redundant detach & clone; add buffer assertation in the end;
This commit is contained in:
parent
fed8b1587d
commit
7568b34626
@ -108,6 +108,27 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
# dy buffer for local send bwd
|
# dy buffer for local send bwd
|
||||||
self.local_send_backward_buffer = []
|
self.local_send_backward_buffer = []
|
||||||
|
|
||||||
|
def assert_buffer_empty(self):
|
||||||
|
# assert buuffer is empty at end
|
||||||
|
assert len(self.input_tensors[0]) == 0
|
||||||
|
assert len(self.input_tensors[1]) == 0
|
||||||
|
assert len(self.output_tensors[0]) == 0
|
||||||
|
assert len(self.output_tensors[1]) == 0
|
||||||
|
assert len(self.output_tensors_dw[0]) == 0
|
||||||
|
assert len(self.output_tensors_dw[1]) == 0
|
||||||
|
assert len(self.output_tensors_grad_dw[0]) == 0
|
||||||
|
assert len(self.output_tensors_grad_dw[1]) == 0
|
||||||
|
assert len(self.send_forward_buffer[0]) == 0
|
||||||
|
assert len(self.send_forward_buffer[1]) == 0
|
||||||
|
assert len(self.recv_forward_buffer[0]) == 0
|
||||||
|
assert len(self.recv_forward_buffer[1]) == 0
|
||||||
|
assert len(self.send_backward_buffer[0]) == 0
|
||||||
|
assert len(self.send_backward_buffer[1]) == 0
|
||||||
|
assert len(self.recv_backward_buffer[0]) == 0
|
||||||
|
assert len(self.recv_backward_buffer[1]) == 0
|
||||||
|
assert len(self.local_send_forward_buffer) == 0
|
||||||
|
assert len(self.local_send_backward_buffer) == 0
|
||||||
|
|
||||||
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
|
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
|
||||||
"""Load a batch from data iterator.
|
"""Load a batch from data iterator.
|
||||||
|
|
||||||
@ -546,7 +567,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
)
|
)
|
||||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
# We should not detach bwd LOSS
|
# We should not detach bwd LOSS
|
||||||
detached_output_obj = output_obj.clone()
|
pass
|
||||||
else:
|
else:
|
||||||
detached_output_obj = output_obj.clone().detach()
|
detached_output_obj = output_obj.clone().detach()
|
||||||
|
|
||||||
@ -555,7 +576,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
if model_chunk_id == 0:
|
if model_chunk_id == 0:
|
||||||
# is last stage; send to local_send_forward_buffer
|
# is last stage; send to local_send_forward_buffer
|
||||||
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
detached_output_obj = detached_output_obj.detach()
|
|
||||||
self.local_send_forward_buffer.append(detached_output_obj)
|
self.local_send_forward_buffer.append(detached_output_obj)
|
||||||
else:
|
else:
|
||||||
self.send_forward_buffer[model_chunk_id].append(detached_output_obj)
|
self.send_forward_buffer[model_chunk_id].append(detached_output_obj)
|
||||||
@ -816,4 +836,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
model_chunk, data_iter, criterion, optimizer, return_loss, return_outputs
|
model_chunk, data_iter, criterion, optimizer, return_loss, return_outputs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.assert_buffer_empty()
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
@ -558,7 +558,7 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
|||||||
batch_size = test_config["batch_size"]
|
batch_size = test_config["batch_size"]
|
||||||
num_layers = 8
|
num_layers = 8
|
||||||
assert num_layers % num_model_chunk == 0, f"Model with {num_layers} layer can not dist on {num_model_chunk} chunk"
|
assert num_layers % num_model_chunk == 0, f"Model with {num_layers} layer can not dist on {num_model_chunk} chunk"
|
||||||
in_dim = out_dim = 8192
|
in_dim = out_dim = 4096
|
||||||
before_init_memory = torch.cuda.memory_allocated() / 1024**3
|
before_init_memory = torch.cuda.memory_allocated() / 1024**3
|
||||||
print(f"Before init Model: {before_init_memory :.3f} GB on device {stage_manager.get_rank()};")
|
print(f"Before init Model: {before_init_memory :.3f} GB on device {stage_manager.get_rank()};")
|
||||||
model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)
|
model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)
|
||||||
@ -619,7 +619,6 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
|||||||
# output hid_dim * hid_dim * 4(fp32) / 1024**3
|
# output hid_dim * hid_dim * 4(fp32) / 1024**3
|
||||||
print(f"rank {rank}: {(after_pp_step_memory - after_init_memory)} <= {(in_dim * in_dim * 4 * 3 / 1024**3)}")
|
print(f"rank {rank}: {(after_pp_step_memory - after_init_memory)} <= {(in_dim * in_dim * 4 * 3 / 1024**3)}")
|
||||||
assert (after_pp_step_memory - after_init_memory) <= (in_dim * in_dim * 4 * 3 / 1024**3)
|
assert (after_pp_step_memory - after_init_memory) <= (in_dim * in_dim * 4 * 3 / 1024**3)
|
||||||
# pass
|
|
||||||
else:
|
else:
|
||||||
# rank0 will also hold output;
|
# rank0 will also hold output;
|
||||||
print(
|
print(
|
||||||
@ -628,7 +627,7 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
|||||||
assert round((after_pp_step_memory - after_init_memory), 5) <= round(
|
assert round((after_pp_step_memory - after_init_memory), 5) <= round(
|
||||||
(in_dim * in_dim * 4 * 3 / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5
|
(in_dim * in_dim * 4 * 3 / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5
|
||||||
)
|
)
|
||||||
# pass
|
|
||||||
##########################
|
##########################
|
||||||
# Fwd bwd for base
|
# Fwd bwd for base
|
||||||
##########################
|
##########################
|
||||||
|
Loading…
Reference in New Issue
Block a user