mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-25 15:01:43 +00:00
[chatgpt] change critic input as state (#3042)
* fix Critic * fix Critic * fix Critic * fix neglect of attention mask * fix neglect of attention mask * fix neglect of attention mask * add return --------- Co-authored-by: yangwenjun <yangwenjun@soyoung.com> Co-authored-by: yangwjd <yangwjd@chanjet.com>
This commit is contained in:
parent
2ef855c798
commit
b51bfec357
@ -36,12 +36,15 @@ class Critic(LoRAModule):
|
||||
outputs = self.model(sequences, attention_mask=attention_mask)
|
||||
last_hidden_states = outputs['last_hidden_state']
|
||||
|
||||
values = self.value_head(last_hidden_states).squeeze(-1)[:, :-1]
|
||||
values = self.value_head(last_hidden_states).squeeze(-1)
|
||||
|
||||
if action_mask is not None:
|
||||
num_actions = action_mask.size(1)
|
||||
values = values[:, -num_actions:]
|
||||
value = masked_mean(values, action_mask, dim=1)
|
||||
prompt_mask = attention_mask[:, :-num_actions]
|
||||
values = values[:, :-num_actions]
|
||||
value = masked_mean(values, prompt_mask, dim=1)
|
||||
return value
|
||||
|
||||
values = values[:, :-1]
|
||||
value = values.mean(dim=1).squeeze(1)
|
||||
return value
|
||||
|
Loading…
Reference in New Issue
Block a user