mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-18 07:57:46 +00:00
[NFC] polish colossalai/utils/multi_tensor_apply/multi_tensor_apply.py code style (#1559)
This commit is contained in:
parent
b0f4c0bddf
commit
318fbf1145
@ -778,4 +778,4 @@ class OneFOneBPipelineEngine(PipelineEngineBase):
|
|||||||
criterion: Callable = None,
|
criterion: Callable = None,
|
||||||
checkpoint: bool = False) -> None:
|
checkpoint: bool = False) -> None:
|
||||||
use_1F1B = True
|
use_1F1B = True
|
||||||
super().__init__(module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk, criterion, checkpoint)
|
super().__init__(module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk, criterion, checkpoint)
|
||||||
|
@ -26,13 +26,9 @@ class MultiTensorApply(object):
|
|||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Attempted to call MultiTensorApply method, but MultiTensorApply "
|
"Attempted to call MultiTensorApply method, but MultiTensorApply "
|
||||||
"is not available, possibly because Apex was installed without "
|
"is not available, possibly because Apex was installed without "
|
||||||
"--cpp_ext --cuda_ext. Original import error message:",
|
"--cpp_ext --cuda_ext. Original import error message:", MultiTensorApply.import_err)
|
||||||
MultiTensorApply.import_err)
|
|
||||||
|
|
||||||
def __call__(self, op, noop_flag_buffer, tensor_lists, *args):
|
def __call__(self, op, noop_flag_buffer, tensor_lists, *args):
|
||||||
self.check_avail()
|
self.check_avail()
|
||||||
|
|
||||||
return op(self.chunk_size,
|
return op(self.chunk_size, noop_flag_buffer, tensor_lists, *args)
|
||||||
noop_flag_buffer,
|
|
||||||
tensor_lists,
|
|
||||||
*args)
|
|
||||||
|
Loading…
Reference in New Issue
Block a user