[doc] fix docs about booster api usage (#3898)

This commit is contained in:
Baizhou Zhang
2023-06-06 13:36:11 +08:00
committed by GitHub
parent ec9bbc0094
commit c1535ccbba
3 changed files with 6 additions and 6 deletions

View File

@@ -195,7 +195,7 @@ def get_data(batch_size, seq_len, vocab_size):
Finally, we define a model which uses Gemini + ZeRO DDP and define our training loop, As we pre-train GPT in this example, we just use a simple language model loss:
```python
from torch.optim import Adam
from colossalai.nn.optimizer import HybridAdam
from colossalai.booster import Booster
from colossalai.zero import ColoInitContext
@@ -211,7 +211,7 @@ def main():
# build criterion
criterion = GPTLMLoss()
optimizer = Adam(model.parameters(), lr=0.001)
optimizer = HybridAdam(model.parameters(), lr=0.001)
torch.manual_seed(123)
default_pg = ProcessGroup(tp_degree=args.tp_degree)

View File

@@ -197,7 +197,7 @@ def get_data(batch_size, seq_len, vocab_size):
最后使用booster注入 Gemini + ZeRO DDP 特性, 并定义训练循环。由于我们在这个例子中对GPT进行预训练因此只使用了一个简单的语言模型损失函数
```python
from torch.optim import Adam
from colossalai.nn.optimizer import HybridAdam
from colossalai.booster import Booster
from colossalai.zero import ColoInitContext
@@ -213,7 +213,7 @@ def main():
# build criterion
criterion = GPTLMLoss()
optimizer = Adam(model.parameters(), lr=0.001)
optimizer = HybridAdam(model.parameters(), lr=0.001)
torch.manual_seed(123)
default_pg = ProcessGroup(tp_degree=args.tp_degree)