mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 02:26:51 +00:00
[tensor] hijack addmm for colo tensor (#923)
* hijack addmm for colo tensor * fix bugs * polish unit test * polish comments
This commit is contained in:
@@ -142,12 +142,15 @@ class ColoTensor(object):
|
||||
# Model Parameters
|
||||
if self._shard_spec.num_action == 1:
|
||||
parallel_action = self._shard_spec.get_action_by_compute_pattern(self._shard_spec.compute_patterns[0])
|
||||
if parallel_action.compute_pattern in [ComputePattern.TP1DRow_Linear, \
|
||||
ComputePattern.TP1DCol_Embedding]:
|
||||
if parallel_action.compute_pattern in [
|
||||
ComputePattern.TP1DRow_Linear, ComputePattern.TP1DCol_Embedding, ComputePattern.TP1DCol_mm
|
||||
]:
|
||||
self._shard_1d(parallel_action=parallel_action, dim=-1)
|
||||
self._shard_pattern = ShardPattern.Col # We bind our ComputePattern on weight, which has to be transposed when linear().
|
||||
elif parallel_action.compute_pattern in [ComputePattern.TP1DCol_Linear, \
|
||||
ComputePattern.TP1DRow_Embedding]:
|
||||
# We bind our ComputePattern on weight, which has to be transposed when linear().
|
||||
self._shard_pattern = ShardPattern.Col
|
||||
elif parallel_action.compute_pattern in [
|
||||
ComputePattern.TP1DCol_Linear, ComputePattern.TP1DRow_Embedding, ComputePattern.TP1DRow_mm
|
||||
]:
|
||||
self._shard_1d(parallel_action=parallel_action, dim=0)
|
||||
self._shard_pattern = ShardPattern.Row
|
||||
else:
|
||||
|
Reference in New Issue
Block a user