[chat] remove naive strategy and split colossalai strategy (#4094)

* feat: remove on_learn_epoch fn as not used

* revert: add _on_learn_epoch fn

* to: remove the use of NaiveStrategy

* test: remove NaiveStrategy tests

* feat: remove NaiveStrategy

* style: modify comments and params

* feat: split ColossalAIStrategy into LowLevelZeroStrategy and GeminiStrategy

* fix: remove naive

* fix: align with modified colossal strategy

* fix: fix ddp _try_init_dist arg
This commit is contained in:
Wenhao Chen
2023-06-29 18:11:00 +08:00
committed by GitHub
parent b03d64d010
commit edd75a59ea
25 changed files with 323 additions and 349 deletions

View File

@@ -6,7 +6,7 @@ from coati.experience_maker import Experience, NaiveExperienceMaker
from coati.models.base import Actor, Critic
from coati.models.loss import PolicyLoss, ValueLoss
from coati.trainer.callbacks import Callback
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy, Strategy
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy, Strategy
from torch.optim import Adam
from colossalai.nn.optimizer import HybridAdam
@@ -85,7 +85,7 @@ class DetachedPPOTrainer(DetachedTrainer):
evaluator = TrainerPerformanceEvaluator(actor_numel, critic_numel)
callbacks = callbacks + [evaluator]
if isinstance(self.strategy, ColossalAIStrategy):
if isinstance(self.strategy, (LowLevelZeroStrategy, GeminiStrategy)):
self.actor_optim = HybridAdam(self.actor.parameters(), lr=1e-7)
self.critic_optim = HybridAdam(self.critic.parameters(), lr=1e-7)
else:

View File

@@ -1,6 +1,6 @@
import os
from typing import Any, Callable, Dict, List, Optional
from collections import OrderedDict
from typing import Any, Callable, Dict, List, Optional
import torch
import torch.distributed as dist
@@ -10,7 +10,7 @@ from coati.models.gpt import GPTRM, GPTActor, GPTCritic
from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
from coati.models.opt import OPTRM, OPTActor, OPTCritic
from coati.models.roberta import RoBERTaActor, RoBERTaCritic, RoBERTaRM
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
from coati.utils import prepare_llama_tokenizer_and_embedding
from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer, RobertaTokenizer
@@ -76,18 +76,16 @@ def get_reward_model_from_args(model: str, pretrained: str = None, config=None):
def get_strategy_from_args(strategy: str):
if strategy == 'naive':
strategy_ = NaiveStrategy()
elif strategy == 'ddp':
if strategy == 'ddp':
strategy_ = DDPStrategy()
elif strategy == 'colossalai_gemini':
strategy_ = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5)
strategy_ = GeminiStrategy(placement_policy='cuda', initial_scale=2**5)
elif strategy == 'colossalai_zero2':
strategy_ = ColossalAIStrategy(stage=2, placement_policy='cuda')
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
elif strategy == 'colossalai_gemini_cpu':
strategy_ = ColossalAIStrategy(stage=3, placement_policy='cpu', initial_scale=2**5)
strategy_ = GeminiStrategy(placement_policy='cpu', initial_scale=2**5)
elif strategy == 'colossalai_zero2_cpu':
strategy_ = ColossalAIStrategy(stage=2, placement_policy='cpu')
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy='cpu')
else:
raise ValueError(f'Unsupported strategy "{strategy}"')
return strategy_

View File

@@ -1,7 +1,7 @@
import os
import torch.distributed as dist
from coati.trainer.strategies import ColossalAIStrategy, Strategy
from coati.trainer.strategies import GeminiStrategy, LowLevelZeroStrategy, Strategy
from coati.trainer.utils import is_rank_0
from torch import nn
from torch.optim import Optimizer
@@ -69,7 +69,7 @@ class SaveCheckpoint(Callback):
# save optimizer
if self.model_dict[model][1] is None:
continue
only_rank0 = not isinstance(self.strategy, ColossalAIStrategy)
only_rank0 = not isinstance(self.strategy, (LowLevelZeroStrategy, GeminiStrategy))
rank = 0 if is_rank_0() else dist.get_rank()
optim_path = os.path.join(base_path, f'{model}-optim-rank-{rank}.pt')
self.strategy.save_optimizer(optimizer=self.model_dict[model][1], path=optim_path, only_rank0=only_rank0)

