mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-28 21:17:08 +00:00
[refactor] remove gpc dependency in colotensor's _ops (#1189)
This commit is contained in:
@@ -17,11 +17,12 @@ class TensorSpec(object):
|
||||
self.compute_spec = compute_spec
|
||||
self.dist_spec = dist_spec
|
||||
|
||||
# TODO(jiaruifang) actually need tp process group
|
||||
def get_process_group(self):
|
||||
return self.dist_spec.process_group
|
||||
|
||||
def get_process_group_size(self):
|
||||
return dist.get_world_size(self.dist_spec.process_group)
|
||||
return dist.get_world_size(self.dist_spec.process_group.tp_process_group())
|
||||
|
||||
def get_placement(self):
|
||||
return self.dist_spec.placement
|
||||
@@ -30,7 +31,7 @@ class TensorSpec(object):
|
||||
return self.dist_spec.placement == DistPlacementPattern.REPLICATE \
|
||||
or (len(self.dist_spec.num_partitions) == 1
|
||||
and self.dist_spec.num_partitions[0] == 1) \
|
||||
or (self.dist_spec.process_group.size() == 1)
|
||||
or (self.dist_spec.process_group.tp_world_size() == 1)
|
||||
|
||||
def is_shard_1dcol(self):
|
||||
return self.dist_spec.placement == DistPlacementPattern.SHARD \
|
||||
|
Reference in New Issue
Block a user