mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
upgrade ppo dpo rm script
This commit is contained in:
@@ -32,3 +32,9 @@ class Critic(BaseModel):
|
||||
)
|
||||
values = self.value_head(sequence_hidden_states).squeeze(-1) # ensure shape is (B, sequence length)
|
||||
return values
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.get_input_embeddings()
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.model.get_output_embeddings()
|
Reference in New Issue
Block a user