mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 13:30:19 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -13,13 +13,13 @@ from torchvision import datasets, transforms
|
||||
|
||||
try:
|
||||
from transformer_engine import pytorch as te
|
||||
|
||||
HAVE_TE = True
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
HAVE_TE = False
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
|
||||
def __init__(self, use_te=False):
|
||||
super(Net, self).__init__()
|
||||
self.conv1 = nn.Conv2d(1, 32, 3, 1)
|
||||
@@ -64,10 +64,12 @@ def train(args, model, device, train_loader, optimizer, epoch, use_fp8):
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
if batch_idx % args.log_interval == 0:
|
||||
print(f"Train Epoch: {epoch} "
|
||||
f"[{batch_idx * len(data)}/{len(train_loader.dataset)} "
|
||||
f"({100. * batch_idx / len(train_loader):.0f}%)]\t"
|
||||
f"Loss: {loss.item():.6f}")
|
||||
print(
|
||||
f"Train Epoch: {epoch} "
|
||||
f"[{batch_idx * len(data)}/{len(train_loader.dataset)} "
|
||||
f"({100. * batch_idx / len(train_loader):.0f}%)]\t"
|
||||
f"Loss: {loss.item():.6f}"
|
||||
)
|
||||
if args.dry_run:
|
||||
break
|
||||
|
||||
@@ -75,13 +77,11 @@ def train(args, model, device, train_loader, optimizer, epoch, use_fp8):
|
||||
def calibrate(model, device, test_loader):
|
||||
"""Calibration function."""
|
||||
model.eval()
|
||||
test_loss = 0
|
||||
correct = 0
|
||||
with torch.no_grad():
|
||||
for data, target in test_loader:
|
||||
data, target = data.to(device), target.to(device)
|
||||
with te.fp8_autocast(enabled=False, calibrating=True):
|
||||
output = model(data)
|
||||
model(data)
|
||||
|
||||
|
||||
def test(model, device, test_loader, use_fp8):
|
||||
@@ -94,15 +94,17 @@ def test(model, device, test_loader, use_fp8):
|
||||
data, target = data.to(device), target.to(device)
|
||||
with te.fp8_autocast(enabled=use_fp8):
|
||||
output = model(data)
|
||||
test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss
|
||||
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
|
||||
test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss
|
||||
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
|
||||
correct += pred.eq(target.view_as(pred)).sum().item()
|
||||
|
||||
test_loss /= len(test_loader.dataset)
|
||||
|
||||
print(f"\nTest set: Average loss: {test_loss:.4f}, "
|
||||
f"Accuracy: {correct}/{len(test_loader.dataset)} "
|
||||
f"({100. * correct / len(test_loader.dataset):.0f}%)\n")
|
||||
print(
|
||||
f"\nTest set: Average loss: {test_loss:.4f}, "
|
||||
f"Accuracy: {correct}/{len(test_loader.dataset)} "
|
||||
f"({100. * correct / len(test_loader.dataset):.0f}%)\n"
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
@@ -163,10 +165,9 @@ def main():
|
||||
default=False,
|
||||
help="For Saving the current Model",
|
||||
)
|
||||
parser.add_argument("--use-fp8",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use FP8 for inference and training without recalibration")
|
||||
parser.add_argument(
|
||||
"--use-fp8", action="store_true", default=False, help="Use FP8 for inference and training without recalibration"
|
||||
)
|
||||
parser.add_argument("--use-fp8-infer", action="store_true", default=False, help="Use FP8 inference only")
|
||||
parser.add_argument("--use-te", action="store_true", default=False, help="Use Transformer Engine")
|
||||
args = parser.parse_args()
|
||||
@@ -215,7 +216,7 @@ def main():
|
||||
|
||||
if args.save_model or args.use_fp8_infer:
|
||||
torch.save(model.state_dict(), "mnist_cnn.pt")
|
||||
print('Eval with reloaded checkpoint : fp8=' + str(args.use_fp8_infer))
|
||||
print("Eval with reloaded checkpoint : fp8=" + str(args.use_fp8_infer))
|
||||
weights = torch.load("mnist_cnn.pt")
|
||||
model.load_state_dict(weights)
|
||||
test(model, device, test_loader, args.use_fp8_infer)
|
||||
|
Reference in New Issue
Block a user