From 30151c99c72903320d7d865a9d5bf49237a4f267 Mon Sep 17 00:00:00 2001 From: Margaret Qian Date: Thu, 24 Aug 2023 22:13:17 -0700 Subject: [PATCH] Update Mosaic endpoint input/output api (#7391) As noted in prior PRs (https://github.com/hwchase17/langchain/pull/6060, https://github.com/hwchase17/langchain/pull/7348), the input/output format has changed a few times as we've stabilized our inference API. This PR updates the API to the latest stable version as indicated in our docs: https://docs.mosaicml.com/en/latest/inference.html The input format looks like this: `{"inputs": []} ` The output format looks like this: ` {"outputs": []} ` --------- Co-authored-by: Bagatur --- docs/extras/integrations/llms/mosaicml.ipynb | 2 +- .../langchain/embeddings/mosaicml.py | 37 ++++++------------- libs/langchain/langchain/llms/mosaicml.py | 32 +++++----------- .../embeddings/test_mosaicml.py | 4 +- .../integration_tests/llms/test_mosaicml.py | 20 ++++++---- 5 files changed, 36 insertions(+), 59 deletions(-) diff --git a/docs/extras/integrations/llms/mosaicml.ipynb b/docs/extras/integrations/llms/mosaicml.ipynb index 596ee2d7b5f..cd9be156fcb 100644 --- a/docs/extras/integrations/llms/mosaicml.ipynb +++ b/docs/extras/integrations/llms/mosaicml.ipynb @@ -63,7 +63,7 @@ "metadata": {}, "outputs": [], "source": [ - "llm = MosaicML(inject_instruction_format=True, model_kwargs={\"do_sample\": False})" + "llm = MosaicML(inject_instruction_format=True, model_kwargs={\"max_new_tokens\": 128})" ] }, { diff --git a/libs/langchain/langchain/embeddings/mosaicml.py b/libs/langchain/langchain/embeddings/mosaicml.py index 8346bf7cfc1..6a3c3e11c04 100644 --- a/libs/langchain/langchain/embeddings/mosaicml.py +++ b/libs/langchain/langchain/embeddings/mosaicml.py @@ -79,14 +79,8 @@ class MosaicMLInstructorEmbeddings(BaseModel, Embeddings): raise ValueError(f"Error raised by inference endpoint: {e}") try: - parsed_response = response.json() - - if "error" in parsed_response: - # if we get rate limited, try sleeping for 1 second - if ( - not is_retry - and "rate limit exceeded" in parsed_response["error"].lower() - ): + if response.status_code == 429: + if not is_retry: import time time.sleep(self.retry_sleep) @@ -94,16 +88,20 @@ class MosaicMLInstructorEmbeddings(BaseModel, Embeddings): return self._embed(input, is_retry=True) raise ValueError( - f"Error raised by inference API: {parsed_response['error']}" + f"Error raised by inference API: rate limit exceeded.\nResponse: " + f"{response.text}" ) + parsed_response = response.json() + # The inference API has changed a couple of times, so we add some handling # to be robust to multiple response formats. if isinstance(parsed_response, dict): - if "data" in parsed_response: - output_item = parsed_response["data"] - elif "output" in parsed_response: - output_item = parsed_response["output"] + output_keys = ["data", "output", "outputs"] + for key in output_keys: + if key in parsed_response: + output_item = parsed_response[key] + break else: raise ValueError( f"No key data or output in response: {parsed_response}" @@ -113,19 +111,6 @@ class MosaicMLInstructorEmbeddings(BaseModel, Embeddings): embeddings = output_item else: embeddings = [output_item] - elif isinstance(parsed_response, list): - first_item = parsed_response[0] - if isinstance(first_item, list): - embeddings = parsed_response - elif isinstance(first_item, dict): - if "output" in first_item: - embeddings = [item["output"] for item in parsed_response] - else: - raise ValueError( - f"No key data or output in response: {parsed_response}" - ) - else: - raise ValueError(f"Unexpected response format: {parsed_response}") else: raise ValueError(f"Unexpected response type: {parsed_response}") diff --git a/libs/langchain/langchain/llms/mosaicml.py b/libs/langchain/langchain/llms/mosaicml.py index 780e7a8b4da..718466178b3 100644 --- a/libs/langchain/langchain/llms/mosaicml.py +++ b/libs/langchain/langchain/llms/mosaicml.py @@ -138,14 +138,8 @@ class MosaicML(LLM): raise ValueError(f"Error raised by inference endpoint: {e}") try: - parsed_response = response.json() - - if "error" in parsed_response: - # if we get rate limited, try sleeping for 1 second - if ( - not is_retry - and "rate limit exceeded" in parsed_response["error"].lower() - ): + if response.status_code == 429: + if not is_retry: import time time.sleep(self.retry_sleep) @@ -153,9 +147,12 @@ class MosaicML(LLM): return self._call(prompt, stop, run_manager, is_retry=True) raise ValueError( - f"Error raised by inference API: {parsed_response['error']}" + f"Error raised by inference API: rate limit exceeded.\nResponse: " + f"{response.text}" ) + parsed_response = response.json() + # The inference API has changed a couple of times, so we add some handling # to be robust to multiple response formats. if isinstance(parsed_response, dict): @@ -173,23 +170,12 @@ class MosaicML(LLM): text = output_item[0] else: text = output_item - elif isinstance(parsed_response, list): - first_item = parsed_response[0] - if isinstance(first_item, str): - text = first_item - elif isinstance(first_item, dict): - if "output" in parsed_response: - text = first_item["output"] - else: - raise ValueError( - f"No key data or output in response: {parsed_response}" - ) - else: - raise ValueError(f"Unexpected response format: {parsed_response}") else: raise ValueError(f"Unexpected response type: {parsed_response}") - text = text[len(prompt) :] + # Older versions of the API include the input in the output response + if text.startswith(prompt): + text = text[len(prompt) :] except requests.exceptions.JSONDecodeError as e: raise ValueError( diff --git a/libs/langchain/tests/integration_tests/embeddings/test_mosaicml.py b/libs/langchain/tests/integration_tests/embeddings/test_mosaicml.py index a04c6f2c1fa..ae0bec3ddac 100644 --- a/libs/langchain/tests/integration_tests/embeddings/test_mosaicml.py +++ b/libs/langchain/tests/integration_tests/embeddings/test_mosaicml.py @@ -34,7 +34,9 @@ def test_mosaicml_embedding_endpoint() -> None: """Test MosaicML embeddings with a different endpoint""" documents = ["foo bar"] embedding = MosaicMLInstructorEmbeddings( - endpoint_url="https://models.hosted-on.mosaicml.hosting/instructor-xl/v1/predict" + endpoint_url=( + "https://models.hosted-on.mosaicml.hosting/instructor-xl/v1/predict" + ) ) output = embedding.embed_documents(documents) assert len(output) == 1 diff --git a/libs/langchain/tests/integration_tests/llms/test_mosaicml.py b/libs/langchain/tests/integration_tests/llms/test_mosaicml.py index 2b532ab6688..e15fce0fead 100644 --- a/libs/langchain/tests/integration_tests/llms/test_mosaicml.py +++ b/libs/langchain/tests/integration_tests/llms/test_mosaicml.py @@ -1,4 +1,6 @@ """Test MosaicML API wrapper.""" +import re + import pytest from langchain.llms.mosaicml import PROMPT_FOR_GENERATION_FORMAT, MosaicML @@ -13,7 +15,7 @@ def test_mosaicml_llm_call() -> None: def test_mosaicml_endpoint_change() -> None: """Test valid call to MosaicML.""" - new_url = "https://models.hosted-on.mosaicml.hosting/dolly-12b/v1/predict" + new_url = "https://models.hosted-on.mosaicml.hosting/mpt-30b-instruct/v1/predict" llm = MosaicML(endpoint_url=new_url) assert llm.endpoint_url == new_url output = llm("Say foo:") @@ -34,7 +36,7 @@ def test_mosaicml_extra_kwargs() -> None: def test_instruct_prompt() -> None: """Test instruct prompt.""" - llm = MosaicML(inject_instruction_format=True, model_kwargs={"do_sample": False}) + llm = MosaicML(inject_instruction_format=True, model_kwargs={"max_new_tokens": 10}) instruction = "Repeat the word foo" prompt = llm._transform_prompt(instruction) expected_prompt = PROMPT_FOR_GENERATION_FORMAT.format(instruction=instruction) @@ -45,7 +47,7 @@ def test_instruct_prompt() -> None: def test_retry_logic() -> None: """Tests that two queries (which would usually exceed the rate limit) works""" - llm = MosaicML(inject_instruction_format=True, model_kwargs={"do_sample": False}) + llm = MosaicML(inject_instruction_format=True, model_kwargs={"max_new_tokens": 10}) instruction = "Repeat the word foo" prompt = llm._transform_prompt(instruction) expected_prompt = PROMPT_FOR_GENERATION_FORMAT.format(instruction=instruction) @@ -70,9 +72,11 @@ def test_short_retry_does_not_loop() -> None: with pytest.raises( ValueError, - match="Error raised by inference API: Rate limit exceeded: 1 per 1 second", + match=re.escape( + "Error raised by inference API: rate limit exceeded.\nResponse: You have " + "reached maximum request limit.\n" + ), ): - output = llm(prompt) - assert isinstance(output, str) - output = llm(prompt) - assert isinstance(output, str) + for _ in range(10): + output = llm(prompt) + assert isinstance(output, str)