[shardformer] Align bert value (#3907)

* add bert align test, fix dist loss bug

* forward and backward align

* add ignore index

* add shardformer CI

* add gather_output optional for user in shardconfig

* update readme with optional gather_ouput

* add dist crossentropy loss test, remove unused files

* remove unused file

* remove unused file

* rename the file

* polish code
This commit is contained in:
FoolPlayer
2023-06-09 14:36:54 +08:00
committed by Frank Lee
parent 79f8d5d54b
commit f1cb5ac6bf
11 changed files with 174 additions and 197 deletions

View File

@@ -20,7 +20,7 @@
The sample API usage is given below:
``` python
from colossalai.shardformer import shard_model
from colossalai.shardformer import ShardConfig, shard_model
from transformers import BertForMaskedLM
# create huggingface model as normal
@@ -28,7 +28,12 @@ model = BertForMaskedLM.from_pretrained("bert-base-uncased")
# make the huggingface model paralleled to ShardModel
# auto policy:
sharded_model = shard_model(model)
shardconfig = ShardConfig(
rank=rank,
world_size=world_size,
gather_output=True,
)
sharded_model = shard_model(model, config=shardconfig)
# custom policy:
from xxx import <POLICYCLASS>
@@ -72,7 +77,7 @@ More details can be found in shardformer/policies/basepolicy.py
``` python
from colossalai.shardformer.policies.basepolicy import Policy, Layer, Col_Layer, Row_Layer, Argument
CustomPolicy(Policy):
class CustomPolicy(Policy):
@staticmethod
def argument_policy(model_config, shard_config: int) -> Dict[nn.Module, Argument]:
r"""
@@ -235,7 +240,7 @@ CustomPolicy(Policy):
This class is used to specify the replacement policy for a particular layer. If `replace_layer` is None, only parameter partitioning will be performed without replacing the layer class.
CLASS `Col_Layer(Layer)`:
- gather_output (bool): Whether to gather the output of the layer
- gather_output (bool): Whether the output of this layer can be gathered, like the last layer can be gathered, but most of the time, the intermediate layers of the model do not need to be gathered.
This class inherited from `Layer`, representing the layer will be sliced along column.

View File

@@ -0,0 +1 @@
from .shard import ShardConfig, shard_model

View File

@@ -14,7 +14,7 @@ class DistCrossEntropy(Function):
"""
@staticmethod
def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor):
def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor, ignore_index: int):
r"""
Calculate the cross entropy loss before gather, the origin loss function is as follows:
loss = -log(exp(x[class])/sum(exp(x[i]))
@@ -75,8 +75,8 @@ class DistCrossEntropy(Function):
# calculate the loss
# loss = log(sum(exp(x[i]))) - x[class]
loss = torch.log(sum_exp_logits) - pred_logits
loss = torch.sum(loss).div_(loss.numel())
loss = torch.where(target == ignore_index, 0.0, torch.log(sum_exp_logits) - pred_logits)
loss = torch.sum(loss).div_(torch.sum(loss != 0.0))
# caculate the softmax
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
@@ -101,5 +101,5 @@ class DistCrossEntropy(Function):
return grad_logits, None, None
def applyDistCrossEntropy(vocab_logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
return DistCrossEntropy.apply(vocab_logits, labels)
def applyDistCrossEntropy(vocab_logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100) -> torch.Tensor:
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index)

View File

@@ -141,7 +141,7 @@ class BertPolicy(Policy):
weight="decoder.weight",
bias="decoder.bias",
replace_layer=col_nn.Linear1D_Col,
# gather_output=True,
gather_output=True,
)
]
@@ -155,7 +155,8 @@ class BertForMaskedLMPolicy(BertPolicy):
@staticmethod
def inject_policy() -> Tuple[nn.Module, nn.Module]:
return (BertForMaskedLM, BertForMaskedLM_)
# return (BertForMaskedLM, BertForMaskedLM_)
return None
class BertForSequenceClassificationPolicy(BertPolicy):

View File

@@ -5,16 +5,14 @@ __all__ = ['ShardConfig']
@dataclass
class ShardConfig:
"""
The config for sharding the huggingface model for test
r"""
The config for sharding the huggingface model
Args:
rank (int): The rank of local process
world_size (int): The world size of the distributed process
gather_output (bool): Whether to gather the output of the model of the last layer
"""
rank: int
fp16: bool = True
num_gpus: int = 2
world_size: int = 2
backend = "nccl"
verbose: str = 'simple'
seed: int = None
require_grad: bool = False
master_addr: str = "127.0.0.1"
master_port: int = 29500
gather_output: bool = True

View File

