mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-26 04:03:58 +00:00
[shardformer] Align bert value (#3907)
* add bert align test, fix dist loss bug * forward and backward align * add ignore index * add shardformer CI * add gather_output optional for user in shardconfig * update readme with optional gather_ouput * add dist crossentropy loss test, remove unused files * remove unused file * remove unused file * rename the file * polish code
This commit is contained in:
@@ -5,16 +5,14 @@ __all__ = ['ShardConfig']
|
||||
|
||||
@dataclass
|
||||
class ShardConfig:
|
||||
"""
|
||||
The config for sharding the huggingface model for test
|
||||
r"""
|
||||
The config for sharding the huggingface model
|
||||
|
||||
Args:
|
||||
rank (int): The rank of local process
|
||||
world_size (int): The world size of the distributed process
|
||||
gather_output (bool): Whether to gather the output of the model of the last layer
|
||||
"""
|
||||
rank: int
|
||||
fp16: bool = True
|
||||
num_gpus: int = 2
|
||||
world_size: int = 2
|
||||
backend = "nccl"
|
||||
verbose: str = 'simple'
|
||||
seed: int = None
|
||||
require_grad: bool = False
|
||||
master_addr: str = "127.0.0.1"
|
||||
master_port: int = 29500
|
||||
gather_output: bool = True
|
||||
|
Reference in New Issue
Block a user