mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-25 19:21:17 +00:00
[fix] fix ci test; add pytest;
This commit is contained in:
parent
283c9ff5d2
commit
9e0bd1af00
@ -37,7 +37,15 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
overlap_p2p: bool = True,
|
overlap_p2p: bool = True,
|
||||||
):
|
):
|
||||||
super().__init__(stage_manager)
|
super().__init__(stage_manager)
|
||||||
|
# batch info
|
||||||
self.num_microbatch = num_microbatch
|
self.num_microbatch = num_microbatch
|
||||||
|
self.microbatch_size = microbatch_size
|
||||||
|
self.num_model_chunks = num_model_chunks
|
||||||
|
self.batch: Any
|
||||||
|
self.batch_size: int
|
||||||
|
self.last_batch_size: Optional[int] = None
|
||||||
|
self.microbatch_offset: List[int]
|
||||||
|
|
||||||
self.collect_non_loss_data = None
|
self.collect_non_loss_data = None
|
||||||
self.forward_only = None
|
self.forward_only = None
|
||||||
self.schedules = schedule
|
self.schedules = schedule
|
||||||
@ -45,7 +53,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
self.do_post_validation = False
|
self.do_post_validation = False
|
||||||
self.is_first_run = True
|
self.is_first_run = True
|
||||||
self.optimizer = None
|
self.optimizer = None
|
||||||
self.num_model_chunks = num_model_chunks
|
|
||||||
|
|
||||||
# P2PMeta cache
|
# P2PMeta cache
|
||||||
# self.enable_metadata_cache = enable_metadata_cache
|
# self.enable_metadata_cache = enable_metadata_cache
|
||||||
@ -674,6 +681,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
"""
|
"""
|
||||||
Runs Zerobubble schedule, with communication between pipeline stages.
|
Runs Zerobubble schedule, with communication between pipeline stages.
|
||||||
"""
|
"""
|
||||||
|
# # prepare batch
|
||||||
|
self.load_batch(data_iter)
|
||||||
|
# print(f"self.batch {self.batch}; self.batch_size {self.batch_size}; self.num_microbatch {self.num_microbatch}; self.microbatch_size {self.microbatch_size}")
|
||||||
|
|
||||||
it = 0
|
it = 0
|
||||||
# while we still have schedules_node in self.schedules
|
# while we still have schedules_node in self.schedules
|
||||||
while it < len(self.schedules):
|
while it < len(self.schedules):
|
||||||
|
@ -157,6 +157,7 @@ def test_run_fwd_bwd_base(
|
|||||||
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};")
|
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};")
|
||||||
model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)
|
model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)
|
||||||
input0 = torch.rand(in_dim, out_dim, requires_grad=True).to(rank)
|
input0 = torch.rand(in_dim, out_dim, requires_grad=True).to(rank)
|
||||||
|
# data_iter = [input0]
|
||||||
|
|
||||||
input_base = input0.clone()
|
input_base = input0.clone()
|
||||||
model_base = deepcopy(model)
|
model_base = deepcopy(model)
|
||||||
@ -193,6 +194,7 @@ def test_run_fwd_bwd_base(
|
|||||||
scheduler.run_forward_backward(
|
scheduler.run_forward_backward(
|
||||||
model_chunk=local_chunk,
|
model_chunk=local_chunk,
|
||||||
input_obj=input0,
|
input_obj=input0,
|
||||||
|
# data_iter=iter(data_iter),
|
||||||
data_iter=None,
|
data_iter=None,
|
||||||
criterion=criterion,
|
criterion=criterion,
|
||||||
optimizer=None,
|
optimizer=None,
|
||||||
|
Loading…
Reference in New Issue
Block a user