mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 14:41:53 +00:00
update markdown docs (english) (#60)
This commit is contained in:
59
README.md
59
README.md
@@ -42,21 +42,56 @@ pip install -v --no-cache-dir --global-option="--cuda_ext" .
|
||||
|
||||
```python
|
||||
import colossalai
|
||||
from colossalai.trainer import Trainer
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.utils import get_dataloader
|
||||
|
||||
engine, train_dataloader, test_dataloader = colossalai.initialize()
|
||||
|
||||
trainer = Trainer(engine=engine,
|
||||
verbose=True)
|
||||
trainer.fit(
|
||||
train_dataloader=train_dataloader,
|
||||
test_dataloader=test_dataloader,
|
||||
epochs=gpc.config.num_epochs,
|
||||
hooks_cfg=gpc.config.hooks,
|
||||
display_progress=True,
|
||||
test_interval=5
|
||||
# my_config can be path to config file or a dictionary obj
|
||||
# 'localhost' is only for single node, you need to specify
|
||||
# the node name if using multiple nodes
|
||||
colossalai.launch(
|
||||
config=my_config,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
backend='nccl',
|
||||
port=29500,
|
||||
host='localhost'
|
||||
)
|
||||
|
||||
# build your model
|
||||
model = ...
|
||||
|
||||
# build you dataset, the dataloader will have distributed data
|
||||
# sampler by default
|
||||
train_dataset = ...
|
||||
train_dataloader = get_dataloader(dataset=dataset,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
|
||||
# build your
|
||||
optimizer = ...
|
||||
|
||||
# build your loss function
|
||||
criterion = ...
|
||||
|
||||
# build your lr_scheduler
|
||||
engine, train_dataloader, _, _ = colossalai.initialize(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
train_dataloader=train_dataloader
|
||||
)
|
||||
|
||||
# start training
|
||||
engine.train()
|
||||
for epoch in range(NUM_EPOCHS):
|
||||
for data, label in train_dataloader:
|
||||
engine.zero_grad()
|
||||
output = engine(data)
|
||||
loss = engine.criterion(output, label)
|
||||
engine.backward(loss)
|
||||
engine.step()
|
||||
|
||||
```
|
||||
|
||||
### Write a Simple 2D Parallel Model
|
||||
|
Reference in New Issue
Block a user