mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[fix] fix handle name; rm useless comments;
This commit is contained in:
@@ -107,7 +107,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
self.local_send_backward_buffer = []
|
||||
|
||||
# wait pp buffer
|
||||
self.send_handles = []
|
||||
self.wait_handles = []
|
||||
|
||||
def assert_buffer_empty(self):
|
||||
# assert buffer is empty at end
|
||||
@@ -129,7 +129,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
assert len(self.recv_backward_buffer[1]) == 0
|
||||
assert len(self.local_send_forward_buffer) == 0
|
||||
assert len(self.local_send_backward_buffer) == 0
|
||||
# assert len(self.send_handles) == 0
|
||||
|
||||
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
|
||||
"""Load a batch from data iterator.
|
||||
@@ -891,7 +890,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
# communication
|
||||
communication_func = self.communication_map[scheduled_node.type]
|
||||
wait_handle = communication_func(scheduled_node.chunk)
|
||||
self.send_handles.append(wait_handle)
|
||||
self.wait_handles.append(wait_handle)
|
||||
elif scheduled_node.type == "F":
|
||||
self.schedule_f(
|
||||
scheduled_node=scheduled_node,
|
||||
@@ -915,7 +914,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
model_chunk_id=scheduled_node.chunk,
|
||||
optimizer=optimizer,
|
||||
)
|
||||
for h in self.send_handles:
|
||||
for h in self.wait_handles:
|
||||
for hh in h:
|
||||
hh.wait()
|
||||
|
||||
|
Reference in New Issue
Block a user