mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-25 06:52:46 +00:00
[fix] fix bwd w input;
This commit is contained in:
parent
349272c71f
commit
a115106f8d
@ -89,8 +89,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
self.input_tensors = [[], []]
|
self.input_tensors = [[], []]
|
||||||
self.output_tensors = [[], []]
|
self.output_tensors = [[], []]
|
||||||
|
|
||||||
# x & y & dy buffer for schedule w
|
# y & dy buffer for schedule w
|
||||||
self.input_tensors_dw = [[], []]
|
|
||||||
self.output_tensors_dw = [[], []]
|
self.output_tensors_dw = [[], []]
|
||||||
self.output_tensors_grad_dw = [[], []]
|
self.output_tensors_grad_dw = [[], []]
|
||||||
|
|
||||||
@ -111,8 +110,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
assert len(self.input_tensors[1]) == 0
|
assert len(self.input_tensors[1]) == 0
|
||||||
assert len(self.output_tensors[0]) == 0
|
assert len(self.output_tensors[0]) == 0
|
||||||
assert len(self.output_tensors[1]) == 0
|
assert len(self.output_tensors[1]) == 0
|
||||||
assert len(self.input_tensors_dw[0]) == 0
|
|
||||||
assert len(self.input_tensors_dw[1]) == 0
|
|
||||||
assert len(self.output_tensors_dw[0]) == 0
|
assert len(self.output_tensors_dw[0]) == 0
|
||||||
assert len(self.output_tensors_dw[1]) == 0
|
assert len(self.output_tensors_dw[1]) == 0
|
||||||
assert len(self.output_tensors_grad_dw[0]) == 0
|
assert len(self.output_tensors_grad_dw[0]) == 0
|
||||||
@ -528,7 +525,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
model_chunk: Union[ModuleList, Module],
|
model_chunk: Union[ModuleList, Module],
|
||||||
model_chunk_id: int,
|
model_chunk_id: int,
|
||||||
optimizer: OptimizerWrapper,
|
optimizer: OptimizerWrapper,
|
||||||
input_obj: Optional[dict],
|
|
||||||
output_obj: Union[dict, torch.Tensor],
|
output_obj: Union[dict, torch.Tensor],
|
||||||
output_obj_grad: Optional[dict],
|
output_obj_grad: Optional[dict],
|
||||||
):
|
):
|
||||||
@ -555,7 +551,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
output_obj_.append(output_obj) # LOSS
|
output_obj_.append(output_obj) # LOSS
|
||||||
output_obj_grad_.append(None) # None
|
output_obj_grad_.append(None) # None
|
||||||
else:
|
else:
|
||||||
for k, v in input_obj.items():
|
# for k, v in input_obj.items():
|
||||||
|
# if v.requires_grad:
|
||||||
|
# output_obj_.append(output_obj[k])
|
||||||
|
# output_obj_grad_.append(output_obj_grad[k])
|
||||||
|
for k, v in output_obj.items():
|
||||||
if v.requires_grad:
|
if v.requires_grad:
|
||||||
output_obj_.append(output_obj[k])
|
output_obj_.append(output_obj[k])
|
||||||
output_obj_grad_.append(output_obj_grad[k])
|
output_obj_grad_.append(output_obj_grad[k])
|
||||||
@ -636,10 +636,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
# add input and output object for backward b
|
# add input and output object for backward b
|
||||||
if input_obj is not None:
|
if input_obj is not None:
|
||||||
self.input_tensors[model_chunk_id].append(input_obj)
|
self.input_tensors[model_chunk_id].append(input_obj)
|
||||||
self.input_tensors_dw[model_chunk_id].append(input_obj)
|
|
||||||
else:
|
else:
|
||||||
self.input_tensors[model_chunk_id].append(micro_batch)
|
self.input_tensors[model_chunk_id].append(micro_batch)
|
||||||
self.input_tensors_dw[model_chunk_id].append(micro_batch)
|
|
||||||
|
|
||||||
# for bwd b&w, we only need the graph(grad_fn) of output_obj
|
# for bwd b&w, we only need the graph(grad_fn) of output_obj
|
||||||
# Do not deallocate loss, deallocate other output_obj;
|
# Do not deallocate loss, deallocate other output_obj;
|
||||||
@ -760,7 +758,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# get y & dy from buffer
|
# get y & dy from buffer
|
||||||
input_obj = self.input_tensors_dw[model_chunk_id].pop(0)
|
|
||||||
output_obj = self.output_tensors_dw[model_chunk_id].pop(0)
|
output_obj = self.output_tensors_dw[model_chunk_id].pop(0)
|
||||||
output_obj_grad = self.output_tensors_grad_dw[model_chunk_id].pop(0)
|
output_obj_grad = self.output_tensors_grad_dw[model_chunk_id].pop(0)
|
||||||
|
|
||||||
@ -768,7 +765,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
model_chunk=model_chunk,
|
model_chunk=model_chunk,
|
||||||
model_chunk_id=model_chunk_id,
|
model_chunk_id=model_chunk_id,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
input_obj=input_obj,
|
|
||||||
output_obj=output_obj,
|
output_obj=output_obj,
|
||||||
output_obj_grad=output_obj_grad,
|
output_obj_grad=output_obj_grad,
|
||||||
)
|
)
|
||||||
|
@ -596,7 +596,7 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
|||||||
batch_size = test_config["batch_size"]
|
batch_size = test_config["batch_size"]
|
||||||
num_layers = 8
|
num_layers = 8
|
||||||
assert num_layers % num_model_chunk == 0, f"Model with {num_layers} layer can not dist on {num_model_chunk} chunk"
|
assert num_layers % num_model_chunk == 0, f"Model with {num_layers} layer can not dist on {num_model_chunk} chunk"
|
||||||
in_dim = out_dim = 4096
|
in_dim = out_dim = 1024
|
||||||
before_init_memory = torch.cuda.memory_allocated() / 1024**3
|
before_init_memory = torch.cuda.memory_allocated() / 1024**3
|
||||||
print(f"Before init Model: {before_init_memory :.3f} GB on device {stage_manager.get_rank()};")
|
print(f"Before init Model: {before_init_memory :.3f} GB on device {stage_manager.get_rank()};")
|
||||||
model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)
|
model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)
|
||||||
|
Loading…
Reference in New Issue
Block a user