mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-08 03:24:07 +00:00
[doc] Fix gradient accumulation doc. (#4349)
* [doc] fix gradient accumulation doc * [doc] fix gradient accumulation doc
This commit is contained in:
parent
38b792aab2
commit
f40b718959
@ -103,10 +103,12 @@ for idx, (img, label) in enumerate(train_dataloader):
|
|||||||
with sync_context:
|
with sync_context:
|
||||||
output = model(img)
|
output = model(img)
|
||||||
train_loss = criterion(output, label)
|
train_loss = criterion(output, label)
|
||||||
|
train_loss = train_loss / GRADIENT_ACCUMULATION
|
||||||
booster.backward(train_loss, optimizer)
|
booster.backward(train_loss, optimizer)
|
||||||
else:
|
else:
|
||||||
output = model(img)
|
output = model(img)
|
||||||
train_loss = criterion(output, label)
|
train_loss = criterion(output, label)
|
||||||
|
train_loss = train_loss / GRADIENT_ACCUMULATION
|
||||||
booster.backward(train_loss, optimizer)
|
booster.backward(train_loss, optimizer)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
@ -106,10 +106,12 @@ for idx, (img, label) in enumerate(train_dataloader):
|
|||||||
with sync_context:
|
with sync_context:
|
||||||
output = model(img)
|
output = model(img)
|
||||||
train_loss = criterion(output, label)
|
train_loss = criterion(output, label)
|
||||||
|
train_loss = train_loss / GRADIENT_ACCUMULATION
|
||||||
booster.backward(train_loss, optimizer)
|
booster.backward(train_loss, optimizer)
|
||||||
else:
|
else:
|
||||||
output = model(img)
|
output = model(img)
|
||||||
train_loss = criterion(output, label)
|
train_loss = criterion(output, label)
|
||||||
|
train_loss = train_loss / GRADIENT_ACCUMULATION
|
||||||
booster.backward(train_loss, optimizer)
|
booster.backward(train_loss, optimizer)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
Loading…
Reference in New Issue
Block a user