mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 05:01:44 +00:00
add DistSpec for loss and test_model (#947)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from .linear import colo_linear
|
||||
from .element_wise import *
|
||||
from .layernorm import colo_layernorm
|
||||
# from .loss import colo_cross_entropy
|
||||
from .loss import colo_cross_entropy
|
||||
from .embedding import colo_embedding
|
||||
from .addmm import colo_addmm
|
||||
|
@@ -28,7 +28,7 @@ def colo_layernorm(types, args=(), kwargs=None, pg=None):
|
||||
|
||||
if isinstance(input_tensor, ColoTensor):
|
||||
# TODO (ver217): check input dist spec
|
||||
input_tensor.to_dist_spec(dist_spec.replicate())
|
||||
input_tensor.to_dist_spec(dist_spec.replicate(input_tensor.spec.get_process_group()))
|
||||
input_tensor = input_tensor.torch_tensor()
|
||||
if isinstance(weight, ColoTensor):
|
||||
weight = weight.torch_tensor()
|
||||
|
@@ -1,4 +1,4 @@
|
||||
from colossalai.tensor.spec import ShardPattern
|
||||
from colossalai.tensor.dist_spec import DistPlacementPattern
|
||||
import torch
|
||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||
from colossalai.tensor import ColoTensor
|
||||
@@ -27,12 +27,11 @@ def colo_cross_entropy(types, args=(), kwargs=None, pg=None):
|
||||
if isinstance(target, ColoTensor):
|
||||
target = target.torch_tensor()
|
||||
|
||||
if input_tensor.is_gathered(): # Input is gathered
|
||||
# TODO(jzy) Shall we make the result of loss function a ColoTensor?
|
||||
if input_tensor.spec.is_gathered(): # Input is gathered
|
||||
return ColoTensor.init_from_torch_tensor(torch.nn.functional.cross_entropy(
|
||||
input_tensor.torch_tensor(), target, weight))
|
||||
elif input_tensor.has_spec() and input_tensor.shard_spec.num_action == 1: # Single Model Parallel Applied
|
||||
if input_tensor.shard_pattern == ShardPattern.Col:
|
||||
elif input_tensor.has_spec() and input_tensor.spec.num_action == 1: # Single Model Parallel Applied
|
||||
if input_tensor.spec.is_1Dcol():
|
||||
return ColoTensor.init_from_torch_tensor(
|
||||
VocabParallelCrossEntropyLoss1D()(input_tensor.torch_tensor(), target))
|
||||
else:
|
||||
|
@@ -53,7 +53,8 @@ class DistSpecManager:
|
||||
|
||||
@staticmethod
|
||||
def _r2r(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
|
||||
if old_dist_spec.process_group is not None and old_dist_spec.process_group != dist_spec.process_group:
|
||||
if old_dist_spec.process_group is not None and old_dist_spec.process_group != dist_spec.process_group \
|
||||
and dist_spec.process_group is not None:
|
||||
raise NotImplementedError
|
||||
return tensor
|
||||
|
||||
@@ -65,7 +66,8 @@ class DistSpecManager:
|
||||
|
||||
@staticmethod
|
||||
def _s2r(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
|
||||
if old_dist_spec.process_group != dist_spec.process_group:
|
||||
if old_dist_spec.process_group != dist_spec.process_group \
|
||||
and dist_spec.process_group is not None:
|
||||
raise NotImplementedError
|
||||
return DistSpecManager._gather(tensor, old_dist_spec)
|
||||
|
||||
|
@@ -1,7 +1,7 @@
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.tensor.dist_spec import _DistSpec
|
||||
from colossalai.tensor.dist_spec import _DistSpec, DistPlacementPattern
|
||||
|
||||
|
||||
class ComputePattern(Enum):
|
||||
@@ -84,3 +84,16 @@ class TensorSpec(object):
|
||||
|
||||
def get_process_group(self):
|
||||
return self.dist_spec.process_group
|
||||
|
||||
def get_placement(self):
|
||||
return self.dist_spec.placement
|
||||
|
||||
def is_gathered(self):
|
||||
return self.dist_spec.placement == DistPlacementPattern.REPLICATE \
|
||||
or (len(self.dist_spec.num_partitions) == 1
|
||||
and self.dist_spec.num_partitions[0] == 1) \
|
||||
or (self.dist_spec.process_group.size() == 1)
|
||||
|
||||
def is_1Dcol(self):
|
||||
return self.dist_spec.placement == DistPlacementPattern.SHARD \
|
||||
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == -1
|
Reference in New Issue
Block a user