From 7c69dc248a0b847b3bc4ee05f81660d165e3345a Mon Sep 17 00:00:00 2001 From: csunny Date: Sun, 30 Apr 2023 14:21:58 +0800 Subject: [PATCH] checkout load_model to fastchat --- pilot/model/inference.py | 3 ++- pilot/server/vicuna_server.py | 14 +++++++++++--- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/pilot/model/inference.py b/pilot/model/inference.py index 426043aa5..4a013026f 100644 --- a/pilot/model/inference.py +++ b/pilot/model/inference.py @@ -84,4 +84,5 @@ def get_embeddings(model, tokenizer, prompt): input_embeddings = model.get_input_embeddings() embeddings = input_embeddings(torch.LongTensor([input_ids])) mean = torch.mean(embeddings[0], 0).cpu().detach() - return mean \ No newline at end of file + return mean + diff --git a/pilot/server/vicuna_server.py b/pilot/server/vicuna_server.py index 20ed928d0..9fa5ddbee 100644 --- a/pilot/server/vicuna_server.py +++ b/pilot/server/vicuna_server.py @@ -1,17 +1,20 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +import uvicorn from typing import Optional, List from fastapi import FastAPI from pydantic import BaseModel from pilot.model.inference import generate_output, get_embeddings +from fastchat.serve.inference import load_model from pilot.model.loader import ModerLoader from pilot.configs.model_config import * model_path = llm_model_config[LLM_MODEL] -ml = ModerLoader(model_path=model_path) -model, tokenizer = ml.loader(load_8bit=isload_8bit, debug=isdebug) +# ml = ModerLoader(model_path=model_path) +# model, tokenizer = ml.loader(load_8bit=isload_8bit, debug=isdebug) +model, tokenizer = load_model(model_path=model_path, device=DEVICE, num_gpus=1, load_8bit=True, debug=False) app = FastAPI() class PromptRequest(BaseModel): @@ -45,4 +48,9 @@ def embeddings(prompt_request: EmbeddingRequest): params = {"prompt": prompt_request.prompt} print("Received prompt: ", params["prompt"]) output = get_embeddings(model, tokenizer, params["prompt"]) - return {"response": [float(x) for x in output]} \ No newline at end of file + return {"response": [float(x) for x in output]} + + + +if __name__ == "__main__": + uvicorn.run(app, host="0.0.0.0", log_level="info") \ No newline at end of file