[shardformer] Align bert value (#3907)

* add bert align test, fix dist loss bug

* forward and backward align

* add ignore index

* add shardformer CI

* add gather_output optional for user in shardconfig

* update readme with optional gather_ouput

* add dist crossentropy loss test, remove unused files

* remove unused file

* remove unused file

* rename the file

* polish code
This commit is contained in:
FoolPlayer
2023-06-09 14:36:54 +08:00
committed by Frank Lee
parent 79f8d5d54b
commit f1cb5ac6bf
11 changed files with 174 additions and 197 deletions

View File

@@ -5,16 +5,14 @@ __all__ = ['ShardConfig']
@dataclass
class ShardConfig:
"""
The config for sharding the huggingface model for test
r"""
The config for sharding the huggingface model
Args:
rank (int): The rank of local process
world_size (int): The world size of the distributed process
gather_output (bool): Whether to gather the output of the model of the last layer
"""
rank: int
fp16: bool = True
num_gpus: int = 2
world_size: int = 2
backend = "nccl"
verbose: str = 'simple'
seed: int = None
require_grad: bool = False
master_addr: str = "127.0.0.1"
master_port: int = 29500
gather_output: bool = True