ColossalAI/applications/Chat/coati/experience_maker/base.py
Wenhao Chen 7b9b86441f
[chat]: update rm, add wandb and fix bugs (#4471)
* feat: modify forward fn of critic and reward model

* feat: modify calc_action_log_probs

* to: add wandb in sft and rm trainer

* feat: update train_sft

* feat: update train_rm

* style: modify type annotation and add warning

* feat: pass tokenizer to ppo trainer

* to: modify trainer base and maker base

* feat: add wandb in ppo trainer

* feat: pass tokenizer to generate

* test: update generate fn tests

* test: update train tests

* fix: remove action_mask

* feat: remove unused code

* fix: fix wrong ignore_index

* fix: fix mock tokenizer

* chore: update requirements

* revert: modify make_experience

* fix: fix inference

* fix: add padding side

* style: modify _on_learn_batch_end

* test: use mock tokenizer

* fix: use bf16 to avoid overflow

* fix: fix workflow

* [chat] fix gemini strategy

* [chat] fix

* sync: update colossalai strategy

* fix: fix args and model dtype

* fix: fix checkpoint test

* fix: fix requirements

* fix: fix missing import and wrong arg

* fix: temporarily skip gemini test in stage 3

* style: apply pre-commit

* fix: temporarily skip gemini test in stage 1&2

---------

Co-authored-by: Mingyan Jiang <1829166702@qq.com>
2023-09-20 15:53:58 +08:00

71 lines
2.3 KiB
Python

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional
import torch
from coati.models.base import Actor, Critic, RewardModel
@dataclass
class Experience:
"""Experience is a batch of data.
These data should have the sequence length and number of actions.
Left padding for sequences is applied.
Shapes of each tensor:
sequences: (B, S)
action_log_probs: (B, A)
values: (B)
reward: (B)
advantages: (B)
attention_mask: (B, S)
action_mask: (B, A)
"A" is the number of actions.
"""
sequences: torch.Tensor
action_log_probs: torch.Tensor
values: torch.Tensor
reward: torch.Tensor
advantages: torch.Tensor
attention_mask: Optional[torch.LongTensor]
action_mask: Optional[torch.BoolTensor]
@torch.no_grad()
def to_device(self, device: torch.device) -> None:
self.sequences = self.sequences.to(device)
self.action_log_probs = self.action_log_probs.to(device)
self.values = self.values.to(device)
self.reward = self.reward.to(device)
self.advantages = self.advantages.to(device)
if self.attention_mask is not None:
self.attention_mask = self.attention_mask.to(device)
if self.action_mask is not None:
self.action_mask = self.action_mask.to(device)
def pin_memory(self):
self.sequences = self.sequences.pin_memory()
self.action_log_probs = self.action_log_probs.pin_memory()
self.values = self.values.pin_memory()
self.reward = self.reward.pin_memory()
self.advantages = self.advantages.pin_memory()
if self.attention_mask is not None:
self.attention_mask = self.attention_mask.pin_memory()
if self.action_mask is not None:
self.action_mask = self.action_mask.pin_memory()
return self
class ExperienceMaker(ABC):
def __init__(self, actor: Actor, critic: Critic, reward_model: RewardModel, initial_model: Actor) -> None:
super().__init__()
self.actor = actor
self.critic = critic
self.reward_model = reward_model
self.initial_model = initial_model
@abstractmethod
def make_experience(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **generate_kwargs) -> Experience:
pass