mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 10:34:41 +00:00
[coati] Fix LlamaCritic (#3475)
* mv LlamaForCausalLM to LlamaModel * rm unused imports --------- Co-authored-by: gongenlei <gongenlei@baidu.com>
This commit is contained in:
@@ -1,8 +1,7 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers import AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM
|
from transformers import LlamaConfig, LlamaModel
|
||||||
|
|
||||||
from ..base import Critic
|
from ..base import Critic
|
||||||
|
|
||||||
@@ -28,11 +27,11 @@ class LlamaCritic(Critic):
|
|||||||
**kwargs) -> None:
|
**kwargs) -> None:
|
||||||
|
|
||||||
if pretrained is not None:
|
if pretrained is not None:
|
||||||
model = LlamaForCausalLM.from_pretrained(pretrained)
|
model = LlamaModel.from_pretrained(pretrained)
|
||||||
elif config is not None:
|
elif config is not None:
|
||||||
model = LlamaForCausalLM(config)
|
model = LlamaModel(config)
|
||||||
else:
|
else:
|
||||||
model = LlamaForCausalLM(LlamaConfig())
|
model = LlamaModel(LlamaConfig())
|
||||||
|
|
||||||
if checkpoint:
|
if checkpoint:
|
||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable()
|
||||||
|
Reference in New Issue
Block a user