From b09adff724c2bbded1c71cc51a707f736a0e2899 Mon Sep 17 00:00:00 2001 From: Yuanchen <70520919+chengeharrison@users.noreply.github.com> Date: Tue, 4 Apr 2023 09:46:23 +0800 Subject: [PATCH] [chat]fix sft training for bloom, gpt and opt (#3418) fix sft training for bloom, gpt and opt --- applications/Chat/coati/models/bloom/bloom_lm.py | 3 +++ applications/Chat/coati/models/gpt/gpt_lm.py | 3 +++ applications/Chat/coati/models/opt/opt_lm.py | 3 +++ 3 files changed, 9 insertions(+) diff --git a/applications/Chat/coati/models/bloom/bloom_lm.py b/applications/Chat/coati/models/bloom/bloom_lm.py index 628af2e34..e4184fcd0 100644 --- a/applications/Chat/coati/models/bloom/bloom_lm.py +++ b/applications/Chat/coati/models/bloom/bloom_lm.py @@ -33,3 +33,6 @@ class BLOOMLM(LM): if checkpoint: model.gradient_checkpointing_enable() super().__init__(model, lora_rank, lora_train_bias) + + def forward(self, input_ids, attention_mask=None, labels=None, **kwargs): + return self.model(input_ids, attention_mask=attention_mask, labels=labels, **kwargs) diff --git a/applications/Chat/coati/models/gpt/gpt_lm.py b/applications/Chat/coati/models/gpt/gpt_lm.py index 23fc13bf2..c558d7e9e 100644 --- a/applications/Chat/coati/models/gpt/gpt_lm.py +++ b/applications/Chat/coati/models/gpt/gpt_lm.py @@ -33,3 +33,6 @@ class GPTLM(LM): if checkpoint: model.gradient_checkpointing_enable() super().__init__(model, lora_rank, lora_train_bias) + + def forward(self, input_ids, attention_mask=None, labels=None, **kwargs): + return self.model(input_ids, attention_mask=attention_mask, labels=labels, **kwargs) diff --git a/applications/Chat/coati/models/opt/opt_lm.py b/applications/Chat/coati/models/opt/opt_lm.py index 65d79e1b2..47afae847 100644 --- a/applications/Chat/coati/models/opt/opt_lm.py +++ b/applications/Chat/coati/models/opt/opt_lm.py @@ -33,3 +33,6 @@ class OPTLM(LM): if checkpoint: model.gradient_checkpointing_enable() super().__init__(model, lora_rank, lora_train_bias) + + def forward(self, input_ids, attention_mask=None, labels=None, **kwargs): + return self.model(input_ids, attention_mask=attention_mask, labels=labels, **kwargs)