[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2024-07-12 07:33:44 +00:00
parent 1f1b856354
commit 51f916b11d
3 changed files with 9 additions and 9 deletions

View File

@@ -11,8 +11,8 @@ from colossalai.accelerator import get_accelerator
from colossalai.interface import OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.quantization.fp8 import cast_from_fp8_pipeline, cast_to_fp8_pipeline
from colossalai.utils import get_current_device
from colossalai.quantization.fp8 import cast_to_fp8_pipeline, cast_from_fp8_pipeline
from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device
from .base import PipelineSchedule
@@ -59,6 +59,7 @@ class InterleavedSchedule(PipelineSchedule):
self.grad_metadata_recv = None
self.fp8_communication = fp8_communication
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
"""Load a batch from data iterator.