mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[chatgpt]add flag of action mask in critic(#3086)
This commit is contained in:
@@ -37,7 +37,7 @@ class Actor(LoRAModule):
|
|||||||
if pad_token_id is not None:
|
if pad_token_id is not None:
|
||||||
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
|
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
|
||||||
if not return_action_mask:
|
if not return_action_mask:
|
||||||
return sequences, attention_mask
|
return sequences, attention_mask, None
|
||||||
input_len = input_ids.size(1)
|
input_len = input_ids.size(1)
|
||||||
eos_token_id = kwargs.get('eos_token_id', None)
|
eos_token_id = kwargs.get('eos_token_id', None)
|
||||||
if eos_token_id is None:
|
if eos_token_id is None:
|
||||||
|
@@ -18,15 +18,19 @@ class Critic(LoRAModule):
|
|||||||
lora_train_bias (str): LoRA bias training mode.
|
lora_train_bias (str): LoRA bias training mode.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
model: nn.Module,
|
self,
|
||||||
value_head: nn.Module,
|
model: nn.Module,
|
||||||
lora_rank: int = 0,
|
value_head: nn.Module,
|
||||||
lora_train_bias: str = 'none') -> None:
|
lora_rank: int = 0,
|
||||||
|
lora_train_bias: str = 'none',
|
||||||
|
use_action_mask: bool = False,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
|
super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
|
||||||
self.model = model
|
self.model = model
|
||||||
self.value_head = value_head
|
self.value_head = value_head
|
||||||
|
self.use_action_mask = use_action_mask
|
||||||
self.convert_to_lora()
|
self.convert_to_lora()
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
@@ -38,7 +42,7 @@ class Critic(LoRAModule):
|
|||||||
|
|
||||||
values = self.value_head(last_hidden_states).squeeze(-1)
|
values = self.value_head(last_hidden_states).squeeze(-1)
|
||||||
|
|
||||||
if action_mask is not None:
|
if action_mask is not None and self.use_action_mask:
|
||||||
num_actions = action_mask.size(1)
|
num_actions = action_mask.size(1)
|
||||||
prompt_mask = attention_mask[:, :-num_actions]
|
prompt_mask = attention_mask[:, :-num_actions]
|
||||||
values = values[:, :-num_actions]
|
values = values[:, :-num_actions]
|
||||||
@@ -46,5 +50,5 @@ class Critic(LoRAModule):
|
|||||||
return value
|
return value
|
||||||
|
|
||||||
values = values[:, :-1]
|
values = values[:, :-1]
|
||||||
value = values.mean(dim=1).squeeze(1)
|
value = values.mean(dim=1)
|
||||||
return value
|
return value
|
||||||
|
@@ -24,7 +24,8 @@ class BLOOMCritic(Critic):
|
|||||||
config: Optional[BloomConfig] = None,
|
config: Optional[BloomConfig] = None,
|
||||||
checkpoint: bool = False,
|
checkpoint: bool = False,
|
||||||
lora_rank: int = 0,
|
lora_rank: int = 0,
|
||||||
lora_train_bias: str = 'none') -> None:
|
lora_train_bias: str = 'none',
|
||||||
|
**kwargs) -> None:
|
||||||
if pretrained is not None:
|
if pretrained is not None:
|
||||||
model = BloomModel.from_pretrained(pretrained)
|
model = BloomModel.from_pretrained(pretrained)
|
||||||
elif config is not None:
|
elif config is not None:
|
||||||
@@ -34,4 +35,4 @@ class BLOOMCritic(Critic):
|
|||||||
if checkpoint:
|
if checkpoint:
|
||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable()
|
||||||
value_head = nn.Linear(model.config.hidden_size, 1)
|
value_head = nn.Linear(model.config.hidden_size, 1)
|
||||||
super().__init__(model, value_head, lora_rank, lora_train_bias)
|
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
|
||||||
|
@@ -20,7 +20,8 @@ class GPTCritic(Critic):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
pretrained: Optional[str] = None,
|
pretrained: Optional[str] = None,
|
||||||
config: Optional[GPT2Config] = None,
|
config: Optional[GPT2Config] = None,
|
||||||
checkpoint: bool = False) -> None:
|
checkpoint: bool = False,
|
||||||
|
**kwargs) -> None:
|
||||||
if pretrained is not None:
|
if pretrained is not None:
|
||||||
model = GPT2Model.from_pretrained(pretrained)
|
model = GPT2Model.from_pretrained(pretrained)
|
||||||
elif config is not None:
|
elif config is not None:
|
||||||
@@ -30,4 +31,4 @@ class GPTCritic(Critic):
|
|||||||
if checkpoint:
|
if checkpoint:
|
||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable()
|
||||||
value_head = nn.Linear(model.config.n_embd, 1)
|
value_head = nn.Linear(model.config.n_embd, 1)
|
||||||
super().__init__(model, value_head)
|
super().__init__(model, value_head, **kwargs)
|
||||||
|
@@ -24,7 +24,8 @@ class OPTCritic(Critic):
|
|||||||
config: Optional[OPTConfig] = None,
|
config: Optional[OPTConfig] = None,
|
||||||
checkpoint: bool = False,
|
checkpoint: bool = False,
|
||||||
lora_rank: int = 0,
|
lora_rank: int = 0,
|
||||||
lora_train_bias: str = 'none') -> None:
|
lora_train_bias: str = 'none',
|
||||||
|
**kargs) -> None:
|
||||||
if pretrained is not None:
|
if pretrained is not None:
|
||||||
model = OPTModel.from_pretrained(pretrained)
|
model = OPTModel.from_pretrained(pretrained)
|
||||||
elif config is not None:
|
elif config is not None:
|
||||||
@@ -34,4 +35,4 @@ class OPTCritic(Critic):
|
|||||||
if checkpoint:
|
if checkpoint:
|
||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable()
|
||||||
value_head = nn.Linear(model.config.hidden_size, 1)
|
value_head = nn.Linear(model.config.hidden_size, 1)
|
||||||
super().__init__(model, value_head, lora_rank, lora_train_bias)
|
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
|
||||||
|
Reference in New Issue
Block a user