mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-17 15:10:14 +00:00
fix multi device error
This commit is contained in:
@@ -81,8 +81,9 @@ def generate_output(model, tokenizer, params, device, context_len=2048):
|
|||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def get_embeddings(model, tokenizer, prompt):
|
def get_embeddings(model, tokenizer, prompt):
|
||||||
input_ids = tokenizer(prompt).input_ids
|
input_ids = tokenizer(prompt).input_ids
|
||||||
input_embeddings = model.get_input_embeddings()
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
embeddings = input_embeddings(torch.LongTensor([input_ids]))
|
input_embeddings = model.get_input_embeddings().to(device)
|
||||||
mean = torch.mean(embeddings[0], 0).cpu().detach()
|
|
||||||
return mean
|
|
||||||
|
|
||||||
|
embeddings = input_embeddings(torch.LongTensor([input_ids]).to(device))
|
||||||
|
mean = torch.mean(embeddings[0], 0).cpu().detach()
|
||||||
|
return mean.to(device)
|
||||||
|
Reference in New Issue
Block a user