mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-06 16:14:07 +00:00
[doc] migrate the markdown files (#2652)
This commit is contained in:
61
docs/source/en/basics/model_checkpoint.md
Normal file
61
docs/source/en/basics/model_checkpoint.md
Normal file
@@ -0,0 +1,61 @@
|
||||
# Model Checkpoint
|
||||
|
||||
Author : Guangyang Lu
|
||||
|
||||
**Prerequisite:**
|
||||
- [Launch Colossal-AI](./launch_colossalai.md)
|
||||
- [Initialize Colossal-AI](./initialize_features.md)
|
||||
|
||||
**Example Code:**
|
||||
- [ColossalAI-Examples Model Checkpoint](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/utils/checkpoint)
|
||||
|
||||
**This function is experiential.**
|
||||
|
||||
## Introduction
|
||||
|
||||
In this tutorial, you will learn how to save and load model checkpoints.
|
||||
|
||||
To leverage the power of parallel strategies in Colossal-AI, modifications to models and tensors are needed, for which you cannot directly use `torch.save` or `torch.load` to save or load model checkpoints. Therefore, we have provided you with the API to achieve the same thing.
|
||||
|
||||
Moreover, when loading, you are not demanded to use the same parallel strategy as saving.
|
||||
|
||||
## How to use
|
||||
|
||||
### Save
|
||||
|
||||
There are two ways to train a model in Colossal-AI, by engine or by trainer.
|
||||
**Be aware that we only save the `state_dict`.** Therefore, when loading the checkpoints, you need to define the model first.
|
||||
|
||||
#### Save when using engine
|
||||
|
||||
```python
|
||||
from colossalai.utils import save_checkpoint
|
||||
model = ...
|
||||
engine, _, _, _ = colossalai.initialize(model=model, ...)
|
||||
for epoch in range(num_epochs):
|
||||
... # do some training
|
||||
save_checkpoint('xxx.pt', epoch, model)
|
||||
```
|
||||
|
||||
#### Save when using trainer
|
||||
```python
|
||||
from colossalai.trainer import Trainer, hooks
|
||||
model = ...
|
||||
engine, _, _, _ = colossalai.initialize(model=model, ...)
|
||||
trainer = Trainer(engine, ...)
|
||||
hook_list = [
|
||||
hooks.SaveCheckpointHook(1, 'xxx.pt', model)
|
||||
...]
|
||||
|
||||
trainer.fit(...
|
||||
hook=hook_list)
|
||||
```
|
||||
|
||||
### Load
|
||||
|
||||
```python
|
||||
from colossalai.utils import load_checkpoint
|
||||
model = ...
|
||||
load_checkpoint('xxx.pt', model)
|
||||
... # train or test
|
||||
```
|
||||
Reference in New Issue
Block a user