mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-21 10:50:56 +00:00
[hotfix]change to fit latest p2p (#1100)
* [CLI] add CLI launcher
* Revert "[CLI] add CLI launcher"
This reverts commit df7e6506d4
.
* [hotfix]change to fit latest p2p
* polish
* polish
This commit is contained in:
parent
72bd7c696b
commit
1e9f9c227f
@ -86,7 +86,14 @@ class PipelineSchedule(BaseSchedule):
|
|||||||
|
|
||||||
self.num_microbatches = num_microbatches
|
self.num_microbatches = num_microbatches
|
||||||
self.dtype = torch.float
|
self.dtype = torch.float
|
||||||
self.tensor_shape = tensor_shape
|
assert not isinstance(tensor_shape,
|
||||||
|
int), "tensor_shape type should be one of Union[torch.Size, List[int], Tuple[int]]."
|
||||||
|
if tensor_shape is None:
|
||||||
|
self.tensor_shape = tensor_shape
|
||||||
|
elif isinstance(tensor_shape, torch.Size):
|
||||||
|
self.tensor_shape = tensor_shape
|
||||||
|
else:
|
||||||
|
self.tensor_shape = torch.Size(tensor_shape)
|
||||||
self.scatter_gather_tensors = False
|
self.scatter_gather_tensors = False
|
||||||
if gpc.is_initialized(ParallelMode.PARALLEL_1D) and gpc.get_world_size(ParallelMode.PARALLEL_1D) > 1:
|
if gpc.is_initialized(ParallelMode.PARALLEL_1D) and gpc.get_world_size(ParallelMode.PARALLEL_1D) > 1:
|
||||||
self.scatter_gather_tensors = scatter_gather_tensors
|
self.scatter_gather_tensors = scatter_gather_tensors
|
||||||
|
Loading…
Reference in New Issue
Block a user