mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[example] llama2 add fine-tune example (#4673)
* [shardformer] update shardformer readme [shardformer] update shardformer readme [shardformer] update shardformer readme * [shardformer] update llama2/opt finetune example and shardformer update to llama2 * [shardformer] update llama2/opt finetune example and shardformer update to llama2 * [shardformer] update llama2/opt finetune example and shardformer update to llama2 * [shardformer] change dataset * [shardformer] change dataset * [shardformer] fix CI * [shardformer] fix * [shardformer] fix * [shardformer] fix * [shardformer] fix * [shardformer] fix [example] update opt example [example] resolve comments fix fix * [example] llama2 add finetune example * [example] llama2 add finetune example * [example] llama2 add finetune example * [example] llama2 add finetune example * fix * update llama2 example * update llama2 example * fix * update llama2 example * update llama2 example * update llama2 example * update llama2 example * update llama2 example * update llama2 example * Update requirements.txt * update llama2 example * update llama2 example * update llama2 example
This commit is contained in:
@@ -13,6 +13,7 @@ from torch.distributed import ProcessGroup
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
|
||||
from .general_checkpoint_io import GeneralCheckpointIO
|
||||
@@ -71,6 +72,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
self.verbose = verbose
|
||||
self.working_to_master_map = None
|
||||
self.master_to_working_map = None
|
||||
self.coordinator = DistCoordinator()
|
||||
|
||||
@staticmethod
|
||||
def _model_sharder(model: nn.Module,
|
||||
@@ -655,7 +657,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
dist.all_gather(gather_tensor, v, group=tp_group)
|
||||
v = torch.cat(gather_tensor, dim=partition_dim)
|
||||
|
||||
state_[k] = v.detach().clone().cpu()
|
||||
state_[k] = v.detach().clone().cpu()
|
||||
|
||||
return state_
|
||||
|
||||
|
Reference in New Issue
Block a user