mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 17:17:05 +00:00
[legacy] move trainer to legacy (#4545)
* [legacy] move trainer to legacy * [doc] update docs related to trainer * [test] ignore legacy test
This commit is contained in:
@@ -43,7 +43,7 @@ from colossalai.engine.schedule import (InterleavedPipelineSchedule,
|
||||
PipelineSchedule)
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.layer.wrapper import PipelineSharedModuleWrapper
|
||||
from colossalai.trainer import Trainer, hooks
|
||||
from colossalai.legacy.trainer import Trainer, hooks
|
||||
from colossalai.utils.timer import MultiTimer
|
||||
from model_zoo.gpt import GPTLMLoss
|
||||
from torch.nn import functional as F
|
||||
@@ -273,3 +273,4 @@ def train():
|
||||
return_output_label=False,
|
||||
)
|
||||
```
|
||||
<!-- doc-test-command: echo -->
|
||||
|
@@ -36,7 +36,7 @@ from colossalai.builder import build_pipeline_model
|
||||
from colossalai.engine.schedule import (InterleavedPipelineSchedule,
|
||||
PipelineSchedule)
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.trainer import Trainer, hooks
|
||||
from colossalai.legacy.trainer import Trainer, hooks
|
||||
from colossalai.utils import MultiTimer, get_dataloader
|
||||
from timm.models import vision_transformer as vit
|
||||
from torchvision import transforms
|
||||
@@ -244,3 +244,4 @@ def train():
|
||||
hooks=hook_list,
|
||||
display_progress=True)
|
||||
```
|
||||
<!-- doc-test-command: echo -->
|
||||
|
@@ -74,7 +74,7 @@ from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.lr_scheduler import LinearWarmupLR
|
||||
from colossalai.nn.metric import Accuracy
|
||||
from colossalai.trainer import Trainer, hooks
|
||||
from colossalai.legacy.trainer import Trainer, hooks
|
||||
```
|
||||
|
||||
- 其他模块
|
||||
@@ -589,3 +589,4 @@ torchrun --standalone --nproc_per_node <NUM_GPUs> train_hybrid.py --config ./co
|
||||
# If your torch >= 1.9.0
|
||||
# python -m torch.distributed.run --standalone --nproc_per_node= <NUM_GPUs> train_hybrid.py --config ./configs/config_hybrid_parallel.py
|
||||
```
|
||||
<!-- doc-test-command: echo -->
|
||||
|
@@ -61,7 +61,7 @@ Trainer 的参数 `schedule` 默认值是 `None` 。在大多数情况下,除
|
||||
|
||||
```python
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.trainer import Trainer, hooks
|
||||
from colossalai.legacy.trainer import Trainer, hooks
|
||||
|
||||
# build components and initialize with colossalai.initialize
|
||||
...
|
||||
@@ -104,7 +104,7 @@ trainer.fit(
|
||||
|
||||
```python
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.trainer import hooks
|
||||
from colossalai.legacy.trainer import hooks
|
||||
|
||||
class LogMessageHook(hooks.BaseHook):
|
||||
|
||||
@@ -341,7 +341,7 @@ for epoch in range(gpc.config.NUM_EPOCHS):
|
||||
|
||||
```python
|
||||
from colossalai.nn.metric import Accuracy
|
||||
from colossalai.trainer import Trainer, hooks
|
||||
from colossalai.legacy.trainer import Trainer, hooks
|
||||
|
||||
|
||||
# create a trainer object
|
||||
@@ -384,3 +384,4 @@ python -m torch.distributed.launch --nproc_per_node <num_gpus> --master_addr loc
|
||||
# with trainer
|
||||
python -m torch.distributed.launch --nproc_per_node <num_gpus> --master_addr localhost --master_port 29500 run_resnet_cifar10_with_trainer.py
|
||||
```
|
||||
<!-- doc-test-command: echo -->
|
||||
|
@@ -41,7 +41,7 @@ for epoch in range(num_epochs):
|
||||
|
||||
#### 用 trainer 保存
|
||||
```python
|
||||
from colossalai.trainer import Trainer, hooks
|
||||
from colossalai.legacy.trainer import Trainer, hooks
|
||||
model = ...
|
||||
engine, _, _, _ = colossalai.initialize(model=model, ...)
|
||||
trainer = Trainer(engine, ...)
|
||||
@@ -61,3 +61,4 @@ model = ...
|
||||
load_checkpoint('xxx.pt', model)
|
||||
... # train or test
|
||||
```
|
||||
<!-- doc-test-command: echo -->
|
||||
|
@@ -245,7 +245,7 @@ from pathlib import Path
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import get_dataloader
|
||||
from colossalai.trainer import Trainer, hooks
|
||||
from colossalai.legacy.trainer import Trainer, hooks
|
||||
from colossalai.nn.lr_scheduler import LinearWarmupLR
|
||||
from timm.models import vit_base_patch16_224
|
||||
from torchvision import datasets, transforms
|
||||
|
@@ -78,7 +78,7 @@ import colossalai.nn as col_nn
|
||||
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.trainer import Trainer, hooks
|
||||
from colossalai.legacy.trainer import Trainer, hooks
|
||||
from colossalai.utils import MultiTimer, get_dataloader
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.pipeline.pipelinable import PipelinableContext
|
||||
@@ -156,3 +156,4 @@ trainer.fit(train_dataloader=train_dataloader,
|
||||
```
|
||||
|
||||
我们使用 `2` 个流水段,并且 batch 将被切分为 `4` 个 micro batches。
|
||||
<!-- doc-test-command: echo -->
|
||||
|
Reference in New Issue
Block a user