diff --git a/applications/Chat/coati/models/bloom/bloom_lm.py b/applications/Chat/coati/models/bloom/bloom_lm.py index 628af2e34..e4184fcd0 100644 --- a/applications/Chat/coati/models/bloom/bloom_lm.py +++ b/applications/Chat/coati/models/bloom/bloom_lm.py @@ -33,3 +33,6 @@ class BLOOMLM(LM): if checkpoint: model.gradient_checkpointing_enable() super().__init__(model, lora_rank, lora_train_bias) + + def forward(self, input_ids, attention_mask=None, labels=None, **kwargs): + return self.model(input_ids, attention_mask=attention_mask, labels=labels, **kwargs) diff --git a/applications/Chat/coati/models/gpt/gpt_lm.py b/applications/Chat/coati/models/gpt/gpt_lm.py index 23fc13bf2..c558d7e9e 100644 --- a/applications/Chat/coati/models/gpt/gpt_lm.py +++ b/applications/Chat/coati/models/gpt/gpt_lm.py @@ -33,3 +33,6 @@ class GPTLM(LM): if checkpoint: model.gradient_checkpointing_enable() super().__init__(model, lora_rank, lora_train_bias) + + def forward(self, input_ids, attention_mask=None, labels=None, **kwargs): + return self.model(input_ids, attention_mask=attention_mask, labels=labels, **kwargs) diff --git a/applications/Chat/coati/models/opt/opt_lm.py b/applications/Chat/coati/models/opt/opt_lm.py index 65d79e1b2..47afae847 100644 --- a/applications/Chat/coati/models/opt/opt_lm.py +++ b/applications/Chat/coati/models/opt/opt_lm.py @@ -33,3 +33,6 @@ class OPTLM(LM): if checkpoint: model.gradient_checkpointing_enable() super().__init__(model, lora_rank, lora_train_bias) + + def forward(self, input_ids, attention_mask=None, labels=None, **kwargs): + return self.model(input_ids, attention_mask=attention_mask, labels=labels, **kwargs)