[tensor] dist spec s2s uses all-to-all (#1136)

* dist spec s2s uses all-to-all

* update unit test

* add sanity check

* polish unitest test with titans

* add sanity check for DistMgr

* add sanity check

Co-authored-by: jiaruifang <fangjiarui123@gmail.com>
This commit is contained in:
ver217
2022-06-22 11:32:38 +08:00
committed by GitHub
parent c77da0dc81
commit ffa025e120
2 changed files with 60 additions and 11 deletions

View File

@@ -25,7 +25,7 @@ def run():
row_shard = DistSpecManager._shard_as(x, old_dist_spec, row_spec)
assert torch.equal(x.chunk(size, 0)[rank], row_shard)
assert torch.equal(x, DistSpecManager._gather(row_shard, row_spec))
col_shard = DistSpecManager._shard_as(x, old_dist_spec, col_spec)
col_shard = DistSpecManager._all_to_all(row_shard, row_spec, col_spec)
assert torch.equal(x.chunk(size, -1)[rank], col_shard)
assert torch.equal(x, DistSpecManager._gather(col_shard, col_spec))
mat_shard = DistSpecManager._shard_as(x, old_dist_spec, mat_spec)