mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-25 19:21:17 +00:00
[fix] fix fwd branch, fwd pass both micro_batch & internal_inputs'
This commit is contained in:
parent
b6616f544e
commit
1739df423c
@ -429,18 +429,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
# Only attention_mask from micro_batch is used
|
# Only attention_mask from micro_batch is used
|
||||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||||
# fwd calculate
|
# fwd calculate
|
||||||
if isinstance(model_chunk, ModuleList):
|
internal_inputs = {} if input_obj is None else input_obj
|
||||||
# fwd for ModuleList model
|
# internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
|
||||||
if input_obj is None:
|
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, internal_inputs)
|
||||||
output_obj = model_chunk[model_chunk_id](**micro_batch)
|
|
||||||
else:
|
|
||||||
output_obj = model_chunk[model_chunk_id](**input_obj)
|
|
||||||
else:
|
|
||||||
# fwd for shardformer
|
|
||||||
# NOTE: in shardformer, each device still has the entire model, so we need to use relevant stage layers
|
|
||||||
internal_inputs = {} if input_obj is None else input_obj
|
|
||||||
# internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
|
|
||||||
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, internal_inputs)
|
|
||||||
|
|
||||||
# last layer in model
|
# last layer in model
|
||||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
|
@ -48,7 +48,7 @@ def pp_linear_fwd(
|
|||||||
return forward(hidden_states)
|
return forward(hidden_states)
|
||||||
# fwd start
|
# fwd start
|
||||||
elif stage_mgr.is_first_stage() and model_chunk_id == 0:
|
elif stage_mgr.is_first_stage() and model_chunk_id == 0:
|
||||||
return {"hidden_states": forward(hidden_states)}
|
return {"hidden_states": forward(data)}
|
||||||
# fwd middle
|
# fwd middle
|
||||||
else:
|
else:
|
||||||
return {"hidden_states": forward(hidden_states)}
|
return {"hidden_states": forward(hidden_states)}
|
||||||
@ -601,7 +601,7 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
|||||||
print(f"Before init Model: {before_init_memory :.3f} GB on device {stage_manager.get_rank()};")
|
print(f"Before init Model: {before_init_memory :.3f} GB on device {stage_manager.get_rank()};")
|
||||||
model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)
|
model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)
|
||||||
# data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)]
|
# data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)]
|
||||||
data_iter = {"hidden_states": torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)}
|
data_iter = {"data": torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)}
|
||||||
# input_base = [t.clone() for t in data_iter]
|
# input_base = [t.clone() for t in data_iter]
|
||||||
input_base = {k: v.clone() for k, v in data_iter.items()}
|
input_base = {k: v.clone() for k, v in data_iter.items()}
|
||||||
model_base = deepcopy(model)
|
model_base = deepcopy(model)
|
||||||
@ -694,7 +694,7 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
|||||||
# Fwd bwd for base
|
# Fwd bwd for base
|
||||||
##########################
|
##########################
|
||||||
# fwd & bwd
|
# fwd & bwd
|
||||||
output_base = model_base(input_base["hidden_states"])
|
output_base = model_base(input_base["data"])
|
||||||
loss_base = criterion_base(output_base)
|
loss_base = criterion_base(output_base)
|
||||||
loss_base.backward()
|
loss_base.backward()
|
||||||
optimizer_base.step()
|
optimizer_base.step()
|
||||||
|
Loading…
Reference in New Issue
Block a user