[fix] fix send recv signature;

This commit is contained in:
duanjunwen 2024-10-29 03:33:58 +00:00
parent fafe049b83
commit fa3ccda8ee

View File

@ -1,5 +1,5 @@
from functools import partial from functools import partial
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union from typing import Any, Callable, Dict, Iterable, List, Optional, Union
import torch import torch
import torch.cuda import torch.cuda
@ -206,7 +206,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
model_chunk_id = self.num_model_chunks - model_chunk_id - 1 model_chunk_id = self.num_model_chunks - model_chunk_id - 1
return model_chunk_id return model_chunk_id
def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Tuple[Any, List]: def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> List:
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage. """Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
For ZBV. For ZBV.
@ -267,7 +267,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# return input_tensor, wait_handles # return input_tensor, wait_handles
return wait_handles return wait_handles
def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Tuple[Any, List]: def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> List:
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage. """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
For ZBV. For ZBV.