[shardformer] fix type hint

This commit is contained in:
ver217 2023-07-05 15:20:59 +08:00
parent 1a87dd737d
commit d4b96abe5c

View File

@ -15,8 +15,8 @@ class ShardConfig:
The config for sharding the huggingface model
Args:
tensor_parallel_process_group (int): The process group for tensor parallelism, defaults to None, which is the global process group.
pipeline_stage_manager (PipelineStageManager): The pipeline stage manager, defaults to None, which means no pipeline.
tensor_parallel_process_group (Optional[ProcessGroup]): The process group for tensor parallelism, defaults to None, which is the global process group.
pipeline_stage_manager (Optional[PipelineStageManager]): The pipeline stage manager, defaults to None, which means no pipeline.
enable_tensor_parallelism (bool): Whether to turn on tensor parallelism, default is True.
enable_fused_normalization (bool): Whether to use fused layernorm, default is False.
enable_all_optimization (bool): Whether to turn on all optimization, default is False.