[chat] refactor actor class (#3968)

* refactor: separate log_probs fn from Actor forward fn

* refactor: separate generate fn from Actor class

* feat: update unwrap_model and get_base_model
* unwrap_model returns model not wrapped by Strategy
* get_base_model returns HF model for Actor, Critic and RewardModel

* feat: simplify Strategy.prepare

* style: remove get_base_model method of Actor

* perf: tokenize text in batches

* refactor: move calc_action_log_probs to utils of model

* test: update test with new forward fn

* style: rename forward fn args

* fix: do not unwrap model in save_model fn of naive strategy

* test: add gemini test for train_prompts

* fix: fix _set_default_generate_kwargs
This commit is contained in:
Wenhao Chen 2023-06-13 13:31:56 +08:00 committed by GitHub
parent b3ab7fbabf
commit 9d02590c9a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 151 additions and 120 deletions

View File

@ -35,14 +35,14 @@ class PromptDataset(Dataset):
logger.info(f"Limiting dataset to {max_datasets_size} examples.") logger.info(f"Limiting dataset to {max_datasets_size} examples.")
list_data_dict = list_data_dict[:max_datasets_size] list_data_dict = list_data_dict[:max_datasets_size]
for data_dict in list_data_dict: instructions = [data_dict["instruction"] for data_dict in list_data_dict]
token = tokenizer(data_dict["instruction"], tokens = tokenizer(instructions,
return_tensors='pt', return_tensors='pt',
max_length=max_length, max_length=max_length,
padding='max_length', padding='max_length',
truncation=True) truncation=True)
for k, tensor in token.items(): for k, tensor in tokens.items():
self.keyed_prompt[k].extend(tensor.to(torch.cuda.current_device()).unbind()) self.keyed_prompt[k] = tensor.to(torch.cuda.current_device()).unbind()
def __len__(self): def __len__(self):
return len(self.keyed_prompt["input_ids"]) return len(self.keyed_prompt["input_ids"])

View File

@ -74,21 +74,18 @@ class SFTDataset(Dataset):
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx]) return dict(input_ids=self.input_ids[idx], labels=self.labels[idx])
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, max_length: int) -> Dict: def _tokenize_fn(strings: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer,
max_length: int
) -> Dict[str, torch.Tensor]:
"""Tokenize a list of strings.""" """Tokenize a list of strings."""
tokenized_list = [ tokenized_list = tokenizer(
tokenizer( strings, return_tensors="pt", padding="longest",
text, max_length=max_length, truncation=True
return_tensors="pt", )
padding="longest", input_ids = labels = tokenized_list["input_ids"]
max_length=max_length, input_ids_lens = labels_lens = \
truncation=True, tokenized_list["input_ids"].ne(tokenizer.pad_token_id).sum(dim=-1)
) for text in strings
]
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
input_ids_lens = labels_lens = [
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
]
return dict( return dict(
input_ids=input_ids, input_ids=input_ids,
labels=labels, labels=labels,
@ -105,7 +102,10 @@ def preprocess(
) -> Dict: ) -> Dict:
"""Preprocess the data by tokenizing.""" """Preprocess the data by tokenizing."""
examples = [s + t for s, t in zip(sources, targets)] examples = [s + t for s, t in zip(sources, targets)]
examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer, max_length) for strings in (examples, sources)] examples_tokenized, sources_tokenized = [
_tokenize_fn(strings, tokenizer, max_length)
for strings in (examples, sources)
]
input_ids = examples_tokenized["input_ids"] input_ids = examples_tokenized["input_ids"]
labels = copy.deepcopy(input_ids) labels = copy.deepcopy(input_ids)
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]): for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):

View File

