mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[nfc] fix typo colossalai/ applications/ (#3831)
* fix typo colossalai/autochunk auto_parallel amp * fix typo colossalai/auto_parallel nn utils etc. * fix typo colossalai/auto_parallel autochunk fx/passes etc. * fix typo docs/ * change placememt_policy to placement_policy in docs/ and examples/ * fix typo colossalai/ applications/
This commit is contained in:
@@ -21,7 +21,7 @@ __all__ = [
|
||||
|
||||
class BroadcastType(Enum):
|
||||
EQUAL = auto()
|
||||
PADDDING = auto()
|
||||
PADDING = auto()
|
||||
MULTIPLE = auto()
|
||||
|
||||
|
||||
@@ -69,18 +69,18 @@ def get_broadcast_dim_info(logical_shape, physical_shape):
|
||||
for i in range(logical_num_dims):
|
||||
# get the trailing dim size
|
||||
logical_dim_idx = logical_num_dims - i - 1
|
||||
phyiscal_dim_idx = physical_num_dims - i - 1
|
||||
physical_dim_idx = physical_num_dims - i - 1
|
||||
logical_dim_size = logical_shape[logical_dim_idx]
|
||||
|
||||
if phyiscal_dim_idx >= 0:
|
||||
physical_dim_size = physical_shape[phyiscal_dim_idx]
|
||||
if physical_dim_idx >= 0:
|
||||
physical_dim_size = physical_shape[physical_dim_idx]
|
||||
|
||||
if physical_dim_size == logical_dim_size:
|
||||
logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.EQUAL
|
||||
elif physical_dim_size == 1 and physical_dim_size != logical_dim_size:
|
||||
logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.MULTIPLE
|
||||
else:
|
||||
logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.PADDDING
|
||||
logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.PADDING
|
||||
|
||||
return logical_dim_broadcast_info
|
||||
|
||||
@@ -117,7 +117,7 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe
|
||||
for shape_dim, mesh_dim in logical_dim_partition.items():
|
||||
logical_broadcast_type = logical_dim_broadcast_info[shape_dim]
|
||||
|
||||
if logical_broadcast_type == BroadcastType.PADDDING or logical_broadcast_type == BroadcastType.MULTIPLE:
|
||||
if logical_broadcast_type == BroadcastType.PADDING or logical_broadcast_type == BroadcastType.MULTIPLE:
|
||||
removed_dims.extend(mesh_dim)
|
||||
else:
|
||||
# get the corresponding physical dim
|
||||
|
Reference in New Issue
Block a user