diff --git a/pilot/model/inference.py b/pilot/model/inference.py index 71192e877..66766b3b3 100644 --- a/pilot/model/inference.py +++ b/pilot/model/inference.py @@ -3,6 +3,71 @@ import torch +@torch.inference_mode() +def generate_stream(model, tokenizer, params, device, + context_len=2048, stream_interval=2): + + """Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py """ + prompt = params["prompt"] + l_prompt = len(prompt) + temperature = float(params.get("temperature", 1.0)) + max_new_tokens = int(params.get("max_new_tokens", 256)) + stop_str = params.get("stop", None) + + input_ids = tokenizer(prompt).input_ids + output_ids = list(input_ids) + + max_src_len = context_len - max_new_tokens - 8 + input_ids = input_ids[-max_src_len:] + + for i in range(max_new_tokens): + if i == 0: + out = model( + torch.as_tensor([input_ids], device=device), use_cache=True) + logits = out.logits + past_key_values = out.past_key_values + else: + attention_mask = torch.ones( + 1, past_key_values[0][0].shape[-2] + 1, device=device) + out = model(input_ids=torch.as_tensor([[token]], device=device), + use_cache=True, + attention_mask=attention_mask, + past_key_values=past_key_values) + logits = out.logits + past_key_values = out.past_key_values + + last_token_logits = logits[0][-1] + + if device == "mps": + # Switch to CPU by avoiding some bugs in mps backend. + last_token_logits = last_token_logits.float().to("cpu") + + if temperature < 1e-4: + token = int(torch.argmax(last_token_logits)) + else: + probs = torch.softmax(last_token_logits / temperature, dim=-1) + token = int(torch.multinomial(probs, num_samples=1)) + + output_ids.append(token) + + if token == tokenizer.eos_token_id: + stopped = True + else: + stopped = False + + if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped: + output = tokenizer.decode(output_ids, skip_special_tokens=True) + pos = output.rfind(stop_str, l_prompt) + if pos != -1: + output = output[:pos] + stopped = True + yield output + + if stopped: + break + + del past_key_values + @torch.inference_mode() def generate_output(model, tokenizer, params, device, context_len=2048, stream_interval=2): prompt = params["prompt"] diff --git a/pilot/model/loader.py b/pilot/model/loader.py index 5f18a023c..e601621f7 100644 --- a/pilot/model/loader.py +++ b/pilot/model/loader.py @@ -5,6 +5,7 @@ import torch from transformers import ( AutoTokenizer, AutoModelForCausalLM, + AutoModel ) from fastchat.serve.compression import compress_module @@ -23,20 +24,39 @@ class ModerLoader: "device_map": "auto", } - def loader(self, load_8bit=False, debug=False): - - tokenizer = AutoTokenizer.from_pretrained(self.model_path, use_fast=False) - model = AutoModelForCausalLM.from_pretrained(self.model_path, low_cpu_mem_usage=True, **self.kwargs) + def loader(self, num_gpus, load_8bit=False, debug=False): + if self.device == "cpu": + kwargs = {} + elif self.device == "cuda": + kwargs = {"torch_dtype": torch.float16} + if num_gpus == "auto": + kwargs["device_map"] = "auto" + else: + num_gpus = int(num_gpus) + if num_gpus != 1: + kwargs.update({ + "device_map": "auto", + "max_memory": {i: "13GiB" for i in range(num_gpus)}, + }) + else: + raise ValueError(f"Invalid device: {self.device}") + + if "chatglm" in self.model_path: + tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True) + model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True).half().cuda() + else: + tokenizer = AutoTokenizer.from_pretrained(self.model_path, use_fast=False) + model = AutoModelForCausalLM.from_pretrained(self.model_path, + low_cpu_mem_usage=True, **kwargs) + + if load_8bit: + compress_module(model, self.device) + + if (self.device == "cuda" and num_gpus == 1): + model.to(self.device) if debug: print(model) - if load_8bit: - compress_module(model, self.device) - - # if self.device == "cuda": - # model.to(self.device) - return model, tokenizer - - + diff --git a/pilot/server/vicuna_server.py b/pilot/server/vicuna_server.py index 79bc1dab3..674afb71b 100644 --- a/pilot/server/vicuna_server.py +++ b/pilot/server/vicuna_server.py @@ -7,10 +7,12 @@ import json from typing import Optional, List from fastapi import FastAPI, Request, BackgroundTasks from fastapi.responses import StreamingResponse -from fastchat.serve.inference import generate_stream +from pilot.model.inference import generate_stream from pydantic import BaseModel from pilot.model.inference import generate_output, get_embeddings from fastchat.serve.inference import load_model + + from pilot.model.loader import ModerLoader from pilot.configs.model_config import * @@ -20,9 +22,9 @@ model_path = LLM_MODEL_CONFIG[LLM_MODEL] global_counter = 0 model_semaphore = None -# ml = ModerLoader(model_path=model_path) -# model, tokenizer = ml.loader(load_8bit=isload_8bit, debug=isdebug) -model, tokenizer = load_model(model_path=model_path, device=DEVICE, num_gpus=1, load_8bit=True, debug=False) +ml = ModerLoader(model_path=model_path) +model, tokenizer = ml.loader(num_gpus=1, load_8bit=ISLOAD_8BIT, debug=ISDEBUG) +#model, tokenizer = load_model(model_path=model_path, device=DEVICE, num_gpus=1, load_8bit=True, debug=False) class ModelWorker: def __init__(self): diff --git a/pilot/vector_store/file_loader.py b/pilot/vector_store/file_loader.py new file mode 100644 index 000000000..22b1a1c57 --- /dev/null +++ b/pilot/vector_store/file_loader.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +from langchain.prompts import PromptTemplate +from langchain.vectorstores import Chroma +from langchain.text_splitter import CharacterTextSplitter +from langchain.document_loaders import UnstructuredFileLoader, UnstructuredPDFLoader + +VECTOR_SEARCH_TOP_K = 5 + +class BaseKnownLedgeQA: + + llm: object = None + embeddings: object = None + + top_k: int = VECTOR_SEARCH_TOP_K + + def __init__(self) -> None: + pass + + def init_vector_store(self): + pass + + def load_knownlege(self): + pass + + def _load_file(self, filename): + # 加载文件 + if filename.lower().endswith(".pdf"): + loader = UnstructuredFileLoader(filename) + text_splitor = CharacterTextSplitter() + docs = loader.load_and_split(text_splitor) + else: + loader = UnstructuredFileLoader(filename, mode="elements") + text_splitor = CharacterTextSplitter() + docs = loader.load_and_split(text_splitor) + return docs + + def _load_from_url(self, url): + pass + \ No newline at end of file