mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-01 10:54:15 +00:00
Update Mosaic endpoint input/output api (#7391)
As noted in prior PRs (https://github.com/hwchase17/langchain/pull/6060, https://github.com/hwchase17/langchain/pull/7348), the input/output format has changed a few times as we've stabilized our inference API. This PR updates the API to the latest stable version as indicated in our docs: https://docs.mosaicml.com/en/latest/inference.html The input format looks like this: `{"inputs": [<prompt>]} ` The output format looks like this: ` {"outputs": [<output_text>]} ` --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
ade482c17e
commit
30151c99c7
@ -63,7 +63,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm = MosaicML(inject_instruction_format=True, model_kwargs={\"do_sample\": False})"
|
||||
"llm = MosaicML(inject_instruction_format=True, model_kwargs={\"max_new_tokens\": 128})"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -79,14 +79,8 @@ class MosaicMLInstructorEmbeddings(BaseModel, Embeddings):
|
||||
raise ValueError(f"Error raised by inference endpoint: {e}")
|
||||
|
||||
try:
|
||||
parsed_response = response.json()
|
||||
|
||||
if "error" in parsed_response:
|
||||
# if we get rate limited, try sleeping for 1 second
|
||||
if (
|
||||
not is_retry
|
||||
and "rate limit exceeded" in parsed_response["error"].lower()
|
||||
):
|
||||
if response.status_code == 429:
|
||||
if not is_retry:
|
||||
import time
|
||||
|
||||
time.sleep(self.retry_sleep)
|
||||
@ -94,16 +88,20 @@ class MosaicMLInstructorEmbeddings(BaseModel, Embeddings):
|
||||
return self._embed(input, is_retry=True)
|
||||
|
||||
raise ValueError(
|
||||
f"Error raised by inference API: {parsed_response['error']}"
|
||||
f"Error raised by inference API: rate limit exceeded.\nResponse: "
|
||||
f"{response.text}"
|
||||
)
|
||||
|
||||
parsed_response = response.json()
|
||||
|
||||
# The inference API has changed a couple of times, so we add some handling
|
||||
# to be robust to multiple response formats.
|
||||
if isinstance(parsed_response, dict):
|
||||
if "data" in parsed_response:
|
||||
output_item = parsed_response["data"]
|
||||
elif "output" in parsed_response:
|
||||
output_item = parsed_response["output"]
|
||||
output_keys = ["data", "output", "outputs"]
|
||||
for key in output_keys:
|
||||
if key in parsed_response:
|
||||
output_item = parsed_response[key]
|
||||
break
|
||||
else:
|
||||
raise ValueError(
|
||||
f"No key data or output in response: {parsed_response}"
|
||||
@ -113,19 +111,6 @@ class MosaicMLInstructorEmbeddings(BaseModel, Embeddings):
|
||||
embeddings = output_item
|
||||
else:
|
||||
embeddings = [output_item]
|
||||
elif isinstance(parsed_response, list):
|
||||
first_item = parsed_response[0]
|
||||
if isinstance(first_item, list):
|
||||
embeddings = parsed_response
|
||||
elif isinstance(first_item, dict):
|
||||
if "output" in first_item:
|
||||
embeddings = [item["output"] for item in parsed_response]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"No key data or output in response: {parsed_response}"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unexpected response format: {parsed_response}")
|
||||
else:
|
||||
raise ValueError(f"Unexpected response type: {parsed_response}")
|
||||
|
||||
|
@ -138,14 +138,8 @@ class MosaicML(LLM):
|
||||
raise ValueError(f"Error raised by inference endpoint: {e}")
|
||||
|
||||
try:
|
||||
parsed_response = response.json()
|
||||
|
||||
if "error" in parsed_response:
|
||||
# if we get rate limited, try sleeping for 1 second
|
||||
if (
|
||||
not is_retry
|
||||
and "rate limit exceeded" in parsed_response["error"].lower()
|
||||
):
|
||||
if response.status_code == 429:
|
||||
if not is_retry:
|
||||
import time
|
||||
|
||||
time.sleep(self.retry_sleep)
|
||||
@ -153,9 +147,12 @@ class MosaicML(LLM):
|
||||
return self._call(prompt, stop, run_manager, is_retry=True)
|
||||
|
||||
raise ValueError(
|
||||
f"Error raised by inference API: {parsed_response['error']}"
|
||||
f"Error raised by inference API: rate limit exceeded.\nResponse: "
|
||||
f"{response.text}"
|
||||
)
|
||||
|
||||
parsed_response = response.json()
|
||||
|
||||
# The inference API has changed a couple of times, so we add some handling
|
||||
# to be robust to multiple response formats.
|
||||
if isinstance(parsed_response, dict):
|
||||
@ -173,22 +170,11 @@ class MosaicML(LLM):
|
||||
text = output_item[0]
|
||||
else:
|
||||
text = output_item
|
||||
elif isinstance(parsed_response, list):
|
||||
first_item = parsed_response[0]
|
||||
if isinstance(first_item, str):
|
||||
text = first_item
|
||||
elif isinstance(first_item, dict):
|
||||
if "output" in parsed_response:
|
||||
text = first_item["output"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"No key data or output in response: {parsed_response}"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unexpected response format: {parsed_response}")
|
||||
else:
|
||||
raise ValueError(f"Unexpected response type: {parsed_response}")
|
||||
|
||||
# Older versions of the API include the input in the output response
|
||||
if text.startswith(prompt):
|
||||
text = text[len(prompt) :]
|
||||
|
||||
except requests.exceptions.JSONDecodeError as e:
|
||||
|
@ -34,7 +34,9 @@ def test_mosaicml_embedding_endpoint() -> None:
|
||||
"""Test MosaicML embeddings with a different endpoint"""
|
||||
documents = ["foo bar"]
|
||||
embedding = MosaicMLInstructorEmbeddings(
|
||||
endpoint_url="https://models.hosted-on.mosaicml.hosting/instructor-xl/v1/predict"
|
||||
endpoint_url=(
|
||||
"https://models.hosted-on.mosaicml.hosting/instructor-xl/v1/predict"
|
||||
)
|
||||
)
|
||||
output = embedding.embed_documents(documents)
|
||||
assert len(output) == 1
|
||||
|
@ -1,4 +1,6 @@
|
||||
"""Test MosaicML API wrapper."""
|
||||
import re
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.llms.mosaicml import PROMPT_FOR_GENERATION_FORMAT, MosaicML
|
||||
@ -13,7 +15,7 @@ def test_mosaicml_llm_call() -> None:
|
||||
|
||||
def test_mosaicml_endpoint_change() -> None:
|
||||
"""Test valid call to MosaicML."""
|
||||
new_url = "https://models.hosted-on.mosaicml.hosting/dolly-12b/v1/predict"
|
||||
new_url = "https://models.hosted-on.mosaicml.hosting/mpt-30b-instruct/v1/predict"
|
||||
llm = MosaicML(endpoint_url=new_url)
|
||||
assert llm.endpoint_url == new_url
|
||||
output = llm("Say foo:")
|
||||
@ -34,7 +36,7 @@ def test_mosaicml_extra_kwargs() -> None:
|
||||
|
||||
def test_instruct_prompt() -> None:
|
||||
"""Test instruct prompt."""
|
||||
llm = MosaicML(inject_instruction_format=True, model_kwargs={"do_sample": False})
|
||||
llm = MosaicML(inject_instruction_format=True, model_kwargs={"max_new_tokens": 10})
|
||||
instruction = "Repeat the word foo"
|
||||
prompt = llm._transform_prompt(instruction)
|
||||
expected_prompt = PROMPT_FOR_GENERATION_FORMAT.format(instruction=instruction)
|
||||
@ -45,7 +47,7 @@ def test_instruct_prompt() -> None:
|
||||
|
||||
def test_retry_logic() -> None:
|
||||
"""Tests that two queries (which would usually exceed the rate limit) works"""
|
||||
llm = MosaicML(inject_instruction_format=True, model_kwargs={"do_sample": False})
|
||||
llm = MosaicML(inject_instruction_format=True, model_kwargs={"max_new_tokens": 10})
|
||||
instruction = "Repeat the word foo"
|
||||
prompt = llm._transform_prompt(instruction)
|
||||
expected_prompt = PROMPT_FOR_GENERATION_FORMAT.format(instruction=instruction)
|
||||
@ -70,9 +72,11 @@ def test_short_retry_does_not_loop() -> None:
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Error raised by inference API: Rate limit exceeded: 1 per 1 second",
|
||||
match=re.escape(
|
||||
"Error raised by inference API: rate limit exceeded.\nResponse: You have "
|
||||
"reached maximum request limit.\n"
|
||||
),
|
||||
):
|
||||
output = llm(prompt)
|
||||
assert isinstance(output, str)
|
||||
for _ in range(10):
|
||||
output = llm(prompt)
|
||||
assert isinstance(output, str)
|
||||
|
Loading…
Reference in New Issue
Block a user