@ -1,5 +1,6 @@
import torch import torch
from coati.models.utils import compute_reward, normalize from coati.models.generation import generate_with_actor
from coati.models.utils import calc_action_log_probs, compute_reward, normalize
from .base import Experience, ExperienceMaker from .base import Experience, ExperienceMaker
@ -16,13 +17,16 @@ class NaiveExperienceMaker(ExperienceMaker):
self.initial_model.eval() self.initial_model.eval()
self.reward_model.eval() self.reward_model.eval()
sequences, attention_mask, action_mask = self.actor.generate(input_ids, sequences, attention_mask, action_mask = generate_with_actor(self.actor,
input_ids,
return_action_mask=True, return_action_mask=True,
**generate_kwargs) **generate_kwargs)
num_actions = action_mask.size(1) num_actions = action_mask.size(1)
action_log_probs = self.actor(sequences, num_actions, attention_mask) actor_output = self.actor(sequences, attention_mask)
base_action_log_probs = self.initial_model(sequences, num_actions, attention_mask) action_log_probs = calc_action_log_probs(actor_output, sequences, num_actions)
base_model_output = self.initial_model(sequences, attention_mask)
base_action_log_probs = calc_action_log_probs(base_model_output, sequences, num_actions)
value = self.critic(sequences, action_mask, attention_mask) value = self.critic(sequences, action_mask, attention_mask)
r = self.reward_model(sequences, attention_mask) r = self.reward_model(sequences, attention_mask)
reward = compute_reward(r, self.kl_coef, action_log_probs, base_action_log_probs, action_mask=action_mask) reward = compute_reward(r, self.kl_coef, action_log_probs, base_action_log_probs, action_mask=action_mask)

View File

@ -1,3 +1,5 @@
from typing import Union
import torch.nn as nn import torch.nn as nn
from .actor import Actor from .actor import Actor
@ -5,10 +7,10 @@ from .critic import Critic
from .reward_model import RewardModel from .reward_model import RewardModel
def get_base_model(model: nn.Module) -> nn.Module: def get_base_model(model: Union[Actor, Critic, RewardModel]) -> nn.Module:
"""Get the base model of our wrapper classes. """Get the base model of our wrapper classes.
For Actor, it's base model is ``actor.model`` and it's usually a ``transformers.PreTrainedModel``. For Actor, Critic and RewardModel, return ``model.model``,
For Critic and RewardModel, it's base model is itself. it's usually a ``transformers.PreTrainedModel``.
Args: Args:
model (nn.Module): model to get base model from model (nn.Module): model to get base model from
@ -16,9 +18,9 @@ def get_base_model(model: nn.Module) -> nn.Module:
Returns: Returns:
nn.Module: the base model nn.Module: the base model
""" """
if isinstance(model, Actor): assert isinstance(model, (Actor, Critic, RewardModel)), \
return model.get_base_model() f'Expect Actor, Critic or RewardModel, got {type(model)}, use unwrap_model first.'
return model return model.model
__all__ = ['Actor', 'Critic', 'RewardModel', 'get_base_model'] __all__ = ['Actor', 'Critic', 'RewardModel', 'get_base_model']

View File

@ -1,12 +1,9 @@
from typing import Optional, Tuple, Union from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from ..generation import generate
from ..lora import LoRAModule from ..lora import LoRAModule
from ..utils import log_probs_from_logits
class Actor(LoRAModule): class Actor(LoRAModule):
@ -24,42 +21,16 @@ class Actor(LoRAModule):
self.model = model self.model = model
self.convert_to_lora() self.convert_to_lora()
@torch.no_grad()
def generate(
self,
input_ids: torch.Tensor,
return_action_mask: bool = True,
**kwargs
) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
sequences = generate(self.model, input_ids, **kwargs)
attention_mask = None
pad_token_id = kwargs.get('pad_token_id', None)
if pad_token_id is not None:
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
if not return_action_mask:
return sequences, attention_mask, None
input_len = input_ids.size(1)
eos_token_id = kwargs.get('eos_token_id', None)
if eos_token_id is None:
action_mask = torch.ones_like(sequences, dtype=torch.bool)
else:
# left padding may be applied, only mask action
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
action_mask[:, :input_len] = False
action_mask = action_mask[:, 1:]
return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len):]
def forward(self, def forward(self,
sequences: torch.LongTensor, input_ids: torch.LongTensor,
num_actions: int, attention_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: **model_kwargs, # HACK: `generate` method may pass more kwargs
"""Returns action log probs ) -> torch.Tensor:
"""Returns model output.
""" """
output = self.model(sequences, attention_mask=attention_mask) output = self.model(
logits = output['logits'] input_ids,
log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:]) attention_mask=attention_mask,
return log_probs[:, -num_actions:] **model_kwargs
)
def get_base_model(self): return output
return self.model

