mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
@@ -10,8 +10,8 @@ from colossalai.accelerator import get_accelerator
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.quantization.fp8 import cast_from_fp8_pipeline, cast_to_fp8_pipeline
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.quantization.fp8 import cast_to_fp8_pipeline, cast_from_fp8_pipeline
|
||||
|
||||
from ._utils import (
|
||||
detach,
|
||||
@@ -172,6 +172,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
|
||||
if self.fp8_communication:
|
||||
cast_from_fp8_pipeline(output_tensor, del_metadata=False)
|
||||
|
||||
def send_backward(self, input_tensor_grad: Any, prev_rank: int = None) -> None:
|
||||
"""Sends the gradient tensor to the previous stage in pipeline.
|
||||
For 1F1B.
|
||||
|
Reference in New Issue
Block a user