mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-12-23 12:36:03 +00:00
[chat] refactor trainer (#3648)
* [chat] ppo trainer remove useless args * [chat] update examples * [chat] update benchmark * [chat] update examples * [chat] fix sft training with wandb * [chat] polish docstr
This commit is contained in:
@@ -1,14 +1,19 @@
|
||||
import torch.distributed as dist
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from coati.models.bloom import BLOOMActor, BLOOMCritic
|
||||
from coati.models.gpt import GPTActor, GPTCritic
|
||||
from coati.models.opt import OPTActor, OPTCritic
|
||||
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import os
|
||||
import torch.distributed as dist
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
|
||||
def is_rank_0() -> bool:
|
||||
return not dist.is_initialized() or dist.get_rank() == 0
|
||||
|
||||
|
||||
def to_device(x: Any, device: torch.device) -> Any:
|
||||
|
||||
def _to(t: Any):
|
||||
if isinstance(t, torch.Tensor):
|
||||
return t.to(device)
|
||||
return t
|
||||
|
||||
return tree_map(_to, x)
|
||||
|
||||
Reference in New Issue
Block a user