mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-31 23:35:47 +00:00
[doc] Fix gradient accumulation doc. (#4349)
* [doc] fix gradient accumulation doc * [doc] fix gradient accumulation doc
This commit is contained in:
parent
38b792aab2
commit
f40b718959
@ -103,10 +103,12 @@ for idx, (img, label) in enumerate(train_dataloader):
|
||||
with sync_context:
|
||||
output = model(img)
|
||||
train_loss = criterion(output, label)
|
||||
train_loss = train_loss / GRADIENT_ACCUMULATION
|
||||
booster.backward(train_loss, optimizer)
|
||||
else:
|
||||
output = model(img)
|
||||
train_loss = criterion(output, label)
|
||||
train_loss = train_loss / GRADIENT_ACCUMULATION
|
||||
booster.backward(train_loss, optimizer)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
@ -106,10 +106,12 @@ for idx, (img, label) in enumerate(train_dataloader):
|
||||
with sync_context:
|
||||
output = model(img)
|
||||
train_loss = criterion(output, label)
|
||||
train_loss = train_loss / GRADIENT_ACCUMULATION
|
||||
booster.backward(train_loss, optimizer)
|
||||
else:
|
||||
output = model(img)
|
||||
train_loss = criterion(output, label)
|
||||
train_loss = train_loss / GRADIENT_ACCUMULATION
|
||||
booster.backward(train_loss, optimizer)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
Loading…
Reference in New Issue
Block a user