checkout load_model to fastchat

This commit is contained in:
csunny 2023-04-30 14:21:58 +08:00
parent 4381e21d0c
commit 7c69dc248a
2 changed files with 13 additions and 4 deletions

View File

@ -85,3 +85,4 @@ def get_embeddings(model, tokenizer, prompt):
embeddings = input_embeddings(torch.LongTensor([input_ids]))
mean = torch.mean(embeddings[0], 0).cpu().detach()
return mean

View File

@ -1,17 +1,20 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import uvicorn
from typing import Optional, List
from fastapi import FastAPI
from pydantic import BaseModel
from pilot.model.inference import generate_output, get_embeddings
from fastchat.serve.inference import load_model
from pilot.model.loader import ModerLoader
from pilot.configs.model_config import *
model_path = llm_model_config[LLM_MODEL]
ml = ModerLoader(model_path=model_path)
model, tokenizer = ml.loader(load_8bit=isload_8bit, debug=isdebug)
# ml = ModerLoader(model_path=model_path)
# model, tokenizer = ml.loader(load_8bit=isload_8bit, debug=isdebug)
model, tokenizer = load_model(model_path=model_path, device=DEVICE, num_gpus=1, load_8bit=True, debug=False)
app = FastAPI()
class PromptRequest(BaseModel):
@ -46,3 +49,8 @@ def embeddings(prompt_request: EmbeddingRequest):
print("Received prompt: ", params["prompt"])
output = get_embeddings(model, tokenizer, params["prompt"])
return {"response": [float(x) for x in output]}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", log_level="info")