mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,13 +1,12 @@
|
||||
from colossalai.legacy.tensor import ComputePattern, ProcessGroup, ShardSpec, distspec
|
||||
from colossalai.legacy.tensor import ComputePattern, ProcessGroup, ShardSpec
|
||||
|
||||
from .colo_module import ColoModule
|
||||
|
||||
|
||||
class ColoEmbedding(ColoModule):
|
||||
|
||||
def __init__(self):
|
||||
super(ColoEmbedding, self).__init__()
|
||||
self._register_shard_params(['weight'])
|
||||
self._register_shard_params(["weight"])
|
||||
|
||||
def register(self, compute_pattern, pg: ProcessGroup):
|
||||
if not compute_pattern in self._allowed_patterns:
|
||||
@@ -20,18 +19,18 @@ class ColoEmbedding(ColoModule):
|
||||
self._register_allowed_patterns(
|
||||
compute_pattern=_compute_pattern,
|
||||
dist_specs={
|
||||
'weight': ShardSpec([0], [pg.tp_world_size()]),
|
||||
"weight": ShardSpec([0], [pg.tp_world_size()]),
|
||||
},
|
||||
mode='row',
|
||||
mode="row",
|
||||
)
|
||||
|
||||
# TP1D Col Linear
|
||||
self._register_allowed_patterns(
|
||||
compute_pattern=_compute_pattern,
|
||||
dist_specs={
|
||||
'weight': ShardSpec([-1], [pg.tp_world_size()]),
|
||||
"weight": ShardSpec([-1], [pg.tp_world_size()]),
|
||||
},
|
||||
mode='col',
|
||||
mode="col",
|
||||
)
|
||||
|
||||
self._set_default(compute_pattern=_compute_pattern, target_mode='row')
|
||||
self._set_default(compute_pattern=_compute_pattern, target_mode="row")
|
||||
|
Reference in New Issue
Block a user