diff --git a/pilot/model/inference.py b/pilot/model/inference.py index 4a013026f..b15e1d749 100644 --- a/pilot/model/inference.py +++ b/pilot/model/inference.py @@ -81,8 +81,9 @@ def generate_output(model, tokenizer, params, device, context_len=2048): @torch.inference_mode() def get_embeddings(model, tokenizer, prompt): input_ids = tokenizer(prompt).input_ids - input_embeddings = model.get_input_embeddings() - embeddings = input_embeddings(torch.LongTensor([input_ids])) - mean = torch.mean(embeddings[0], 0).cpu().detach() - return mean + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + input_embeddings = model.get_input_embeddings().to(device) + embeddings = input_embeddings(torch.LongTensor([input_ids]).to(device)) + mean = torch.mean(embeddings[0], 0).cpu().detach() + return mean.to(device)