ColossalAI/applications/Chat/coati/trainer/strategies/ddp.py
Hongxin Liu b5f0566363
[chat] add distributed PPO trainer (#3740)
* Detached ppo (#9)

* run the base

* working on dist ppo

* sync

* detached trainer

* update detached trainer. no maker update function

* facing init problem

* 1 maker 1 trainer detached run. but no model update

* facing cuda problem

* fix save functions

* verified maker update

* nothing

* add ignore

* analyize loss issue

* remove some debug codes

* facing 2m1t stuck issue

* 2m1t verified

* do not use torchrun

* working on 2m2t

* working on 2m2t

* initialize strategy in ray actor env

* facing actor's init order issue

* facing ddp model update issue (need unwarp ddp)

* unwrap ddp actor

* checking 1m2t stuck problem

* nothing

* set timeout for trainer choosing. It solves the stuck problem!

* delete some debug output

* rename to sync with upstream

* rename to sync with upstream

* coati rename

* nothing

* I am going to detach the replaybuffer from trainer and make it a Ray Actor. Two benefits: 1. support TP trainer. 2. asynchronized buffer operations

* experience_maker_holder performs target-revolving _send_experience() instead of length comparison.

* move code to ray subfolder

* working on pipeline inference

* apply comments

* working on pipeline strategy. in progress.

* remove pipeline code. clean this branch

* update remote parameters by state_dict. no test

* nothing

* state_dict sharding transfer

* merge debug branch

* gemini _unwrap_model fix

* simplify code

* simplify code & fix LoRALinear AttributeError

* critic unwrapped state_dict

---------

Co-authored-by: csric <richcsr256@gmail.com>

* [chat] add perfomance evaluator and fix bugs (#10)

* [chat] add performance evaluator for ray

* [chat] refactor debug arg

* [chat] support hf config

* [chat] fix generation

* [chat] add 1mmt dummy example

* [chat] fix gemini ckpt

* split experience to send (#11)

Co-authored-by: csric <richcsr256@gmail.com>

* [chat] refactor trainer and maker (#12)

* [chat] refactor experience maker holder

* [chat] refactor model init

* [chat] refactor trainer args

* [chat] refactor model init

* [chat] refactor trainer

* [chat] refactor experience sending logic and training loop args (#13)

* [chat] refactor experience send logic

* [chat] refactor trainer

* [chat] refactor trainer

* [chat] refactor experience maker

* [chat] refactor pbar

* [chat] refactor example folder (#14)

* [chat] support quant (#15)

* [chat] add quant

* [chat] add quant example

* prompt example (#16)

* prompt example

* prompt load csv data

* remove legacy try

---------

Co-authored-by: csric <richcsr256@gmail.com>

* [chat] add mmmt dummy example and refactor experience sending (#17)

* [chat] add mmmt dummy example

* [chat] refactor naive strategy

* [chat] fix struck problem

* [chat] fix naive strategy

* [chat] optimize experience maker sending logic

* [chat] refactor sending assignment

* [chat] refactor performance evaluator (#18)

* Prompt Example & requires_grad state_dict & sharding state_dict (#19)

* prompt example

* prompt load csv data

* remove legacy try

* maker models require_grad set to False

* working on zero redundancy update

* mmmt_prompt example; naive strategy requires_grad state_dict & sharding; maker model requires_no_grad.

* remove legacy examples

* remove legacy examples

* remove replay buffer tp state. bad design

---------

Co-authored-by: csric <richcsr256@gmail.com>

* state_dict sending adapts to new unwrap function (#20)

* prompt example

* prompt load csv data

* remove legacy try

* maker models require_grad set to False

* working on zero redundancy update

* mmmt_prompt example; naive strategy requires_grad state_dict & sharding; maker model requires_no_grad.

* remove legacy examples

* remove legacy examples

* remove replay buffer tp state. bad design

* opt benchmark

* better script

* nothing

* [chat] strategy refactor unwrap model

* [chat] strategy refactor save model

* [chat] add docstr

* [chat] refactor trainer save model

* [chat] fix strategy typing

* [chat] refactor trainer save model

* [chat] update readme

* [chat] fix unit test

* working on lora reconstruction

* state_dict sending adapts to new unwrap function

* remove comments

---------

Co-authored-by: csric <richcsr256@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>

* [chat-ray] add readme (#21)

* add readme

* transparent graph

* add note background

---------

Co-authored-by: csric <richcsr256@gmail.com>

* [chat] get images from url (#22)

* Refactor/chat ray (#23)

* [chat] lora add todo

* [chat] remove unused pipeline strategy

* [chat] refactor example structure

* [chat] setup ci for ray

* [chat-ray] Support LoRA trainer. LoRA weights reconstruction. (#24)

* lora support prototype

* lora support

* 1mmt lora & remove useless code

---------

Co-authored-by: csric <richcsr256@gmail.com>

* [chat] fix test ci for ray

* [chat] fix test ci requirements for ray

* [chat] fix ray runtime env

* [chat] fix ray runtime env

* [chat] fix example ci docker args

* [chat] add debug info in trainer

* [chat] add nccl debug info

* [chat] skip ray test

* [doc] fix typo

---------

Co-authored-by: csric <59389055+CsRic@users.noreply.github.com>
Co-authored-by: csric <richcsr256@gmail.com>
2023-06-07 10:41:16 +08:00

83 lines
2.9 KiB
Python

import os
import random
from typing import Optional
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from coati.replay_buffer import ReplayBuffer
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from .naive import NaiveStrategy
from .sampler import DistributedSampler
class DDPStrategy(NaiveStrategy):
"""
Strategy for distributed training using torch.distributed.
"""
def __init__(self, seed: int = 42) -> None:
self.seed = seed
super().__init__()
def setup_distributed(self) -> None:
self._try_init_dist(force=True)
self.set_seed(self.seed)
def set_seed(self, seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
def setup_model(self, model: nn.Module) -> nn.Module:
device = torch.cuda.current_device()
return DDP(model, device_ids=[device])
def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader:
# DDP only mode, replay buffers on each rank are different.
# sampler = DistributedSampler(replay_buffer,
# num_replicas=dist.get_world_size(),
# rank=dist.get_rank(),
# shuffle=True,
# seed=self.seed,
# drop_last=True)
return DataLoader(
replay_buffer,
batch_size=replay_buffer.sample_batch_size,
# sampler=sampler,
shuffle=True,
drop_last=True,
pin_memory=pin_memory,
collate_fn=replay_buffer.collate_fn)
def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None:
if only_rank0 and dist.get_rank() != 0:
return
super().save_model(model, path, only_rank0)
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
if only_rank0 and dist.get_rank() != 0:
return
super().save_optimizer(optimizer, path, only_rank0)
def setup_sampler(self, dataset) -> DistributedSampler:
return DistributedSampler(dataset, dist.get_world_size(), dist.get_rank())
def unwrap_model(self, model: nn.Module) -> nn.Module:
base_model: DDP = super().unwrap_model(model)
return base_model.module
def save_pretrained(self,
model: nn.Module,
path: str,
only_rank0: bool = True,
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
if only_rank0 and dist.get_rank() != 0:
return
super().save_pretrained(model, path, only_rank0, tokenizer)