mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -2,22 +2,20 @@ import torch
|
||||
import torch.nn as nn
|
||||
from transformers import GPT2Config, GPT2LMHeadModel
|
||||
|
||||
class GPTLMModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
hidden_size=768,
|
||||
num_layers=12,
|
||||
num_attention_heads=12,
|
||||
max_seq_len=1024,
|
||||
vocab_size=50257):
|
||||
class GPTLMModel(nn.Module):
|
||||
def __init__(self, hidden_size=768, num_layers=12, num_attention_heads=12, max_seq_len=1024, vocab_size=50257):
|
||||
super().__init__()
|
||||
self.model = GPT2LMHeadModel(
|
||||
GPT2Config(n_embd=hidden_size,
|
||||
n_layer=num_layers,
|
||||
n_head=num_attention_heads,
|
||||
n_positions=max_seq_len,
|
||||
n_ctx=max_seq_len,
|
||||
vocab_size=vocab_size))
|
||||
GPT2Config(
|
||||
n_embd=hidden_size,
|
||||
n_layer=num_layers,
|
||||
n_head=num_attention_heads,
|
||||
n_positions=max_seq_len,
|
||||
n_ctx=max_seq_len,
|
||||
vocab_size=vocab_size,
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
# Only return lm_logits
|
||||
@@ -25,7 +23,6 @@ class GPTLMModel(nn.Module):
|
||||
|
||||
|
||||
class GPTLMLoss(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.loss_fn = nn.CrossEntropyLoss()
|
||||
@@ -36,6 +33,7 @@ class GPTLMLoss(nn.Module):
|
||||
# Flatten the tokens
|
||||
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
|
||||
|
||||
def get_gpt2_components(model_type: str, batch_size: int):
|
||||
vocab_size = 1024
|
||||
seq_len = 8
|
||||
@@ -62,4 +60,4 @@ def get_gpt2_components(model_type: str, batch_size: int):
|
||||
kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
return kwargs
|
||||
|
||||
return gpt2_model_builder, gpt2_data_gen
|
||||
return gpt2_model_builder, gpt2_data_gen
|
||||
|
@@ -1,2 +1,2 @@
|
||||
colossalai >= 0.1.12
|
||||
torch >= 1.8.1
|
||||
torch >= 1.8.1
|
||||
|
@@ -3,7 +3,6 @@ import time
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from model_zoo import GPTLMLoss, get_gpt2_components
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
import colossalai
|
||||
@@ -14,18 +13,19 @@ from colossalai.fx.profiler import parameter_size
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.testing import spawn
|
||||
from colossalai.utils import get_current_device
|
||||
from model_zoo import GPTLMLoss, get_gpt2_components
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model_type', type=str, default="gpt2_medium")
|
||||
parser.add_argument('--batch_size', type=int, default=64)
|
||||
parser.add_argument('--solver_type', type=str, default='asyn')
|
||||
parser.add_argument('--memory_budget', type=float, default=16)
|
||||
parser.add_argument("--model_type", type=str, default="gpt2_medium")
|
||||
parser.add_argument("--batch_size", type=int, default=64)
|
||||
parser.add_argument("--solver_type", type=str, default="asyn")
|
||||
parser.add_argument("--memory_budget", type=float, default=16)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed')
|
||||
@pytest.mark.skipif(NOT_NVML, reason="pynvml is not installed")
|
||||
def train_gpt(args):
|
||||
memory_budget = args.memory_budget * 1024 * 1024 * 1024
|
||||
solver_type = args.solver_type
|
||||
@@ -34,10 +34,15 @@ def train_gpt(args):
|
||||
|
||||
# build model
|
||||
model_builder, data_gen = get_gpt2_components(model_type=model_type, batch_size=batch_size)
|
||||
label = torch.randint(low=0, high=128, size=(
|
||||
64,
|
||||
8,
|
||||
), device=get_current_device())
|
||||
label = torch.randint(
|
||||
low=0,
|
||||
high=128,
|
||||
size=(
|
||||
64,
|
||||
8,
|
||||
),
|
||||
device=get_current_device(),
|
||||
)
|
||||
criterion = GPTLMLoss()
|
||||
|
||||
start_time = time.time()
|
||||
@@ -80,18 +85,20 @@ def train_gpt(args):
|
||||
exec_time = sum(sorted(time_list)[:5]) / 5
|
||||
runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024**2
|
||||
runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024**2
|
||||
print(f'solver_type: {solver_type} | model_type: {model_type}')
|
||||
print(f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB '
|
||||
f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|')
|
||||
print(f"solver_type: {solver_type} | model_type: {model_type}")
|
||||
print(
|
||||
f"| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB "
|
||||
f"| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|"
|
||||
)
|
||||
print(time_list)
|
||||
|
||||
|
||||
def run(rank, world_size, port, args):
|
||||
config = {}
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
train_gpt(args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
spawn(run, 1, args=args)
|
||||
|
@@ -29,8 +29,8 @@ def get_gpu_mem():
|
||||
return torch.cuda.memory_allocated() / 1024**2
|
||||
|
||||
|
||||
def get_mem_info(prefix=''):
|
||||
return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB'
|
||||
def get_mem_info(prefix=""):
|
||||
return f"{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB"
|
||||
|
||||
|
||||
def get_tflops(model_numel, batch_size, seq_len, step_time):
|
||||
@@ -51,14 +51,14 @@ def main():
|
||||
logger = get_dist_logger()
|
||||
config = transformers.GPT2Config(n_position=SEQ_LENGTH, n_layer=NUM_LAYERS, n_head=NUM_HEADS, n_embd=HIDDEN_DIM)
|
||||
if FP16:
|
||||
model = GPT2LMHeadModel(config=config).half().to('cuda')
|
||||
model = GPT2LMHeadModel(config=config).half().to("cuda")
|
||||
else:
|
||||
model = GPT2LMHeadModel(config=config).to('cuda')
|
||||
model = GPT2LMHeadModel(config=config).to("cuda")
|
||||
global_numel = sum([p.numel() for p in model.parameters()])
|
||||
|
||||
meta_input_sample = {
|
||||
'input_ids': torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64).to('meta'),
|
||||
'attention_mask': torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64).to('meta'),
|
||||
"input_ids": torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64).to("meta"),
|
||||
"attention_mask": torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64).to("meta"),
|
||||
}
|
||||
|
||||
gm, solution = autoparallelize(model, meta_input_sample, return_solution=True)
|
||||
@@ -72,7 +72,7 @@ def main():
|
||||
criterion = GPTLMLoss()
|
||||
|
||||
optimizer = torch.optim.Adam(gm.parameters(), lr=0.01)
|
||||
logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
|
||||
logger.info(get_mem_info(prefix="After init model, "), ranks=[0])
|
||||
get_tflops_func = partial(get_tflops, global_numel, BATCH_SIZE, SEQ_LENGTH)
|
||||
torch.cuda.synchronize()
|
||||
model.train()
|
||||
@@ -89,10 +89,11 @@ def main():
|
||||
torch.cuda.synchronize()
|
||||
step_time = time() - start
|
||||
logger.info(
|
||||
f'[{n+1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}',
|
||||
ranks=[0])
|
||||
f"[{n+1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}",
|
||||
ranks=[0],
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@@ -8,7 +8,6 @@ from transformers.pytorch_utils import Conv1D
|
||||
|
||||
|
||||
class GPT2MLP(nn.Module):
|
||||
|
||||
def __init__(self, intermediate_size, config):
|
||||
super().__init__()
|
||||
embed_dim = config.hidden_size
|
||||
@@ -30,15 +29,15 @@ class GPT2MLP(nn.Module):
|
||||
# 2. The order of split and view op has been changed in the customized GPT2Attention class, the new
|
||||
# order is same as megatron-lm gpt model.
|
||||
class GPT2Attention(nn.Module):
|
||||
|
||||
def __init__(self, config, layer_idx=None):
|
||||
super().__init__()
|
||||
|
||||
max_positions = config.max_position_embeddings
|
||||
self.register_buffer(
|
||||
"bias",
|
||||
torch.tril(torch.ones((max_positions, max_positions),
|
||||
dtype=torch.uint8)).view(1, 1, max_positions, max_positions),
|
||||
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
|
||||
1, 1, max_positions, max_positions
|
||||
),
|
||||
)
|
||||
self.register_buffer("masked_bias", torch.tensor(-1e4))
|
||||
|
||||
@@ -64,7 +63,7 @@ class GPT2Attention(nn.Module):
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
||||
|
||||
if self.scale_attn_weights:
|
||||
attn_weights = attn_weights / (value.size(-1)**0.5)
|
||||
attn_weights = attn_weights / (value.size(-1) ** 0.5)
|
||||
|
||||
# Layer-wise attention scaling
|
||||
if self.scale_attn_by_inverse_layer_idx:
|
||||
@@ -72,7 +71,7 @@ class GPT2Attention(nn.Module):
|
||||
|
||||
# if only "normal" attention layer implements causal mask
|
||||
query_length, key_length = query.size(-2), key.size(-2)
|
||||
causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length].to(torch.bool)
|
||||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
|
||||
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
|
||||
|
||||
if attention_mask is not None:
|
||||
@@ -93,7 +92,7 @@ class GPT2Attention(nn.Module):
|
||||
def _split_heads(self, tensor, num_heads, attn_head_size):
|
||||
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
|
||||
tensor = tensor.view(new_shape)
|
||||
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
|
||||
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
|
||||
|
||||
def _merge_heads(self, tensor, num_heads, attn_head_size):
|
||||
tensor = tensor.permute(0, 2, 1, 3).contiguous()
|
||||
@@ -106,10 +105,9 @@ class GPT2Attention(nn.Module):
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
|
||||
|
||||
qkv = self.c_attn(hidden_states)
|
||||
query, key, value = self._split_heads(qkv, self.num_heads, 3 * self.head_dim).split(self.head_dim, dim=3)
|
||||
present = (key, value)
|
||||
(key, value)
|
||||
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
||||
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
||||
attn_output = self.c_proj(attn_output)
|
||||
@@ -117,7 +115,6 @@ class GPT2Attention(nn.Module):
|
||||
|
||||
|
||||
class GPT2Block(nn.Module):
|
||||
|
||||
def __init__(self, config, layer_idx=None):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
@@ -152,7 +149,6 @@ class GPT2Block(nn.Module):
|
||||
|
||||
|
||||
class GPT2Model(GPT2PreTrainedModel):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@@ -189,11 +185,9 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
# GPT2Attention mask.
|
||||
attention_mask = attention_mask.view(batch_size, -1)
|
||||
attention_mask = attention_mask[:, None, None, :]
|
||||
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||
attention_mask = (1.0 - attention_mask) * -10000.0
|
||||
|
||||
encoder_attention_mask = None
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
@@ -217,7 +211,6 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
|
||||
|
||||
class GPT2LMHeadModel(GPT2PreTrainedModel):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.transformer = GPT2Model(config)
|
||||
@@ -241,7 +234,6 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
||||
|
||||
|
||||
class GPTLMLoss(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.loss_fn = nn.CrossEntropyLoss()
|
||||
|
@@ -4,22 +4,25 @@ from transformers import GPT2Config, GPT2LMHeadModel
|
||||
|
||||
## Define the Model and Loss Based on Huggingface transformers GPT2LMHeadModel
|
||||
class GPTLMModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
hidden_size=768,
|
||||
num_layers=12,
|
||||
num_attention_heads=12,
|
||||
max_seq_len=1024,
|
||||
vocab_size=50257,
|
||||
checkpoint=False):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size=768,
|
||||
num_layers=12,
|
||||
num_attention_heads=12,
|
||||
max_seq_len=1024,
|
||||
vocab_size=50257,
|
||||
checkpoint=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.checkpoint = checkpoint
|
||||
self.config = GPT2Config(n_embd=hidden_size,
|
||||
n_layer=num_layers,
|
||||
n_head=num_attention_heads,
|
||||
n_positions=max_seq_len,
|
||||
n_ctx=max_seq_len,
|
||||
vocab_size=vocab_size)
|
||||
self.config = GPT2Config(
|
||||
n_embd=hidden_size,
|
||||
n_layer=num_layers,
|
||||
n_head=num_attention_heads,
|
||||
n_positions=max_seq_len,
|
||||
n_ctx=max_seq_len,
|
||||
vocab_size=vocab_size,
|
||||
)
|
||||
self.model = GPT2LMHeadModel(self.config)
|
||||
if checkpoint:
|
||||
self.model.gradient_checkpointing_enable()
|
||||
@@ -70,4 +73,4 @@ def model_builder(model_size: str) -> callable:
|
||||
raise TypeError(f"model_builder {model_size}")
|
||||
|
||||
|
||||
__all__ = ['model_builder']
|
||||
__all__ = ["model_builder"]
|
||||
|
@@ -3,41 +3,34 @@ import time
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from model_zoo import model_builder
|
||||
from torch import nn
|
||||
from tqdm import tqdm
|
||||
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.fx.passes.adding_split_node_pass import (
|
||||
avgnode_split_pass,
|
||||
gpipe_dp_split_pass,
|
||||
split_with_split_nodes_pass,
|
||||
)
|
||||
from colossalai.fx.passes.adding_split_node_pass import gpipe_dp_split_pass, split_with_split_nodes_pass
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.legacy.pipeline.middleware.adaptor import get_fx_topology
|
||||
from colossalai.legacy.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine
|
||||
from colossalai.legacy.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine
|
||||
from colossalai.legacy.pipeline.rpc.utils import rpc_run
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from model_zoo import model_builder
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model_type', type=str, default="gpt2_medium")
|
||||
parser.add_argument('--world_size', type=int, default=2)
|
||||
parser.add_argument('--batch_size', type=int, default=16)
|
||||
parser.add_argument('--dp_degree', type=int, default=1)
|
||||
parser.add_argument('--tp_degree', type=int, default=1)
|
||||
parser.add_argument('--num_microbatches', type=int, default=2)
|
||||
parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda')
|
||||
parser.add_argument('--master_addr', type=str, default='localhost')
|
||||
parser.add_argument('--master_port', type=str, default='29011')
|
||||
parser.add_argument('--num_worker_threads', type=int, default=128)
|
||||
parser.add_argument("--model_type", type=str, default="gpt2_medium")
|
||||
parser.add_argument("--world_size", type=int, default=2)
|
||||
parser.add_argument("--batch_size", type=int, default=16)
|
||||
parser.add_argument("--dp_degree", type=int, default=1)
|
||||
parser.add_argument("--tp_degree", type=int, default=1)
|
||||
parser.add_argument("--num_microbatches", type=int, default=2)
|
||||
parser.add_argument("--device", type=str, choices=["cpu", "cuda"], default="cuda")
|
||||
parser.add_argument("--master_addr", type=str, default="localhost")
|
||||
parser.add_argument("--master_port", type=str, default="29011")
|
||||
parser.add_argument("--num_worker_threads", type=int, default=128)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
class GPTLMLoss(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.loss_fn = nn.CrossEntropyLoss()
|
||||
@@ -63,16 +56,16 @@ def get_tflops(model_numel, batch_size, seq_len, step_time):
|
||||
# Create annotated model which is noted where to be splitted.
|
||||
def get_annotated_model(model, data_kwargs, num_stages, num_microbatches):
|
||||
tracer = ColoTracer()
|
||||
meta_args = {k: v.to('meta') for k, v in data_kwargs.items()}
|
||||
meta_args = {k: v.to("meta") for k, v in data_kwargs.items()}
|
||||
graph = tracer.trace(root=model, meta_args=meta_args)
|
||||
gm = torch.fx.GraphModule(model, graph, model.__class__.__name__)
|
||||
|
||||
interp_meta_args = tuple([v.to('meta') for k, v in data_kwargs.items()])
|
||||
interp_meta_args = tuple([v.to("meta") for k, v in data_kwargs.items()])
|
||||
interp = MetaInfoProp(gm)
|
||||
interp.run(*interp_meta_args)
|
||||
|
||||
#annotated_model = avgnode_split_pass(gm, num_stages)
|
||||
annotated_model = gpipe_dp_split_pass(gm, num_stages, num_microbatches, mode='block', block_limit=0.01)
|
||||
# annotated_model = avgnode_split_pass(gm, num_stages)
|
||||
annotated_model = gpipe_dp_split_pass(gm, num_stages, num_microbatches, mode="block", block_limit=0.01)
|
||||
|
||||
return annotated_model
|
||||
|
||||
@@ -83,7 +76,7 @@ def create_partition_module(pp_rank: int, num_stages: int, model, data_kwargs, n
|
||||
topo = get_fx_topology(top_module)
|
||||
for submodule in split_submodules:
|
||||
if isinstance(submodule, torch.fx.GraphModule):
|
||||
setattr(submodule, '_topo', topo)
|
||||
setattr(submodule, "_topo", topo)
|
||||
return split_submodules[pp_rank + 1]
|
||||
|
||||
|
||||
@@ -107,8 +100,10 @@ def run_master(args):
|
||||
|
||||
disable_existing_loggers()
|
||||
logger = get_dist_logger()
|
||||
logger.info(f"{args.model_type}, batch size {batch_size}, num stage {stage_num}, num microbatch {num_microbatches}",
|
||||
ranks=[0])
|
||||
logger.info(
|
||||
f"{args.model_type}, batch size {batch_size}, num stage {stage_num}, num microbatch {num_microbatches}",
|
||||
ranks=[0],
|
||||
)
|
||||
|
||||
torch.manual_seed(123)
|
||||
|
||||
@@ -117,26 +112,28 @@ def run_master(args):
|
||||
|
||||
# warm up pipeline fx partition
|
||||
input_ids, attn_mask = get_data(batch_size, SEQ_LEN, VOCAB_SIZE)
|
||||
warmup_data_kwargs = {'input_ids': input_ids, 'attention_mask': attn_mask}
|
||||
warmup_data_kwargs = {"input_ids": input_ids, "attention_mask": attn_mask}
|
||||
|
||||
# create model
|
||||
logger.info(f'start model_builder')
|
||||
logger.info(f"start model_builder")
|
||||
model = model_builder(model_type)(checkpoint=False)
|
||||
logger.info(f'end model_builder')
|
||||
logger.info(f"end model_builder")
|
||||
|
||||
# set 1f1b pipeline engine
|
||||
pp_engine = FillDrainPipelineEngine(partition_fn=partial(partition, model, warmup_data_kwargs, num_microbatches),
|
||||
stage_num=stage_num,
|
||||
num_microbatches=num_microbatches,
|
||||
device=device,
|
||||
chunk=1,
|
||||
criterion=criterion,
|
||||
metric=None,
|
||||
checkpoint=False)
|
||||
pp_engine = FillDrainPipelineEngine(
|
||||
partition_fn=partial(partition, model, warmup_data_kwargs, num_microbatches),
|
||||
stage_num=stage_num,
|
||||
num_microbatches=num_microbatches,
|
||||
device=device,
|
||||
chunk=1,
|
||||
criterion=criterion,
|
||||
metric=None,
|
||||
checkpoint=False,
|
||||
)
|
||||
|
||||
partition_numels = pp_engine.remote_numels()
|
||||
for rank, numel in partition_numels.items():
|
||||
logger.info(f'{rank=} numel in the partition:{numel}')
|
||||
logger.info(f"{rank=} numel in the partition:{numel}")
|
||||
|
||||
# build optim
|
||||
pp_engine.initialize_optimizer(torch.optim.Adam, lr=1e-3)
|
||||
@@ -145,7 +142,7 @@ def run_master(args):
|
||||
for n in range(NUM_STEPS):
|
||||
# we just use randomly generated data here
|
||||
input_ids, attn_mask = get_data(batch_size, SEQ_LEN, VOCAB_SIZE)
|
||||
batch = {'input_ids': input_ids, 'attention_mask': attn_mask}
|
||||
batch = {"input_ids": input_ids, "attention_mask": attn_mask}
|
||||
|
||||
start = time.time()
|
||||
outputs = pp_engine.forward_backward(batch=batch, labels=input_ids, forward_only=False)
|
||||
@@ -175,6 +172,6 @@ def run_master(args):
|
||||
logger.info(f"Avg TFLOPS per GPU is {sum(gpu_tflops) / world_size:.3f}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
rpc_run(args, run_master)
|
||||
|
Reference in New Issue
Block a user