Update Mosaic endpoint input/output api (#7391)

As noted in prior PRs (https://github.com/hwchase17/langchain/pull/6060,
https://github.com/hwchase17/langchain/pull/7348), the input/output
format has changed a few times as we've stabilized our inference API.
This PR updates the API to the latest stable version as indicated in our
docs: https://docs.mosaicml.com/en/latest/inference.html

The input format looks like this:

`{"inputs": [<prompt>]}
`

The output format looks like this:
`
{"outputs": [<output_text>]}
`
---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Margaret Qian 2023-08-24 22:13:17 -07:00 committed by GitHub
parent ade482c17e
commit 30151c99c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 36 additions and 59 deletions

View File

@ -63,7 +63,7 @@
"metadata": {},
"outputs": [],
"source": [
"llm = MosaicML(inject_instruction_format=True, model_kwargs={\"do_sample\": False})"
"llm = MosaicML(inject_instruction_format=True, model_kwargs={\"max_new_tokens\": 128})"
]
},
{

View File

@ -79,14 +79,8 @@ class MosaicMLInstructorEmbeddings(BaseModel, Embeddings):
raise ValueError(f"Error raised by inference endpoint: {e}")
try:
parsed_response = response.json()
if "error" in parsed_response:
# if we get rate limited, try sleeping for 1 second
if (
not is_retry
and "rate limit exceeded" in parsed_response["error"].lower()
):
if response.status_code == 429:
if not is_retry:
import time
time.sleep(self.retry_sleep)
@ -94,16 +88,20 @@ class MosaicMLInstructorEmbeddings(BaseModel, Embeddings):
return self._embed(input, is_retry=True)
raise ValueError(
f"Error raised by inference API: {parsed_response['error']}"
f"Error raised by inference API: rate limit exceeded.\nResponse: "
f"{response.text}"
)
parsed_response = response.json()
# The inference API has changed a couple of times, so we add some handling
# to be robust to multiple response formats.
if isinstance(parsed_response, dict):
if "data" in parsed_response:
output_item = parsed_response["data"]
elif "output" in parsed_response:
output_item = parsed_response["output"]
output_keys = ["data", "output", "outputs"]
for key in output_keys:
if key in parsed_response:
output_item = parsed_response[key]
break
else:
raise ValueError(
f"No key data or output in response: {parsed_response}"
@ -113,19 +111,6 @@ class MosaicMLInstructorEmbeddings(BaseModel, Embeddings):
embeddings = output_item
else:
embeddings = [output_item]
elif isinstance(parsed_response, list):
first_item = parsed_response[0]
if isinstance(first_item, list):
embeddings = parsed_response
elif isinstance(first_item, dict):
if "output" in first_item:
embeddings = [item["output"] for item in parsed_response]
else:
raise ValueError(
f"No key data or output in response: {parsed_response}"
)
else:
raise ValueError(f"Unexpected response format: {parsed_response}")
else:
raise ValueError(f"Unexpected response type: {parsed_response}")

View File

@ -138,14 +138,8 @@ class MosaicML(LLM):
raise ValueError(f"Error raised by inference endpoint: {e}")
try:
parsed_response = response.json()
if "error" in parsed_response:
# if we get rate limited, try sleeping for 1 second
if (
not is_retry
and "rate limit exceeded" in parsed_response["error"].lower()
):
if response.status_code == 429:
if not is_retry:
import time
time.sleep(self.retry_sleep)
@ -153,9 +147,12 @@ class MosaicML(LLM):
return self._call(prompt, stop, run_manager, is_retry=True)
raise ValueError(
f"Error raised by inference API: {parsed_response['error']}"
f"Error raised by inference API: rate limit exceeded.\nResponse: "
f"{response.text}"
)
parsed_response = response.json()
# The inference API has changed a couple of times, so we add some handling
# to be robust to multiple response formats.
if isinstance(parsed_response, dict):
@ -173,22 +170,11 @@ class MosaicML(LLM):
text = output_item[0]
else:
text = output_item
elif isinstance(parsed_response, list):
first_item = parsed_response[0]
if isinstance(first_item, str):
text = first_item
elif isinstance(first_item, dict):
if "output" in parsed_response:
text = first_item["output"]
else:
raise ValueError(
f"No key data or output in response: {parsed_response}"
)
else:
raise ValueError(f"Unexpected response format: {parsed_response}")
else:
raise ValueError(f"Unexpected response type: {parsed_response}")
# Older versions of the API include the input in the output response
if text.startswith(prompt):
text = text[len(prompt) :]
except requests.exceptions.JSONDecodeError as e:

View File

@ -34,7 +34,9 @@ def test_mosaicml_embedding_endpoint() -> None:
"""Test MosaicML embeddings with a different endpoint"""
documents = ["foo bar"]
embedding = MosaicMLInstructorEmbeddings(
endpoint_url="https://models.hosted-on.mosaicml.hosting/instructor-xl/v1/predict"
endpoint_url=(
"https://models.hosted-on.mosaicml.hosting/instructor-xl/v1/predict"
)
)
output = embedding.embed_documents(documents)
assert len(output) == 1

View File

@ -1,4 +1,6 @@
"""Test MosaicML API wrapper."""
import re
import pytest
from langchain.llms.mosaicml import PROMPT_FOR_GENERATION_FORMAT, MosaicML
@ -13,7 +15,7 @@ def test_mosaicml_llm_call() -> None:
def test_mosaicml_endpoint_change() -> None:
"""Test valid call to MosaicML."""
new_url = "https://models.hosted-on.mosaicml.hosting/dolly-12b/v1/predict"
new_url = "https://models.hosted-on.mosaicml.hosting/mpt-30b-instruct/v1/predict"
llm = MosaicML(endpoint_url=new_url)
assert llm.endpoint_url == new_url
output = llm("Say foo:")
@ -34,7 +36,7 @@ def test_mosaicml_extra_kwargs() -> None:
def test_instruct_prompt() -> None:
"""Test instruct prompt."""
llm = MosaicML(inject_instruction_format=True, model_kwargs={"do_sample": False})
llm = MosaicML(inject_instruction_format=True, model_kwargs={"max_new_tokens": 10})
instruction = "Repeat the word foo"
prompt = llm._transform_prompt(instruction)
expected_prompt = PROMPT_FOR_GENERATION_FORMAT.format(instruction=instruction)
@ -45,7 +47,7 @@ def test_instruct_prompt() -> None:
def test_retry_logic() -> None:
"""Tests that two queries (which would usually exceed the rate limit) works"""
llm = MosaicML(inject_instruction_format=True, model_kwargs={"do_sample": False})
llm = MosaicML(inject_instruction_format=True, model_kwargs={"max_new_tokens": 10})
instruction = "Repeat the word foo"
prompt = llm._transform_prompt(instruction)
expected_prompt = PROMPT_FOR_GENERATION_FORMAT.format(instruction=instruction)
@ -70,9 +72,11 @@ def test_short_retry_does_not_loop() -> None:
with pytest.raises(
ValueError,
match="Error raised by inference API: Rate limit exceeded: 1 per 1 second",
match=re.escape(
"Error raised by inference API: rate limit exceeded.\nResponse: You have "
"reached maximum request limit.\n"
),
):
output = llm(prompt)
assert isinstance(output, str)
for _ in range(10):
output = llm(prompt)
assert isinstance(output, str)