mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-15 06:00:07 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
@@ -11,8 +11,8 @@ from colossalai.accelerator import get_accelerator
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.quantization.fp8 import cast_from_fp8_pipeline, cast_to_fp8_pipeline
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.quantization.fp8 import cast_to_fp8_pipeline, cast_from_fp8_pipeline
|
||||
|
||||
from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device
|
||||
from .base import PipelineSchedule
|
||||
@@ -59,6 +59,7 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
self.grad_metadata_recv = None
|
||||
|
||||
self.fp8_communication = fp8_communication
|
||||
|
||||
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
|
||||
"""Load a batch from data iterator.
|
||||
|
||||
|
@@ -10,8 +10,8 @@ from colossalai.accelerator import get_accelerator
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.quantization.fp8 import cast_from_fp8_pipeline, cast_to_fp8_pipeline
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.quantization.fp8 import cast_to_fp8_pipeline, cast_from_fp8_pipeline
|
||||
|
||||
from ._utils import (
|
||||
detach,
|
||||
@@ -172,6 +172,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
|
||||
if self.fp8_communication:
|
||||
cast_from_fp8_pipeline(output_tensor, del_metadata=False)
|
||||
|
||||
def send_backward(self, input_tensor_grad: Any, prev_rank: int = None) -> None:
|
||||
"""Sends the gradient tensor to the previous stage in pipeline.
|
||||
For 1F1B.
|
||||
|
Reference in New Issue
Block a user