mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -11,7 +11,6 @@ def exists(val):
|
||||
|
||||
|
||||
def eval_decorator(fn):
|
||||
|
||||
def inner(model, *args, **kwargs):
|
||||
was_training = model.training
|
||||
model.eval()
|
||||
@@ -34,7 +33,6 @@ def top_k(logits, thres=0.9):
|
||||
|
||||
|
||||
class AutoregressiveWrapper(nn.Module):
|
||||
|
||||
def __init__(self, net, max_seq_len=2048, pad_value=0):
|
||||
super().__init__()
|
||||
self.max_seq_len = max_seq_len
|
||||
|
@@ -1,14 +1,13 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from torch import einsum, matmul, nn
|
||||
from torch import matmul, nn
|
||||
|
||||
# normalization
|
||||
# they use layernorm without bias, something that pytorch does not offer
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
|
||||
def __init__(self, dim, eps=1e-5):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
@@ -24,7 +23,6 @@ class LayerNorm(nn.Module):
|
||||
|
||||
|
||||
class ParallelResidual(nn.Module):
|
||||
|
||||
def __init__(self, *fns):
|
||||
super().__init__()
|
||||
self.fns = nn.ModuleList(fns)
|
||||
@@ -38,16 +36,15 @@ class ParallelResidual(nn.Module):
|
||||
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
inv_freq = 1.0 / (10000**(torch.arange(0, dim, 2).float() / dim))
|
||||
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
|
||||
self.register_buffer("inv_freq", inv_freq)
|
||||
|
||||
def forward(self, max_seq_len, *, device):
|
||||
seq = torch.arange(max_seq_len, device=device)
|
||||
#freqs = einsum("i , j -> i j", seq.type_as(self.inv_freq), self.inv_freq)
|
||||
#freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq)
|
||||
# freqs = einsum("i , j -> i j", seq.type_as(self.inv_freq), self.inv_freq)
|
||||
# freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq)
|
||||
i, j = len(seq.type_as(self.inv_freq)), len(self.inv_freq)
|
||||
freqs = matmul(seq.type_as(self.inv_freq).reshape(i, 1), self.inv_freq.reshape(1, j))
|
||||
return torch.cat((freqs, freqs), dim=-1)
|
||||
@@ -69,7 +66,6 @@ def apply_rotary_pos_emb(pos, t):
|
||||
|
||||
|
||||
class SwiGLU(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
x, gate = x.chunk(2, dim=-1)
|
||||
return F.silu(gate) * x
|
||||
@@ -87,7 +83,6 @@ def FeedForward(dim, mult=4):
|
||||
|
||||
# attention
|
||||
class Attention(nn.Module):
|
||||
|
||||
def __init__(self, dim, dim_head=64, heads=8):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
@@ -160,7 +155,7 @@ class Attention(nn.Module):
|
||||
|
||||
# similarity
|
||||
|
||||
#sim = einsum("b h i d, b j d -> b h i j", q, k)
|
||||
# sim = einsum("b h i d, b j d -> b h i j", q, k)
|
||||
sim = matmul(q.reshape(b, h * i, d), k.transpose(1, 2))
|
||||
sim = sim.reshape(b, h, i, j)
|
||||
|
||||
@@ -178,7 +173,7 @@ class Attention(nn.Module):
|
||||
|
||||
# aggregate values
|
||||
|
||||
#out = einsum("b h i j, b j d -> b h i d", attn, v)
|
||||
# out = einsum("b h i j, b j d -> b h i d", attn, v)
|
||||
out = matmul(attn.reshape(b_, h_ * i_, j_), v)
|
||||
out = out.reshape(b_, h_, i_, d_)
|
||||
|
||||
@@ -193,12 +188,17 @@ class Attention(nn.Module):
|
||||
|
||||
def PaLM(*, dim, num_tokens, depth, dim_head=64, heads=8, ff_mult=4):
|
||||
net = nn.Sequential(
|
||||
nn.Embedding(num_tokens, dim), *[
|
||||
nn.Embedding(num_tokens, dim),
|
||||
*[
|
||||
ParallelResidual(
|
||||
Attention(dim=dim, dim_head=dim_head, heads=heads),
|
||||
FeedForward(dim=dim, mult=ff_mult),
|
||||
) for _ in range(depth)
|
||||
], LayerNorm(dim), nn.Linear(dim, num_tokens, bias=False))
|
||||
)
|
||||
for _ in range(depth)
|
||||
],
|
||||
LayerNorm(dim),
|
||||
nn.Linear(dim, num_tokens, bias=False),
|
||||
)
|
||||
|
||||
# they used embedding weight tied projection out to logits, not common, but works
|
||||
net[-1].weight = net[0].weight
|
||||
|
@@ -37,7 +37,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--distplan",
|
||||
type=str,
|
||||
default='colossalai',
|
||||
default="colossalai",
|
||||
help="The distributed plan [colossalai, pytorch].",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -46,12 +46,14 @@ def parse_args():
|
||||
default=1.0,
|
||||
help="Fraction of optimizer states to be offloaded. This is only used for gemini.",
|
||||
)
|
||||
parser.add_argument('-p',
|
||||
'--plugin',
|
||||
type=str,
|
||||
default='torch_ddp',
|
||||
choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'],
|
||||
help="plugin to use")
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--plugin",
|
||||
type=str,
|
||||
default="torch_ddp",
|
||||
choices=["torch_ddp", "torch_ddp_fp16", "gemini", "low_level_zero"],
|
||||
help="plugin to use",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_size",
|
||||
type=int,
|
||||
@@ -122,7 +124,6 @@ print("generate dataset ready!")
|
||||
|
||||
|
||||
class TextSamplerDataset(Dataset):
|
||||
|
||||
def __init__(self, data, seq_len):
|
||||
super().__init__()
|
||||
self.data = data
|
||||
@@ -130,7 +131,7 @@ class TextSamplerDataset(Dataset):
|
||||
|
||||
def __getitem__(self, index):
|
||||
rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
|
||||
full_seq = self.data[rand_start:rand_start + self.seq_len + 1].long()
|
||||
full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long()
|
||||
return full_seq.cuda()
|
||||
|
||||
def __len__(self):
|
||||
@@ -146,18 +147,18 @@ if args.distplan == "colossalai":
|
||||
# instantiate GPT-like decoder model
|
||||
|
||||
booster_kwargs = {}
|
||||
if args.plugin == 'torch_ddp_fp16':
|
||||
booster_kwargs['mixed_precision'] = 'fp16'
|
||||
if args.plugin.startswith('torch_ddp'):
|
||||
if args.plugin == "torch_ddp_fp16":
|
||||
booster_kwargs["mixed_precision"] = "fp16"
|
||||
if args.plugin.startswith("torch_ddp"):
|
||||
plugin = TorchDDPPlugin()
|
||||
elif args.plugin == 'gemini':
|
||||
elif args.plugin == "gemini":
|
||||
plugin = GeminiPlugin(offload_optim_frac=args.offload_optim_frac, initial_scale=2**5)
|
||||
elif args.plugin == 'low_level_zero':
|
||||
elif args.plugin == "low_level_zero":
|
||||
plugin = LowLevelZeroPlugin(initial_scale=2**5)
|
||||
logger.info(f"plugin: {plugin}")
|
||||
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||
|
||||
ctx = LazyInitContext(default_device=get_current_device()) if args.plugin == 'gemini' else nullcontext()
|
||||
ctx = LazyInitContext(default_device=get_current_device()) if args.plugin == "gemini" else nullcontext()
|
||||
|
||||
with ctx:
|
||||
model = PaLM(num_tokens=50304, dim=4096, depth=64)
|
||||
@@ -182,7 +183,6 @@ get_tflops_func = partial(get_tflops, numel, args.batch_size, SEQ_LEN)
|
||||
model.train()
|
||||
tflops_list = []
|
||||
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
|
||||
|
||||
if args.distplan == "colossalai":
|
||||
optimizer.zero_grad()
|
||||
start = time()
|
||||
@@ -231,12 +231,12 @@ logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}")
|
||||
# loss = model(next(val_loader))
|
||||
# print(f"validation loss: {loss.item()}")
|
||||
|
||||
# if i % GENERATE_EVERY == 0:
|
||||
# model.eval()
|
||||
# inp = random.choice(val_dataset)[:-1]
|
||||
# prime = decode_tokens(inp)
|
||||
# print(f"%s \n\n %s", (prime, "*" * 100))
|
||||
# if i % GENERATE_EVERY == 0:
|
||||
# model.eval()
|
||||
# inp = random.choice(val_dataset)[:-1]
|
||||
# prime = decode_tokens(inp)
|
||||
# print(f"%s \n\n %s", (prime, "*" * 100))
|
||||
|
||||
# sample = model.generate(inp[None, ...], GENERATE_LENGTH)
|
||||
# output_str = decode_tokens(sample[0])
|
||||
# print(output_str)
|
||||
# sample = model.generate(inp[None, ...], GENERATE_LENGTH)
|
||||
# output_str = decode_tokens(sample[0])
|
||||
# print(output_str)
|
||||
|
Reference in New Issue
Block a user