[NFC] polish applications/Chat/coati/trainer/base.py code style (#4260)

This commit is contained in:
shenggan 2023-07-18 10:59:57 +08:00 committed by binmakeswell
parent b2debdc09b
commit 798cb72907

View File

@ -25,7 +25,8 @@ class SLTrainer(ABC):
optim (Optimizer): the optimizer to use for training optim (Optimizer): the optimizer to use for training
""" """
def __init__(self, def __init__(
self,
strategy: Strategy, strategy: Strategy,
max_epochs: int, max_epochs: int,
model: nn.Module, model: nn.Module,
@ -50,10 +51,7 @@ class SLTrainer(ABC):
def fit(self, *args, **kwargs): def fit(self, *args, **kwargs):
self._before_fit(*args, **kwargs) self._before_fit(*args, **kwargs)
for epoch in tqdm.trange(self.max_epochs, for epoch in tqdm.trange(self.max_epochs, desc="Epochs", disable=not is_rank_0() or self.no_epoch_bar):
desc="Epochs",
disable=not is_rank_0() or self.no_epoch_bar
):
self._train(epoch) self._train(epoch)
self._eval(epoch) self._eval(epoch)
@ -75,8 +73,7 @@ class OnPolicyTrainer(ABC):
buffer: NaiveReplayBuffer, buffer: NaiveReplayBuffer,
sample_buffer: bool, sample_buffer: bool,
dataloader_pin_memory: bool, dataloader_pin_memory: bool,
callbacks: List[Callback] = [] callbacks: List[Callback] = []) -> None:
) -> None:
super().__init__() super().__init__()
self.strategy = strategy self.strategy = strategy
self.buffer = buffer self.buffer = buffer
@ -154,7 +151,8 @@ class OnPolicyTrainer(ABC):
self._learn(update_step) self._learn(update_step)
self._on_learn_epoch_end(update_step) self._on_learn_epoch_end(update_step)
def fit(self, def fit(
self,
prompt_dataloader: DataLoader, prompt_dataloader: DataLoader,
pretrain_dataloader: DataLoader, pretrain_dataloader: DataLoader,
num_episodes: int, num_episodes: int,
@ -175,23 +173,16 @@ class OnPolicyTrainer(ABC):
self.pretrain_dataloader = CycledDataLoader(pretrain_dataloader) self.pretrain_dataloader = CycledDataLoader(pretrain_dataloader)
with self._fit_ctx(): with self._fit_ctx():
for episode in tqdm.trange(num_episodes, for episode in tqdm.trange(num_episodes, desc="Episodes", disable=not is_rank_0()):
desc="Episodes",
disable=not is_rank_0()):
with self._episode_ctx(episode): with self._episode_ctx(episode):
for collect_step in tqdm.trange(num_collect_steps, for collect_step in tqdm.trange(num_collect_steps, desc="Collect steps", disable=not is_rank_0()):
desc="Collect steps",
disable=not is_rank_0()):
self._collect_phase(collect_step) self._collect_phase(collect_step)
if not self.sample_buffer: if not self.sample_buffer:
# HACK(cwher): according to the design of boost API, dataloader should also be boosted, # HACK(cwher): according to the design of boost API, dataloader should also be boosted,
# but it is impractical to adapt this pattern in RL training. Thus, I left dataloader unboosted. # but it is impractical to adapt this pattern in RL training. Thus, I left dataloader unboosted.
# I only call strategy.setup_dataloader() to setup dataloader. # I only call strategy.setup_dataloader() to setup dataloader.
self.dataloader = self.strategy.setup_dataloader(self.buffer, self.dataloader = self.strategy.setup_dataloader(self.buffer, self.dataloader_pin_memory)
self.dataloader_pin_memory) for update_step in tqdm.trange(num_update_steps, desc="Update steps", disable=not is_rank_0()):
for update_step in tqdm.trange(num_update_steps,
desc="Update steps",
disable=not is_rank_0()):
self._update_phase(update_step) self._update_phase(update_step)
# NOTE: this is for on-policy algorithms # NOTE: this is for on-policy algorithms
self.buffer.clear() self.buffer.clear()