[shardformer] support vision transformer (#4096)

* first v of vit shardformer

* keep vit

* update

* vit shard add vitattention vitlayer

* update num head shard para

* finish test for vit

* add new_model_class & postprocess

* add vit readme

* delete old files & fix the conflict

* fix sth
This commit is contained in:
Kun Lin
2023-06-28 13:28:18 +08:00
committed by Frank Lee
parent ac80937138
commit 8af29ee47a
10 changed files with 159 additions and 8 deletions

View File

@@ -287,4 +287,4 @@ def reduce_forward(input_, process_group):
def reduce_backward(input_, process_group):
return _ReduceBackward.apply(input_, process_group)
return _ReduceBackward.apply(input_, process_group)