mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 21:40:02 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,13 +1,12 @@
|
||||
from colossalai.legacy.tensor import ComputePattern, ProcessGroup, ShardSpec, distspec
|
||||
from colossalai.legacy.tensor import ComputePattern, ProcessGroup, ShardSpec
|
||||
|
||||
from .colo_module import ColoModule
|
||||
|
||||
|
||||
class ColoLinear(ColoModule):
|
||||
|
||||
def __init__(self):
|
||||
super(ColoLinear, self).__init__()
|
||||
self._register_shard_params(['weight', 'bias'])
|
||||
self._register_shard_params(["weight", "bias"])
|
||||
|
||||
def register(self, compute_pattern, pg: ProcessGroup):
|
||||
if not compute_pattern in self._allowed_patterns:
|
||||
@@ -19,21 +18,15 @@ class ColoLinear(ColoModule):
|
||||
_compute_pattern = ComputePattern.TP1D
|
||||
self._register_allowed_patterns(
|
||||
compute_pattern=_compute_pattern,
|
||||
dist_specs={
|
||||
'weight': ShardSpec([-1], [pg.tp_world_size()]),
|
||||
'bias': None
|
||||
},
|
||||
mode='row',
|
||||
dist_specs={"weight": ShardSpec([-1], [pg.tp_world_size()]), "bias": None},
|
||||
mode="row",
|
||||
)
|
||||
|
||||
# TP1D Col Linear
|
||||
self._register_allowed_patterns(
|
||||
compute_pattern=_compute_pattern,
|
||||
dist_specs={
|
||||
'weight': ShardSpec([0], [pg.tp_world_size()]),
|
||||
'bias': ShardSpec([0], [pg.tp_world_size()])
|
||||
},
|
||||
mode='col',
|
||||
dist_specs={"weight": ShardSpec([0], [pg.tp_world_size()]), "bias": ShardSpec([0], [pg.tp_world_size()])},
|
||||
mode="col",
|
||||
)
|
||||
|
||||
self._set_default(compute_pattern=_compute_pattern, target_mode='row')
|
||||
self._set_default(compute_pattern=_compute_pattern, target_mode="row")
|
||||
|
Reference in New Issue
Block a user