upgrade ppo dpo rm script

This commit is contained in:
YeAnbang
2024-05-28 03:04:39 +00:00
parent 7a7e86987d
commit 929e1e3da4
15 changed files with 169 additions and 139 deletions

View File

@@ -36,3 +36,9 @@ class RewardModel(BaseModel):
)
values = self.value_head(sequence_hidden_states).squeeze(-1) # Ensure shape is (B,)
return values
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def get_output_embeddings(self):
return self.model.get_output_embeddings()