@@ -65,6 +65,8 @@ class ModelSharder(object):
BertForMaskedLM.forward -> BertForMaskedLM_.forward
"""
inject_policy = self.policy.inject_policy()
if inject_policy is None:
return
if inject_policy is None:
return
@@ -148,7 +150,7 @@ class ModelSharder(object):
n_cast = policy_layer.n_cast
reversed = policy_layer.reversed
if policy_layer.__class__.__name__ == "Col_Layer":
gather_output = policy_layer.gather_output
gather_output = policy_layer.gather_output and self.shard_config.gather_output
if weight_attr is not None:
if hasattr_(org_layer, weight_attr):

View File

@@ -1 +0,0 @@
parallel = dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d'))

View File

@@ -1,50 +0,0 @@
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import colossalai
from colossalai.shardformer.layer.dist_crossentropy import applyDistCrossEntropy
from colossalai.shardformer.layer.dropout import Dropout1D
def get_args():
parser = colossalai.get_default_parser()
parser.add_argument("--module", type=str, default='distloss')
return parser.parse_args()
def test_dist_crossentropy():
pred = torch.randn(2, 4, 8, requires_grad=True)
labels = torch.randint(8, (1, 4)).repeat(2, 1)
pred_ = pred.view(-1, 8)
labels_ = labels.view(-1)
loss = F.cross_entropy(pred_, labels_)
loss.backward()
print(f"normal loss:{loss}")
pred = pred.chunk(int(os.environ['WORLD_SIZE']), -1)[int(os.environ['RANK'])]
loss = applyDistCrossEntropy(pred.to('cuda'), labels.to('cuda'))
loss.backward()
print(f"dist loss:{loss}")
def test_dropout():
input = torch.randn(5, 4).to("cuda")
m = Dropout1D(p=0.2).to("cuda")
for i in range(2):
print(f"Output: {m(input)}")
print(torch.randn(1))
if __name__ == '__main__':
args = get_args()
colossalai.launch_from_torch(config={})
if args.module == 'distloss':
test_dist_crossentropy()
elif args.module == 'dropout':
test_dropout()
else:
print("not implemented yet")

View File

@@ -1,124 +0,0 @@
import os
import random
import torch
import torch.nn as nn
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import AutoTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling, GPT2LMHeadModel, get_scheduler
import colossalai
from colossalai.shardformer.shard import ShardConfig, shard_model
from colossalai.utils import get_current_device, print_rank_0
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
def get_args():
parser = colossalai.get_default_parser()
parser.add_argument("--mode", type=str, default='inference')
parser.add_argument("--save_model", action='store_true')
parser.add_argument("--model", type=str, default='bert-base-uncased')
return parser.parse_args()
def load_data(args):
tokenizer = AutoTokenizer.from_pretrained(args.model)
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
# tokenizer.pad_token_id = 0
datasets = load_dataset('wikitext', 'wikitext-2-raw-v1')
# datasets=load_dataset("yelp_review_full")
tokenized_datasets = datasets.map(
lambda examples: tokenizer(examples["text"], truncation=True, padding="max_length"), batched=True)
tokenized_datasets = tokenized_datasets.remove_columns(["text"])
# tokenized_datasets=tokenized_datasets.rename_column("label","labels")
tokenized_datasets.set_format("torch")
train_dataset = tokenized_datasets["train"]
test_dataset = tokenized_datasets["test"]
datacollector = DataCollatorForLanguageModeling(tokenizer, mlm=True, mlm_probability=0.15, return_tensors="pt")
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=datacollector)
eval_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=True, collate_fn=datacollector)
return train_dataloader, eval_dataloader
def inference(model: nn.Module, args):
print(model)
# print(model.wte.weight.shape)
tokenizer = AutoTokenizer.from_pretrained(args.model)
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
tokenizer.pad_token_id = 0
token = "Hello, my dog is cute"
inputs = tokenizer(token, return_tensors="pt")
inputs.to("cuda")
model.eval()
model.to("cuda")
outputs = model(**inputs)
print(outputs[0])
def train(model: nn.Module, args, num_epoch: int = 3):
train_dataloader, eval_dataloader = load_data(args)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
num_training = num_epoch * len(train_dataloader)
progress_bar = tqdm(range(num_training))
lr_scheduler = get_scheduler(name="linear",
optimizer=optimizer,
num_warmup_steps=0,
num_training_steps=num_training)
best_test_loss = float("inf")
model.to("cuda")
model.train()
for epoch in range(num_epoch):
progress_bar.set_description(f"Rank {get_current_device()} epoch {epoch}")
for batch in train_dataloader:
optimizer.zero_grad()
batch = {k: v.to('cuda') for k, v in batch.items()}
outputs = model(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
lr_scheduler.step()
progress_bar.update(1)
train_loss = loss
loss = 0.0
for batch in eval_dataloader:
batch = {k: v.to('cuda') for k, v in batch.items()}
outputs = model(**batch)
# loss = outputs.loss
assert not torch.isnan(outputs.loss), f"{batch}"
loss += outputs.loss.item()
# loss = criterion(outputs.logits, batch["input_ids"])
test_loss = loss / len(eval_dataloader)
print_rank_0(f"Train Loss: {train_loss:.4f} Test Loss:{test_loss:.4f}")
if args.save_model and test_loss < best_test_loss:
best_test_loss = test_loss
torch.save(model.state_dict(), "./checkpoints/best_model.pth")
if __name__ == "__main__":
args = get_args()
colossalai.launch_from_torch(config=args.config)
if args.model == 'bert-base-uncased':
model = BertForMaskedLM.from_pretrained("bert-base-uncased")
elif args.model == 'gpt2':
model = GPT2LMHeadModel.from_pretrained("gpt2")
else:
raise AttributeError("model not supported")
shard_config = ShardConfig(
rank=int(str(get_current_device()).split(':')[-1]),
world_size=int(os.environ['WORLD_SIZE']),
)
sharded_model = shard_model(model, shard_config)
if args.mode == "train":
train(sharded_model, args)
elif args.mode == "inference":
inference(sharded_model, args)
else:
raise NotImplementedError