[example] llama2 add fine-tune example (#4673)

* [shardformer] update shardformer readme

[shardformer] update shardformer readme

[shardformer] update shardformer readme

* [shardformer] update llama2/opt finetune example and shardformer update to llama2

* [shardformer] update llama2/opt finetune example and shardformer update to llama2

* [shardformer] update llama2/opt finetune example and shardformer update to llama2

* [shardformer] change dataset

* [shardformer] change dataset

* [shardformer] fix CI

* [shardformer] fix

* [shardformer] fix

* [shardformer] fix

* [shardformer] fix

* [shardformer] fix

[example] update opt example

[example] resolve comments

fix

fix

* [example] llama2 add finetune example

* [example] llama2 add finetune example

* [example] llama2 add finetune example

* [example] llama2 add finetune example

* fix

* update llama2 example

* update llama2 example

* fix

* update llama2 example

* update llama2 example

* update llama2 example

* update llama2 example

* update llama2 example

* update llama2 example

* Update requirements.txt

* update llama2 example

* update llama2 example

* update llama2 example
This commit is contained in:
flybird11111
2023-09-15 18:45:44 +08:00
committed by GitHub
parent ac2797996b
commit 4c4482f3ad
8 changed files with 402 additions and 35 deletions

View File

@@ -13,6 +13,7 @@ from torch.distributed import ProcessGroup
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from colossalai.cluster import DistCoordinator
from colossalai.interface import OptimizerWrapper
from .general_checkpoint_io import GeneralCheckpointIO
@@ -71,6 +72,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
self.verbose = verbose
self.working_to_master_map = None
self.master_to_working_map = None
self.coordinator = DistCoordinator()
@staticmethod
def _model_sharder(model: nn.Module,
@@ -655,7 +657,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
dist.all_gather(gather_tensor, v, group=tp_group)
v = torch.cat(gather_tensor, dim=partition_dim)
state_[k] = v.detach().clone().cpu()
state_[k] = v.detach().clone().cpu()
return state_