View File

@@ -15,7 +15,7 @@ from colossalai.utils import get_current_device
from .base import OnPolicyTrainer
from .callbacks import Callback
from .strategies import ColossalAIStrategy, Strategy
from .strategies import GeminiStrategy, Strategy
from .utils import is_rank_0, to_device
@@ -82,9 +82,8 @@ class PPOTrainer(OnPolicyTrainer):
callbacks: List[Callback] = [],
**generate_kwargs
) -> None:
if isinstance(strategy, ColossalAIStrategy):
from colossalai.booster.plugin import GeminiPlugin
assert not (isinstance(strategy.plugin, GeminiPlugin) and offload_inference_models), \
if isinstance(strategy, GeminiStrategy):
assert not offload_inference_models, \
"GeminiPlugin is not compatible with manual model.to('cpu')"
buffer = NaiveReplayBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)

View File

@@ -12,7 +12,7 @@ from torch.utils.data import DataLoader
from colossalai.logging import DistributedLogger
from .base import SLTrainer
from .strategies import ColossalAIStrategy, Strategy
from .strategies import GeminiStrategy, Strategy
from .utils import is_rank_0, to_device
@@ -38,9 +38,8 @@ class SFTTrainer(SLTrainer):
max_epochs: int = 2,
accumulation_steps: int = 8,
) -> None:
if accumulation_steps > 1 and isinstance(strategy, ColossalAIStrategy):
from colossalai.booster.plugin import GeminiPlugin
assert not isinstance(strategy.plugin, GeminiPlugin), \
if accumulation_steps > 1:
assert not isinstance(strategy, GeminiStrategy), \
"Accumulation steps are not supported in stage 3 of ColossalAI"
super().__init__(strategy, max_epochs, model, optim)

View File

@@ -1,6 +1,8 @@
from .base import Strategy
from .colossalai import ColossalAIStrategy
from .colossalai import GeminiStrategy, LowLevelZeroStrategy
from .ddp import DDPStrategy
from .naive import NaiveStrategy
__all__ = ['Strategy', 'NaiveStrategy', 'DDPStrategy', 'ColossalAIStrategy']
__all__ = [
'Strategy', 'DDPStrategy',
'LowLevelZeroStrategy', 'GeminiStrategy'
]

View File

