mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[fix] fix ci test; add pytest;
This commit is contained in:
@@ -37,7 +37,15 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
overlap_p2p: bool = True,
|
||||
):
|
||||
super().__init__(stage_manager)
|
||||
# batch info
|
||||
self.num_microbatch = num_microbatch
|
||||
self.microbatch_size = microbatch_size
|
||||
self.num_model_chunks = num_model_chunks
|
||||
self.batch: Any
|
||||
self.batch_size: int
|
||||
self.last_batch_size: Optional[int] = None
|
||||
self.microbatch_offset: List[int]
|
||||
|
||||
self.collect_non_loss_data = None
|
||||
self.forward_only = None
|
||||
self.schedules = schedule
|
||||
@@ -45,7 +53,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
self.do_post_validation = False
|
||||
self.is_first_run = True
|
||||
self.optimizer = None
|
||||
self.num_model_chunks = num_model_chunks
|
||||
|
||||
# P2PMeta cache
|
||||
# self.enable_metadata_cache = enable_metadata_cache
|
||||
@@ -674,6 +681,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
"""
|
||||
Runs Zerobubble schedule, with communication between pipeline stages.
|
||||
"""
|
||||
# # prepare batch
|
||||
self.load_batch(data_iter)
|
||||
# print(f"self.batch {self.batch}; self.batch_size {self.batch_size}; self.num_microbatch {self.num_microbatch}; self.microbatch_size {self.microbatch_size}")
|
||||
|
||||
it = 0
|
||||
# while we still have schedules_node in self.schedules
|
||||
while it < len(self.schedules):
|
||||
|
Reference in New Issue
Block a user