mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-23 18:39:56 +00:00
[Tensor] remove ParallelAction, use ComputeSpec instread (#1166)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from typing import Dict
|
||||
from colossalai.tensor import ColoParameter, ParallelAction, TensorSpec
|
||||
from colossalai.tensor import ColoParameter, ComputeSpec, TensorSpec
|
||||
from . import ColoModule
|
||||
import torch
|
||||
|
||||
@@ -39,7 +39,7 @@ def check_colo_module(module: torch.nn.Module, recursive=True):
|
||||
if not isinstance(param, ColoParameter):
|
||||
raise Exception(f'Invalid ColoParameter spec: {param} in {module} is not a ColoParameter.')
|
||||
if param.has_spec():
|
||||
cur_compute_pattern = param.spec.parallel_action.compute_pattern
|
||||
cur_compute_pattern = param.spec.compute_spec.compute_pattern
|
||||
if compute_pattern is None:
|
||||
compute_pattern = cur_compute_pattern
|
||||
else:
|
||||
@@ -79,11 +79,11 @@ def check_colo_module(module: torch.nn.Module, recursive=True):
|
||||
check_colo_module(submodule, recursive=True)
|
||||
|
||||
|
||||
def init_colo_module(module: torch.nn.Module, parallel_action: ParallelAction, recursive=True, mode='default'):
|
||||
def init_colo_module(module: torch.nn.Module, parallel_action: ComputeSpec, recursive=True, mode='default'):
|
||||
compute_pattern = parallel_action.compute_pattern
|
||||
if is_colo_module(module):
|
||||
# for each param
|
||||
# set DistSpec and ParallelAction
|
||||
# set DistSpec and ComputeSpec
|
||||
colo_module = get_colo_module(module)
|
||||
colo_module.register(compute_pattern)
|
||||
if not colo_module.has_compute_pattern_with_mode(compute_pattern, mode=mode):
|
||||
|
Reference in New Issue
Block a user