mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-17 15:36:53 +00:00
[chat]fix sft training for bloom, gpt and opt (#3418)
fix sft training for bloom, gpt and opt
This commit is contained in:
parent
638a07a7f9
commit
b09adff724
@ -33,3 +33,6 @@ class BLOOMLM(LM):
|
|||||||
if checkpoint:
|
if checkpoint:
|
||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable()
|
||||||
super().__init__(model, lora_rank, lora_train_bias)
|
super().__init__(model, lora_rank, lora_train_bias)
|
||||||
|
|
||||||
|
def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
|
||||||
|
return self.model(input_ids, attention_mask=attention_mask, labels=labels, **kwargs)
|
||||||
|
@ -33,3 +33,6 @@ class GPTLM(LM):
|
|||||||
if checkpoint:
|
if checkpoint:
|
||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable()
|
||||||
super().__init__(model, lora_rank, lora_train_bias)
|
super().__init__(model, lora_rank, lora_train_bias)
|
||||||
|
|
||||||
|
def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
|
||||||
|
return self.model(input_ids, attention_mask=attention_mask, labels=labels, **kwargs)
|
||||||
|
@ -33,3 +33,6 @@ class OPTLM(LM):
|
|||||||
if checkpoint:
|
if checkpoint:
|
||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable()
|
||||||
super().__init__(model, lora_rank, lora_train_bias)
|
super().__init__(model, lora_rank, lora_train_bias)
|
||||||
|
|
||||||
|
def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
|
||||||
|
return self.model(input_ids, attention_mask=attention_mask, labels=labels, **kwargs)
|
||||||
|
Loading…
Reference in New Issue
Block a user