[coati] Fix LlamaCritic (#3475)

* mv LlamaForCausalLM to LlamaModel

* rm unused imports

---------

Co-authored-by: gongenlei <gongenlei@baidu.com>
This commit is contained in:
gongenlei
2023-04-07 11:39:09 +08:00
committed by GitHub
parent 8f2c55f9c9
commit a7ca297281

View File

@@ -1,8 +1,7 @@
from typing import Optional from typing import Optional
import torch
import torch.nn as nn import torch.nn as nn
from transformers import AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM from transformers import LlamaConfig, LlamaModel
from ..base import Critic from ..base import Critic
@@ -28,11 +27,11 @@ class LlamaCritic(Critic):
**kwargs) -> None: **kwargs) -> None:
if pretrained is not None: if pretrained is not None:
model = LlamaForCausalLM.from_pretrained(pretrained) model = LlamaModel.from_pretrained(pretrained)
elif config is not None: elif config is not None:
model = LlamaForCausalLM(config) model = LlamaModel(config)
else: else:
model = LlamaForCausalLM(LlamaConfig()) model = LlamaModel(LlamaConfig())
if checkpoint: if checkpoint:
model.gradient_checkpointing_enable() model.gradient_checkpointing_enable()