mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-31 08:34:14 +00:00
[coati] Fix LlamaCritic (#3475)
* mv LlamaForCausalLM to LlamaModel * rm unused imports --------- Co-authored-by: gongenlei <gongenlei@baidu.com>
This commit is contained in:
@@ -1,8 +1,7 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM
|
||||
from transformers import LlamaConfig, LlamaModel
|
||||
|
||||
from ..base import Critic
|
||||
|
||||
@@ -28,11 +27,11 @@ class LlamaCritic(Critic):
|
||||
**kwargs) -> None:
|
||||
|
||||
if pretrained is not None:
|
||||
model = LlamaForCausalLM.from_pretrained(pretrained)
|
||||
model = LlamaModel.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = LlamaForCausalLM(config)
|
||||
model = LlamaModel(config)
|
||||
else:
|
||||
model = LlamaForCausalLM(LlamaConfig())
|
||||
model = LlamaModel(LlamaConfig())
|
||||
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
Reference in New Issue
Block a user