mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-30 20:55:17 +00:00
[chat] refactor actor class (#3968)
* refactor: separate log_probs fn from Actor forward fn * refactor: separate generate fn from Actor class * feat: update unwrap_model and get_base_model * unwrap_model returns model not wrapped by Strategy * get_base_model returns HF model for Actor, Critic and RewardModel * feat: simplify Strategy.prepare * style: remove get_base_model method of Actor * perf: tokenize text in batches * refactor: move calc_action_log_probs to utils of model * test: update test with new forward fn * style: rename forward fn args * fix: do not unwrap model in save_model fn of naive strategy * test: add gemini test for train_prompts * fix: fix _set_default_generate_kwargs
This commit is contained in:
parent
b3ab7fbabf
commit
9d02590c9a
@ -35,14 +35,14 @@ class PromptDataset(Dataset):
|
|||||||
logger.info(f"Limiting dataset to {max_datasets_size} examples.")
|
logger.info(f"Limiting dataset to {max_datasets_size} examples.")
|
||||||
list_data_dict = list_data_dict[:max_datasets_size]
|
list_data_dict = list_data_dict[:max_datasets_size]
|
||||||
|
|
||||||
for data_dict in list_data_dict:
|
instructions = [data_dict["instruction"] for data_dict in list_data_dict]
|
||||||
token = tokenizer(data_dict["instruction"],
|
tokens = tokenizer(instructions,
|
||||||
return_tensors='pt',
|
return_tensors='pt',
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
padding='max_length',
|
padding='max_length',
|
||||||
truncation=True)
|
truncation=True)
|
||||||
for k, tensor in token.items():
|
for k, tensor in tokens.items():
|
||||||
self.keyed_prompt[k].extend(tensor.to(torch.cuda.current_device()).unbind())
|
self.keyed_prompt[k] = tensor.to(torch.cuda.current_device()).unbind()
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.keyed_prompt["input_ids"])
|
return len(self.keyed_prompt["input_ids"])
|
||||||
|
@ -74,21 +74,18 @@ class SFTDataset(Dataset):
|
|||||||
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx])
|
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx])
|
||||||
|
|
||||||
|
|
||||||
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, max_length: int) -> Dict:
|
def _tokenize_fn(strings: Sequence[str],
|
||||||
|
tokenizer: transformers.PreTrainedTokenizer,
|
||||||
|
max_length: int
|
||||||
|
) -> Dict[str, torch.Tensor]:
|
||||||
"""Tokenize a list of strings."""
|
"""Tokenize a list of strings."""
|
||||||
tokenized_list = [
|
tokenized_list = tokenizer(
|
||||||
tokenizer(
|
strings, return_tensors="pt", padding="longest",
|
||||||
text,
|
max_length=max_length, truncation=True
|
||||||
return_tensors="pt",
|
)
|
||||||
padding="longest",
|
input_ids = labels = tokenized_list["input_ids"]
|
||||||
max_length=max_length,
|
input_ids_lens = labels_lens = \
|
||||||
truncation=True,
|
tokenized_list["input_ids"].ne(tokenizer.pad_token_id).sum(dim=-1)
|
||||||
) for text in strings
|
|
||||||
]
|
|
||||||
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
|
|
||||||
input_ids_lens = labels_lens = [
|
|
||||||
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
|
|
||||||
]
|
|
||||||
return dict(
|
return dict(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
labels=labels,
|
labels=labels,
|
||||||
@ -105,7 +102,10 @@ def preprocess(
|
|||||||
) -> Dict:
|
) -> Dict:
|
||||||
"""Preprocess the data by tokenizing."""
|
"""Preprocess the data by tokenizing."""
|
||||||
examples = [s + t for s, t in zip(sources, targets)]
|
examples = [s + t for s, t in zip(sources, targets)]
|
||||||
examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer, max_length) for strings in (examples, sources)]
|
examples_tokenized, sources_tokenized = [
|
||||||
|
_tokenize_fn(strings, tokenizer, max_length)
|
||||||
|
for strings in (examples, sources)
|
||||||
|
]
|
||||||
input_ids = examples_tokenized["input_ids"]
|
input_ids = examples_tokenized["input_ids"]
|
||||||
labels = copy.deepcopy(input_ids)
|
labels = copy.deepcopy(input_ids)
|
||||||
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
|
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
from coati.models.utils import compute_reward, normalize
|
from coati.models.generation import generate_with_actor
|
||||||
|
from coati.models.utils import calc_action_log_probs, compute_reward, normalize
|
||||||
|
|
||||||
from .base import Experience, ExperienceMaker
|
from .base import Experience, ExperienceMaker
|
||||||
|
|
||||||
@ -16,13 +17,16 @@ class NaiveExperienceMaker(ExperienceMaker):
|
|||||||
self.initial_model.eval()
|
self.initial_model.eval()
|
||||||
self.reward_model.eval()
|
self.reward_model.eval()
|
||||||
|
|
||||||
sequences, attention_mask, action_mask = self.actor.generate(input_ids,
|
sequences, attention_mask, action_mask = generate_with_actor(self.actor,
|
||||||
|
input_ids,
|
||||||
return_action_mask=True,
|
return_action_mask=True,
|
||||||
**generate_kwargs)
|
**generate_kwargs)
|
||||||
num_actions = action_mask.size(1)
|
num_actions = action_mask.size(1)
|
||||||
|
|
||||||
action_log_probs = self.actor(sequences, num_actions, attention_mask)
|
actor_output = self.actor(sequences, attention_mask)
|
||||||
base_action_log_probs = self.initial_model(sequences, num_actions, attention_mask)
|
action_log_probs = calc_action_log_probs(actor_output, sequences, num_actions)
|
||||||
|
base_model_output = self.initial_model(sequences, attention_mask)
|
||||||
|
base_action_log_probs = calc_action_log_probs(base_model_output, sequences, num_actions)
|
||||||
value = self.critic(sequences, action_mask, attention_mask)
|
value = self.critic(sequences, action_mask, attention_mask)
|
||||||
r = self.reward_model(sequences, attention_mask)
|
r = self.reward_model(sequences, attention_mask)
|
||||||
reward = compute_reward(r, self.kl_coef, action_log_probs, base_action_log_probs, action_mask=action_mask)
|
reward = compute_reward(r, self.kl_coef, action_log_probs, base_action_log_probs, action_mask=action_mask)
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from typing import Union
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from .actor import Actor
|
from .actor import Actor
|
||||||
@ -5,10 +7,10 @@ from .critic import Critic
|
|||||||
from .reward_model import RewardModel
|
from .reward_model import RewardModel
|
||||||
|
|
||||||
|
|
||||||
def get_base_model(model: nn.Module) -> nn.Module:
|
def get_base_model(model: Union[Actor, Critic, RewardModel]) -> nn.Module:
|
||||||
"""Get the base model of our wrapper classes.
|
"""Get the base model of our wrapper classes.
|
||||||
For Actor, it's base model is ``actor.model`` and it's usually a ``transformers.PreTrainedModel``.
|
For Actor, Critic and RewardModel, return ``model.model``,
|
||||||
For Critic and RewardModel, it's base model is itself.
|
it's usually a ``transformers.PreTrainedModel``.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (nn.Module): model to get base model from
|
model (nn.Module): model to get base model from
|
||||||
@ -16,9 +18,9 @@ def get_base_model(model: nn.Module) -> nn.Module:
|
|||||||
Returns:
|
Returns:
|
||||||
nn.Module: the base model
|
nn.Module: the base model
|
||||||
"""
|
"""
|
||||||
if isinstance(model, Actor):
|
assert isinstance(model, (Actor, Critic, RewardModel)), \
|
||||||
return model.get_base_model()
|
f'Expect Actor, Critic or RewardModel, got {type(model)}, use unwrap_model first.'
|
||||||
return model
|
return model.model
|
||||||
|
|
||||||
|
|
||||||
__all__ = ['Actor', 'Critic', 'RewardModel', 'get_base_model']
|
__all__ = ['Actor', 'Critic', 'RewardModel', 'get_base_model']
|
||||||
|
@ -1,12 +1,9 @@
|
|||||||
from typing import Optional, Tuple, Union
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from ..generation import generate
|
|
||||||
from ..lora import LoRAModule
|
from ..lora import LoRAModule
|
||||||
from ..utils import log_probs_from_logits
|
|
||||||
|
|
||||||
|
|
||||||
class Actor(LoRAModule):
|
class Actor(LoRAModule):
|
||||||
@ -24,42 +21,16 @@ class Actor(LoRAModule):
|
|||||||
self.model = model
|
self.model = model
|
||||||
self.convert_to_lora()
|
self.convert_to_lora()
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def generate(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
return_action_mask: bool = True,
|
|
||||||
**kwargs
|
|
||||||
) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
|
|
||||||
sequences = generate(self.model, input_ids, **kwargs)
|
|
||||||
attention_mask = None
|
|
||||||
pad_token_id = kwargs.get('pad_token_id', None)
|
|
||||||
if pad_token_id is not None:
|
|
||||||
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
|
|
||||||
if not return_action_mask:
|
|
||||||
return sequences, attention_mask, None
|
|
||||||
input_len = input_ids.size(1)
|
|
||||||
eos_token_id = kwargs.get('eos_token_id', None)
|
|
||||||
if eos_token_id is None:
|
|
||||||
action_mask = torch.ones_like(sequences, dtype=torch.bool)
|
|
||||||
else:
|
|
||||||
# left padding may be applied, only mask action
|
|
||||||
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
|
|
||||||
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
|
|
||||||
action_mask[:, :input_len] = False
|
|
||||||
action_mask = action_mask[:, 1:]
|
|
||||||
return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len):]
|
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
sequences: torch.LongTensor,
|
input_ids: torch.LongTensor,
|
||||||
num_actions: int,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
**model_kwargs, # HACK: `generate` method may pass more kwargs
|
||||||
"""Returns action log probs
|
) -> torch.Tensor:
|
||||||
|
"""Returns model output.
|
||||||
"""
|
"""
|
||||||
output = self.model(sequences, attention_mask=attention_mask)
|
output = self.model(
|
||||||
logits = output['logits']
|
input_ids,
|
||||||
log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
|
attention_mask=attention_mask,
|
||||||
return log_probs[:, -num_actions:]
|
**model_kwargs
|
||||||
|
)
|
||||||
def get_base_model(self):
|
return output
|
||||||
return self.model
|
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
from typing import Any, Callable, Optional
|
from typing import Any, Callable, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from transformers.generation_logits_process import (
|
from transformers.generation_logits_process import (
|
||||||
@ -55,9 +57,8 @@ def sample(model: nn.Module,
|
|||||||
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
||||||
|
|
||||||
for _ in range(input_ids.size(1), max_length):
|
for _ in range(input_ids.size(1), max_length):
|
||||||
model_inputs = prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else {
|
model_inputs = prepare_inputs_fn(input_ids, **model_kwargs) \
|
||||||
'input_ids': input_ids
|
if prepare_inputs_fn is not None else {'input_ids': input_ids}
|
||||||
}
|
|
||||||
outputs = model(**model_inputs)
|
outputs = model(**model_inputs)
|
||||||
|
|
||||||
next_token_logits = outputs['logits'][:, -1, :]
|
next_token_logits = outputs['logits'][:, -1, :]
|
||||||
@ -144,3 +145,35 @@ def generate(model: nn.Module,
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unsupported generation mode")
|
raise ValueError("Unsupported generation mode")
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def generate_with_actor(actor_model: nn.Module,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
return_action_mask: bool = True,
|
||||||
|
**kwargs
|
||||||
|
) -> Union[Tuple[torch.LongTensor, torch.LongTensor],
|
||||||
|
Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
|
||||||
|
"""Generate token sequence with actor model. Refer to `generate` for more details.
|
||||||
|
"""
|
||||||
|
# generate sequences
|
||||||
|
sequences = generate(actor_model, input_ids, **kwargs)
|
||||||
|
|
||||||
|
# calculate auxiliary tensors
|
||||||
|
attention_mask = None
|
||||||
|
pad_token_id = kwargs.get('pad_token_id', None)
|
||||||
|
if pad_token_id is not None:
|
||||||
|
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
|
||||||
|
if not return_action_mask:
|
||||||
|
return sequences, attention_mask, None
|
||||||
|
input_len = input_ids.size(1)
|
||||||
|
eos_token_id = kwargs.get('eos_token_id', None)
|
||||||
|
if eos_token_id is None:
|
||||||
|
action_mask = torch.ones_like(sequences, dtype=torch.bool)
|
||||||
|
else:
|
||||||
|
# left padding may be applied, only mask action
|
||||||
|
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
|
||||||
|
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
|
||||||
|
action_mask[:, :input_len] = False
|
||||||
|
action_mask = action_mask[:, 1:]
|
||||||
|
return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len):]
|
||||||
|
@ -46,6 +46,25 @@ def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.T
|
|||||||
return log_probs_labels.squeeze(-1)
|
return log_probs_labels.squeeze(-1)
|
||||||
|
|
||||||
|
|
||||||
|
def calc_action_log_probs(output: torch.Tensor,
|
||||||
|
sequences: torch.LongTensor,
|
||||||
|
num_actions: int
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Calculate action log probs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output (torch.Tensor): Output tensor of Actor.forward.
|
||||||
|
sequences (torch.LongTensor): Input sequences.
|
||||||
|
num_actions (int): Number of actions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Action log probs.
|
||||||
|
"""
|
||||||
|
logits = output['logits']
|
||||||
|
log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
|
||||||
|
return log_probs[:, -num_actions:]
|
||||||
|
|
||||||
|
|
||||||
def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor:
|
def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor:
|
||||||
tensor = tensor * mask
|
tensor = tensor * mask
|
||||||
tensor = tensor.sum(dim=dim)
|
tensor = tensor.sum(dim=dim)
|
||||||
|
@ -3,8 +3,9 @@ from typing import Any, Callable, Dict, List, Optional, Union
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from coati.experience_maker import Experience, NaiveExperienceMaker
|
from coati.experience_maker import Experience, NaiveExperienceMaker
|
||||||
from coati.models.base import Actor, Critic
|
from coati.models.base import Actor, Critic, get_base_model
|
||||||
from coati.models.loss import GPTLMLoss, PolicyLoss, ValueLoss
|
from coati.models.loss import GPTLMLoss, PolicyLoss, ValueLoss
|
||||||
|
from coati.models.utils import calc_action_log_probs
|
||||||
from coati.replay_buffer import NaiveReplayBuffer
|
from coati.replay_buffer import NaiveReplayBuffer
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
@ -165,7 +166,8 @@ class PPOTrainer(Trainer):
|
|||||||
self.critic.train()
|
self.critic.train()
|
||||||
# policy loss
|
# policy loss
|
||||||
num_actions = experience.action_mask.size(1)
|
num_actions = experience.action_mask.size(1)
|
||||||
action_log_probs = self.actor(experience.sequences, num_actions, attention_mask=experience.attention_mask)
|
actor_output = self.actor(experience.sequences, attention_mask=experience.attention_mask)
|
||||||
|
action_log_probs = calc_action_log_probs(actor_output, experience.sequences, num_actions)
|
||||||
actor_loss = self.actor_loss_fn(action_log_probs,
|
actor_loss = self.actor_loss_fn(action_log_probs,
|
||||||
experience.action_log_probs,
|
experience.action_log_probs,
|
||||||
experience.advantages,
|
experience.advantages,
|
||||||
@ -175,7 +177,7 @@ class PPOTrainer(Trainer):
|
|||||||
if self.ptx_coef != 0:
|
if self.ptx_coef != 0:
|
||||||
batch = next(iter(self.pretrain_dataloader))
|
batch = next(iter(self.pretrain_dataloader))
|
||||||
batch = to_device(batch, self.device)
|
batch = to_device(batch, self.device)
|
||||||
ptx_log_probs = self.actor.get_base_model()(batch['input_ids'],
|
ptx_log_probs = self.actor(batch['input_ids'],
|
||||||
attention_mask=batch['attention_mask'])['logits']
|
attention_mask=batch['attention_mask'])['logits']
|
||||||
ptx_loss = self.ptx_loss_fn(ptx_log_probs, batch['labels'])
|
ptx_loss = self.ptx_loss_fn(ptx_log_probs, batch['labels'])
|
||||||
actor_loss = ptx_loss * self.ptx_coef + actor_loss * (1 - self.ptx_coef)
|
actor_loss = ptx_loss * self.ptx_coef + actor_loss * (1 - self.ptx_coef)
|
||||||
@ -200,14 +202,15 @@ class PPOTrainer(Trainer):
|
|||||||
return {'reward': experience.reward.mean().item()}
|
return {'reward': experience.reward.mean().item()}
|
||||||
|
|
||||||
|
|
||||||
def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> None:
|
def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> Dict:
|
||||||
origin_model = strategy.unwrap_model(actor)
|
unwrapper_model = strategy.unwrap_model(actor)
|
||||||
|
hf_model = get_base_model(unwrapper_model)
|
||||||
new_kwargs = {**generate_kwargs}
|
new_kwargs = {**generate_kwargs}
|
||||||
# use huggingface models method directly
|
# use huggingface models method directly
|
||||||
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'):
|
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(hf_model, 'prepare_inputs_for_generation'):
|
||||||
new_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation
|
new_kwargs['prepare_inputs_fn'] = hf_model.prepare_inputs_for_generation
|
||||||
|
|
||||||
if 'update_model_kwargs_fn' not in generate_kwargs and hasattr(origin_model, '_update_model_kwargs_for_generation'):
|
if 'update_model_kwargs_fn' not in generate_kwargs and hasattr(hf_model, '_update_model_kwargs_for_generation'):
|
||||||
new_kwargs['update_model_kwargs_fn'] = origin_model._update_model_kwargs_for_generation
|
new_kwargs['update_model_kwargs_fn'] = hf_model._update_model_kwargs_for_generation
|
||||||
|
|
||||||
return new_kwargs
|
return new_kwargs
|
||||||
|
@ -4,7 +4,6 @@ from typing import Any, List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from coati.models.base import Actor, get_base_model
|
|
||||||
from coati.replay_buffer import ReplayBuffer
|
from coati.replay_buffer import ReplayBuffer
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
@ -69,21 +68,16 @@ class Strategy(ABC):
|
|||||||
Union[List[ModelOrModelOptimPair], ModelOrModelOptimPair]: Models or model-optimizer-pairs in the original order.
|
Union[List[ModelOrModelOptimPair], ModelOrModelOptimPair]: Models or model-optimizer-pairs in the original order.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def prepare_model(model: nn.Module):
|
|
||||||
if isinstance(model, Actor):
|
|
||||||
return Actor(self.setup_model(model.get_base_model()))
|
|
||||||
return self.setup_model(model)
|
|
||||||
|
|
||||||
rets = []
|
rets = []
|
||||||
for arg in models_or_model_optim_pairs:
|
for arg in models_or_model_optim_pairs:
|
||||||
if isinstance(arg, tuple):
|
if isinstance(arg, tuple):
|
||||||
assert len(arg) == 2, f'Expect (model, optimizer) pair, got a tuple with size "{len(arg)}"'
|
assert len(arg) == 2, f'Expect (model, optimizer) pair, got a tuple with size "{len(arg)}"'
|
||||||
model, optimizer = arg
|
model, optimizer = arg
|
||||||
model = prepare_model(model)
|
model = self.setup_model(model)
|
||||||
optimizer = self.setup_optimizer(optimizer, get_base_model(model))
|
optimizer = self.setup_optimizer(optimizer, model)
|
||||||
rets.append((model, optimizer))
|
rets.append((model, optimizer))
|
||||||
elif isinstance(arg, nn.Module):
|
elif isinstance(arg, nn.Module):
|
||||||
rets.append(prepare_model(arg))
|
rets.append(self.setup_model(model))
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f'Expect model or (model, optimizer) pair, got {type(arg)}')
|
raise RuntimeError(f'Expect model or (model, optimizer) pair, got {type(arg)}')
|
||||||
|
|
||||||
@ -93,16 +87,15 @@ class Strategy(ABC):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def unwrap_model(model: nn.Module) -> nn.Module:
|
def unwrap_model(model: nn.Module) -> nn.Module:
|
||||||
"""Get the unwrapped model from a wrapped model. Useful for getting original huggingface model.
|
"""Get the unwrapped model from a wrapped model made by Strategy.prepare.
|
||||||
For Actor, it will unwrap `actor.model`.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (nn.Module): the model to unwrap
|
model (nn.Module): the model to unwrap
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
nn.Module: the original model (usually a huggingface model)
|
nn.Module: the original model
|
||||||
"""
|
"""
|
||||||
return get_base_model(model)
|
return model
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None:
|
def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None:
|
||||||
|
@ -5,7 +5,6 @@ import torch
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from coati.models.base import get_base_model
|
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
||||||
|
|
||||||
@ -153,14 +152,13 @@ class ColossalAIStrategy(DDPStrategy):
|
|||||||
def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None:
|
def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None:
|
||||||
if only_rank0 and dist.get_rank() != 0 and self.stage != 3:
|
if only_rank0 and dist.get_rank() != 0 and self.stage != 3:
|
||||||
return
|
return
|
||||||
base_model = get_base_model(model)
|
|
||||||
if self.stage == 3:
|
if self.stage == 3:
|
||||||
assert isinstance(base_model, ZeroDDP)
|
assert isinstance(model, ZeroDDP)
|
||||||
# for stage 3, state_dict() method should be called on every rank
|
# for stage 3, state_dict() method should be called on every rank
|
||||||
state_dict = base_model.state_dict(only_rank_0=only_rank0)
|
state_dict = model.state_dict(only_rank_0=only_rank0)
|
||||||
else:
|
else:
|
||||||
# only_rank0 is false or rank == 0
|
# only_rank0 is false or rank == 0
|
||||||
state_dict = base_model.state_dict()
|
state_dict = model.state_dict()
|
||||||
if only_rank0 and dist.get_rank() != 0:
|
if only_rank0 and dist.get_rank() != 0:
|
||||||
return
|
return
|
||||||
torch.save(state_dict, path)
|
torch.save(state_dict, path)
|
||||||
@ -172,11 +170,10 @@ class ColossalAIStrategy(DDPStrategy):
|
|||||||
torch.save(optimizer.state_dict(), path)
|
torch.save(optimizer.state_dict(), path)
|
||||||
|
|
||||||
def unwrap_model(self, model: nn.Module) -> nn.Module:
|
def unwrap_model(self, model: nn.Module) -> nn.Module:
|
||||||
base_model: Union[nn.Module, ZeroDDP] = get_base_model(model)
|
|
||||||
if self.stage == 3:
|
if self.stage == 3:
|
||||||
assert isinstance(base_model, ZeroDDP)
|
assert isinstance(model, ZeroDDP)
|
||||||
return base_model.module
|
return model.module
|
||||||
return base_model
|
return model
|
||||||
|
|
||||||
def save_pretrained(self,
|
def save_pretrained(self,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
@ -196,5 +193,5 @@ class ColossalAIStrategy(DDPStrategy):
|
|||||||
# if isinstance(module, LoraLinear):
|
# if isinstance(module, LoraLinear):
|
||||||
# module.merge_weights = True
|
# module.merge_weights = True
|
||||||
# module.eval()
|
# module.eval()
|
||||||
base_model: ZeroDDP = get_base_model(model)
|
assert isinstance(model, ZeroDDP)
|
||||||
yield from base_model.state_dict_shard(max_shard_size=1024, only_rank_0=False)
|
yield from model.state_dict_shard(max_shard_size=1024, only_rank_0=False)
|
||||||
|
@ -69,8 +69,8 @@ class DDPStrategy(NaiveStrategy):
|
|||||||
return DistributedSampler(dataset, dist.get_world_size(), dist.get_rank())
|
return DistributedSampler(dataset, dist.get_world_size(), dist.get_rank())
|
||||||
|
|
||||||
def unwrap_model(self, model: nn.Module) -> nn.Module:
|
def unwrap_model(self, model: nn.Module) -> nn.Module:
|
||||||
base_model: DDP = super().unwrap_model(model)
|
assert isinstance(model, DDP)
|
||||||
return base_model.module
|
return model.module
|
||||||
|
|
||||||
def save_pretrained(self,
|
def save_pretrained(self,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
|
@ -58,14 +58,13 @@ class NaiveStrategy(Strategy):
|
|||||||
collate_fn=replay_buffer.collate_fn)
|
collate_fn=replay_buffer.collate_fn)
|
||||||
|
|
||||||
def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None:
|
def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None:
|
||||||
base_model = get_base_model(model)
|
state_dict = model.state_dict()
|
||||||
state_dict = base_model.state_dict()
|
|
||||||
torch.save(state_dict, path)
|
torch.save(state_dict, path)
|
||||||
|
|
||||||
def load_model(self, model: nn.Module, path: str, map_location: Any = None, strict: bool = True) -> None:
|
def load_model(self, model: nn.Module, path: str, map_location: Any = None, strict: bool = True) -> None:
|
||||||
base_model = get_base_model(model)
|
unwrapped_model = self.unwrap_model(model)
|
||||||
state_dict = torch.load(path, map_location=map_location)
|
state_dict = torch.load(path, map_location=map_location)
|
||||||
base_model.load_state_dict(state_dict, strict=strict)
|
unwrapped_model.load_state_dict(state_dict, strict=strict)
|
||||||
|
|
||||||
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
|
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
|
||||||
torch.save(optimizer.state_dict(), path)
|
torch.save(optimizer.state_dict(), path)
|
||||||
|
@ -121,6 +121,14 @@ torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py --prompt_datas
|
|||||||
--rm_pretrain 'gpt2' \
|
--rm_pretrain 'gpt2' \
|
||||||
--rm_path ${BASE}/rm_ckpt_gpt.pt \
|
--rm_path ${BASE}/rm_ckpt_gpt.pt \
|
||||||
--save_path ${BASE}/actor_checkpoint_prompts.pt
|
--save_path ${BASE}/actor_checkpoint_prompts.pt
|
||||||
|
|
||||||
|
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py --prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
|
||||||
|
--strategy colossalai_gemini --num_episodes 1 --max_timesteps 2 \
|
||||||
|
--update_timesteps 2 --max_epochs 1 --train_batch_size 2 \
|
||||||
|
--pretrain 'gpt2' --model gpt2 \
|
||||||
|
--rm_pretrain 'gpt2' \
|
||||||
|
--rm_path ${BASE}/rm_ckpt_gpt.pt \
|
||||||
|
--save_path ${BASE}/actor_checkpoint_prompts.pt
|
||||||
rm -rf ${BASE}/rm_ckpt_gpt.pt
|
rm -rf ${BASE}/rm_ckpt_gpt.pt
|
||||||
|
|
||||||
rm -rf ${BASE}/actor_checkpoint_prompts.pt
|
rm -rf ${BASE}/actor_checkpoint_prompts.pt
|
||||||
|
@ -6,6 +6,7 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from coati.models.gpt import GPTActor
|
from coati.models.gpt import GPTActor
|
||||||
|
from coati.models.utils import calc_action_log_probs
|
||||||
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy
|
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy
|
||||||
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
||||||
|
|
||||||
@ -43,7 +44,8 @@ def run_test_checkpoint(strategy):
|
|||||||
def run_step():
|
def run_step():
|
||||||
data = get_data(BATCH_SIZE)
|
data = get_data(BATCH_SIZE)
|
||||||
action_mask = torch.ones_like(data['attention_mask'], dtype=torch.bool)
|
action_mask = torch.ones_like(data['attention_mask'], dtype=torch.bool)
|
||||||
action_log_probs = actor(data['input_ids'], action_mask.size(1), data['attention_mask'])
|
actor_output = actor(data['input_ids'], data['attention_mask'])
|
||||||
|
action_log_probs = calc_action_log_probs(actor_output, data['input_ids'], action_mask.size(1))
|
||||||
loss = action_log_probs.sum()
|
loss = action_log_probs.sum()
|
||||||
strategy.backward(loss, actor, actor_optim)
|
strategy.backward(loss, actor, actor_optim)
|
||||||
strategy.optimizer_step(actor_optim)
|
strategy.optimizer_step(actor_optim)
|
||||||
|
Loading…
Reference in New Issue
Block a user