mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-22 15:26:57 +00:00
update markdown docs (english) (#60)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
# Build your engine & Customize your trainer
|
||||
# Colossal-AI Engine & Customize Your Trainer
|
||||
|
||||
## Build your engine
|
||||
## Colossal-AI engine
|
||||
|
||||
To better understand how `Engine` class works, let's start from the conception of the process function in common
|
||||
engines. The process function usually controls the behavior over a batch of a dataset, `Engine` class just controls the
|
||||
@@ -16,15 +16,7 @@ def process_function(dataloader, model, criterion, optim):
|
||||
optim.setp()
|
||||
```
|
||||
|
||||
In `ignite.engine` or `keras.engine`, the process function is always provided by users. However, it is tricky for users
|
||||
to write their own process functions for pipeline parallelism. Aiming at offering accessible hybrid parallelism for
|
||||
users, we provide the powerful `Engine` class. This class enables pipeline parallelism and offers
|
||||
one-forward-one-backward non-interleaving strategy. Also, you can use pre-defined learning rate scheduler in
|
||||
the `Engine` class to adjust learning rate during training.
|
||||
|
||||
In order to build your engine, just set variables `model`, `criterion`, `optimizer`, `lr_scheduler` and `schedule`. The
|
||||
following code block provides an example. **The engine is automatically created from the config file for you if you
|
||||
start with `colossalai.initialize`.**
|
||||
The engine class is a high-level wrapper of these frequently-used functions while preserving the PyTorch-like function signature and integrating with our features.
|
||||
|
||||
```python
|
||||
import torch
|
||||
@@ -32,18 +24,25 @@ import torch.nn as nn
|
||||
import torchvision.models as models
|
||||
import colossalai
|
||||
from colossalai.engine import Engine
|
||||
from torchvision.datasets import CIFAR10
|
||||
|
||||
model = models.resnet18()
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = torch.optim.Adam(model.parameters())
|
||||
schedule = colossalai.engine.NonPipelineSchedule()
|
||||
|
||||
MyEngine = Engine(
|
||||
model=model,
|
||||
criterion=criterion,
|
||||
optimizer=optimizer,
|
||||
step_schedule=schedule
|
||||
)
|
||||
dataset = CIFAR10(...)
|
||||
dataloader = colossalai.utils.get_dataloader(dataset)
|
||||
|
||||
engine, dataloader, _, _ = colossalai.initialize(model, optimizer, criterion, dataloader)
|
||||
|
||||
# exmaple of a training iteratio
|
||||
for img, label in dataloader:
|
||||
engine.zero_grad()
|
||||
output = engine(img)
|
||||
loss = engine.criterion(output, label)
|
||||
engine.backward(loss)
|
||||
engine.step()
|
||||
|
||||
```
|
||||
|
||||
More information regarding the class can be found in the API references.
|
||||
@@ -54,14 +53,14 @@ More information regarding the class can be found in the API references.
|
||||
|
||||
To learn how to customize a trainer which meets your needs, let's first give a look at the `Trainer` class. We highly
|
||||
recommend that you read *Get Started*
|
||||
section and *Build your engine* first.
|
||||
section and *Colossal-AI engine* first.
|
||||
|
||||
The `Trainer` class enables researchers and engineers to use our system more conveniently. Instead of having to write
|
||||
your own scripts, you can simply construct your own trainer by calling the `Trainer` class, just like what we did in the
|
||||
following code block.
|
||||
|
||||
```python
|
||||
MyTrainer = Trainer(my_engine)
|
||||
trainer = Trainer(engine)
|
||||
```
|
||||
|
||||
After that, you can use the `fit` method to train or evaluate your model. In order to make our `Trainer` class even more
|
||||
@@ -71,26 +70,55 @@ class allows you to execute your hook functions at specified time. We have alrea
|
||||
as listed below. What you need to do is just picking the right ones which suit your needs. Detailed descriptions of the
|
||||
class can be found in the API references.
|
||||
|
||||
```python
|
||||
hooks = [
|
||||
dict(type='LogMetricByEpochHook'),
|
||||
dict(type='LogTimingByEpochHook'),
|
||||
dict(type='LogMemoryByEpochHook'),
|
||||
dict(type='AccuracyHook'),
|
||||
dict(type='LossHook'),
|
||||
dict(type='TensorboardHook', log_dir='./tfb_logs'),
|
||||
dict(type='SaveCheckpointHook', interval=5, checkpoint_dir='./ckpt'),
|
||||
dict(type='LoadCheckpointHook', epoch=20, checkpoint_dir='./ckpt')
|
||||
]
|
||||
```
|
||||
|
||||
These hook functions will record metrics, elapsed time and memory usage and write them to log after each epoch. Besides,
|
||||
they print the current loss and accuracy to let users monitor the performance of the model.
|
||||
|
||||
```python
|
||||
import colossalai
|
||||
from colossalai.trainer import hooks, Trainer
|
||||
from colossalai.utils import MultiTimer
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
... = colossalai.initialize(...)
|
||||
|
||||
timer = MultiTimer()
|
||||
logger = get_dist_logger()
|
||||
|
||||
# if you want to save log to file
|
||||
logger.log_to_file('./logs/')
|
||||
|
||||
trainer = Trainer(
|
||||
engine=engine,
|
||||
timer=timer,
|
||||
logger=logger
|
||||
)
|
||||
|
||||
hook_list = [
|
||||
hooks.LossHook(),
|
||||
hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False),
|
||||
hooks.AccuracyHook(),
|
||||
hooks.TensorboardHook(log_dir='./tb_logs', ranks=[0]),
|
||||
hooks.LogMetricByEpochHook(logger),
|
||||
hooks.LogMemoryByEpochHook(logger),
|
||||
hooks.LogTimingByEpochHook(timer, logger),
|
||||
hooks.SaveCheckpointHook(checkpoint_dir='./ckpt')
|
||||
]
|
||||
|
||||
trainer.fit(
|
||||
train_dataloader=train_dataloader,
|
||||
epochs=NUM_EPOCHS,
|
||||
test_dataloader=test_dataloader,
|
||||
test_interval=1,
|
||||
hooks=hook_list,
|
||||
display_progress=True
|
||||
)
|
||||
|
||||
```
|
||||
|
||||
### Hook
|
||||
|
||||
If you have your specific needs, feel free to extend our `BaseHook` class to add your own functions, or our `MetricHook`
|
||||
class to write a metric collector. These hook functions can be called at twelve timing in the trainer's life cycle.
|
||||
class to write a metric collector. These hook functions can be called at different stage in the trainer's life cycle.
|
||||
Besides, you can define the priorities of all hooks to arrange the execution order of them. More information can be
|
||||
found in the API references.
|
||||
|
||||
|
Reference in New Issue
Block a user