#!/usr/bin/env python3 # -*- coding:utf-8 -*- import json from typing import Any, List, Mapping, Optional from urllib.parse import urljoin import requests from langchain.embeddings.base import Embeddings from langchain.llms.base import LLM from pydantic import BaseModel from pilot.configs.config import Config CFG = Config() class VicunaLLM(LLM): vicuna_generate_path = "generate_stream" def _call( self, prompt: str, temperature: float, max_new_tokens: int, stop: Optional[List[str]] = None, ) -> str: params = { "prompt": prompt, "temperature": temperature, "max_new_tokens": max_new_tokens, "stop": stop, } response = requests.post( url=urljoin(CFG.MODEL_SERVER, self.vicuna_generate_path), data=json.dumps(params), ) skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("") * 3 for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): if chunk: data = json.loads(chunk.decode()) if data["error_code"] == 0: output = data["text"][skip_echo_len:].strip() yield output @property def _llm_type(self) -> str: return "custome" def _identifying_params(self) -> Mapping[str, Any]: return {} class VicunaEmbeddingLLM(BaseModel, Embeddings): vicuna_embedding_path = "embedding" def _call(self, prompt: str) -> str: p = prompt.strip() print("Sending prompt ", p) response = requests.post( url=urljoin(CFG.MODEL_SERVER, self.vicuna_embedding_path), json={"prompt": p}, ) response.raise_for_status() return response.json()["response"] def embed_documents(self, texts: List[str]) -> List[List[float]]: """Call out to Vicuna's server embedding endpoint for embedding search docs. Args: texts: The list of text to embed Returns: List of embeddings. one for each text. """ results = [] for text in texts: response = self.embed_query(text) results.append(response) return results def embed_query(self, text: str) -> List[float]: """Call out to Vicuna's server embedding endpoint for embedding query text. Args: text: The text to embed. Returns: Embedding for the text """ embedding = self._call(text) return embedding