[chatgpt]add flag of action mask in critic(#3086)

This commit is contained in:
Fazzie-Maqianli
2023-03-10 14:40:14 +08:00
committed by GitHub
parent 95a36eae63
commit 02ae80bf9c
5 changed files with 21 additions and 14 deletions

View File

@@ -37,7 +37,7 @@ class Actor(LoRAModule):
if pad_token_id is not None: if pad_token_id is not None:
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device) attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
if not return_action_mask: if not return_action_mask:
return sequences, attention_mask return sequences, attention_mask, None
input_len = input_ids.size(1) input_len = input_ids.size(1)
eos_token_id = kwargs.get('eos_token_id', None) eos_token_id = kwargs.get('eos_token_id', None)
if eos_token_id is None: if eos_token_id is None:

View File

@@ -18,15 +18,19 @@ class Critic(LoRAModule):
lora_train_bias (str): LoRA bias training mode. lora_train_bias (str): LoRA bias training mode.
""" """
def __init__(self, def __init__(
model: nn.Module, self,
value_head: nn.Module, model: nn.Module,
lora_rank: int = 0, value_head: nn.Module,
lora_train_bias: str = 'none') -> None: lora_rank: int = 0,
lora_train_bias: str = 'none',
use_action_mask: bool = False,
) -> None:
super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias) super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
self.model = model self.model = model
self.value_head = value_head self.value_head = value_head
self.use_action_mask = use_action_mask
self.convert_to_lora() self.convert_to_lora()
def forward(self, def forward(self,
@@ -38,7 +42,7 @@ class Critic(LoRAModule):
values = self.value_head(last_hidden_states).squeeze(-1) values = self.value_head(last_hidden_states).squeeze(-1)
if action_mask is not None: if action_mask is not None and self.use_action_mask:
num_actions = action_mask.size(1) num_actions = action_mask.size(1)
prompt_mask = attention_mask[:, :-num_actions] prompt_mask = attention_mask[:, :-num_actions]
values = values[:, :-num_actions] values = values[:, :-num_actions]
@@ -46,5 +50,5 @@ class Critic(LoRAModule):
return value return value
values = values[:, :-1] values = values[:, :-1]
value = values.mean(dim=1).squeeze(1) value = values.mean(dim=1)
return value return value

View File

@@ -24,7 +24,8 @@ class BLOOMCritic(Critic):
config: Optional[BloomConfig] = None, config: Optional[BloomConfig] = None,
checkpoint: bool = False, checkpoint: bool = False,
lora_rank: int = 0, lora_rank: int = 0,
lora_train_bias: str = 'none') -> None: lora_train_bias: str = 'none',
**kwargs) -> None:
if pretrained is not None: if pretrained is not None:
model = BloomModel.from_pretrained(pretrained) model = BloomModel.from_pretrained(pretrained)
elif config is not None: elif config is not None:
@@ -34,4 +35,4 @@ class BLOOMCritic(Critic):
if checkpoint: if checkpoint:
model.gradient_checkpointing_enable() model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.hidden_size, 1) value_head = nn.Linear(model.config.hidden_size, 1)
super().__init__(model, value_head, lora_rank, lora_train_bias) super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)

View File

@@ -20,7 +20,8 @@ class GPTCritic(Critic):
def __init__(self, def __init__(self,
pretrained: Optional[str] = None, pretrained: Optional[str] = None,
config: Optional[GPT2Config] = None, config: Optional[GPT2Config] = None,
checkpoint: bool = False) -> None: checkpoint: bool = False,
**kwargs) -> None:
if pretrained is not None: if pretrained is not None:
model = GPT2Model.from_pretrained(pretrained) model = GPT2Model.from_pretrained(pretrained)
elif config is not None: elif config is not None:
@@ -30,4 +31,4 @@ class GPTCritic(Critic):
if checkpoint: if checkpoint:
model.gradient_checkpointing_enable() model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.n_embd, 1) value_head = nn.Linear(model.config.n_embd, 1)
super().__init__(model, value_head) super().__init__(model, value_head, **kwargs)

View File

@@ -24,7 +24,8 @@ class OPTCritic(Critic):
config: Optional[OPTConfig] = None, config: Optional[OPTConfig] = None,
checkpoint: bool = False, checkpoint: bool = False,
lora_rank: int = 0, lora_rank: int = 0,
lora_train_bias: str = 'none') -> None: lora_train_bias: str = 'none',
**kargs) -> None:
if pretrained is not None: if pretrained is not None:
model = OPTModel.from_pretrained(pretrained) model = OPTModel.from_pretrained(pretrained)
elif config is not None: elif config is not None:
@@ -34,4 +35,4 @@ class OPTCritic(Critic):
if checkpoint: if checkpoint:
model.gradient_checkpointing_enable() model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.hidden_size, 1) value_head = nn.Linear(model.config.hidden_size, 1)
super().__init__(model, value_head, lora_rank, lora_train_bias) super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)