[coati] Fix LlamaCritic (#3475)

* mv LlamaForCausalLM to LlamaModel

* rm unused imports

---------

Co-authored-by: gongenlei <gongenlei@baidu.com>
This commit is contained in:
gongenlei
2023-04-07 11:39:09 +08:00
committed by GitHub
parent 8f2c55f9c9
commit a7ca297281

View File

@@ -1,8 +1,7 @@
from typing import Optional
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM
from transformers import LlamaConfig, LlamaModel
from ..base import Critic
@@ -28,11 +27,11 @@ class LlamaCritic(Critic):
**kwargs) -> None:
if pretrained is not None:
model = LlamaForCausalLM.from_pretrained(pretrained)
model = LlamaModel.from_pretrained(pretrained)
elif config is not None:
model = LlamaForCausalLM(config)
model = LlamaModel(config)
else:
model = LlamaForCausalLM(LlamaConfig())
model = LlamaModel(LlamaConfig())
if checkpoint:
model.gradient_checkpointing_enable()