From 85946d423607835359ecef4bebb0cd74828cba7c Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Fri, 24 May 2024 10:03:05 +0800 Subject: [PATCH 1/2] [Inference]Fix readme and example for API server (#5742) * fix chatapi readme and example * updating doc * add an api and change the doc * remove * add credits and del 'API' heading * readme * readme --- colossalai/inference/README.md | 54 +++++++++++++++++++++-- colossalai/inference/server/README.md | 27 ------------ colossalai/inference/server/api_server.py | 21 ++++++--- examples/inference/client/locustfile.py | 9 ++-- requirements/requirements.txt | 2 + 5 files changed, 73 insertions(+), 40 deletions(-) delete mode 100644 colossalai/inference/server/README.md diff --git a/colossalai/inference/README.md b/colossalai/inference/README.md index cdb32a0f8..b46222d80 100644 --- a/colossalai/inference/README.md +++ b/colossalai/inference/README.md @@ -207,13 +207,13 @@ Learnt from [PagedAttention](https://arxiv.org/abs/2309.06180) by [vLLM](https:/ Request handler is responsible for managing requests and scheduling a proper batch from exisiting requests. Based on [Orca's](https://www.usenix.org/conference/osdi22/presentation/yu) and [vLLM's](https://github.com/vllm-project/vllm) research and work on batching requests, we applied continuous batching with unpadded sequences, which enables various number of sequences to pass projections (i.e. Q, K, and V) together in different steps by hiding the dimension of number of sequences, and decrement the latency of incoming sequences by inserting a prefill batch during a decoding step and then decoding together.

- +
Naive Batching: decode until each sequence encounters eos in a batch

- +
Continuous Batching: dynamically adjust the batch size by popping out finished sequences and inserting prefill batch

@@ -222,6 +222,54 @@ Request handler is responsible for managing requests and scheduling a proper bat Modeling contains models, layers, and policy, which are hand-crafted for better performance easier usage. Integrated with `shardformer`, users can define their own policy or use our preset policies for specific models. Our modeling files are aligned with [Transformers](https://github.com/huggingface/transformers). For more details about the usage of modeling and policy, please check `colossalai/shardformer`. +## Online Service +Colossal-Inference supports fast-api based online service. Simple completion and chat are both supported. Follow the commands below and you can simply construct a server with both completion and chat functionalities. For now we support `Llama2`,`Llama3` and `Baichuan2` model, etc. we will fullfill the blank quickly. + +### API + +- GET '/ping': +Ping is used to check if the server can receive and send information. +- GET '/engine_check': +Check is the background engine is working. +- POST '/completion': +Completion api is used for single sequence request, like answer a question or complete words. +- POST '/chat': +Chat api is used for conversation-style request, which often includes dialogue participants(i.e. roles) and corresponding words. Considering the input data are very different from normal inputs, we introduce Chat-Template to match the data format in chat models. +#### chat-template +Followed `transformers`, we add the chat-template argument. As chat models have been trained with very different formats for converting conversations into a single tokenizable string. Using a format that matches the training data is extremely important. This attribute(chat_template) is inclueded in HuggingFace tokenizers, containing a Jinja template that converts conversation histories into a correctly formatted string. You can refer to the [HuggingFace-blog](https://huggingface.co/blog/chat-templates) for more information. We also provide a simple example temlate bellow. Both str or file style chat template are supported. +### Usage +#### Args for customizing your server +The configuration for api server contains both serving interface and engine backend. +For Interface: +- `--host`: The host url on your device for the server. +- `--port`: The port for service +- `--model`: The model that backend engine uses, both path and transformers model card are supported. +- `--chat-template` The file path of chat template or the template string. +- `--response-role` The role that colossal-inference plays. +For Engine Backend: +- `--block_size`: The memory usage for each block. +- `--max_batch_size`: The max batch size for engine to infer. This changes the speed of inference, +- `--max_input_len`: The max input length of a request. +- `--max_output_len`: The output length of response. +- `--dtype` and `--use_cuda_kernel`: Deciding the precision and kernel usage. +For more detailed arguments, please refer to source code. + +### Examples +```bash +# First, Lauch an API locally. +python3 -m colossalai.inference.server.api_server --model path of your model --chat-template "{% for message in messages %}{{'<|im_start|>'+message['role']+'\n'+message['content']+'<|im_end|>'+'\n'}}{% endfor %}" + +# Second, you can turn to the page `http://127.0.0.1:8000/docs` to check the api + +# For completion service, you can invoke it +curl -X POST http://127.0.0.1:8000/completion -H 'Content-Type: application/json' -d '{"prompt":"hello, who are you? "}' + +# For chat service, you can invoke it +curl -X POST http://127.0.0.1:8000/chat -H 'Content-Type: application/json' -d '{"messages":[{"role":"system","content":"you are a helpful assistant"},{"role":"user","content":"what is 1+1?"}]}' + +# You can check the engine status now +curl http://localhost:8000/engine_check +``` ## 🌟 Acknowledgement @@ -229,7 +277,7 @@ This project was written from scratch but we learned a lot from several other gr - [vLLM](https://github.com/vllm-project/vllm) - [flash-attention](https://github.com/Dao-AILab/flash-attention) - +- [HuggingFace](https://huggingface.co) If you wish to cite relevant research papars, you can find the reference below. ```bibtex diff --git a/colossalai/inference/server/README.md b/colossalai/inference/server/README.md deleted file mode 100644 index 8b5f29fc0..000000000 --- a/colossalai/inference/server/README.md +++ /dev/null @@ -1,27 +0,0 @@ -# Online Service -Colossal-Inference supports fast-api based online service. Simple completion and chat are both supported. Follow the commands below and -you can simply construct a server with both completion and chat functionalities. For now we only support `Llama` model, we will fullfill -the blank quickly. - -# Usage -```bash -# First, Lauch an API locally. -python3 -m colossalai.inference.server.api_server --model path of your llama2 model --chat_template "{% for message in messages %} -{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}" - - -# Second, you can turn to the page `http://127.0.0.1:8000/docs` to check the api - -# For completion service, you can invoke it -curl -X POST http://127.0.0.1:8000/completion -H 'Content-Type: application/json' -d '{"prompt":"hello, who are you? ","stream":"False"}' - -# For chat service, you can invoke it -curl -X POST http://127.0.0.1:8000/completion -H 'Content-Type: application/json' -d '{"converation": - [{"role": "system", "content": "you are a helpful assistant"}, - {"role": "user", "content": "what is 1+1?"},], - "stream": "False",}' -# If you just want to test a simple generation, turn to generate api -curl -X POST http://127.0.0.1:8000/generate -H 'Content-Type: application/json' -d '{"prompt":"hello, who are you? ","stream":"False"}' - -``` -We also support streaming output, simply change the `stream` to `True` in the request body. diff --git a/colossalai/inference/server/api_server.py b/colossalai/inference/server/api_server.py index 91c77ed35..dbc816df5 100644 --- a/colossalai/inference/server/api_server.py +++ b/colossalai/inference/server/api_server.py @@ -30,7 +30,6 @@ from colossalai.inference.utils import find_available_ports from colossalai.inference.core.async_engine import AsyncInferenceEngine, InferenceEngine # noqa TIMEOUT_KEEP_ALIVE = 5 # seconds. -supported_models_dict = {"Llama_Models": ("llama2-7b",)} prompt_template_choices = ["llama", "vicuna"] async_engine = None chat_serving = None @@ -39,15 +38,25 @@ completion_serving = None app = FastAPI() -# NOTE: (CjhHa1) models are still under development, need to be updated -@app.get("/models") -def get_available_models() -> Response: - return JSONResponse(supported_models_dict) +@app.get("/ping") +def health_check() -> JSONResponse: + """Health Check for server.""" + return JSONResponse({"status": "Healthy"}) + + +@app.get("/engine_check") +def engine_check() -> bool: + """Check if the background loop is running.""" + loop_status = async_engine.background_loop_status + if loop_status == False: + return JSONResponse({"status": "Error"}) + return JSONResponse({"status": "Running"}) @app.post("/generate") async def generate(request: Request) -> Response: """Generate completion for the request. + NOTE: THIS API IS USED ONLY FOR TESTING, DO NOT USE THIS IF YOU ARE IN ACTUAL APPLICATION. A request should be a JSON object with the following fields: - prompts: the prompts to use for the generation. @@ -133,7 +142,7 @@ def add_engine_config(parser): # Parallel arguments not supported now # KV cache arguments - parser.add_argument("--block-size", type=int, default=16, choices=[8, 16, 32], help="token block size") + parser.add_argument("--block_size", type=int, default=16, choices=[16, 32], help="token block size") parser.add_argument("--max_batch_size", type=int, default=8, help="maximum number of batch size") diff --git a/examples/inference/client/locustfile.py b/examples/inference/client/locustfile.py index a65c8b667..437713eb0 100644 --- a/examples/inference/client/locustfile.py +++ b/examples/inference/client/locustfile.py @@ -20,7 +20,7 @@ class QuickstartUser(HttpUser): self.client.post( "/chat", json={ - "converation": [ + "messages": [ {"role": "system", "content": "you are a helpful assistant"}, {"role": "user", "content": "what is 1+1?"}, ], @@ -34,7 +34,7 @@ class QuickstartUser(HttpUser): self.client.post( "/chat", json={ - "converation": [ + "messages": [ {"role": "system", "content": "you are a helpful assistant"}, {"role": "user", "content": "what is 1+1?"}, ], @@ -42,6 +42,7 @@ class QuickstartUser(HttpUser): }, ) + # offline-generation is only for showing the usage, it will never be used in actual serving. @tag("offline-generation") @task(5) def generate_streaming(self): @@ -54,5 +55,5 @@ class QuickstartUser(HttpUser): @tag("online-generation", "offline-generation") @task - def get_models(self): - self.client.get("/models") + def health_check(self): + self.client.get("/ping") diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 741975942..d30b26dbc 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -20,4 +20,6 @@ transformers==4.36.2 peft>=0.7.1 bitsandbytes>=0.39.0 rpyc==6.0.0 +fastapi +uvicorn==0.29.0 galore_torch From 2fc85abf43111f9cc08906b8cdf7068970ec91dd Mon Sep 17 00:00:00 2001 From: botbw Date: Fri, 24 May 2024 10:31:16 +0800 Subject: [PATCH 2/2] [gemini] async grad chunk reduce (all-reduce&reduce-scatter) (#5713) * [gemini] async grad chunk reduce (all-reduce&reduce-scatter) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [gemini] add test * [gemini] rename func * [gemini] update llama benchmark * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [gemini] use tensor counter * [gemini] change default config in GeminiPlugin and GeminiDDP * [chore] typo * [gemini] fix sync issue & add test cases * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- colossalai/booster/plugin/gemini_plugin.py | 2 + colossalai/zero/gemini/chunk/chunk.py | 32 +++++--- colossalai/zero/gemini/chunk/manager.py | 8 +- colossalai/zero/gemini/gemini_ddp.py | 79 +++++++++++++------ colossalai/zero/gemini/gemini_optimizer.py | 4 +- examples/language/llama/benchmark.py | 3 + tests/test_zero/test_gemini/test_chunkv2.py | 8 +- tests/test_zero/test_gemini/test_fwd_bwd.py | 10 ++- .../test_zero/test_gemini/test_grad_accum.py | 13 ++- tests/test_zero/test_gemini/test_grad_clip.py | 4 +- tests/test_zero/test_gemini/test_optim.py | 12 ++- 11 files changed, 130 insertions(+), 45 deletions(-) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 964cd302a..eb8db6212 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -361,6 +361,7 @@ class GeminiPlugin(DPPluginBase): enable_sequence_parallelism: bool = False, enable_jit_fused: bool = False, enable_sequence_overlap: bool = False, + enable_async_reduce: bool = True, verbose: bool = False, ) -> None: super().__init__() @@ -386,6 +387,7 @@ class GeminiPlugin(DPPluginBase): memstats=memstats, mixed_precision=PRECISION_STR_TO_DTYPE[precision], master_weights=master_weights, + enable_async_reduce=enable_async_reduce, ) self.zero_optim_config = dict( gpu_margin_mem_ratio=gpu_margin_mem_ratio, diff --git a/colossalai/zero/gemini/chunk/chunk.py b/colossalai/zero/gemini/chunk/chunk.py index cad2622f2..8f048f0b7 100644 --- a/colossalai/zero/gemini/chunk/chunk.py +++ b/colossalai/zero/gemini/chunk/chunk.py @@ -164,6 +164,8 @@ class Chunk: self.l2_norm = None self.grad_chunk = None + # the async all-reduce/reduce-scatter work of this grad chunk (None means sync) + self.grad_reduce_work = None @property def memory_usage(self) -> Dict[str, int]: @@ -244,7 +246,7 @@ class Chunk: assert self.cuda_shard is not None # only check on CUDA valid_tensor = self.cuda_shard[: self.valid_end] - return torch.isinf(valid_tensor).any().item() | torch.isnan(valid_tensor).any().item() + return torch.isinf(valid_tensor).any() | torch.isnan(valid_tensor).any() def set_l2_norm(self) -> None: """Record l2 norm of this chunks on CUDA.""" @@ -374,37 +376,49 @@ class Chunk: if self.is_gathered: self.__scatter() - def reduce(self): + def reduce(self, async_op: bool = False): """Reduce scatter all the gradients. It's an operation done in CUDA.""" # sanity check assert self.is_gathered - + assert self.grad_reduce_work is None if self.pg_size == 1: # tricky code here # just move cuda_global_chunk to cuda_shard # the communication is not necessary self.__scatter() if self.extra_dp_group is not None: - dist.all_reduce(self.cuda_shard, group=self.extra_dp_group) + self.grad_reduce_work = dist.all_reduce(self.cuda_shard, group=self.extra_dp_group, async_op=async_op) elif self.keep_gathered: # we use all-reduce here - dist.all_reduce(self.cuda_global_chunk, group=self.torch_pg) - if self.extra_dp_group is not None: - dist.all_reduce(self.cuda_global_chunk, group=self.extra_dp_group) + self.grad_reduce_work = dist.all_reduce(self.cuda_global_chunk, group=self.torch_pg, async_op=async_op) + if self.extra_dp_group is not None: # cannot guranatee the order of multiple all-reduce + self.wait_async_reduce() + self.grad_reduce_work = dist.all_reduce( + self.cuda_global_chunk, group=self.extra_dp_group, async_op=async_op + ) else: self.cuda_shard = torch.empty( self.shard_size, dtype=self.dtype, device=get_accelerator().get_current_device() ) input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0)) - dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg) + self.grad_reduce_work = dist.reduce_scatter( + self.cuda_shard, input_list, group=self.torch_pg, async_op=async_op + ) + if self.extra_dp_group is not None: - dist.all_reduce(self.cuda_shard, group=self.extra_dp_group) + self.wait_async_reduce() + self.grad_reduce_work = dist.all_reduce(self.cuda_shard, group=self.extra_dp_group, async_op=async_op) free_storage(self.cuda_global_chunk) self.is_gathered = False self.__update_tensors_state(TensorState.HOLD) + def wait_async_reduce(self) -> None: + if self.grad_reduce_work is not None: + self.grad_reduce_work.wait() + self.grad_reduce_work = None + def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None: """ Make a transition of the tensor into the next state. diff --git a/colossalai/zero/gemini/chunk/manager.py b/colossalai/zero/gemini/chunk/manager.py index 333a3f224..6ec595914 100644 --- a/colossalai/zero/gemini/chunk/manager.py +++ b/colossalai/zero/gemini/chunk/manager.py @@ -41,7 +41,7 @@ class ChunkManager: self.reuse_fp16_chunk = reuse_fp16_chunk # Whether model is accumulating gradients, self.accumulating_grads = False - self.overflow_counter = 0 + self.overflow_counter = torch.tensor([0], dtype=torch.int, device=get_accelerator().get_current_device()) def register_tensor( self, @@ -143,12 +143,12 @@ class ChunkManager: chunk = self.tensor_chunk_map[tensor] chunk.tensor_trans_state(tensor, state) - def reduce_chunk(self, chunk: Chunk) -> bool: + def reduce_chunk(self, chunk: Chunk, async_op: bool = False) -> bool: """Reduce or all reduce the chunk.""" if not chunk.can_reduce: return False self.__sub_memory_usage(chunk.memory_usage) - chunk.reduce() + chunk.reduce(async_op=async_op) self.__sub_accessed_chunk(chunk) self.__add_memory_usage(chunk.memory_usage) return True @@ -272,7 +272,7 @@ class ChunkManager: return grad_chunk def rearrange_accumulated_grad_chunk(self, chunk: Chunk) -> Chunk: - """Rearrange gradients accumulated in chunk.grad_chunk, and getP prepared for gradient reduction.""" + """Rearrange gradients accumulated in chunk.grad_chunk, and get prepared for gradient reduction.""" assert chunk.grad_chunk is not None diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index c1029097a..23f6ee683 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -96,6 +96,7 @@ class GeminiDDP(ModelWrapper): master_weights: bool = True, extra_dp_group: Optional[ProcessGroup] = None, verbose: bool = False, + enable_async_reduce: bool = True, ) -> None: assert mixed_precision in (torch.float16, torch.bfloat16) reuse_fp16_chunk = master_weights if not enable_gradient_accumulation else False @@ -178,6 +179,7 @@ class GeminiDDP(ModelWrapper): if is_ddp_ignored(p): continue if p.requires_grad: + assert not hasattr(p, "_grad_handle") p._grad_handle = p.register_hook( partial( GeminiDDP.grad_handle, @@ -187,6 +189,7 @@ class GeminiDDP(ModelWrapper): master_weights=self.master_weights, enable_gradient_accumulation=self.enable_gradient_accumulation, p=p, + async_reduce=enable_async_reduce, ) ) @@ -334,6 +337,11 @@ class GeminiDDP(ModelWrapper): setattr(param, "_gemini_reduced", False) def _post_backward(self): + for param in self.param2name: + if hasattr(param, "_release_grad_chunk_cb"): + param._release_grad_chunk_cb() + delattr(param, "_release_grad_chunk_cb") + if self.chunk_manager.accessed_mem != 0: error_params = ["Reduction failed at followed parameters:"] for param in self.param2name: @@ -371,6 +379,7 @@ class GeminiDDP(ModelWrapper): master_weights: bool, enable_gradient_accumulation: bool, p: nn.Parameter, + async_reduce: bool, ): setattr(p, "_gemini_reduced", True) empty_grad = torch.empty_like(grad) @@ -406,31 +415,57 @@ class GeminiDDP(ModelWrapper): grad_chunk.copy_tensor_to_chunk_slice(p, grad, update_ptr=chunk_manager.reuse_fp16_chunk) else: grad_chunk.add_tensor_to_chunk_slice(p, grad) - reduced = chunk_manager.reduce_chunk(grad_chunk) - if reduced: - if not chunk_manager.reuse_fp16_chunk: - if chunk.keep_gathered: - chunk_manager.fake_release_chunk(chunk) - else: - chunk_manager.release_chunk(chunk) - if grad_chunk.is_gathered: - grad_chunk.cuda_global_chunk.div_(chunk.pg_size) - if chunk.extra_dp_group is not None: - grad_chunk.cuda_global_chunk.div_(chunk.extra_dp_size) + reduced = chunk_manager.reduce_chunk(grad_chunk, async_op=async_reduce) + if reduced: # if not async, can release immediately, else release in when work finished + if async_reduce: + # dirty fix by installing callback + assert not hasattr(p, "_release_grad_chunk_cb") + + def _release_grad_chunk_cb(): + grad_chunk.wait_async_reduce() + GeminiDDP.release_grad_chunk_handle( + chunk_manager, + grads_device, + master_weights, + enable_gradient_accumulation, + p, + chunk, + grad_chunk, + ) + + p._release_grad_chunk_cb = _release_grad_chunk_cb else: - grad_chunk.cuda_shard.div_(chunk.pg_size) - if chunk.extra_dp_group is not None: - grad_chunk.cuda_shard.div_(chunk.extra_dp_size) - # check overflow elements - chunk_manager.overflow_counter += grad_chunk.has_inf_or_nan - # record l2 norm for gradient clipping. flag is bound to fp16 chunk - if chunk.l2_norm_flag: - grad_chunk.set_l2_norm() - chunk_manager.move_chunk(grad_chunk, grads_device[p], force_copy=True) - if not (master_weights) or (enable_gradient_accumulation): - chunk_manager.move_chunk(chunk, grads_device[p], force_copy=True) + GeminiDDP.release_grad_chunk_handle( + chunk_manager, grads_device, master_weights, enable_gradient_accumulation, p, chunk, grad_chunk + ) return empty_grad + @staticmethod + def release_grad_chunk_handle( + chunk_manager, grads_device, master_weights, enable_gradient_accumulation, p, chunk, grad_chunk + ): + if not chunk_manager.reuse_fp16_chunk: + if chunk.keep_gathered: + chunk_manager.fake_release_chunk(chunk) + else: + chunk_manager.release_chunk(chunk) + if grad_chunk.is_gathered: + grad_chunk.cuda_global_chunk.div_(chunk.pg_size) + if chunk.extra_dp_group is not None: + grad_chunk.cuda_global_chunk.div_(chunk.extra_dp_size) + else: + grad_chunk.cuda_shard.div_(chunk.pg_size) + if chunk.extra_dp_group is not None: + grad_chunk.cuda_shard.div_(chunk.extra_dp_size) + # check overflow elements + chunk_manager.overflow_counter += grad_chunk.has_inf_or_nan + # record l2 norm for gradient clipping. flag is bound to fp16 chunk + if chunk.l2_norm_flag: + grad_chunk.set_l2_norm() + chunk_manager.move_chunk(grad_chunk, grads_device[p], force_copy=True) + if not (master_weights) or (enable_gradient_accumulation): + chunk_manager.move_chunk(chunk, grads_device[p], force_copy=True) + def zero_grad(self, set_to_none: bool = False) -> None: self.module.zero_grad(set_to_none=True) diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index 18918eabc..1d755c417 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -62,10 +62,10 @@ class GeminiFP16MixedPrecisionMixin(FP16MixedPrecisionMixin): self.module = module def check_local_overflow(self) -> bool: - return self.module.chunk_manager.overflow_counter > 0 + return self.module.chunk_manager.overflow_counter.item() > 0 def pre_zero_grad(self) -> None: - self.module.chunk_manager.overflow_counter = 0 + self.module.chunk_manager.overflow_counter.zero_() class GeminiOptimizer(OptimizerWrapper): diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index 5cc602181..6f91ff7b7 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -76,6 +76,8 @@ def main(): parser.add_argument("--mbs", type=int, default=1, help="Micro batch size of pipeline parallel") parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled") parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False) + parser.add_argument("--disable-async-reduce", action="store_true", help="Customize checkpoint", default=False) + args = parser.parse_args() colossalai.launch_from_torch() @@ -110,6 +112,7 @@ def main(): extra_dp_size=args.extra_dp, enable_fused_normalization=torch.cuda.is_available(), enable_flash_attention=args.xformers, + enable_async_reduce=not args.disable_async_reduce, ) elif args.plugin == "gemini_auto": plugin = GeminiPlugin( diff --git a/tests/test_zero/test_gemini/test_chunkv2.py b/tests/test_zero/test_gemini/test_chunkv2.py index 257311328..51b20c400 100644 --- a/tests/test_zero/test_gemini/test_chunkv2.py +++ b/tests/test_zero/test_gemini/test_chunkv2.py @@ -34,7 +34,8 @@ def check_equal(param, param_cp): @parameterize("init_device", [None, torch.device("cpu")]) @parameterize("keep_gathered", [True, False]) @parameterize("pin_memory", [True, False]) -def exam_chunk_basic(init_device, keep_gathered, pin_memory): +@parameterize("async_op", [True, False]) +def exam_chunk_basic(init_device, keep_gathered, pin_memory, async_op): world_size = torch.distributed.get_world_size() pg = _get_default_group() my_chunk = Chunk( @@ -94,9 +95,12 @@ def exam_chunk_basic(init_device, keep_gathered, pin_memory): assert my_chunk.tensor_state_cnter[TensorState.READY_FOR_REDUCE] == 4 assert my_chunk.can_reduce - my_chunk.reduce() + my_chunk.reduce(async_op) assert my_chunk.tensor_state_cnter[TensorState.HOLD] == 4 + if async_op: + my_chunk.wait_async_reduce() + if keep_gathered is False: assert my_chunk.cuda_shard.size(0) == 1024 // world_size assert my_chunk.device_type == "cuda" diff --git a/tests/test_zero/test_gemini/test_fwd_bwd.py b/tests/test_zero/test_gemini/test_fwd_bwd.py index 570a0aa42..4279793d7 100644 --- a/tests/test_zero/test_gemini/test_fwd_bwd.py +++ b/tests/test_zero/test_gemini/test_fwd_bwd.py @@ -40,12 +40,14 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module): @parameterize("model_name", ["transformers_gpt_lm"]) @parameterize("use_grad_checkpoint", [False, True]) @parameterize("master_weights", [False, True]) +@parameterize("enable_async_reduce", [False, True]) def exam_gpt_fwd_bwd( placement_config, keep_gather, model_name: str, use_grad_checkpoint: bool = False, master_weights: bool = True, + enable_async_reduce=True, ): init_device = get_accelerator().get_current_device() model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next( @@ -69,7 +71,13 @@ def exam_gpt_fwd_bwd( config_dict[world_size]["chunk_size"] = 5000 config_dict[world_size]["keep_gathered"] = keep_gather model = GeminiDDP( - model, config_dict, init_device, pin_memory=True, **placement_config, master_weights=master_weights + model, + config_dict, + init_device, + pin_memory=True, + **placement_config, + master_weights=master_weights, + enable_async_reduce=enable_async_reduce, ) optimizer = HybridAdam(model.parameters(), lr=1e-3) zero_optim = GeminiOptimizer(optimizer, model, initial_scale=1) diff --git a/tests/test_zero/test_gemini/test_grad_accum.py b/tests/test_zero/test_gemini/test_grad_accum.py index fd0e9fd7c..6e6c27e3f 100644 --- a/tests/test_zero/test_gemini/test_grad_accum.py +++ b/tests/test_zero/test_gemini/test_grad_accum.py @@ -50,8 +50,14 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module): @parameterize("model_name", ["transformers_gpt_lm"]) @parameterize("master_weights", [False, True]) @parameterize("use_grad_checkpoint", [False, True]) +@parameterize("enable_async_reduce", [False, True]) def exam_gemini_grad_acc( - placement_config, keep_gathered: bool, model_name: str, master_weights: bool, use_grad_checkpoint: bool + placement_config, + keep_gathered: bool, + model_name: str, + master_weights: bool, + use_grad_checkpoint: bool, + enable_async_reduce: bool, ): init_device = get_accelerator().get_current_device() model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next( @@ -81,10 +87,13 @@ def exam_gemini_grad_acc( pin_memory=True, enable_gradient_accumulation=True, master_weights=master_weights, + enable_async_reduce=enable_async_reduce, **placement_config, ) optimizer = HybridAdam(gemini_model.parameters(), lr=1e-3) - gemini_optim = GeminiOptimizer(optimizer, gemini_model, initial_scale=1, max_norm=1.0) + gemini_optim = GeminiOptimizer( + optimizer, gemini_model, initial_scale=1, max_norm=1.0, enable_async_reduce=enable_async_reduce + ) rank = dist.get_rank() diff --git a/tests/test_zero/test_gemini/test_grad_clip.py b/tests/test_zero/test_gemini/test_grad_clip.py index 0a9bac092..7a1609ca5 100644 --- a/tests/test_zero/test_gemini/test_grad_clip.py +++ b/tests/test_zero/test_gemini/test_grad_clip.py @@ -52,7 +52,8 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module): @parameterize("placement_config", PLACEMENT_CONFIGS) @parameterize("model_name", ["transformers_gpt_lm"]) @parameterize("master_weights", [True, False]) -def exam_grad_clipping(placement_config, model_name: str, master_weights: bool): +@parameterize("enable_async_reduce", [False, True]) +def exam_grad_clipping(placement_config, model_name: str, master_weights: bool, enable_async_reduce: bool): set_seed(1912) model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next( iter(model_zoo.get_sub_registry(model_name).values()) @@ -84,6 +85,7 @@ def exam_grad_clipping(placement_config, model_name: str, master_weights: bool): chunk_init_device=init_device, pin_memory=True, master_weights=master_weights, + enable_async_reduce=enable_async_reduce, **placement_config, ) diff --git a/tests/test_zero/test_gemini/test_optim.py b/tests/test_zero/test_gemini/test_optim.py index a9366e7bc..c610259b2 100644 --- a/tests/test_zero/test_gemini/test_optim.py +++ b/tests/test_zero/test_gemini/test_optim.py @@ -73,7 +73,10 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dty @parameterize("model_name", TEST_MODELS) @parameterize("mixed_precision", [torch.half, torch.bfloat16]) @parameterize("master_weights", [True, False]) -def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dtype, master_weights: bool): +@parameterize("enable_async_reduce", [False, True]) +def exam_model_step( + placement_config, model_name: str, mixed_precision: torch.dtype, master_weights: bool, enable_async_reduce=True +): set_seed(42) model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next( iter(model_zoo.get_sub_registry(model_name).values()) @@ -96,7 +99,12 @@ def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dt config_dict[world_size]["chunk_size"] = 5000 config_dict[world_size]["keep_gathered"] = False model = GeminiDDP( - model, config_dict, **placement_config, mixed_precision=mixed_precision, master_weights=master_weights + model, + config_dict, + **placement_config, + mixed_precision=mixed_precision, + master_weights=master_weights, + enable_async_reduce=enable_async_reduce, ) optimizer = HybridAdam(model.parameters(), lr=1e-3)