@@ -18,25 +18,17 @@ from colossalai.zero.gemini.gemini_ddp import GeminiDDP
from .ddp import DDPStrategy
class ColossalAIStrategy(DDPStrategy):
class LowLevelZeroStrategy(DDPStrategy):
"""
The strategy for training with ColossalAI.
Args:
stage(int): The stage to use in ZeRO. Choose in (1, 2, 3)
precision(str): The precision to use. Choose in ('fp32', 'fp16'). Stage 3 only supports fp16.
stage(int): The stage to use in ZeRO. Choose in (1, 2)
precision(str): The precision to use. Choose in ('fp32', 'fp16').
seed(int): The seed for the random number generator.
shard_init(bool): Whether to shard the model parameters during initialization. Only for ZeRO-3.
This is not compatible with `from_pretrained()`. We temporarily disable this and will support it in the future.
placement_policy(str): The placement policy for gemini. Choose in ('cpu', 'cuda')
If it is “cpu”, parameters, gradients and optimizer states will be offloaded to CPU,
If it is “cuda”, they will not be offloaded, which means max CUDA memory will be used. It is the fastest.
pin_memory(bool): Whether to pin the memory for the data loader. Only for ZeRO-3.
force_outputs_fp32(bool): Whether to force the outputs to be fp32. Only for ZeRO-3.
search_range_m(int): The number of search range for the chunk size, divided by 2^20. Only for ZeRO-3.
hidden_dim(optional, int): The hidden dimension for the gemini. Only for ZeRO-3.
min_chunk_size_m(float): The minimum chunk size divided by 2^20. Only for ZeRO-3.
gpu_margin_mem_ratio(float): The margin memory ratio for the GPU. Only for ZeRO-3.
reduce_bucket_size(int): The reduce bucket size in bytes. Only for ZeRO-1 and ZeRO-2.
overlap_communication(bool): Whether to overlap communication and computation. Only for ZeRO-1 and ZeRO-2.
initial_scale(float): The initial scale for the optimizer.
@@ -51,132 +43,185 @@ class ColossalAIStrategy(DDPStrategy):
"""
def __init__(
self,
stage: int = 3,
precision: str = 'fp16',
seed: int = 42,
shard_init: bool = False, # only for stage 3
placement_policy: str = 'cuda',
pin_memory: bool = True, # only for stage 3
force_outputs_fp32: bool = False, # only for stage 3
search_range_m: int = 32, # only for stage 3
hidden_dim: Optional[int] = None, # only for stage 3
min_chunk_size_m: float = 32, # only for stage 3
gpu_margin_mem_ratio: float = 0.0, # only for stage 3
reduce_bucket_size: int = 12 * 1024**2, # only for stage 1&2
overlap_communication: bool = True, # only for stage 1&2
initial_scale: float = 2**16,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
min_scale: float = 1,
max_scale: float = 2**32,
max_norm: float = 0.0,
norm_type: float = 2.0) -> None:
def __init__(self,
stage: int = 3,
precision: str = 'fp16',
seed: int = 42,
placement_policy: str = 'cuda',
reduce_bucket_size: int = 12 * 1024**2, # only for stage 1&2
overlap_communication: bool = True, # only for stage 1&2
initial_scale: float = 2**16,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
min_scale: float = 1,
max_scale: float = 2**32,
max_norm: float = 0.0,
norm_type: float = 2.0
) -> None:
assert stage in (1, 2, 3), f'Unsupported stage "{stage}"'
assert stage in (1, 2), f'Unsupported stage "{stage}"'
assert placement_policy in ('cpu', 'cuda'), f'Unsupported placement policy "{placement_policy}"'
assert precision in ('fp32', 'fp16'), f'Unsupported precision "{precision}"'
# TODO(ver217): support shard_init when using from_pretrained()
if shard_init:
warnings.warn(f'Shard init is not supported model.from_pretrained() yet. '
'Please load weights after strategy.prepare()')
if stage == 3 and precision == 'fp32':
warnings.warn(f'Stage 3 only supports fp16. Precision is set to fp16.')
precision = 'fp16'
self.precision = precision
self.shard_init = shard_init
optim_kwargs = dict(initial_scale=initial_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
min_scale=min_scale,
max_scale=max_scale,
max_norm=max_norm,
norm_type=norm_type)
# NOTE: dist should be initialized before calling get_current_device()
if stage == 3:
plugin_initializer = lambda: GeminiPlugin(
# gemini_config
device=get_current_device(),
placement_policy=placement_policy,
precision=precision,
pin_memory=pin_memory,
force_outputs_fp32=force_outputs_fp32,
strict_ddp_mode=shard_init,
search_range_m=search_range_m,
hidden_dim=hidden_dim,
min_chunk_size_m=min_chunk_size_m,
# zero_optim_config
gpu_margin_mem_ratio=gpu_margin_mem_ratio,
# optim_config
**optim_kwargs)
else:
plugin_initializer = lambda: LowLevelZeroPlugin(
# zero_config
stage=stage,
precision=precision,
# zero_optim_config
reduce_bucket_size_in_m=reduce_bucket_size,
overlap_communication=overlap_communication,
cpu_offload=(placement_policy == 'cpu'),
# optim_config
**optim_kwargs)
plugin_initializer = lambda: LowLevelZeroPlugin(
# zero_config
stage=stage,
precision=precision,
# zero_optim_config
reduce_bucket_size_in_m=reduce_bucket_size,
overlap_communication=overlap_communication,
cpu_offload=(placement_policy == 'cpu'),
# optim_config
initial_scale=initial_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
min_scale=min_scale,
max_scale=max_scale,
max_norm=max_norm,
norm_type=norm_type
)
super().__init__(seed, plugin_initializer)
def _post_init(self) -> None:
assert isinstance(self.plugin, (LowLevelZeroPlugin, GeminiPlugin)), \
assert isinstance(self.plugin, LowLevelZeroPlugin), \
f'{type(self).__name__}\'s plugin is not initialized properly.'
def setup_distributed(self) -> None:
colossalai.launch_from_torch({}, seed=self.seed)
def unwrap_model(self, model: nn.Module) -> nn.Module:
assert isinstance(model, LowLevelZeroModel)
return model.module
def get_model_state_dict_shard(self, model: nn.Module, **config):
assert isinstance(model, LowLevelZeroModel)
yield from model.state_dict_shard(max_shard_size=1024, only_rank_0=False)
class GeminiStrategy(DDPStrategy):
"""
The strategy for training with ColossalAI.
Args:
seed(int): The seed for the random number generator.
shard_init(bool): Whether to shard the model parameters during initialization. Only for ZeRO-3.
This is not compatible with `from_pretrained()`. We temporarily disable this and will support it in the future.
placement_policy(str): The placement policy for gemini. Choose in ('cpu', 'cuda')
If it is “cpu”, parameters, gradients and optimizer states will be offloaded to CPU,
If it is “cuda”, they will not be offloaded, which means max CUDA memory will be used. It is the fastest.
pin_memory(bool): Whether to pin the memory for the data loader. Only for ZeRO-3.
force_outputs_fp32(bool): Whether to force the outputs to be fp32. Only for ZeRO-3.
search_range_m(int): The number of search range for the chunk size, divided by 2^20. Only for ZeRO-3.
hidden_dim(optional, int): The hidden dimension for the gemini. Only for ZeRO-3.
min_chunk_size_m(float): The minimum chunk size divided by 2^20. Only for ZeRO-3.
gpu_margin_mem_ratio(float): The margin memory ratio for the GPU. Only for ZeRO-3.
initial_scale(float): The initial scale for the optimizer.
growth_factor(float): The growth factor for the optimizer.
backoff_factor(float): The backoff factor for the optimizer.
growth_interval(int): The growth interval for the optimizer.
hysteresis(int): The hysteresis for the optimizer.
min_scale(float): The minimum scale for the optimizer.
max_scale(float): The maximum scale for the optimizer.
max_norm(float): The maximum norm for the optimizer.
norm_type(float): The norm type for the optimizer.
"""
def __init__(self,
seed: int = 42,
shard_init: bool = False, # only for stage 3
placement_policy: str = 'cuda',
pin_memory: bool = True, # only for stage 3
force_outputs_fp32: bool = False, # only for stage 3
search_range_m: int = 32, # only for stage 3
hidden_dim: Optional[int] = None, # only for stage 3
min_chunk_size_m: float = 32, # only for stage 3
gpu_margin_mem_ratio: float = 0.0, # only for stage 3
initial_scale: float = 2**16,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
min_scale: float = 1,
max_scale: float = 2**32,
max_norm: float = 0.0,
norm_type: float = 2.0
) -> None:
assert placement_policy in ('cpu', 'cuda'), f'Unsupported placement policy "{placement_policy}"'
# TODO(ver217): support shard_init when using from_pretrained()
if shard_init:
warnings.warn(
f'Shard init is not supported model.from_pretrained() yet. '
'Please load weights after strategy.prepare()'
)
self.shard_init = shard_init
warnings.warn(f'Stage 3 only supports fp16. Precision is set to fp16.')
# NOTE: dist should be initialized before calling get_current_device()
plugin_initializer = lambda: GeminiPlugin(
# gemini_config
device=get_current_device(),
placement_policy=placement_policy,
precision='fp16',
pin_memory=pin_memory,
force_outputs_fp32=force_outputs_fp32,
strict_ddp_mode=shard_init,
search_range_m=search_range_m,
hidden_dim=hidden_dim,
min_chunk_size_m=min_chunk_size_m,
# zero_optim_config
gpu_margin_mem_ratio=gpu_margin_mem_ratio,
# optim_config
initial_scale=initial_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
min_scale=min_scale,
max_scale=max_scale,
max_norm=max_norm,
norm_type=norm_type
)
super().__init__(seed, plugin_initializer)
def _post_init(self) -> None:
assert isinstance(self.plugin, GeminiPlugin), \
f'{type(self).__name__}\'s plugin is not initialized properly.'
def setup_distributed(self) -> None:
colossalai.launch_from_torch({}, seed=self.seed)
def model_init_context(self):
if isinstance(self.plugin, GeminiPlugin):
world_size = dist.get_world_size()
shard_pg = ProcessGroup(tp_degree=world_size) if self.shard_init else None
default_dist_spec = ShardSpec([-1], [world_size]) if self.shard_init else None
return ColoInitContext(device=get_current_device(),
dtype=torch.half,
default_pg=shard_pg,
default_dist_spec=default_dist_spec)
return super().model_init_context()
world_size = dist.get_world_size()
shard_pg = ProcessGroup(tp_degree=world_size) if self.shard_init else None
default_dist_spec = ShardSpec([-1], [world_size]) if self.shard_init else None
return ColoInitContext(device=get_current_device(),
dtype=torch.half,
default_pg=shard_pg,
default_dist_spec=default_dist_spec)
def unwrap_model(self, model: nn.Module) -> nn.Module:
if isinstance(self.plugin, GeminiPlugin):
assert isinstance(model, GeminiModel)
ddp_model = model.unwrap()
assert isinstance(ddp_model, GeminiDDP)
return ddp_model.module
elif isinstance(self.plugin, LowLevelZeroPlugin):
assert isinstance(model, LowLevelZeroModel)
return model.module
else:
raise RuntimeError(f'Unsupported plugin {type(self.plugin)}')
assert isinstance(model, GeminiModel)
ddp_model = model.unwrap()
assert isinstance(ddp_model, GeminiDDP)
return ddp_model.module
def save_pretrained(self,
model: nn.Module,
path: str,
only_rank0: bool = True,
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
if isinstance(self.plugin, GeminiPlugin):
raise RuntimeError('ColossalAI strategy with stage-3 does not support save_pretrained() now')
super().save_pretrained(model, path, only_rank0, tokenizer)
raise RuntimeError('ColossalAI strategy with stage-3 does not support save_pretrained() now')
def get_model_state_dict_shard(self, model: nn.Module, **config):
if not isinstance(self.plugin, GeminiPlugin):
yield from super().get_model_state_dict_shard(model, **config)
else:
# unwrapped_model = self._unwrap_model(model)
# for module in unwrapped_model.modules():
# if isinstance(module, LoraLinear):
# module.merge_weights = True
# module.eval()
assert isinstance(model, LowLevelZeroModel)
yield from model.state_dict_shard(max_shard_size=1024, only_rank_0=False)
assert isinstance(self.plugin, GeminiPlugin)
yield from super().get_model_state_dict_shard(model, **config)

View File

@@ -1,4 +1,6 @@
import os
import random
from collections import OrderedDict
from typing import Callable, Optional
import numpy as np
@@ -6,18 +8,27 @@ import torch
import torch.distributed as dist
import torch.nn as nn
from coati.replay_buffer import ReplayBuffer
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from colossalai.booster.plugin import TorchDDPPlugin
from colossalai.booster.plugin.torch_ddp_plugin import TorchDDPModel
from .naive import NaiveStrategy
from .base import Strategy
from .sampler import DistributedSampler
class DDPStrategy(NaiveStrategy):
# TODO Move this to a util.py (Moving to ray.util introduces ringed import)
def get_grad_required_state_dict(model: nn.Module):
state_dict = OrderedDict()
for name, parameter in model.named_parameters():
if parameter.requires_grad:
state_dict[name] = parameter.detach()
return state_dict
class DDPStrategy(Strategy):
"""
Strategy for distributed training using torch.distributed.
"""
@@ -29,6 +40,24 @@ class DDPStrategy(NaiveStrategy):
self.seed = seed
super().__init__(plugin_initializer)
def _try_init_dist(self, force: bool = False) -> None:
try:
rank = int(os.environ['RANK'])
local_rank = int(os.environ['LOCAL_RANK'])
world_size = int(os.environ['WORLD_SIZE'])
host = os.environ['MASTER_ADDR']
port = int(os.environ['MASTER_PORT'])
dist.init_process_group('nccl', init_method=f'tcp://[{host}]:{port}', world_size=world_size, rank=rank)
torch.cuda.set_device(local_rank)
except KeyError as e:
if force:
raise RuntimeError(
f"Could not find {e} in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch"
)
except Exception as e:
if force:
raise e
def _post_init(self) -> None:
assert isinstance(self.plugin, TorchDDPPlugin), \
f'{type(self).__name__}\'s plugin is not initialized properly.'
@@ -42,9 +71,6 @@ class DDPStrategy(NaiveStrategy):
np.random.seed(seed)
torch.manual_seed(seed)
def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: Optimizer, **kwargs) -> None:
self.booster.backward(loss, optimizer)
def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader:
return self.plugin.prepare_dataloader(replay_buffer,
batch_size=replay_buffer.sample_batch_size,
@@ -68,4 +94,32 @@ class DDPStrategy(NaiveStrategy):
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
if only_rank0 and dist.get_rank() != 0:
return
super().save_pretrained(model, path, only_rank0, tokenizer)
unwrapped_model = self.unwrap_model(model)
assert isinstance(unwrapped_model, PreTrainedModel)
unwrapped_model.save_pretrained(path)
if tokenizer is not None:
tokenizer.save_pretrained(path)
def get_model_state_dict_shard(self, model: nn.Module, **config):
# TODO: implement sharding on naive strategy
model = self.unwrap_model(model)
if 'requires_grad_only' in config and config['requires_grad_only'] == True:
state_dict = get_grad_required_state_dict(model)
else:
state_dict = model.state_dict()
if 'shard_size' in config:
shard_size = config['shard_size']
accumulate_size = 0
state_dict_shard = OrderedDict()
for name, param in state_dict.items():
state_dict_shard[name] = param
accumulate_size += param.numel() * param.element_size()
if accumulate_size >= shard_size:
accumulate_size = 0
yield state_dict_shard
state_dict_shard = OrderedDict()
if accumulate_size > 0:
yield state_dict_shard
else:
yield state_dict

View File

@@ -1,103 +0,0 @@
import os
from collections import OrderedDict
from typing import Optional
import torch
import torch.distributed as dist
import torch.nn as nn
from coati.replay_buffer import ReplayBuffer
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from .base import Strategy
# TODO Move this to a util.py (Moving to ray.util introduces ringed import)
def get_grad_required_state_dict(model: nn.Module):
state_dict = OrderedDict()
for name, parameter in model.named_parameters():
if parameter.requires_grad:
state_dict[name] = parameter.detach()
return state_dict
class NaiveStrategy(Strategy):
"""
Strategy for single GPU. No parallelism is used.
"""
def _post_init(self) -> None:
assert self.plugin is None, \
f'{type(self).__name__}\'s plugin is not initialized properly.'
def setup_distributed(self) -> None:
self._try_init_dist(force=False)
def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: Optimizer, **kwargs) -> None:
# HACK: self.booster.backward(loss, optimizer) can't work if plugin is None,
# it would run `optimizer.backward(loss)`, which is not compatible with torch.optim.Optimizer
assert self.plugin is None, "DO NOT call this method if plugin is not None"
loss.backward()
def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader:
return DataLoader(replay_buffer,
batch_size=replay_buffer.sample_batch_size,
shuffle=True,
drop_last=True,
pin_memory=pin_memory,
collate_fn=replay_buffer.collate_fn)
def save_pretrained(self,
model: nn.Module,
path: str,
only_rank0: bool = True,
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
unwrapped_model = self.unwrap_model(model)
assert isinstance(unwrapped_model, PreTrainedModel)
unwrapped_model.save_pretrained(path)
if tokenizer is not None:
tokenizer.save_pretrained(path)
def get_model_state_dict_shard(self, model: nn.Module, **config):
# TODO: implement sharding on naive strategy
model = self.unwrap_model(model)
if 'requires_grad_only' in config and config['requires_grad_only'] == True:
state_dict = get_grad_required_state_dict(model)
else:
state_dict = model.state_dict()
if 'shard_size' in config:
shard_size = config['shard_size']
accumulate_size = 0
state_dict_shard = OrderedDict()
for name, param in state_dict.items():
state_dict_shard[name] = param
accumulate_size += param.numel() * param.element_size()
if accumulate_size >= shard_size:
accumulate_size = 0
yield state_dict_shard
state_dict_shard = OrderedDict()
if accumulate_size > 0:
yield state_dict_shard
else:
yield state_dict
def _try_init_dist(self, force: bool = False) -> None:
try:
rank = int(os.environ['RANK'])
local_rank = int(os.environ['LOCAL_RANK'])
world_size = int(os.environ['WORLD_SIZE'])
host = os.environ['MASTER_ADDR']
port = int(os.environ['MASTER_PORT'])
dist.init_process_group('nccl', init_method=f'tcp://[{host}]:{port}', world_size=world_size, rank=rank)
torch.cuda.set_device(local_rank)
except KeyError as e:
if force:
raise RuntimeError(
f"Could not find {e} in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch"
)
except Exception as e:
if force:
raise e