mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-25 03:31:56 +00:00
Fixed docstring in colossalai (#171)
This commit is contained in:
@@ -176,16 +176,16 @@ def build_pipeline_model_from_cfg(config, num_chunks: int = 1, partition_method:
|
||||
...
|
||||
)
|
||||
|
||||
:param config: configuration of the model
|
||||
:param config: Configuration of the model
|
||||
:type config: dict
|
||||
:param num_chunks: the number of chunks you want to have on the current stage. This value should be 1
|
||||
:param num_chunks: The number of chunks you want to have on the current stage. This value should be 1
|
||||
in most cases unless you are using virutal pipeline parallelism.
|
||||
:type num_chunks: int
|
||||
:param partition_method: this parameter determines how you want to split your model layers into stages,
|
||||
:type num_chunks: int, optional
|
||||
:param partition_method: This parameter determines how you want to split your model layers into stages,
|
||||
you can set it as 'layer' or 'parameter'
|
||||
:type partition_method: str
|
||||
:param verbose: whether to print the logs
|
||||
:type verbose: bool
|
||||
:type partition_method: str, optional
|
||||
:param verbose: Whether to print the logs
|
||||
:type verbose: bool, optional
|
||||
"""
|
||||
ori_model = build_model(config)
|
||||
layers = ori_model.layers_cfg
|
||||
@@ -240,13 +240,13 @@ def build_pipeline_model(layers: nn.Sequential, num_chunks: int = 1, verbose: bo
|
||||
"""An intializer to split the model into different stages for pipeline parallelism.
|
||||
Note that `layer` must be `torch.nn.Sequential`.
|
||||
|
||||
:param layers: layers of model
|
||||
:type config: `torch.nn.Sequential`
|
||||
:param num_chunks: the number of chunks you want to have on the current stage. This value should be 1
|
||||
:param layers: Layers of model
|
||||
:type layers: `torch.nn.Sequential`
|
||||
:param num_chunks: The number of chunks you want to have on the current stage. This value should be 1
|
||||
in most cases unless you are using virutal pipeline parallelism.
|
||||
:type num_chunks: int
|
||||
:param verbose: whether to print the logs
|
||||
:type verbose: bool
|
||||
:type num_chunks: int, optional
|
||||
:param verbose: Whether to print the logs
|
||||
:type verbose: bool, optional
|
||||
"""
|
||||
pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
|
Reference in New Issue
Block a user