[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -13,13 +13,13 @@ from torchvision import datasets, transforms
try:
from transformer_engine import pytorch as te
HAVE_TE = True
except (ImportError, ModuleNotFoundError):
HAVE_TE = False
class Net(nn.Module):
def __init__(self, use_te=False):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
@@ -64,10 +64,12 @@ def train(args, model, device, train_loader, optimizer, epoch, use_fp8):
loss.backward()
optimizer.step()
if batch_idx % args.log_interval == 0:
print(f"Train Epoch: {epoch} "
f"[{batch_idx * len(data)}/{len(train_loader.dataset)} "
f"({100. * batch_idx / len(train_loader):.0f}%)]\t"
f"Loss: {loss.item():.6f}")
print(
f"Train Epoch: {epoch} "
f"[{batch_idx * len(data)}/{len(train_loader.dataset)} "
f"({100. * batch_idx / len(train_loader):.0f}%)]\t"
f"Loss: {loss.item():.6f}"
)
if args.dry_run:
break
@@ -75,13 +77,11 @@ def train(args, model, device, train_loader, optimizer, epoch, use_fp8):
def calibrate(model, device, test_loader):
"""Calibration function."""
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
with te.fp8_autocast(enabled=False, calibrating=True):
output = model(data)
model(data)
def test(model, device, test_loader, use_fp8):
@@ -94,15 +94,17 @@ def test(model, device, test_loader, use_fp8):
data, target = data.to(device), target.to(device)
with te.fp8_autocast(enabled=use_fp8):
output = model(data)
test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print(f"\nTest set: Average loss: {test_loss:.4f}, "
f"Accuracy: {correct}/{len(test_loader.dataset)} "
f"({100. * correct / len(test_loader.dataset):.0f}%)\n")
print(
f"\nTest set: Average loss: {test_loss:.4f}, "
f"Accuracy: {correct}/{len(test_loader.dataset)} "
f"({100. * correct / len(test_loader.dataset):.0f}%)\n"
)
def main():
@@ -163,10 +165,9 @@ def main():
default=False,
help="For Saving the current Model",
)
parser.add_argument("--use-fp8",
action="store_true",
default=False,
help="Use FP8 for inference and training without recalibration")
parser.add_argument(
"--use-fp8", action="store_true", default=False, help="Use FP8 for inference and training without recalibration"
)
parser.add_argument("--use-fp8-infer", action="store_true", default=False, help="Use FP8 inference only")
parser.add_argument("--use-te", action="store_true", default=False, help="Use Transformer Engine")
args = parser.parse_args()
@@ -215,7 +216,7 @@ def main():
if args.save_model or args.use_fp8_infer:
torch.save(model.state_dict(), "mnist_cnn.pt")
print('Eval with reloaded checkpoint : fp8=' + str(args.use_fp8_infer))
print("Eval with reloaded checkpoint : fp8=" + str(args.use_fp8_infer))
weights = torch.load("mnist_cnn.pt")
model.load_state_dict(weights)
test(model, device, test_loader, args.use_fp8_infer)