[shardformer, pipeline] add gradient_checkpointing_ratio and heterogenous shard policy for llama (#5508)

* feat: add `GradientCheckpointConfig` and `PipelineGradientCheckpointConfig`

* feat: apply `GradientCheckpointConfig` to policy and llama_forward

* feat: move `distribute_layer` and `get_stage_index` to PipelineStageManager

* fix: add optional args for `distribute_layer` and `get_stage_index`

* fix: fix changed API calls

* test: update llama tests

* style: polish `GradientCheckpointConfig`

* fix: fix pipeline utils tests
This commit is contained in:
Wenhao Chen
2024-04-01 11:34:58 +08:00
committed by GitHub
parent df5e9c53cf
commit e614aa34f3
28 changed files with 396 additions and 213 deletions

View File

@@ -1,4 +1,23 @@
import random
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.policies.t5 import T5BasePolicy
from colossalai.shardformer.shard.shard_config import ShardConfig
class _ShardConfig(ShardConfig):
def __post_init__(self):
pass
class _PipelineStageManager(PipelineStageManager):
def __init__(self):
self.is_interleave = False
self.num_layers_per_stage = None
@property
def num_stages(self):
return random.randint(5, 10)
def test_t5_pipeline_distribution():
@@ -10,7 +29,10 @@ def test_t5_pipeline_distribution():
"decoder_starting_stage": [1, 1, 2, 2, 3, 1, 5, 2],
}
stage_manager = _PipelineStageManager()
shard_config = _ShardConfig(pipeline_stage_manager=stage_manager)
policy = T5BasePolicy()
policy.set_shard_config(shard_config)
for i in range(num_test_cases):
_, decoder_starting_stage = policy.distribute_t5_layers(
test_dict["num_encoder_layers"][i],
@@ -35,7 +57,10 @@ def test_t5_pipeline_layers():
}
for i in range(num_test_cases):
stage_manager = _PipelineStageManager()
shard_config = _ShardConfig(pipeline_stage_manager=stage_manager)
policy = T5BasePolicy()
policy.set_shard_config(shard_config)
layers_per_stage, decoder_starting_stage = policy.distribute_t5_layers(
test_dict["num_encoder_layers"][i],
test_dict["num_decoder_layers"][i],

View File

@@ -1,4 +1,23 @@
import random
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.policies.whisper import WhisperPolicy
from colossalai.shardformer.shard.shard_config import ShardConfig
class _ShardConfig(ShardConfig):
def __post_init__(self):
pass
class _PipelineStageManager(PipelineStageManager):
def __init__(self):
self.is_interleave = False
self.num_layers_per_stage = None
@property
def num_stages(self):
return random.randint(5, 10)
def test_whisper_pipeline_distribution():
@@ -10,7 +29,10 @@ def test_whisper_pipeline_distribution():
"decoder_starting_stage": [1, 1, 2, 2, 3, 1, 5, 2],
}
stage_manager = _PipelineStageManager()
shard_config = _ShardConfig(pipeline_stage_manager=stage_manager)
policy = WhisperPolicy()
policy.set_shard_config(shard_config)
for i in range(num_test_cases):
_, decoder_starting_stage = policy.distribute_whisper_layers(
test_dict["num_encoder_layers"][i],
@@ -34,7 +56,10 @@ def test_whisper_pipeline_layers():
],
}
stage_manager = _PipelineStageManager()
shard_config = _ShardConfig(pipeline_stage_manager=stage_manager)
policy = WhisperPolicy()
policy.set_shard_config(shard_config)
for i in range(num_test_cases):
layers_per_stage, decoder_starting_stage = policy.distribute_whisper_layers(
test_dict["num_encoder_layers"][i],