mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 10:34:41 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -9,12 +9,14 @@ SUPPORT_XFORMERS = False
|
||||
SUPPORT_FLASH2 = False
|
||||
try:
|
||||
import xformers.ops as xops
|
||||
|
||||
SUPPORT_XFORMERS = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_func
|
||||
|
||||
SUPPORT_FLASH2 = True
|
||||
except ImportError:
|
||||
pass
|
||||
@@ -62,10 +64,9 @@ def llama_flash_attention(
|
||||
if SUPPORT_FLASH2:
|
||||
attn_output = flash_attn_func(query_states, key_states, value_states, causal=True)
|
||||
else:
|
||||
attn_output = xops.memory_efficient_attention(query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_bias=xops.LowerTriangularMask())
|
||||
attn_output = xops.memory_efficient_attention(
|
||||
query_states, key_states, value_states, attn_bias=xops.LowerTriangularMask()
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
|
@@ -25,21 +25,22 @@ from colossalai.utils import get_current_device
|
||||
# ==============================
|
||||
|
||||
MODEL_CONFIGS = {
|
||||
'7b':
|
||||
LlamaConfig(max_position_embeddings=4096),
|
||||
'13b':
|
||||
LlamaConfig(hidden_size=5120,
|
||||
intermediate_size=13824,
|
||||
num_hidden_layers=40,
|
||||
num_attention_heads=40,
|
||||
max_position_embeddings=4096),
|
||||
'70b':
|
||||
LlamaConfig(hidden_size=8192,
|
||||
intermediate_size=28672,
|
||||
num_hidden_layers=80,
|
||||
num_attention_heads=64,
|
||||
max_position_embeddings=4096,
|
||||
num_key_value_heads=8),
|
||||
"7b": LlamaConfig(max_position_embeddings=4096),
|
||||
"13b": LlamaConfig(
|
||||
hidden_size=5120,
|
||||
intermediate_size=13824,
|
||||
num_hidden_layers=40,
|
||||
num_attention_heads=40,
|
||||
max_position_embeddings=4096,
|
||||
),
|
||||
"70b": LlamaConfig(
|
||||
hidden_size=8192,
|
||||
intermediate_size=28672,
|
||||
num_hidden_layers=80,
|
||||
num_attention_heads=64,
|
||||
max_position_embeddings=4096,
|
||||
num_key_value_heads=8,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -48,31 +49,31 @@ def main():
|
||||
# Parse Arguments
|
||||
# ==============================
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-c', '--config', type=str, default='7b', help='Model configuration')
|
||||
parser.add_argument('-p',
|
||||
'--plugin',
|
||||
choices=['gemini', 'gemini_auto', 'fsdp', 'fsdp_cpu', '3d', '3d_cpu'],
|
||||
default='gemini',
|
||||
help='Choose which plugin to use')
|
||||
parser.add_argument('-b', '--batch_size', type=int, default=2, help='Batch size')
|
||||
parser.add_argument('-s', '--num_steps', type=int, default=5, help='Number of steps to run')
|
||||
parser.add_argument('-i', '--ignore_steps', type=int, default=2, help='Number of steps to ignore')
|
||||
parser.add_argument('-g', '--grad_checkpoint', action='store_true', help='Use gradient checkpointing')
|
||||
parser.add_argument('-l', '--max_length', type=int, default=4096, help='Max sequence length')
|
||||
parser.add_argument('-w',
|
||||
'--warmup_ratio',
|
||||
type=float,
|
||||
default=0.8,
|
||||
help='warm up ratio of non-model data. Only for gemini-auto')
|
||||
parser.add_argument('-m', '--memory_limit', type=int, help='Gemini memory limit in mb')
|
||||
parser.add_argument('-x', '--xformers', action='store_true', help='Use xformers')
|
||||
parser.add_argument('--shard_param_frac', type=float, default=1.0, help='Shard param fraction. Only for gemini')
|
||||
parser.add_argument('--offload_optim_frac', type=float, default=0.0, help='Offload optim fraction. Only for gemini')
|
||||
parser.add_argument('--offload_param_frac', type=float, default=0.0, help='Offload param fraction. Only for gemini')
|
||||
parser.add_argument('--tp', type=int, default=1, help='Tensor parallel size')
|
||||
parser.add_argument('--pp', type=int, default=1, help='Pipeline parallel size')
|
||||
parser.add_argument('--mbs', type=int, default=1)
|
||||
parser.add_argument('--zero', type=int, default=0)
|
||||
parser.add_argument("-c", "--config", type=str, default="7b", help="Model configuration")
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--plugin",
|
||||
choices=["gemini", "gemini_auto", "fsdp", "fsdp_cpu", "3d", "3d_cpu"],
|
||||
default="gemini",
|
||||
help="Choose which plugin to use",
|
||||
)
|
||||
parser.add_argument("-b", "--batch_size", type=int, default=2, help="Batch size")
|
||||
parser.add_argument("-s", "--num_steps", type=int, default=5, help="Number of steps to run")
|
||||
parser.add_argument("-i", "--ignore_steps", type=int, default=2, help="Number of steps to ignore")
|
||||
parser.add_argument("-g", "--grad_checkpoint", action="store_true", help="Use gradient checkpointing")
|
||||
parser.add_argument("-l", "--max_length", type=int, default=4096, help="Max sequence length")
|
||||
parser.add_argument(
|
||||
"-w", "--warmup_ratio", type=float, default=0.8, help="warm up ratio of non-model data. Only for gemini-auto"
|
||||
)
|
||||
parser.add_argument("-m", "--memory_limit", type=int, help="Gemini memory limit in mb")
|
||||
parser.add_argument("-x", "--xformers", action="store_true", help="Use xformers")
|
||||
parser.add_argument("--shard_param_frac", type=float, default=1.0, help="Shard param fraction. Only for gemini")
|
||||
parser.add_argument("--offload_optim_frac", type=float, default=0.0, help="Offload optim fraction. Only for gemini")
|
||||
parser.add_argument("--offload_param_frac", type=float, default=0.0, help="Offload param fraction. Only for gemini")
|
||||
parser.add_argument("--tp", type=int, default=1, help="Tensor parallel size")
|
||||
parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel size")
|
||||
parser.add_argument("--mbs", type=int, default=1)
|
||||
parser.add_argument("--zero", type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
||||
colossalai.launch_from_torch({})
|
||||
@@ -85,56 +86,67 @@ def main():
|
||||
# Initialize Booster
|
||||
# ==============================
|
||||
use_empty_init = True
|
||||
if args.plugin == 'gemini':
|
||||
plugin = GeminiPlugin(precision='bf16',
|
||||
shard_param_frac=args.shard_param_frac,
|
||||
offload_optim_frac=args.offload_optim_frac,
|
||||
offload_param_frac=args.offload_param_frac)
|
||||
elif args.plugin == 'gemini_auto':
|
||||
plugin = GeminiPlugin(placement_policy='auto', precision='bf16', warmup_non_model_data_ratio=args.warmup_ratio)
|
||||
elif args.plugin == 'fsdp':
|
||||
if args.plugin == "gemini":
|
||||
plugin = GeminiPlugin(
|
||||
precision="bf16",
|
||||
shard_param_frac=args.shard_param_frac,
|
||||
offload_optim_frac=args.offload_optim_frac,
|
||||
offload_param_frac=args.offload_param_frac,
|
||||
)
|
||||
elif args.plugin == "gemini_auto":
|
||||
plugin = GeminiPlugin(placement_policy="auto", precision="bf16", warmup_non_model_data_ratio=args.warmup_ratio)
|
||||
elif args.plugin == "fsdp":
|
||||
if use_empty_init:
|
||||
plugin = TorchFSDPPlugin(
|
||||
mixed_precision=MixedPrecision(param_dtype=torch.float16,
|
||||
reduce_dtype=torch.float16,
|
||||
buffer_dtype=torch.float16),
|
||||
mixed_precision=MixedPrecision(
|
||||
param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16
|
||||
),
|
||||
param_init_fn=empty_init(),
|
||||
)
|
||||
else:
|
||||
plugin = TorchFSDPPlugin(mixed_precision=MixedPrecision(
|
||||
param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16))
|
||||
elif args.plugin == 'fsdp_cpu':
|
||||
plugin = TorchFSDPPlugin(
|
||||
mixed_precision=MixedPrecision(
|
||||
param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16
|
||||
)
|
||||
)
|
||||
elif args.plugin == "fsdp_cpu":
|
||||
if use_empty_init:
|
||||
plugin = TorchFSDPPlugin(
|
||||
mixed_precision=MixedPrecision(param_dtype=torch.float16,
|
||||
reduce_dtype=torch.float16,
|
||||
buffer_dtype=torch.float16),
|
||||
mixed_precision=MixedPrecision(
|
||||
param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16
|
||||
),
|
||||
cpu_offload=CPUOffload(offload_params=True),
|
||||
param_init_fn=empty_init(),
|
||||
)
|
||||
else:
|
||||
plugin = TorchFSDPPlugin(mixed_precision=MixedPrecision(param_dtype=torch.float16,
|
||||
reduce_dtype=torch.float16,
|
||||
buffer_dtype=torch.float16),
|
||||
cpu_offload=CPUOffload(offload_params=True))
|
||||
elif args.plugin == '3d':
|
||||
plugin = HybridParallelPlugin(tp_size=args.tp,
|
||||
pp_size=args.pp,
|
||||
zero_stage=args.zero,
|
||||
enable_fused_normalization=True,
|
||||
num_microbatches=args.mbs,
|
||||
precision='bf16')
|
||||
elif args.plugin == '3d_cpu':
|
||||
plugin = HybridParallelPlugin(tp_size=args.tp,
|
||||
pp_size=args.pp,
|
||||
zero_stage=args.zero,
|
||||
cpu_offload=True,
|
||||
enable_fused_normalization=True,
|
||||
num_microbatches=args.mbs,
|
||||
initial_scale=2**8,
|
||||
precision='bf16')
|
||||
plugin = TorchFSDPPlugin(
|
||||
mixed_precision=MixedPrecision(
|
||||
param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16
|
||||
),
|
||||
cpu_offload=CPUOffload(offload_params=True),
|
||||
)
|
||||
elif args.plugin == "3d":
|
||||
plugin = HybridParallelPlugin(
|
||||
tp_size=args.tp,
|
||||
pp_size=args.pp,
|
||||
zero_stage=args.zero,
|
||||
enable_fused_normalization=True,
|
||||
num_microbatches=args.mbs,
|
||||
precision="bf16",
|
||||
)
|
||||
elif args.plugin == "3d_cpu":
|
||||
plugin = HybridParallelPlugin(
|
||||
tp_size=args.tp,
|
||||
pp_size=args.pp,
|
||||
zero_stage=args.zero,
|
||||
cpu_offload=True,
|
||||
enable_fused_normalization=True,
|
||||
num_microbatches=args.mbs,
|
||||
initial_scale=2**8,
|
||||
precision="bf16",
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'Unknown plugin {args.plugin}')
|
||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
@@ -144,17 +156,19 @@ def main():
|
||||
dp_size = plugin.dp_size if isinstance(plugin, HybridParallelPlugin) else coordinator.world_size
|
||||
|
||||
config = MODEL_CONFIGS[args.config]
|
||||
dataset = RandomDataset(num_samples=args.batch_size * args.num_steps * dp_size,
|
||||
max_length=args.max_length,
|
||||
vocab_size=config.vocab_size)
|
||||
dataset = RandomDataset(
|
||||
num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size
|
||||
)
|
||||
dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True)
|
||||
|
||||
# ==============================
|
||||
# Initialize Model and Optimizer
|
||||
# ==============================
|
||||
init_ctx = LazyInitContext(
|
||||
default_device=get_current_device()) if isinstance(plugin,
|
||||
(GeminiPlugin, HybridParallelPlugin)) else nullcontext()
|
||||
init_ctx = (
|
||||
LazyInitContext(default_device=get_current_device())
|
||||
if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin))
|
||||
else nullcontext()
|
||||
)
|
||||
|
||||
with init_ctx:
|
||||
model = LlamaForCausalLM(config)
|
||||
@@ -163,38 +177,36 @@ def main():
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
if args.xformers:
|
||||
assert SUPPORT_FLASH, 'Use flash attention while xfomers is not installed'
|
||||
assert SUPPORT_FLASH, "Use flash attention while xfomers is not installed"
|
||||
replace_xformers(model)
|
||||
|
||||
model_numel = get_model_numel(model)
|
||||
coordinator.print_on_master(f'Model params: {format_numel_str(model_numel)}')
|
||||
performance_evaluator = PerformanceEvaluator(model_numel,
|
||||
args.grad_checkpoint,
|
||||
args.ignore_steps,
|
||||
dp_world_size=dp_size)
|
||||
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
|
||||
performance_evaluator = PerformanceEvaluator(
|
||||
model_numel, args.grad_checkpoint, args.ignore_steps, dp_world_size=dp_size
|
||||
)
|
||||
|
||||
optimizer = HybridAdam(model.parameters())
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader)
|
||||
torch.set_default_dtype(torch.float)
|
||||
coordinator.print_on_master(f'Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB')
|
||||
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
||||
coordinator.print_on_master(
|
||||
f'Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB')
|
||||
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB"
|
||||
)
|
||||
|
||||
if isinstance(plugin, HybridParallelPlugin) and args.pp > 1:
|
||||
data_iter = iter(dataloader)
|
||||
for step in tqdm(range(len(dataloader)), desc='Step', disable=not coordinator.is_master()):
|
||||
for step in tqdm(range(len(dataloader)), desc="Step", disable=not coordinator.is_master()):
|
||||
performance_evaluator.on_step_start(step)
|
||||
booster.execute_pipeline(data_iter,
|
||||
model,
|
||||
criterion=lambda outputs, inputs: outputs[0],
|
||||
optimizer=optimizer,
|
||||
return_loss=False)
|
||||
booster.execute_pipeline(
|
||||
data_iter, model, criterion=lambda outputs, inputs: outputs[0], optimizer=optimizer, return_loss=False
|
||||
)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
performance_evaluator.on_step_end(input_ids=torch.empty(args.batch_size, args.max_length))
|
||||
else:
|
||||
for step, batch in enumerate(tqdm(dataloader, desc='Step', disable=not coordinator.is_master())):
|
||||
for step, batch in enumerate(tqdm(dataloader, desc="Step", disable=not coordinator.is_master())):
|
||||
performance_evaluator.on_step_start(step)
|
||||
outputs = model(**batch)
|
||||
loss = outputs[0]
|
||||
@@ -204,8 +216,8 @@ def main():
|
||||
performance_evaluator.on_step_end(**batch)
|
||||
|
||||
performance_evaluator.on_fit_end()
|
||||
coordinator.print_on_master(f'Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB')
|
||||
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@@ -12,21 +12,22 @@ from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
class StatefulDistributedSampler(DistributedSampler):
|
||||
|
||||
def __init__(self,
|
||||
dataset: Dataset,
|
||||
num_replicas: Optional[int] = None,
|
||||
rank: Optional[int] = None,
|
||||
shuffle: bool = True,
|
||||
seed: int = 0,
|
||||
drop_last: bool = False) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
num_replicas: Optional[int] = None,
|
||||
rank: Optional[int] = None,
|
||||
shuffle: bool = True,
|
||||
seed: int = 0,
|
||||
drop_last: bool = False,
|
||||
) -> None:
|
||||
super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last)
|
||||
self.start_index: int = 0
|
||||
|
||||
def __iter__(self) -> Iterator:
|
||||
iterator = super().__iter__()
|
||||
indices = list(iterator)
|
||||
indices = indices[self.start_index:]
|
||||
indices = indices[self.start_index :]
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self) -> int:
|
||||
@@ -36,15 +37,17 @@ class StatefulDistributedSampler(DistributedSampler):
|
||||
self.start_index = start_index
|
||||
|
||||
|
||||
def prepare_dataloader(dataset,
|
||||
batch_size,
|
||||
shuffle=False,
|
||||
seed=1024,
|
||||
drop_last=False,
|
||||
pin_memory=False,
|
||||
num_workers=0,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
**kwargs):
|
||||
def prepare_dataloader(
|
||||
dataset,
|
||||
batch_size,
|
||||
shuffle=False,
|
||||
seed=1024,
|
||||
drop_last=False,
|
||||
pin_memory=False,
|
||||
num_workers=0,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Prepare a dataloader for distributed training. The dataloader will be wrapped by
|
||||
`torch.utils.data.DataLoader` and `StatefulDistributedSampler`.
|
||||
@@ -68,10 +71,9 @@ def prepare_dataloader(dataset,
|
||||
"""
|
||||
_kwargs = kwargs.copy()
|
||||
process_group = process_group or _get_default_group()
|
||||
sampler = StatefulDistributedSampler(dataset,
|
||||
num_replicas=process_group.size(),
|
||||
rank=process_group.rank(),
|
||||
shuffle=shuffle)
|
||||
sampler = StatefulDistributedSampler(
|
||||
dataset, num_replicas=process_group.size(), rank=process_group.rank(), shuffle=shuffle
|
||||
)
|
||||
|
||||
# Deterministic dataloader
|
||||
def seed_worker(worker_id):
|
||||
@@ -80,28 +82,29 @@ def prepare_dataloader(dataset,
|
||||
torch.manual_seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
|
||||
return DataLoader(dataset,
|
||||
batch_size=batch_size,
|
||||
sampler=sampler,
|
||||
worker_init_fn=seed_worker,
|
||||
drop_last=drop_last,
|
||||
pin_memory=pin_memory,
|
||||
num_workers=num_workers,
|
||||
**_kwargs)
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
sampler=sampler,
|
||||
worker_init_fn=seed_worker,
|
||||
drop_last=drop_last,
|
||||
pin_memory=pin_memory,
|
||||
num_workers=num_workers,
|
||||
**_kwargs,
|
||||
)
|
||||
|
||||
|
||||
def load_json(file_path: str):
|
||||
with open(file_path, 'r') as f:
|
||||
with open(file_path, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def save_json(data, file_path: str):
|
||||
with open(file_path, 'w') as f:
|
||||
with open(file_path, "w") as f:
|
||||
json.dump(data, f, indent=4)
|
||||
|
||||
|
||||
class RandomDataset(Dataset):
|
||||
|
||||
def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000):
|
||||
self.num_samples = num_samples
|
||||
self.max_length = max_length
|
||||
@@ -113,7 +116,7 @@ class RandomDataset(Dataset):
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return {
|
||||
'input_ids': self.input_ids[idx],
|
||||
'attention_mask': self.attention_mask[idx],
|
||||
'labels': self.input_ids[idx]
|
||||
"input_ids": self.input_ids[idx],
|
||||
"attention_mask": self.attention_mask[idx],
|
||||
"labels": self.input_ids[idx],
|
||||
}
|
||||
|
@@ -39,20 +39,20 @@ def format_numel_str(numel: int) -> str:
|
||||
M = 1024**2
|
||||
K = 1024
|
||||
if numel >= B:
|
||||
return f'{numel / B:.2f} B'
|
||||
return f"{numel / B:.2f} B"
|
||||
elif numel >= M:
|
||||
return f'{numel / M:.2f} M'
|
||||
return f"{numel / M:.2f} M"
|
||||
elif numel >= K:
|
||||
return f'{numel / K:.2f} K'
|
||||
return f"{numel / K:.2f} K"
|
||||
else:
|
||||
return f'{numel}'
|
||||
return f"{numel}"
|
||||
|
||||
|
||||
def tokenize_batch_for_finetune(batch, tokenizer: Optional[LlamaTokenizer] = None, max_length: int = 2048):
|
||||
texts = [sample['prompt'] + sample['completion'] for sample in batch]
|
||||
data = tokenizer(texts, return_tensors="pt", padding='max_length', truncation=True, max_length=max_length)
|
||||
texts = [sample["prompt"] + sample["completion"] for sample in batch]
|
||||
data = tokenizer(texts, return_tensors="pt", padding="max_length", truncation=True, max_length=max_length)
|
||||
data = {k: v.cuda() for k, v in data.items()}
|
||||
data['labels'] = data['input_ids'].clone()
|
||||
data["labels"] = data["input_ids"].clone()
|
||||
return data
|
||||
|
||||
|
||||
@@ -62,30 +62,40 @@ def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
|
||||
return tensor
|
||||
|
||||
|
||||
def save(booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler, epoch: int, step: int,
|
||||
batch_size: int, coordinator: DistCoordinator, save_dir: str):
|
||||
save_dir = os.path.join(save_dir, f'epoch{epoch}-step{step}')
|
||||
os.makedirs(os.path.join(save_dir, 'model'), exist_ok=True)
|
||||
def save(
|
||||
booster: Booster,
|
||||
model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
lr_scheduler: _LRScheduler,
|
||||
epoch: int,
|
||||
step: int,
|
||||
batch_size: int,
|
||||
coordinator: DistCoordinator,
|
||||
save_dir: str,
|
||||
):
|
||||
save_dir = os.path.join(save_dir, f"epoch{epoch}-step{step}")
|
||||
os.makedirs(os.path.join(save_dir, "model"), exist_ok=True)
|
||||
|
||||
booster.save_model(model, os.path.join(save_dir, 'model'), shard=True)
|
||||
booster.save_optimizer(optimizer, os.path.join(save_dir, 'optimizer'), shard=True)
|
||||
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, 'lr_scheduler'))
|
||||
booster.save_model(model, os.path.join(save_dir, "model"), shard=True)
|
||||
booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True)
|
||||
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
|
||||
running_states = {
|
||||
'epoch': epoch,
|
||||
'step': step,
|
||||
'sample_start_index': step * batch_size,
|
||||
"epoch": epoch,
|
||||
"step": step,
|
||||
"sample_start_index": step * batch_size,
|
||||
}
|
||||
if coordinator.is_master():
|
||||
save_json(running_states, os.path.join(save_dir, 'running_states.json'))
|
||||
save_json(running_states, os.path.join(save_dir, "running_states.json"))
|
||||
|
||||
|
||||
def load(booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler,
|
||||
load_dir: str) -> Tuple[int, int, int]:
|
||||
booster.load_model(model, os.path.join(load_dir, 'model'))
|
||||
booster.load_optimizer(optimizer, os.path.join(load_dir, 'optimizer'))
|
||||
booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, 'lr_scheduler'))
|
||||
running_states = load_json(os.path.join(load_dir, 'running_states.json'))
|
||||
return running_states['epoch'], running_states['step'], running_states['sample_start_index']
|
||||
def load(
|
||||
booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler, load_dir: str
|
||||
) -> Tuple[int, int, int]:
|
||||
booster.load_model(model, os.path.join(load_dir, "model"))
|
||||
booster.load_optimizer(optimizer, os.path.join(load_dir, "optimizer"))
|
||||
booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, "lr_scheduler"))
|
||||
running_states = load_json(os.path.join(load_dir, "running_states.json"))
|
||||
return running_states["epoch"], running_states["step"], running_states["sample_start_index"]
|
||||
|
||||
|
||||
def _criterion(outputs, inputs):
|
||||
@@ -97,27 +107,29 @@ def main():
|
||||
# Parse Arguments
|
||||
# ==============================
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model_path', type=str, help="pretrained checkpoint path, used with mode==finetune")
|
||||
parser.add_argument('-p',
|
||||
'--plugin',
|
||||
choices=['gemini', 'gemini_auto', 'zero2', 'zero2_cpu', 'hybrid_parallel'],
|
||||
default='gemini',
|
||||
help='Choose which plugin to use')
|
||||
parser.add_argument('-d', '--dataset', type=str, default='yizhongw/self_instruct', help='Data set path')
|
||||
parser.add_argument('--task_name', type=str, default="super_natural_instructions", help='task to run')
|
||||
parser.add_argument('-e', '--num_epochs', type=int, default=1, help='Number of epochs')
|
||||
parser.add_argument('-b', '--batch_size', type=int, default=2, help='Local batch size')
|
||||
parser.add_argument('--lr', type=float, default=3e-4, help='Learning rate')
|
||||
parser.add_argument('-w', '--weigth_decay', type=float, default=0.1, help='Weight decay')
|
||||
parser.add_argument('-g', '--grad_checkpoint', action='store_true', help='Use gradient checkpointing')
|
||||
parser.add_argument('-l', '--max_length', type=int, default=4096, help='Max sequence length')
|
||||
parser.add_argument('-x', '--mixed_precision', default='fp16', choices=['fp16', 'bf16'], help='Mixed precision')
|
||||
parser.add_argument('-i', '--save_interval', type=int, default=1000, help='Save interval')
|
||||
parser.add_argument('-o', '--save_dir', type=str, default='checkpoint', help='Checkpoint directory')
|
||||
parser.add_argument('-f', '--load', type=str, default=None, help='Load checkpoint')
|
||||
parser.add_argument('--grad_clip', type=float, default=1.0, help='Gradient clipping')
|
||||
parser.add_argument('-t', '--tensorboard_dir', type=str, default='tb_logs', help='Tensorboard directory')
|
||||
parser.add_argument('-a', '--flash_attention', action='store_true', help='Use Flash Attention')
|
||||
parser.add_argument("--model_path", type=str, help="pretrained checkpoint path, used with mode==finetune")
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--plugin",
|
||||
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "hybrid_parallel"],
|
||||
default="gemini",
|
||||
help="Choose which plugin to use",
|
||||
)
|
||||
parser.add_argument("-d", "--dataset", type=str, default="yizhongw/self_instruct", help="Data set path")
|
||||
parser.add_argument("--task_name", type=str, default="super_natural_instructions", help="task to run")
|
||||
parser.add_argument("-e", "--num_epochs", type=int, default=1, help="Number of epochs")
|
||||
parser.add_argument("-b", "--batch_size", type=int, default=2, help="Local batch size")
|
||||
parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
|
||||
parser.add_argument("-w", "--weigth_decay", type=float, default=0.1, help="Weight decay")
|
||||
parser.add_argument("-g", "--grad_checkpoint", action="store_true", help="Use gradient checkpointing")
|
||||
parser.add_argument("-l", "--max_length", type=int, default=4096, help="Max sequence length")
|
||||
parser.add_argument("-x", "--mixed_precision", default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
|
||||
parser.add_argument("-i", "--save_interval", type=int, default=1000, help="Save interval")
|
||||
parser.add_argument("-o", "--save_dir", type=str, default="checkpoint", help="Checkpoint directory")
|
||||
parser.add_argument("-f", "--load", type=str, default=None, help="Load checkpoint")
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping")
|
||||
parser.add_argument("-t", "--tensorboard_dir", type=str, default="tb_logs", help="Tensorboard directory")
|
||||
parser.add_argument("-a", "--flash_attention", action="store_true", help="Use Flash Attention")
|
||||
args = parser.parse_args()
|
||||
|
||||
# ==============================
|
||||
@@ -129,36 +141,34 @@ def main():
|
||||
# ==============================
|
||||
# Initialize Booster
|
||||
# ==============================
|
||||
if args.plugin == 'gemini':
|
||||
if args.plugin == "gemini":
|
||||
plugin = GeminiPlugin(precision=args.mixed_precision, initial_scale=2**16, max_norm=args.grad_clip)
|
||||
elif args.plugin == 'gemini_auto':
|
||||
plugin = GeminiPlugin(precision=args.mixed_precision,
|
||||
placement_policy='auto',
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip)
|
||||
elif args.plugin == 'zero2':
|
||||
plugin = LowLevelZeroPlugin(stage=2,
|
||||
precision=args.mixed_precision,
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip)
|
||||
elif args.plugin == 'zero2_cpu':
|
||||
plugin = LowLevelZeroPlugin(stage=2,
|
||||
precision=args.mixed_precision,
|
||||
initial_scale=2**16,
|
||||
cpu_offload=True,
|
||||
max_norm=args.grad_clip)
|
||||
elif args.plugin == 'hybrid_parallel':
|
||||
elif args.plugin == "gemini_auto":
|
||||
plugin = GeminiPlugin(
|
||||
precision=args.mixed_precision, placement_policy="auto", initial_scale=2**16, max_norm=args.grad_clip
|
||||
)
|
||||
elif args.plugin == "zero2":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
stage=2, precision=args.mixed_precision, initial_scale=2**16, max_norm=args.grad_clip
|
||||
)
|
||||
elif args.plugin == "zero2_cpu":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
stage=2, precision=args.mixed_precision, initial_scale=2**16, cpu_offload=True, max_norm=args.grad_clip
|
||||
)
|
||||
elif args.plugin == "hybrid_parallel":
|
||||
# modify the param accordingly, default configuration is for llama2-7b
|
||||
plugin = HybridParallelPlugin(tp_size=4,
|
||||
pp_size=2,
|
||||
num_microbatches=None,
|
||||
microbatch_size=1,
|
||||
enable_jit_fused=False,
|
||||
zero_stage=0,
|
||||
precision='fp32',
|
||||
initial_scale=1)
|
||||
plugin = HybridParallelPlugin(
|
||||
tp_size=4,
|
||||
pp_size=2,
|
||||
num_microbatches=None,
|
||||
microbatch_size=1,
|
||||
enable_jit_fused=False,
|
||||
zero_stage=0,
|
||||
precision="fp32",
|
||||
initial_scale=1,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'Unknown plugin {args.plugin}')
|
||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
@@ -179,8 +189,9 @@ def main():
|
||||
|
||||
config = LlamaConfig.from_pretrained(args.model_path)
|
||||
# use lazy init when using GeminiPlugin
|
||||
init_ctx = LazyInitContext(
|
||||
default_device=get_current_device()) if isinstance(plugin, GeminiPlugin) else nullcontext()
|
||||
init_ctx = (
|
||||
LazyInitContext(default_device=get_current_device()) if isinstance(plugin, GeminiPlugin) else nullcontext()
|
||||
)
|
||||
|
||||
with init_ctx:
|
||||
model = LlamaForCausalLM(config)
|
||||
@@ -188,57 +199,56 @@ def main():
|
||||
# ==============================
|
||||
# Initialize Tokenizer, Dataset and Dataloader
|
||||
# ==============================
|
||||
tokenizer = LlamaTokenizer.from_pretrained('hf-internal-testing/llama-tokenizer')
|
||||
tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
# follows fast chat: https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py#L257
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
|
||||
dataset = load_dataset(args.dataset, args.task_name)
|
||||
train_ds = dataset['train']
|
||||
dataloader = prepare_dataloader(train_ds,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=partial(tokenize_batch_for_finetune,
|
||||
tokenizer=tokenizer,
|
||||
max_length=args.max_length))
|
||||
train_ds = dataset["train"]
|
||||
dataloader = prepare_dataloader(
|
||||
train_ds,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=partial(tokenize_batch_for_finetune, tokenizer=tokenizer, max_length=args.max_length),
|
||||
)
|
||||
|
||||
if args.grad_checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
if args.flash_attention:
|
||||
assert SUPPORT_XFORMERS, 'Use flash attention while xfomers is not installed'
|
||||
assert SUPPORT_XFORMERS, "Use flash attention while xfomers is not installed"
|
||||
replace_xformers(model)
|
||||
|
||||
model_numel = get_model_numel(model)
|
||||
coordinator.print_on_master(f'Model params: {format_numel_str(model_numel)}')
|
||||
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
|
||||
|
||||
optimizer = HybridAdam(model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=args.weigth_decay)
|
||||
total_step = args.num_epochs * len(dataloader)
|
||||
lr_scheduler = CosineAnnealingWarmupLR(optimizer,
|
||||
total_steps=total_step,
|
||||
warmup_steps=math.ceil(total_step * 0.03),
|
||||
eta_min=0.1 * args.lr)
|
||||
default_dtype = torch.float16 if args.mixed_precision == 'fp16' else torch.bfloat16
|
||||
lr_scheduler = CosineAnnealingWarmupLR(
|
||||
optimizer, total_steps=total_step, warmup_steps=math.ceil(total_step * 0.03), eta_min=0.1 * args.lr
|
||||
)
|
||||
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
|
||||
torch.set_default_dtype(default_dtype)
|
||||
model, optimizer, _, dataloader, lr_scheduler = booster.boost(model,
|
||||
optimizer,
|
||||
dataloader=dataloader,
|
||||
lr_scheduler=lr_scheduler)
|
||||
model, optimizer, _, dataloader, lr_scheduler = booster.boost(
|
||||
model, optimizer, dataloader=dataloader, lr_scheduler=lr_scheduler
|
||||
)
|
||||
torch.set_default_dtype(torch.float)
|
||||
|
||||
booster.load_model(model, args.model_path)
|
||||
|
||||
coordinator.print_on_master(f'Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB')
|
||||
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
||||
coordinator.print_on_master(
|
||||
f'Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB')
|
||||
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB"
|
||||
)
|
||||
|
||||
# load checkpoint if specified
|
||||
start_epoch = 0
|
||||
start_step = 0
|
||||
sampler_start_idx = 0
|
||||
if args.load is not None:
|
||||
coordinator.print_on_master('Loading checkpoint')
|
||||
coordinator.print_on_master("Loading checkpoint")
|
||||
start_epoch, start_step, sampler_start_idx = load(booster, model, optimizer, lr_scheduler, args.load)
|
||||
coordinator.print_on_master(f'Loaded checkpoint {args.load} at epoch {start_epoch} step {start_step}')
|
||||
coordinator.print_on_master(f"Loaded checkpoint {args.load} at epoch {start_epoch} step {start_step}")
|
||||
|
||||
num_steps_per_epoch = len(dataloader)
|
||||
|
||||
@@ -249,19 +259,18 @@ def main():
|
||||
step_nums = num_steps_per_epoch - start_step
|
||||
dataloader_iter = iter(dataloader)
|
||||
|
||||
with tqdm(range(step_nums),
|
||||
desc=f'Epoch {epoch}',
|
||||
disable=not print_flag,
|
||||
total=num_steps_per_epoch,
|
||||
initial=start_step) as pbar:
|
||||
with tqdm(
|
||||
range(step_nums),
|
||||
desc=f"Epoch {epoch}",
|
||||
disable=not print_flag,
|
||||
total=num_steps_per_epoch,
|
||||
initial=start_step,
|
||||
) as pbar:
|
||||
for step in pbar:
|
||||
if use_pipeline:
|
||||
outputs = booster.execute_pipeline(dataloader_iter,
|
||||
model,
|
||||
_criterion,
|
||||
optimizer,
|
||||
return_loss=True,
|
||||
return_outputs=True)
|
||||
outputs = booster.execute_pipeline(
|
||||
dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True
|
||||
)
|
||||
loss = outputs["loss"]
|
||||
else:
|
||||
batch = next(dataloader_iter)
|
||||
@@ -276,20 +285,29 @@ def main():
|
||||
if not use_pipeline:
|
||||
all_reduce_mean(loss)
|
||||
if print_flag:
|
||||
pbar.set_postfix({'loss': loss.item()})
|
||||
writer.add_scalar('loss', loss.item(), epoch * num_steps_per_epoch + step)
|
||||
pbar.set_postfix({"loss": loss.item()})
|
||||
writer.add_scalar("loss", loss.item(), epoch * num_steps_per_epoch + step)
|
||||
|
||||
if args.save_interval > 0 and (step + 1) % args.save_interval == 0:
|
||||
coordinator.print_on_master(f'Saving checkpoint')
|
||||
save(booster, model, optimizer, lr_scheduler, epoch, step + 1, args.batch_size, coordinator,
|
||||
args.save_dir)
|
||||
coordinator.print_on_master(f'Saved checkpoint at epoch {epoch} step {step + 1}')
|
||||
coordinator.print_on_master(f"Saving checkpoint")
|
||||
save(
|
||||
booster,
|
||||
model,
|
||||
optimizer,
|
||||
lr_scheduler,
|
||||
epoch,
|
||||
step + 1,
|
||||
args.batch_size,
|
||||
coordinator,
|
||||
args.save_dir,
|
||||
)
|
||||
coordinator.print_on_master(f"Saved checkpoint at epoch {epoch} step {step + 1}")
|
||||
# the continue epochs are not resumed, so we need to reset the sampler start index and start step
|
||||
dataloader.sampler.set_start_index(0)
|
||||
start_step = 0
|
||||
|
||||
coordinator.print_on_master(f'Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB')
|
||||
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@@ -23,10 +23,10 @@ def format_numel_str(numel: int) -> str:
|
||||
M = 1024**2
|
||||
K = 1024
|
||||
if numel >= B:
|
||||
return f'{numel / B:.2f} B'
|
||||
return f"{numel / B:.2f} B"
|
||||
elif numel >= M:
|
||||
return f'{numel / M:.2f} M'
|
||||
return f"{numel / M:.2f} M"
|
||||
elif numel >= K:
|
||||
return f'{numel / K:.2f} K'
|
||||
return f"{numel / K:.2f} K"
|
||||
else:
|
||||
return f'{numel}'
|
||||
return f"{numel}"
|
||||
|
@@ -10,9 +10,9 @@ from colossalai.cluster import DistCoordinator
|
||||
|
||||
def divide(x: float, y: float) -> float:
|
||||
if y == 0:
|
||||
return float('inf')
|
||||
elif y == float('inf'):
|
||||
return float('nan')
|
||||
return float("inf")
|
||||
elif y == float("inf"):
|
||||
return float("nan")
|
||||
return x / y
|
||||
|
||||
|
||||
@@ -27,10 +27,9 @@ def all_reduce_mean(x: float, world_size: int) -> float:
|
||||
|
||||
|
||||
class Timer:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.start_time: Optional[float] = None
|
||||
self.duration: float = 0.
|
||||
self.duration: float = 0.0
|
||||
|
||||
def start(self) -> None:
|
||||
self.start_time = time()
|
||||
@@ -41,7 +40,7 @@ class Timer:
|
||||
self.start_time = None
|
||||
|
||||
def reset(self) -> None:
|
||||
self.duration = 0.
|
||||
self.duration = 0.0
|
||||
|
||||
|
||||
class PerformanceEvaluator:
|
||||
@@ -56,11 +55,13 @@ class PerformanceEvaluator:
|
||||
ignore_episodes: The number of episodes to ignore when calculating the performance.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_numel: int,
|
||||
enable_grad_checkpoint: bool = False,
|
||||
ignore_steps: int = 0,
|
||||
dp_world_size: Optional[int] = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
model_numel: int,
|
||||
enable_grad_checkpoint: bool = False,
|
||||
ignore_steps: int = 0,
|
||||
dp_world_size: Optional[int] = None,
|
||||
) -> None:
|
||||
self.model_numel = model_numel
|
||||
self.enable_grad_checkpoint = enable_grad_checkpoint
|
||||
self.ignore_steps = ignore_steps
|
||||
@@ -96,7 +97,9 @@ class PerformanceEvaluator:
|
||||
mp_world_size = self.coordinator.world_size // self.dp_world_size
|
||||
avg_tflops_per_gpu = self.flop / 1e12 / (avg_duration + 1e-12) / mp_world_size
|
||||
self.coordinator.print_on_master(
|
||||
f'num_samples: {self.num_samples}, dp_world_size: {self.dp_world_size}, flop: {self.flop}, avg_duration: {avg_duration}, '
|
||||
f'avg_throughput: {avg_throughput}')
|
||||
f"num_samples: {self.num_samples}, dp_world_size: {self.dp_world_size}, flop: {self.flop}, avg_duration: {avg_duration}, "
|
||||
f"avg_throughput: {avg_throughput}"
|
||||
)
|
||||
self.coordinator.print_on_master(
|
||||
f'Throughput: {avg_throughput:.2f} samples/sec, TFLOPS per GPU: {avg_tflops_per_gpu:.2f}')
|
||||
f"Throughput: {avg_throughput:.2f} samples/sec, TFLOPS per GPU: {avg_tflops_per_gpu:.2f}"
|
||||
)
|
||||
|
@@ -29,21 +29,22 @@ from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
MODEL_CONFIGS = {
|
||||
'7b':
|
||||
LlamaConfig(max_position_embeddings=4096),
|
||||
'13b':
|
||||
LlamaConfig(hidden_size=5120,
|
||||
intermediate_size=13824,
|
||||
num_hidden_layers=40,
|
||||
num_attention_heads=40,
|
||||
max_position_embeddings=4096),
|
||||
'70b':
|
||||
LlamaConfig(hidden_size=8192,
|
||||
intermediate_size=28672,
|
||||
num_hidden_layers=80,
|
||||
num_attention_heads=64,
|
||||
max_position_embeddings=4096,
|
||||
num_key_value_heads=8),
|
||||
"7b": LlamaConfig(max_position_embeddings=4096),
|
||||
"13b": LlamaConfig(
|
||||
hidden_size=5120,
|
||||
intermediate_size=13824,
|
||||
num_hidden_layers=40,
|
||||
num_attention_heads=40,
|
||||
max_position_embeddings=4096,
|
||||
),
|
||||
"70b": LlamaConfig(
|
||||
hidden_size=8192,
|
||||
intermediate_size=28672,
|
||||
num_hidden_layers=80,
|
||||
num_attention_heads=64,
|
||||
max_position_embeddings=4096,
|
||||
num_key_value_heads=8,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -56,20 +57,20 @@ def format_numel_str(numel: int) -> str:
|
||||
M = 1024**2
|
||||
K = 1024
|
||||
if numel >= B:
|
||||
return f'{numel / B:.2f} B'
|
||||
return f"{numel / B:.2f} B"
|
||||
elif numel >= M:
|
||||
return f'{numel / M:.2f} M'
|
||||
return f"{numel / M:.2f} M"
|
||||
elif numel >= K:
|
||||
return f'{numel / K:.2f} K'
|
||||
return f"{numel / K:.2f} K"
|
||||
else:
|
||||
return f'{numel}'
|
||||
return f"{numel}"
|
||||
|
||||
|
||||
def tokenize_batch_for_pretrain(batch, tokenizer: Optional[LlamaTokenizer] = None, max_length: int = 2048):
|
||||
texts = [sample['text'] for sample in batch]
|
||||
data = tokenizer(texts, return_tensors="pt", padding='max_length', truncation=True, max_length=max_length)
|
||||
texts = [sample["text"] for sample in batch]
|
||||
data = tokenizer(texts, return_tensors="pt", padding="max_length", truncation=True, max_length=max_length)
|
||||
data = {k: v.cuda() for k, v in data.items()}
|
||||
data['labels'] = data['input_ids'].clone()
|
||||
data["labels"] = data["input_ids"].clone()
|
||||
return data
|
||||
|
||||
|
||||
@@ -79,30 +80,40 @@ def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
|
||||
return tensor
|
||||
|
||||
|
||||
def save(booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler, epoch: int, step: int,
|
||||
batch_size: int, coordinator: DistCoordinator, save_dir: str):
|
||||
save_dir = os.path.join(save_dir, f'epoch{epoch}-step{step}')
|
||||
os.makedirs(os.path.join(save_dir, 'model'), exist_ok=True)
|
||||
def save(
|
||||
booster: Booster,
|
||||
model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
lr_scheduler: _LRScheduler,
|
||||
epoch: int,
|
||||
step: int,
|
||||
batch_size: int,
|
||||
coordinator: DistCoordinator,
|
||||
save_dir: str,
|
||||
):
|
||||
save_dir = os.path.join(save_dir, f"epoch{epoch}-step{step}")
|
||||
os.makedirs(os.path.join(save_dir, "model"), exist_ok=True)
|
||||
|
||||
booster.save_model(model, os.path.join(save_dir, 'model'), shard=True)
|
||||
booster.save_optimizer(optimizer, os.path.join(save_dir, 'optimizer'), shard=True)
|
||||
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, 'lr_scheduler'))
|
||||
booster.save_model(model, os.path.join(save_dir, "model"), shard=True)
|
||||
booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True)
|
||||
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
|
||||
running_states = {
|
||||
'epoch': epoch,
|
||||
'step': step,
|
||||
'sample_start_index': step * batch_size,
|
||||
"epoch": epoch,
|
||||
"step": step,
|
||||
"sample_start_index": step * batch_size,
|
||||
}
|
||||
if coordinator.is_master():
|
||||
save_json(running_states, os.path.join(save_dir, 'running_states.json'))
|
||||
save_json(running_states, os.path.join(save_dir, "running_states.json"))
|
||||
|
||||
|
||||
def load(booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler,
|
||||
load_dir: str) -> Tuple[int, int, int]:
|
||||
booster.load_model(model, os.path.join(load_dir, 'model'))
|
||||
booster.load_optimizer(optimizer, os.path.join(load_dir, 'optimizer'))
|
||||
booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, 'lr_scheduler'))
|
||||
running_states = load_json(os.path.join(load_dir, 'running_states.json'))
|
||||
return running_states['epoch'], running_states['step'], running_states['sample_start_index']
|
||||
def load(
|
||||
booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler, load_dir: str
|
||||
) -> Tuple[int, int, int]:
|
||||
booster.load_model(model, os.path.join(load_dir, "model"))
|
||||
booster.load_optimizer(optimizer, os.path.join(load_dir, "optimizer"))
|
||||
booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, "lr_scheduler"))
|
||||
running_states = load_json(os.path.join(load_dir, "running_states.json"))
|
||||
return running_states["epoch"], running_states["step"], running_states["sample_start_index"]
|
||||
|
||||
|
||||
def _criterion(outputs, inputs):
|
||||
@@ -114,31 +125,31 @@ def main():
|
||||
# Parse Arguments
|
||||
# ==============================
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-c', '--config', type=str, default='7b', help='Model configuration')
|
||||
parser.add_argument('-p',
|
||||
'--plugin',
|
||||
choices=['gemini', 'gemini_auto', 'zero2', 'zero2_cpu', 'hybrid_parallel'],
|
||||
default='gemini',
|
||||
help='Choose which plugin to use')
|
||||
parser.add_argument('-d',
|
||||
'--dataset',
|
||||
type=str,
|
||||
default='togethercomputer/RedPajama-Data-1T-Sample',
|
||||
help='Data set path')
|
||||
parser.add_argument('-e', '--num_epochs', type=int, default=1, help='Number of epochs')
|
||||
parser.add_argument('-b', '--batch_size', type=int, default=2, help='Local batch size')
|
||||
parser.add_argument('--lr', type=float, default=3e-4, help='Learning rate')
|
||||
parser.add_argument('-w', '--weigth_decay', type=float, default=0.1, help='Weight decay')
|
||||
parser.add_argument('-s', '--warmup_steps', type=int, default=2000, help='Warmup steps')
|
||||
parser.add_argument('-g', '--grad_checkpoint', action='store_true', help='Use gradient checkpointing')
|
||||
parser.add_argument('-l', '--max_length', type=int, default=4096, help='Max sequence length')
|
||||
parser.add_argument('-x', '--mixed_precision', default='fp16', choices=['fp16', 'bf16'], help='Mixed precision')
|
||||
parser.add_argument('-i', '--save_interval', type=int, default=1000, help='Save interval')
|
||||
parser.add_argument('-o', '--save_dir', type=str, default='checkpoint', help='Checkpoint directory')
|
||||
parser.add_argument('-f', '--load', type=str, default=None, help='Load checkpoint')
|
||||
parser.add_argument('--grad_clip', type=float, default=1.0, help='Gradient clipping')
|
||||
parser.add_argument('-t', '--tensorboard_dir', type=str, default='tb_logs', help='Tensorboard directory')
|
||||
parser.add_argument('-a', '--flash_attention', action='store_true', help='Use Flash Attention')
|
||||
parser.add_argument("-c", "--config", type=str, default="7b", help="Model configuration")
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--plugin",
|
||||
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "hybrid_parallel"],
|
||||
default="gemini",
|
||||
help="Choose which plugin to use",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-d", "--dataset", type=str, default="togethercomputer/RedPajama-Data-1T-Sample", help="Data set path"
|
||||
)
|
||||
parser.add_argument("-e", "--num_epochs", type=int, default=1, help="Number of epochs")
|
||||
parser.add_argument("-b", "--batch_size", type=int, default=2, help="Local batch size")
|
||||
parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
|
||||
parser.add_argument("-w", "--weigth_decay", type=float, default=0.1, help="Weight decay")
|
||||
parser.add_argument("-s", "--warmup_steps", type=int, default=2000, help="Warmup steps")
|
||||
parser.add_argument("-g", "--grad_checkpoint", action="store_true", help="Use gradient checkpointing")
|
||||
parser.add_argument("-l", "--max_length", type=int, default=4096, help="Max sequence length")
|
||||
parser.add_argument("-x", "--mixed_precision", default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
|
||||
parser.add_argument("-i", "--save_interval", type=int, default=1000, help="Save interval")
|
||||
parser.add_argument("-o", "--save_dir", type=str, default="checkpoint", help="Checkpoint directory")
|
||||
parser.add_argument("-f", "--load", type=str, default=None, help="Load checkpoint")
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping")
|
||||
parser.add_argument("-t", "--tensorboard_dir", type=str, default="tb_logs", help="Tensorboard directory")
|
||||
parser.add_argument("-a", "--flash_attention", action="store_true", help="Use Flash Attention")
|
||||
args = parser.parse_args()
|
||||
|
||||
# ==============================
|
||||
@@ -150,36 +161,34 @@ def main():
|
||||
# ==============================
|
||||
# Initialize Booster
|
||||
# ==============================
|
||||
if args.plugin == 'gemini':
|
||||
if args.plugin == "gemini":
|
||||
plugin = GeminiPlugin(precision=args.mixed_precision, initial_scale=2**16, max_norm=args.grad_clip)
|
||||
elif args.plugin == 'gemini_auto':
|
||||
plugin = GeminiPlugin(precision=args.mixed_precision,
|
||||
placement_policy='auto',
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip)
|
||||
elif args.plugin == 'zero2':
|
||||
plugin = LowLevelZeroPlugin(stage=2,
|
||||
precision=args.mixed_precision,
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip)
|
||||
elif args.plugin == 'zero2_cpu':
|
||||
plugin = LowLevelZeroPlugin(stage=2,
|
||||
precision=args.mixed_precision,
|
||||
initial_scale=2**16,
|
||||
cpu_offload=True,
|
||||
max_norm=args.grad_clip)
|
||||
elif args.plugin == 'hybrid_parallel':
|
||||
elif args.plugin == "gemini_auto":
|
||||
plugin = GeminiPlugin(
|
||||
precision=args.mixed_precision, placement_policy="auto", initial_scale=2**16, max_norm=args.grad_clip
|
||||
)
|
||||
elif args.plugin == "zero2":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
stage=2, precision=args.mixed_precision, initial_scale=2**16, max_norm=args.grad_clip
|
||||
)
|
||||
elif args.plugin == "zero2_cpu":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
stage=2, precision=args.mixed_precision, initial_scale=2**16, cpu_offload=True, max_norm=args.grad_clip
|
||||
)
|
||||
elif args.plugin == "hybrid_parallel":
|
||||
# modify the param accordingly, default configuration is for llama2-7b
|
||||
plugin = HybridParallelPlugin(tp_size=4,
|
||||
pp_size=2,
|
||||
num_microbatches=None,
|
||||
microbatch_size=1,
|
||||
enable_jit_fused=False,
|
||||
zero_stage=0,
|
||||
precision='fp32',
|
||||
initial_scale=1)
|
||||
plugin = HybridParallelPlugin(
|
||||
tp_size=4,
|
||||
pp_size=2,
|
||||
num_microbatches=None,
|
||||
microbatch_size=1,
|
||||
enable_jit_fused=False,
|
||||
zero_stage=0,
|
||||
precision="fp32",
|
||||
initial_scale=1,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'Unknown plugin {args.plugin}')
|
||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
@@ -197,27 +206,28 @@ def main():
|
||||
# ==============================
|
||||
# Initialize Tokenizer, Dataset and Dataloader
|
||||
# ==============================
|
||||
tokenizer = LlamaTokenizer.from_pretrained('hf-internal-testing/llama-tokenizer')
|
||||
tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
# follows fast chat: https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py#L257
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
|
||||
dataset = load_dataset(args.dataset)
|
||||
train_ds = dataset['train']
|
||||
dataloader = prepare_dataloader(train_ds,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=partial(tokenize_batch_for_pretrain,
|
||||
tokenizer=tokenizer,
|
||||
max_length=args.max_length))
|
||||
train_ds = dataset["train"]
|
||||
dataloader = prepare_dataloader(
|
||||
train_ds,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=partial(tokenize_batch_for_pretrain, tokenizer=tokenizer, max_length=args.max_length),
|
||||
)
|
||||
|
||||
# ==============================
|
||||
# Initialize Model, Optimizer and LR Scheduler
|
||||
# ==============================
|
||||
config = MODEL_CONFIGS[args.config]
|
||||
# use lazy init when using GeminiPlugin
|
||||
init_ctx = LazyInitContext(
|
||||
default_device=get_current_device()) if isinstance(plugin, GeminiPlugin) else nullcontext()
|
||||
init_ctx = (
|
||||
LazyInitContext(default_device=get_current_device()) if isinstance(plugin, GeminiPlugin) else nullcontext()
|
||||
)
|
||||
|
||||
with init_ctx:
|
||||
model = LlamaForCausalLM(config)
|
||||
@@ -225,37 +235,36 @@ def main():
|
||||
if args.grad_checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
if args.flash_attention:
|
||||
assert SUPPORT_XFORMERS, 'Use flash attention while xfomers is not installed'
|
||||
assert SUPPORT_XFORMERS, "Use flash attention while xfomers is not installed"
|
||||
replace_xformers(model)
|
||||
|
||||
model_numel = get_model_numel(model)
|
||||
coordinator.print_on_master(f'Model params: {format_numel_str(model_numel)}')
|
||||
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
|
||||
|
||||
optimizer = HybridAdam(model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=args.weigth_decay)
|
||||
lr_scheduler = CosineAnnealingWarmupLR(optimizer,
|
||||
total_steps=args.num_epochs * len(dataloader),
|
||||
warmup_steps=args.warmup_steps,
|
||||
eta_min=0.1 * args.lr)
|
||||
default_dtype = torch.float16 if args.mixed_precision == 'fp16' else torch.bfloat16
|
||||
lr_scheduler = CosineAnnealingWarmupLR(
|
||||
optimizer, total_steps=args.num_epochs * len(dataloader), warmup_steps=args.warmup_steps, eta_min=0.1 * args.lr
|
||||
)
|
||||
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
|
||||
torch.set_default_dtype(default_dtype)
|
||||
model, optimizer, _, dataloader, lr_scheduler = booster.boost(model,
|
||||
optimizer,
|
||||
dataloader=dataloader,
|
||||
lr_scheduler=lr_scheduler)
|
||||
model, optimizer, _, dataloader, lr_scheduler = booster.boost(
|
||||
model, optimizer, dataloader=dataloader, lr_scheduler=lr_scheduler
|
||||
)
|
||||
torch.set_default_dtype(torch.float)
|
||||
|
||||
coordinator.print_on_master(f'Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB')
|
||||
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
||||
coordinator.print_on_master(
|
||||
f'Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB')
|
||||
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB"
|
||||
)
|
||||
|
||||
# load checkpoint if specified
|
||||
start_epoch = 0
|
||||
start_step = 0
|
||||
sampler_start_idx = 0
|
||||
if args.load is not None:
|
||||
coordinator.print_on_master('Loading checkpoint')
|
||||
coordinator.print_on_master("Loading checkpoint")
|
||||
start_epoch, start_step, sampler_start_idx = load(booster, model, optimizer, lr_scheduler, args.load)
|
||||
coordinator.print_on_master(f'Loaded checkpoint {args.load} at epoch {start_epoch} step {start_step}')
|
||||
coordinator.print_on_master(f"Loaded checkpoint {args.load} at epoch {start_epoch} step {start_step}")
|
||||
|
||||
num_steps_per_epoch = len(dataloader)
|
||||
|
||||
@@ -266,19 +275,18 @@ def main():
|
||||
step_nums = num_steps_per_epoch - start_step
|
||||
dataloader_iter = iter(dataloader)
|
||||
|
||||
with tqdm(range(step_nums),
|
||||
desc=f'Epoch {epoch}',
|
||||
disable=not print_flag,
|
||||
total=num_steps_per_epoch,
|
||||
initial=start_step) as pbar:
|
||||
with tqdm(
|
||||
range(step_nums),
|
||||
desc=f"Epoch {epoch}",
|
||||
disable=not print_flag,
|
||||
total=num_steps_per_epoch,
|
||||
initial=start_step,
|
||||
) as pbar:
|
||||
for step in pbar:
|
||||
if use_pipeline:
|
||||
outputs = booster.execute_pipeline(dataloader_iter,
|
||||
model,
|
||||
_criterion,
|
||||
optimizer,
|
||||
return_loss=True,
|
||||
return_outputs=True)
|
||||
outputs = booster.execute_pipeline(
|
||||
dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True
|
||||
)
|
||||
loss = outputs["loss"]
|
||||
else:
|
||||
batch = next(dataloader_iter)
|
||||
@@ -293,20 +301,29 @@ def main():
|
||||
if not use_pipeline:
|
||||
all_reduce_mean(loss)
|
||||
if print_flag:
|
||||
pbar.set_postfix({'loss': loss.item()})
|
||||
writer.add_scalar('loss', loss.item(), epoch * num_steps_per_epoch + step)
|
||||
pbar.set_postfix({"loss": loss.item()})
|
||||
writer.add_scalar("loss", loss.item(), epoch * num_steps_per_epoch + step)
|
||||
|
||||
if args.save_interval > 0 and (step + 1) % args.save_interval == 0:
|
||||
coordinator.print_on_master(f'Saving checkpoint')
|
||||
save(booster, model, optimizer, lr_scheduler, epoch, step + 1, args.batch_size, coordinator,
|
||||
args.save_dir)
|
||||
coordinator.print_on_master(f'Saved checkpoint at epoch {epoch} step {step + 1}')
|
||||
coordinator.print_on_master(f"Saving checkpoint")
|
||||
save(
|
||||
booster,
|
||||
model,
|
||||
optimizer,
|
||||
lr_scheduler,
|
||||
epoch,
|
||||
step + 1,
|
||||
args.batch_size,
|
||||
coordinator,
|
||||
args.save_dir,
|
||||
)
|
||||
coordinator.print_on_master(f"Saved checkpoint at epoch {epoch} step {step + 1}")
|
||||
# the continue epochs are not resumed, so we need to reset the sampler start index and start step
|
||||
dataloader.sampler.set_start_index(0)
|
||||
start_step = 0
|
||||
|
||||
coordinator.print_on_master(f'Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB')
|
||||
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
Reference in New Issue
Block a user