mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 14:41:53 +00:00
[hotfix] Add layer norm gradients all-reduce for sequence parallel (#4926)
* [hotfix] Add layer norm gradients all-reduce for sequence parallel. (#4915) * Add layer norm gradients all-reduce for sequence parallel. * skip pipeline inference test * [hotfix] fixing polices of sequence parallel (#4922) * Add layer norm gradients all-reduce for sequence parallel. * fix parameter passing when calling get_autopolicy --------- Co-authored-by: littsk <1214689160@qq.com> * Hotfix/add grad all reduce for sequence parallel (#4927) * Add layer norm gradients all-reduce for sequence parallel. * fix parameter passing when calling get_autopolicy * fix bug using wrong variables --------- Co-authored-by: littsk <1214689160@qq.com> * fix policy initialization * fix bloom and chatglm policices * polish code of handling layernorm * fix moe module * polish code of class initializing --------- Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
This commit is contained in:
@@ -24,7 +24,6 @@ for k, v in inputs.items():
|
||||
new_shape[0] = 16
|
||||
inputs[k] = v.to("cuda").repeat(*new_shape)
|
||||
|
||||
|
||||
def pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
|
||||
model = transformers.LlamaForCausalLM(
|
||||
transformers.LlamaConfig(
|
||||
@@ -59,6 +58,7 @@ def run_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_si
|
||||
@parameterize("pp_size", [2])
|
||||
@parameterize("max_output_len", [4])
|
||||
@parameterize("micro_batch_size", [1])
|
||||
|
||||
@clear_cache_before_run()
|
||||
def run_tp_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
|
||||
pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
|
||||
@@ -76,6 +76,7 @@ def check_tp_pipeline_inference(rank, world_size, port):
|
||||
|
||||
|
||||
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
|
Reference in New Issue
Block a user