[fix] fix handle name; rm useless comments;

This commit is contained in:
duanjunwen
2024-10-29 03:24:15 +00:00
parent 5aee4261a6
commit fafe049b83
4 changed files with 4 additions and 78 deletions

View File

@@ -107,7 +107,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
self.local_send_backward_buffer = []
# wait pp buffer
self.send_handles = []
self.wait_handles = []
def assert_buffer_empty(self):
# assert buffer is empty at end
@@ -129,7 +129,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
assert len(self.recv_backward_buffer[1]) == 0
assert len(self.local_send_forward_buffer) == 0
assert len(self.local_send_backward_buffer) == 0
# assert len(self.send_handles) == 0
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
"""Load a batch from data iterator.
@@ -891,7 +890,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# communication
communication_func = self.communication_map[scheduled_node.type]
wait_handle = communication_func(scheduled_node.chunk)
self.send_handles.append(wait_handle)
self.wait_handles.append(wait_handle)
elif scheduled_node.type == "F":
self.schedule_f(
scheduled_node=scheduled_node,
@@ -915,7 +914,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
model_chunk_id=scheduled_node.chunk,
optimizer=optimizer,
)
for h in self.send_handles:
for h in self.wait_handles:
for hh in h:
hh.wait()