[chat] refactor trainer (#3648)

* [chat] ppo trainer remove useless args

* [chat] update examples

* [chat] update benchmark

* [chat] update examples

* [chat] fix sft training with wandb

* [chat] polish docstr
This commit is contained in:
Hongxin Liu
2023-04-26 18:11:49 +08:00
committed by GitHub
parent f8288315d9
commit 2a951955ad
12 changed files with 72 additions and 536 deletions

View File

@@ -1,14 +1,19 @@
import torch.distributed as dist
from typing import Any, Callable, Dict, List, Optional
from coati.models.bloom import BLOOMActor, BLOOMCritic
from coati.models.gpt import GPTActor, GPTCritic
from coati.models.opt import OPTActor, OPTCritic
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
from typing import Any
import torch
import os
import torch.distributed as dist
from torch.utils._pytree import tree_map
def is_rank_0() -> bool:
return not dist.is_initialized() or dist.get_rank() == 0
def to_device(x: Any, device: torch.device) -> Any:
def _to(t: Any):
if isinstance(t, torch.Tensor):
return t.to(device)
return t
return tree_map(_to, x)