From 467b082c34105176ffa3d1e42adb15b957fa438d Mon Sep 17 00:00:00 2001 From: kYLe Date: Thu, 12 Oct 2023 10:41:25 -0500 Subject: [PATCH] Modify Anyscale integration to work with Anyscale Endpoint (#11569) **Description:** Modify Anyscale integration to work with [Anyscale Endpoint](https://docs.endpoints.anyscale.com/) and it supports invoke, async invoke, stream and async invoke features --------- Co-authored-by: Bagatur --- docs/docs/integrations/llms/anyscale.ipynb | 20 +- libs/langchain/langchain/llms/anyscale.py | 314 +++++++++++++++------ 2 files changed, 237 insertions(+), 97 deletions(-) diff --git a/docs/docs/integrations/llms/anyscale.ipynb b/docs/docs/integrations/llms/anyscale.ipynb index b94df065409..f59e5b154f9 100644 --- a/docs/docs/integrations/llms/anyscale.ipynb +++ b/docs/docs/integrations/llms/anyscale.ipynb @@ -1,7 +1,6 @@ { "cells": [ { - "attachments": {}, "cell_type": "markdown", "id": "9597802c", "metadata": {}, @@ -10,9 +9,7 @@ "\n", "[Anyscale](https://www.anyscale.com/) is a fully-managed [Ray](https://www.ray.io/) platform, on which you can build, deploy, and manage scalable AI and Python applications\n", "\n", - "This example goes over how to use LangChain to interact with `Anyscale` [service](https://docs.anyscale.com/productionize/services-v2/get-started). \n", - "\n", - "It will send the requests to Anyscale Service endpoint, which is concatenate `ANYSCALE_SERVICE_URL` and `ANYSCALE_SERVICE_ROUTE`, with a token defined in `ANYSCALE_SERVICE_TOKEN`" + "This example goes over how to use LangChain to interact with [Anyscale Endpoint](https://app.endpoints.anyscale.com/). " ] }, { @@ -26,9 +23,8 @@ "source": [ "import os\n", "\n", - "os.environ[\"ANYSCALE_SERVICE_URL\"] = ANYSCALE_SERVICE_URL\n", - "os.environ[\"ANYSCALE_SERVICE_ROUTE\"] = ANYSCALE_SERVICE_ROUTE\n", - "os.environ[\"ANYSCALE_SERVICE_TOKEN\"] = ANYSCALE_SERVICE_TOKEN" + "os.environ[\"ANYSCALE_API_BASE\"] = ANYSCALE_API_BASE\n", + "os.environ[\"ANYSCALE_API_KEY\"] = ANYSCALE_API_KEY" ] }, { @@ -41,7 +37,8 @@ "outputs": [], "source": [ "from langchain.llms import Anyscale\n", - "from langchain.prompts import PromptTemplate\nfrom langchain.chains import LLMChain" + "from langchain.prompts import PromptTemplate\n", + "from langchain.chains import LLMChain" ] }, { @@ -69,7 +66,7 @@ }, "outputs": [], "source": [ - "llm = Anyscale()" + "llm = Anyscale(model_name=ANYSCALE_MODEL_NAME)" ] }, { @@ -99,7 +96,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "id": "42f05b34-1a44-4cbd-8342-35c1572b6765", "metadata": {}, @@ -136,13 +132,11 @@ "source": [ "import ray\n", "\n", - "\n", - "@ray.remote\n", + "@ray.remote(num_cpus=0.1)\n", "def send_query(llm, prompt):\n", " resp = llm(prompt)\n", " return resp\n", "\n", - "\n", "futures = [send_query.remote(llm, prompt) for prompt in prompt_list]\n", "results = ray.get(futures)" ] diff --git a/libs/langchain/langchain/llms/anyscale.py b/libs/langchain/langchain/llms/anyscale.py index b3582961141..65e4a9aded3 100644 --- a/libs/langchain/langchain/llms/anyscale.py +++ b/libs/langchain/langchain/llms/anyscale.py @@ -1,126 +1,272 @@ -from typing import Any, Dict, List, Mapping, Optional +"""Wrapper around Anyscale Endpoint""" +from typing import ( + Any, + AsyncIterator, + Dict, + Iterator, + List, + Mapping, + Optional, + Set, + Tuple, +) -import requests - -from langchain.callbacks.manager import CallbackManagerForLLMRun -from langchain.llms.base import LLM -from langchain.llms.utils import enforce_stop_tokens -from langchain.pydantic_v1 import Extra, root_validator +from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) +from langchain.llms.openai import ( + BaseOpenAI, + acompletion_with_retry, + completion_with_retry, +) +from langchain.pydantic_v1 import Field, root_validator +from langchain.schema import Generation, LLMResult +from langchain.schema.output import GenerationChunk from langchain.utils import get_from_dict_or_env -class Anyscale(LLM): - """Anyscale Service models. +def update_token_usage( + keys: Set[str], response: Dict[str, Any], token_usage: Dict[str, Any] +) -> None: + """Update token usage.""" + _keys_to_use = keys.intersection(response["usage"]) + for _key in _keys_to_use: + if _key not in token_usage: + token_usage[_key] = response["usage"][_key] + else: + token_usage[_key] += response["usage"][_key] - To use, you should have the environment variable ``ANYSCALE_SERVICE_URL``, - ``ANYSCALE_SERVICE_ROUTE`` and ``ANYSCALE_SERVICE_TOKEN`` set with your Anyscale - Service, or pass it as a named parameter to the constructor. + +def create_llm_result( + choices: Any, prompts: List[str], token_usage: Dict[str, int], model_name: str +) -> LLMResult: + """Create the LLMResult from the choices and prompts.""" + generations = [] + for i, _ in enumerate(prompts): + choice = choices[i] + generations.append( + [ + Generation( + text=choice["message"]["content"], + generation_info=dict( + finish_reason=choice.get("finish_reason"), + logprobs=choice.get("logprobs"), + ), + ) + ] + ) + llm_output = {"token_usage": token_usage, "model_name": model_name} + return LLMResult(generations=generations, llm_output=llm_output) + + +class Anyscale(BaseOpenAI): + """Wrapper around Anyscale Endpoint. + To use, you should have the environment variable ``ANYSCALE_API_BASE`` and + ``ANYSCALE_API_KEY``set with your Anyscale Endpoint, or pass it as a named + parameter to the constructor. Example: .. code-block:: python - from langchain.llms import Anyscale - anyscale = Anyscale(anyscale_service_url="SERVICE_URL", - anyscale_service_route="SERVICE_ROUTE", - anyscale_service_token="SERVICE_TOKEN") - - # Use Ray for distributed processing - import ray - prompt_list=[] - @ray.remote - def send_query(llm, prompt): - resp = llm(prompt) + anyscalellm = Anyscale(anyscale_api_base="ANYSCALE_API_BASE", + anyscale_api_key="ANYSCALE_API_KEY", + model_name="meta-llama/Llama-2-7b-chat-hf") + # To leverage Ray for parallel processing + @ray.remote(num_cpus=1) + def send_query(llm, text): + resp = llm(text) return resp - futures = [send_query.remote(anyscale, prompt) for prompt in prompt_list] + futures = [send_query.remote(anyscalellm, text) for text in texts] results = ray.get(futures) """ - model_kwargs: Optional[dict] = None - """Keyword arguments to pass to the model. Reserved for future use""" + """Key word arguments to pass to the model.""" + anyscale_api_base: Optional[str] = None + anyscale_api_key: Optional[str] = None - anyscale_service_url: Optional[str] = None - anyscale_service_route: Optional[str] = None - anyscale_service_token: Optional[str] = None - - class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid + prefix_messages: List = Field(default_factory=list) @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - anyscale_service_url = get_from_dict_or_env( - values, "anyscale_service_url", "ANYSCALE_SERVICE_URL" + values["anyscale_api_base"] = get_from_dict_or_env( + values, "anyscale_api_base", "ANYSCALE_API_BASE" ) - anyscale_service_route = get_from_dict_or_env( - values, "anyscale_service_route", "ANYSCALE_SERVICE_ROUTE" + values["anyscale_api_key"] = get_from_dict_or_env( + values, "anyscale_api_key", "ANYSCALE_API_KEY" ) - anyscale_service_token = get_from_dict_or_env( - values, "anyscale_service_token", "ANYSCALE_SERVICE_TOKEN" - ) - if anyscale_service_url.endswith("/"): - anyscale_service_url = anyscale_service_url[:-1] - if not anyscale_service_route.startswith("/"): - anyscale_service_route = "/" + anyscale_service_route try: - anyscale_service_endpoint = f"{anyscale_service_url}/-/routes" - headers = {"Authorization": f"Bearer {anyscale_service_token}"} - requests.get(anyscale_service_endpoint, headers=headers) - except requests.exceptions.RequestException as e: - raise ValueError(e) + import openai + + ## Always create ChatComplete client, replacing the legacy Complete client + values["client"] = openai.ChatCompletion + except ImportError: + raise ImportError( + "Could not import openai python package. " + "Please install it with `pip install openai`." + ) + if values["streaming"] and values["n"] > 1: + raise ValueError("Cannot stream results when n > 1.") + if values["streaming"] and values["best_of"] > 1: + raise ValueError("Cannot stream results when best_of > 1.") - values["anyscale_service_url"] = anyscale_service_url - values["anyscale_service_route"] = anyscale_service_route - values["anyscale_service_token"] = anyscale_service_token return values @property def _identifying_params(self) -> Mapping[str, Any]: """Get the identifying parameters.""" return { - "anyscale_service_url": self.anyscale_service_url, - "anyscale_service_route": self.anyscale_service_route, + **{"model_name": self.model_name}, + **super()._identifying_params, } + @property + def _invocation_params(self) -> Dict[str, Any]: + """Get the parameters used to invoke the model.""" + openai_creds: Dict[str, Any] = { + "api_key": self.anyscale_api_key, + "api_base": self.anyscale_api_base, + } + return {**openai_creds, **{"model": self.model_name}, **super()._default_params} + @property def _llm_type(self) -> str: """Return type of llm.""" - return "anyscale" + return "Anyscale LLM" - def _call( + def _get_chat_messages( + self, prompts: List[str], stop: Optional[List[str]] = None + ) -> Tuple: + if len(prompts) > 1: + raise ValueError( + f"Anyscale currently only supports single prompt, got {prompts}" + ) + messages = self.prefix_messages + [{"role": "user", "content": prompts[0]}] + params: Dict[str, Any] = self._invocation_params + if stop is not None: + if "stop" in params: + raise ValueError("`stop` found in both the input and default params.") + params["stop"] = stop + if params.get("max_tokens") == -1: + # for Chat api, omitting max_tokens is equivalent to having no limit + del params["max_tokens"] + return messages, params + + def _stream( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, - ) -> str: - """Call out to Anyscale Service endpoint. - Args: - prompt: The prompt to pass into the model. - stop: Optional list of stop words to use when generating. - Returns: - The string generated by the model. - Example: - .. code-block:: python - response = anyscale("Tell me a joke.") - """ + ) -> Iterator[GenerationChunk]: + messages, params = self._get_chat_messages([prompt], stop) + params = {**params, **kwargs, "stream": True} + for stream_resp in completion_with_retry( + self, messages=messages, run_manager=run_manager, **params + ): + token = stream_resp["choices"][0]["delta"].get("content", "") + chunk = GenerationChunk(text=token) + yield chunk + if run_manager: + run_manager.on_llm_new_token(token, chunk=chunk) - anyscale_service_endpoint = ( - f"{self.anyscale_service_url}{self.anyscale_service_route}" - ) - headers = {"Authorization": f"Bearer {self.anyscale_service_token}"} - body = {"prompt": prompt} - resp = requests.post(anyscale_service_endpoint, headers=headers, json=body) + async def _astream( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[GenerationChunk]: + messages, params = self._get_chat_messages([prompt], stop) + params = {**params, **kwargs, "stream": True} + async for stream_resp in await acompletion_with_retry( + self, messages=messages, run_manager=run_manager, **params + ): + token = stream_resp["choices"][0]["delta"].get("content", "") + chunk = GenerationChunk(text=token) + yield chunk + if run_manager: + await run_manager.on_llm_new_token(token, chunk=chunk) - if resp.status_code != 200: - raise ValueError( - f"Error returned by service, status code {resp.status_code}" - ) - text = resp.text + def _generate( + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> LLMResult: + choices = [] + token_usage: Dict[str, int] = {} + _keys = {"completion_tokens", "prompt_tokens", "total_tokens"} + for prompt in prompts: + if self.streaming: + generation: Optional[GenerationChunk] = None + for chunk in self._stream(prompt, stop, run_manager, **kwargs): + if generation is None: + generation = chunk + else: + generation += chunk + assert generation is not None + choices.append( + { + "message": {"content": generation.text}, + "finish_reason": generation.generation_info.get("finish_reason") + if generation.generation_info + else None, + "logprobs": generation.generation_info.get("logprobs") + if generation.generation_info + else None, + } + ) - if stop is not None: - # This is a bit hacky, but I can't figure out a better way to enforce - # stop tokens when making calls to huggingface_hub. - text = enforce_stop_tokens(text, stop) - return text + else: + messages, params = self._get_chat_messages([prompt], stop) + params = {**params, **kwargs} + response = completion_with_retry( + self, messages=messages, run_manager=run_manager, **params + ) + choices.extend(response["choices"]) + update_token_usage(_keys, response, token_usage) + return create_llm_result(choices, prompts, token_usage, self.model_name) + + async def _agenerate( + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> LLMResult: + choices = [] + token_usage: Dict[str, int] = {} + _keys = {"completion_tokens", "prompt_tokens", "total_tokens"} + for prompt in prompts: + messages = self.prefix_messages + [{"role": "user", "content": prompt}] + if self.streaming: + generation: Optional[GenerationChunk] = None + async for chunk in self._astream(prompt, stop, run_manager, **kwargs): + if generation is None: + generation = chunk + else: + generation += chunk + assert generation is not None + choices.append( + { + "message": {"content": generation.text}, + "finish_reason": generation.generation_info.get("finish_reason") + if generation.generation_info + else None, + "logprobs": generation.generation_info.get("logprobs") + if generation.generation_info + else None, + } + ) + else: + messages, params = self._get_chat_messages([prompt], stop) + params = {**params, **kwargs} + response = await acompletion_with_retry( + self, messages=messages, run_manager=run_manager, **params + ) + choices.extend(response["choices"]) + update_token_usage(_keys, response, token_usage) + return create_llm_result(choices, prompts, token_usage, self.model_name)