[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2024-07-12 07:33:44 +00:00
parent 1f1b856354
commit 51f916b11d
3 changed files with 9 additions and 9 deletions

View File

@@ -10,8 +10,8 @@ from colossalai.accelerator import get_accelerator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.quantization.fp8 import cast_from_fp8_pipeline, cast_to_fp8_pipeline
from colossalai.utils import get_current_device
from colossalai.quantization.fp8 import cast_to_fp8_pipeline, cast_from_fp8_pipeline
from ._utils import (
detach,
@@ -172,6 +172,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
if self.fp8_communication:
cast_from_fp8_pipeline(output_tensor, del_metadata=False)
def send_backward(self, input_tensor_grad: Any, prev_rank: int = None) -> None:
"""Sends the gradient tensor to the previous stage in pipeline.
For 1F1B.