diff --git a/colossalai/tensor/spec.py b/colossalai/tensor/spec.py index 40a4a2c51..40169aaa4 100644 --- a/colossalai/tensor/spec.py +++ b/colossalai/tensor/spec.py @@ -18,6 +18,9 @@ class ParallelAction(object): self.compute_pattern = compute_pattern self.gather_out = gather_out + def __repr__(self): + return f'compute pattern: {self.compute_pattern}, gather out: {self.gather_out}' + class TensorSpec(object): """ @@ -72,3 +75,6 @@ class TensorSpec(object): def has_compute_pattern(self, compute_pattern: ComputePattern): return self.parallel_action.compute_pattern == compute_pattern + + def __repr__(self): + return f'parallel action: {self.parallel_action}, dist_spec: {self.dist_spec}'