View File

@ -1,8 +1,10 @@
from typing import Any, Callable, Optional from typing import Any, Callable, Optional, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
try: try:
from transformers.generation_logits_process import ( from transformers.generation_logits_process import (
@ -55,9 +57,8 @@ def sample(model: nn.Module,
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
for _ in range(input_ids.size(1), max_length): for _ in range(input_ids.size(1), max_length):
model_inputs = prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else { model_inputs = prepare_inputs_fn(input_ids, **model_kwargs) \
'input_ids': input_ids if prepare_inputs_fn is not None else {'input_ids': input_ids}
}
outputs = model(**model_inputs) outputs = model(**model_inputs)
next_token_logits = outputs['logits'][:, -1, :] next_token_logits = outputs['logits'][:, -1, :]
@ -144,3 +145,35 @@ def generate(model: nn.Module,
raise NotImplementedError raise NotImplementedError
else: else:
raise ValueError("Unsupported generation mode") raise ValueError("Unsupported generation mode")
@torch.no_grad()
def generate_with_actor(actor_model: nn.Module,
input_ids: torch.Tensor,
return_action_mask: bool = True,
**kwargs
) -> Union[Tuple[torch.LongTensor, torch.LongTensor],
Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
"""Generate token sequence with actor model. Refer to `generate` for more details.
"""
# generate sequences
sequences = generate(actor_model, input_ids, **kwargs)
# calculate auxiliary tensors
attention_mask = None
pad_token_id = kwargs.get('pad_token_id', None)
if pad_token_id is not None:
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
if not return_action_mask:
return sequences, attention_mask, None
input_len = input_ids.size(1)
eos_token_id = kwargs.get('eos_token_id', None)
if eos_token_id is None:
action_mask = torch.ones_like(sequences, dtype=torch.bool)
else:
# left padding may be applied, only mask action
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
action_mask[:, :input_len] = False
action_mask = action_mask[:, 1:]
return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len):]

View File

@ -46,6 +46,25 @@ def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.T
return log_probs_labels.squeeze(-1) return log_probs_labels.squeeze(-1)
def calc_action_log_probs(output: torch.Tensor,
sequences: torch.LongTensor,
num_actions: int
) -> torch.Tensor:
"""Calculate action log probs.
Args:
output (torch.Tensor): Output tensor of Actor.forward.
sequences (torch.LongTensor): Input sequences.
num_actions (int): Number of actions.
Returns:
torch.Tensor: Action log probs.
"""
logits = output['logits']
log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
return log_probs[:, -num_actions:]
def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor: def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor:
tensor = tensor * mask tensor = tensor * mask
tensor = tensor.sum(dim=dim) tensor = tensor.sum(dim=dim)

View File

@ -3,8 +3,9 @@ from typing import Any, Callable, Dict, List, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from coati.experience_maker import Experience, NaiveExperienceMaker from coati.experience_maker import Experience, NaiveExperienceMaker
from coati.models.base import Actor, Critic from coati.models.base import Actor, Critic, get_base_model
from coati.models.loss import GPTLMLoss, PolicyLoss, ValueLoss from coati.models.loss import GPTLMLoss, PolicyLoss, ValueLoss
from coati.models.utils import calc_action_log_probs
from coati.replay_buffer import NaiveReplayBuffer from coati.replay_buffer import NaiveReplayBuffer
from torch import Tensor from torch import Tensor
from torch.optim import Optimizer from torch.optim import Optimizer
@ -165,7 +166,8 @@ class PPOTrainer(Trainer):
self.critic.train() self.critic.train()
# policy loss # policy loss
num_actions = experience.action_mask.size(1) num_actions = experience.action_mask.size(1)
action_log_probs = self.actor(experience.sequences, num_actions, attention_mask=experience.attention_mask) actor_output = self.actor(experience.sequences, attention_mask=experience.attention_mask)
action_log_probs = calc_action_log_probs(actor_output, experience.sequences, num_actions)
actor_loss = self.actor_loss_fn(action_log_probs, actor_loss = self.actor_loss_fn(action_log_probs,
experience.action_log_probs, experience.action_log_probs,
experience.advantages, experience.advantages,
@ -175,7 +177,7 @@ class PPOTrainer(Trainer):
if self.ptx_coef != 0: if self.ptx_coef != 0:
batch = next(iter(self.pretrain_dataloader)) batch = next(iter(self.pretrain_dataloader))
batch = to_device(batch, self.device) batch = to_device(batch, self.device)
ptx_log_probs = self.actor.get_base_model()(batch['input_ids'], ptx_log_probs = self.actor(batch['input_ids'],
attention_mask=batch['attention_mask'])['logits'] attention_mask=batch['attention_mask'])['logits']
ptx_loss = self.ptx_loss_fn(ptx_log_probs, batch['labels']) ptx_loss = self.ptx_loss_fn(ptx_log_probs, batch['labels'])
actor_loss = ptx_loss * self.ptx_coef + actor_loss * (1 - self.ptx_coef) actor_loss = ptx_loss * self.ptx_coef + actor_loss * (1 - self.ptx_coef)
@ -200,14 +202,15 @@ class PPOTrainer(Trainer):
return {'reward': experience.reward.mean().item()} return {'reward': experience.reward.mean().item()}
def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> None: def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> Dict:
origin_model = strategy.unwrap_model(actor) unwrapper_model = strategy.unwrap_model(actor)
hf_model = get_base_model(unwrapper_model)
new_kwargs = {**generate_kwargs} new_kwargs = {**generate_kwargs}
# use huggingface models method directly # use huggingface models method directly
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'): if 'prepare_inputs_fn' not in generate_kwargs and hasattr(hf_model, 'prepare_inputs_for_generation'):
new_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation new_kwargs['prepare_inputs_fn'] = hf_model.prepare_inputs_for_generation
if 'update_model_kwargs_fn' not in generate_kwargs and hasattr(origin_model, '_update_model_kwargs_for_generation'): if 'update_model_kwargs_fn' not in generate_kwargs and hasattr(hf_model, '_update_model_kwargs_for_generation'):
new_kwargs['update_model_kwargs_fn'] = origin_model._update_model_kwargs_for_generation new_kwargs['update_model_kwargs_fn'] = hf_model._update_model_kwargs_for_generation
return new_kwargs return new_kwargs

View File

@ -4,7 +4,6 @@ from typing import Any, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from coati.models.base import Actor, get_base_model
from coati.replay_buffer import ReplayBuffer from coati.replay_buffer import ReplayBuffer
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -69,21 +68,16 @@ class Strategy(ABC):
Union[List[ModelOrModelOptimPair], ModelOrModelOptimPair]: Models or model-optimizer-pairs in the original order. Union[List[ModelOrModelOptimPair], ModelOrModelOptimPair]: Models or model-optimizer-pairs in the original order.
""" """
def prepare_model(model: nn.Module):
if isinstance(model, Actor):
return Actor(self.setup_model(model.get_base_model()))
return self.setup_model(model)
rets = [] rets = []
for arg in models_or_model_optim_pairs: for arg in models_or_model_optim_pairs:
if isinstance(arg, tuple): if isinstance(arg, tuple):
assert len(arg) == 2, f'Expect (model, optimizer) pair, got a tuple with size "{len(arg)}"' assert len(arg) == 2, f'Expect (model, optimizer) pair, got a tuple with size "{len(arg)}"'
model, optimizer = arg model, optimizer = arg
model = prepare_model(model) model = self.setup_model(model)
optimizer = self.setup_optimizer(optimizer, get_base_model(model)) optimizer = self.setup_optimizer(optimizer, model)
rets.append((model, optimizer)) rets.append((model, optimizer))
elif isinstance(arg, nn.Module): elif isinstance(arg, nn.Module):
rets.append(prepare_model(arg)) rets.append(self.setup_model(model))
else: else:
raise RuntimeError(f'Expect model or (model, optimizer) pair, got {type(arg)}') raise RuntimeError(f'Expect model or (model, optimizer) pair, got {type(arg)}')
@ -93,16 +87,15 @@ class Strategy(ABC):
@staticmethod @staticmethod
def unwrap_model(model: nn.Module) -> nn.Module: def unwrap_model(model: nn.Module) -> nn.Module:
"""Get the unwrapped model from a wrapped model. Useful for getting original huggingface model. """Get the unwrapped model from a wrapped model made by Strategy.prepare.
For Actor, it will unwrap `actor.model`.
Args: Args:
model (nn.Module): the model to unwrap model (nn.Module): the model to unwrap
Returns: Returns:
nn.Module: the original model (usually a huggingface model) nn.Module: the original model
""" """
return get_base_model(model) return model
@abstractmethod @abstractmethod
def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None: def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None:

View File

@ -5,7 +5,6 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from coati.models.base import get_base_model
from torch.optim import Optimizer from torch.optim import Optimizer
from transformers.tokenization_utils_base import PreTrainedTokenizerBase from transformers.tokenization_utils_base import PreTrainedTokenizerBase
@ -153,14 +152,13 @@ class ColossalAIStrategy(DDPStrategy):
def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None: def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None:
if only_rank0 and dist.get_rank() != 0 and self.stage != 3: if only_rank0 and dist.get_rank() != 0 and self.stage != 3:
return return
base_model = get_base_model(model)
if self.stage == 3: if self.stage == 3:
assert isinstance(base_model, ZeroDDP) assert isinstance(model, ZeroDDP)
# for stage 3, state_dict() method should be called on every rank # for stage 3, state_dict() method should be called on every rank
state_dict = base_model.state_dict(only_rank_0=only_rank0) state_dict = model.state_dict(only_rank_0=only_rank0)
else: else:
# only_rank0 is false or rank == 0 # only_rank0 is false or rank == 0
state_dict = base_model.state_dict() state_dict = model.state_dict()
if only_rank0 and dist.get_rank() != 0: if only_rank0 and dist.get_rank() != 0:
return return
torch.save(state_dict, path) torch.save(state_dict, path)
@ -172,11 +170,10 @@ class ColossalAIStrategy(DDPStrategy):
torch.save(optimizer.state_dict(), path) torch.save(optimizer.state_dict(), path)
def unwrap_model(self, model: nn.Module) -> nn.Module: def unwrap_model(self, model: nn.Module) -> nn.Module:
base_model: Union[nn.Module, ZeroDDP] = get_base_model(model)
if self.stage == 3: if self.stage == 3:
assert isinstance(base_model, ZeroDDP) assert isinstance(model, ZeroDDP)
return base_model.module return model.module
return base_model return model
def save_pretrained(self, def save_pretrained(self,
model: nn.Module, model: nn.Module,
@ -196,5 +193,5 @@ class ColossalAIStrategy(DDPStrategy):
# if isinstance(module, LoraLinear): # if isinstance(module, LoraLinear):
# module.merge_weights = True # module.merge_weights = True
# module.eval() # module.eval()
base_model: ZeroDDP = get_base_model(model) assert isinstance(model, ZeroDDP)
yield from base_model.state_dict_shard(max_shard_size=1024, only_rank_0=False) yield from model.state_dict_shard(max_shard_size=1024, only_rank_0=False)

View File

@ -69,8 +69,8 @@ class DDPStrategy(NaiveStrategy):
return DistributedSampler(dataset, dist.get_world_size(), dist.get_rank()) return DistributedSampler(dataset, dist.get_world_size(), dist.get_rank())
def unwrap_model(self, model: nn.Module) -> nn.Module: def unwrap_model(self, model: nn.Module) -> nn.Module:
base_model: DDP = super().unwrap_model(model) assert isinstance(model, DDP)
return base_model.module return model.module
def save_pretrained(self, def save_pretrained(self,
model: nn.Module, model: nn.Module,

View File

@ -58,14 +58,13 @@ class NaiveStrategy(Strategy):
collate_fn=replay_buffer.collate_fn) collate_fn=replay_buffer.collate_fn)
def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None: def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None:
base_model = get_base_model(model) state_dict = model.state_dict()
state_dict = base_model.state_dict()
torch.save(state_dict, path) torch.save(state_dict, path)
def load_model(self, model: nn.Module, path: str, map_location: Any = None, strict: bool = True) -> None: def load_model(self, model: nn.Module, path: str, map_location: Any = None, strict: bool = True) -> None:
base_model = get_base_model(model) unwrapped_model = self.unwrap_model(model)
state_dict = torch.load(path, map_location=map_location) state_dict = torch.load(path, map_location=map_location)
base_model.load_state_dict(state_dict, strict=strict) unwrapped_model.load_state_dict(state_dict, strict=strict)
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None: def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
torch.save(optimizer.state_dict(), path) torch.save(optimizer.state_dict(), path)

View File

@ -121,6 +121,14 @@ torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py --prompt_datas
--rm_pretrain 'gpt2' \ --rm_pretrain 'gpt2' \
--rm_path ${BASE}/rm_ckpt_gpt.pt \ --rm_path ${BASE}/rm_ckpt_gpt.pt \
--save_path ${BASE}/actor_checkpoint_prompts.pt --save_path ${BASE}/actor_checkpoint_prompts.pt
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py --prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
--strategy colossalai_gemini --num_episodes 1 --max_timesteps 2 \
--update_timesteps 2 --max_epochs 1 --train_batch_size 2 \
--pretrain 'gpt2' --model gpt2 \
--rm_pretrain 'gpt2' \
--rm_path ${BASE}/rm_ckpt_gpt.pt \
--save_path ${BASE}/actor_checkpoint_prompts.pt
rm -rf ${BASE}/rm_ckpt_gpt.pt rm -rf ${BASE}/rm_ckpt_gpt.pt
rm -rf ${BASE}/actor_checkpoint_prompts.pt rm -rf ${BASE}/actor_checkpoint_prompts.pt

View File

@ -6,6 +6,7 @@ import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from coati.models.gpt import GPTActor from coati.models.gpt import GPTActor
from coati.models.utils import calc_action_log_probs
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy
from transformers.models.gpt2.configuration_gpt2 import GPT2Config from transformers.models.gpt2.configuration_gpt2 import GPT2Config
@ -43,7 +44,8 @@ def run_test_checkpoint(strategy):
def run_step(): def run_step():
data = get_data(BATCH_SIZE) data = get_data(BATCH_SIZE)
action_mask = torch.ones_like(data['attention_mask'], dtype=torch.bool) action_mask = torch.ones_like(data['attention_mask'], dtype=torch.bool)
action_log_probs = actor(data['input_ids'], action_mask.size(1), data['attention_mask']) actor_output = actor(data['input_ids'], data['attention_mask'])
action_log_probs = calc_action_log_probs(actor_output, data['input_ids'], action_mask.size(1))
loss = action_log_probs.sum() loss = action_log_probs.sum()
strategy.backward(loss, actor, actor_optim) strategy.backward(loss, actor, actor_optim)
strategy.optimizer_step(actor_optim) strategy.optimizer_step(actor_optim)