From a88bc828d5b20cc177e49bcfdc7e253a49646597 Mon Sep 17 00:00:00 2001 From: ver217 Date: Thu, 16 Feb 2023 20:09:34 +0800 Subject: [PATCH] [chatgpt] disable shard init for colossalai (#2767) --- .../ChatGPT/chatgpt/trainer/strategies/colossalai.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/applications/ChatGPT/chatgpt/trainer/strategies/colossalai.py b/applications/ChatGPT/chatgpt/trainer/strategies/colossalai.py index 665bfa913..578844bdb 100644 --- a/applications/ChatGPT/chatgpt/trainer/strategies/colossalai.py +++ b/applications/ChatGPT/chatgpt/trainer/strategies/colossalai.py @@ -1,3 +1,4 @@ +import warnings from typing import Optional import torch @@ -23,6 +24,7 @@ class ColossalAIStrategy(DDPStrategy): stage(int): The stage to use in ZeRO. Choose in (1, 2, 3) seed(int): The seed for the random number generator. shard_init(bool): Whether to shard the model parameters during initialization. Only for ZeRO-3. + This is not compativle with `from_pretrained()`. We temporarily disable this and will support it in the future. placement_policy(str): The placement policy for gemini. Choose in ('cpu', 'cuda') If it is “cpu”, parameters, gradients and optimizer states will be offloaded to CPU, If it is “cuda”, they will not be offloaded, which means max CUDA memory will be used. It is the fastest. @@ -50,7 +52,7 @@ class ColossalAIStrategy(DDPStrategy): self, stage: int = 3, seed: int = 42, - shard_init: bool = True, # only for stage 3 + shard_init: bool = False, # only for stage 3 placement_policy: str = 'cuda', pin_memory: bool = True, # only for stage 3 force_outputs_fp32: bool = False, # only for stage 3 @@ -72,6 +74,10 @@ class ColossalAIStrategy(DDPStrategy): super().__init__(seed) assert placement_policy in ('cpu', 'cuda'), f'Unsupported placement policy "{placement_policy}"' self.stage = stage + # TODO(ver217): support shard_init when using from_pretrained() + if shard_init: + warnings.warn(f'Shard init is not supported yet. Ignore.') + shard_init = False self.shard_init = shard_init self.gemini_config = dict(device=get_current_device(), placement_policy=placement_policy,