From 8ef6a0538619c2ef9f9ca66e25a958679769bdbb Mon Sep 17 00:00:00 2001 From: csunny Date: Sun, 30 Apr 2023 14:47:33 +0800 Subject: [PATCH] fix multi device error --- pilot/model/inference.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/pilot/model/inference.py b/pilot/model/inference.py index 4a013026f..b15e1d749 100644 --- a/pilot/model/inference.py +++ b/pilot/model/inference.py @@ -81,8 +81,9 @@ def generate_output(model, tokenizer, params, device, context_len=2048): @torch.inference_mode() def get_embeddings(model, tokenizer, prompt): input_ids = tokenizer(prompt).input_ids - input_embeddings = model.get_input_embeddings() - embeddings = input_embeddings(torch.LongTensor([input_ids])) - mean = torch.mean(embeddings[0], 0).cpu().detach() - return mean + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + input_embeddings = model.get_input_embeddings().to(device) + embeddings = input_embeddings(torch.LongTensor([input_ids]).to(device)) + mean = torch.mean(embeddings[0], 0).cpu().detach() + return mean.to(device)