Update Mosaic endpoint input/output api (#7391)

As noted in prior PRs (https://github.com/hwchase17/langchain/pull/6060,
https://github.com/hwchase17/langchain/pull/7348), the input/output
format has changed a few times as we've stabilized our inference API.
This PR updates the API to the latest stable version as indicated in our
docs: https://docs.mosaicml.com/en/latest/inference.html

The input format looks like this:

`{"inputs": [<prompt>]}
`

The output format looks like this:
`
{"outputs": [<output_text>]}
`
---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Margaret Qian 2023-08-24 22:13:17 -07:00 committed by GitHub
parent ade482c17e
commit 30151c99c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 36 additions and 59 deletions

View File

@ -63,7 +63,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"llm = MosaicML(inject_instruction_format=True, model_kwargs={\"do_sample\": False})" "llm = MosaicML(inject_instruction_format=True, model_kwargs={\"max_new_tokens\": 128})"
] ]
}, },
{ {

View File

@ -79,14 +79,8 @@ class MosaicMLInstructorEmbeddings(BaseModel, Embeddings):
raise ValueError(f"Error raised by inference endpoint: {e}") raise ValueError(f"Error raised by inference endpoint: {e}")
try: try:
parsed_response = response.json() if response.status_code == 429:
if not is_retry:
if "error" in parsed_response:
# if we get rate limited, try sleeping for 1 second
if (
not is_retry
and "rate limit exceeded" in parsed_response["error"].lower()
):
import time import time
time.sleep(self.retry_sleep) time.sleep(self.retry_sleep)
@ -94,16 +88,20 @@ class MosaicMLInstructorEmbeddings(BaseModel, Embeddings):
return self._embed(input, is_retry=True) return self._embed(input, is_retry=True)
raise ValueError( raise ValueError(
f"Error raised by inference API: {parsed_response['error']}" f"Error raised by inference API: rate limit exceeded.\nResponse: "
f"{response.text}"
) )
parsed_response = response.json()
# The inference API has changed a couple of times, so we add some handling # The inference API has changed a couple of times, so we add some handling
# to be robust to multiple response formats. # to be robust to multiple response formats.
if isinstance(parsed_response, dict): if isinstance(parsed_response, dict):
if "data" in parsed_response: output_keys = ["data", "output", "outputs"]
output_item = parsed_response["data"] for key in output_keys:
elif "output" in parsed_response: if key in parsed_response:
output_item = parsed_response["output"] output_item = parsed_response[key]
break
else: else:
raise ValueError( raise ValueError(
f"No key data or output in response: {parsed_response}" f"No key data or output in response: {parsed_response}"
@ -113,19 +111,6 @@ class MosaicMLInstructorEmbeddings(BaseModel, Embeddings):
embeddings = output_item embeddings = output_item
else: else:
embeddings = [output_item] embeddings = [output_item]
elif isinstance(parsed_response, list):
first_item = parsed_response[0]
if isinstance(first_item, list):
embeddings = parsed_response
elif isinstance(first_item, dict):
if "output" in first_item:
embeddings = [item["output"] for item in parsed_response]
else:
raise ValueError(
f"No key data or output in response: {parsed_response}"
)
else:
raise ValueError(f"Unexpected response format: {parsed_response}")
else: else:
raise ValueError(f"Unexpected response type: {parsed_response}") raise ValueError(f"Unexpected response type: {parsed_response}")

View File

@ -138,14 +138,8 @@ class MosaicML(LLM):
raise ValueError(f"Error raised by inference endpoint: {e}") raise ValueError(f"Error raised by inference endpoint: {e}")
try: try:
parsed_response = response.json() if response.status_code == 429:
if not is_retry:
if "error" in parsed_response:
# if we get rate limited, try sleeping for 1 second
if (
not is_retry
and "rate limit exceeded" in parsed_response["error"].lower()
):
import time import time
time.sleep(self.retry_sleep) time.sleep(self.retry_sleep)
@ -153,9 +147,12 @@ class MosaicML(LLM):
return self._call(prompt, stop, run_manager, is_retry=True) return self._call(prompt, stop, run_manager, is_retry=True)
raise ValueError( raise ValueError(
f"Error raised by inference API: {parsed_response['error']}" f"Error raised by inference API: rate limit exceeded.\nResponse: "
f"{response.text}"
) )
parsed_response = response.json()
# The inference API has changed a couple of times, so we add some handling # The inference API has changed a couple of times, so we add some handling
# to be robust to multiple response formats. # to be robust to multiple response formats.
if isinstance(parsed_response, dict): if isinstance(parsed_response, dict):
@ -173,23 +170,12 @@ class MosaicML(LLM):
text = output_item[0] text = output_item[0]
else: else:
text = output_item text = output_item
elif isinstance(parsed_response, list):
first_item = parsed_response[0]
if isinstance(first_item, str):
text = first_item
elif isinstance(first_item, dict):
if "output" in parsed_response:
text = first_item["output"]
else:
raise ValueError(
f"No key data or output in response: {parsed_response}"
)
else:
raise ValueError(f"Unexpected response format: {parsed_response}")
else: else:
raise ValueError(f"Unexpected response type: {parsed_response}") raise ValueError(f"Unexpected response type: {parsed_response}")
text = text[len(prompt) :] # Older versions of the API include the input in the output response
if text.startswith(prompt):
text = text[len(prompt) :]
except requests.exceptions.JSONDecodeError as e: except requests.exceptions.JSONDecodeError as e:
raise ValueError( raise ValueError(

View File

@ -34,7 +34,9 @@ def test_mosaicml_embedding_endpoint() -> None:
"""Test MosaicML embeddings with a different endpoint""" """Test MosaicML embeddings with a different endpoint"""
documents = ["foo bar"] documents = ["foo bar"]
embedding = MosaicMLInstructorEmbeddings( embedding = MosaicMLInstructorEmbeddings(
endpoint_url="https://models.hosted-on.mosaicml.hosting/instructor-xl/v1/predict" endpoint_url=(
"https://models.hosted-on.mosaicml.hosting/instructor-xl/v1/predict"
)
) )
output = embedding.embed_documents(documents) output = embedding.embed_documents(documents)
assert len(output) == 1 assert len(output) == 1

View File

@ -1,4 +1,6 @@
"""Test MosaicML API wrapper.""" """Test MosaicML API wrapper."""
import re
import pytest import pytest
from langchain.llms.mosaicml import PROMPT_FOR_GENERATION_FORMAT, MosaicML from langchain.llms.mosaicml import PROMPT_FOR_GENERATION_FORMAT, MosaicML
@ -13,7 +15,7 @@ def test_mosaicml_llm_call() -> None:
def test_mosaicml_endpoint_change() -> None: def test_mosaicml_endpoint_change() -> None:
"""Test valid call to MosaicML.""" """Test valid call to MosaicML."""
new_url = "https://models.hosted-on.mosaicml.hosting/dolly-12b/v1/predict" new_url = "https://models.hosted-on.mosaicml.hosting/mpt-30b-instruct/v1/predict"
llm = MosaicML(endpoint_url=new_url) llm = MosaicML(endpoint_url=new_url)
assert llm.endpoint_url == new_url assert llm.endpoint_url == new_url
output = llm("Say foo:") output = llm("Say foo:")
@ -34,7 +36,7 @@ def test_mosaicml_extra_kwargs() -> None:
def test_instruct_prompt() -> None: def test_instruct_prompt() -> None:
"""Test instruct prompt.""" """Test instruct prompt."""
llm = MosaicML(inject_instruction_format=True, model_kwargs={"do_sample": False}) llm = MosaicML(inject_instruction_format=True, model_kwargs={"max_new_tokens": 10})
instruction = "Repeat the word foo" instruction = "Repeat the word foo"
prompt = llm._transform_prompt(instruction) prompt = llm._transform_prompt(instruction)
expected_prompt = PROMPT_FOR_GENERATION_FORMAT.format(instruction=instruction) expected_prompt = PROMPT_FOR_GENERATION_FORMAT.format(instruction=instruction)
@ -45,7 +47,7 @@ def test_instruct_prompt() -> None:
def test_retry_logic() -> None: def test_retry_logic() -> None:
"""Tests that two queries (which would usually exceed the rate limit) works""" """Tests that two queries (which would usually exceed the rate limit) works"""
llm = MosaicML(inject_instruction_format=True, model_kwargs={"do_sample": False}) llm = MosaicML(inject_instruction_format=True, model_kwargs={"max_new_tokens": 10})
instruction = "Repeat the word foo" instruction = "Repeat the word foo"
prompt = llm._transform_prompt(instruction) prompt = llm._transform_prompt(instruction)
expected_prompt = PROMPT_FOR_GENERATION_FORMAT.format(instruction=instruction) expected_prompt = PROMPT_FOR_GENERATION_FORMAT.format(instruction=instruction)
@ -70,9 +72,11 @@ def test_short_retry_does_not_loop() -> None:
with pytest.raises( with pytest.raises(
ValueError, ValueError,
match="Error raised by inference API: Rate limit exceeded: 1 per 1 second", match=re.escape(
"Error raised by inference API: rate limit exceeded.\nResponse: You have "
"reached maximum request limit.\n"
),
): ):
output = llm(prompt) for _ in range(10):
assert isinstance(output, str) output = llm(prompt)
output = llm(prompt) assert isinstance(output, str)
assert isinstance(output, str)