mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-01 19:03:25 +00:00
Update Mosaic endpoint input/output api (#7391)
As noted in prior PRs (https://github.com/hwchase17/langchain/pull/6060, https://github.com/hwchase17/langchain/pull/7348), the input/output format has changed a few times as we've stabilized our inference API. This PR updates the API to the latest stable version as indicated in our docs: https://docs.mosaicml.com/en/latest/inference.html The input format looks like this: `{"inputs": [<prompt>]} ` The output format looks like this: ` {"outputs": [<output_text>]} ` --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
ade482c17e
commit
30151c99c7
@ -63,7 +63,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"llm = MosaicML(inject_instruction_format=True, model_kwargs={\"do_sample\": False})"
|
"llm = MosaicML(inject_instruction_format=True, model_kwargs={\"max_new_tokens\": 128})"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -79,14 +79,8 @@ class MosaicMLInstructorEmbeddings(BaseModel, Embeddings):
|
|||||||
raise ValueError(f"Error raised by inference endpoint: {e}")
|
raise ValueError(f"Error raised by inference endpoint: {e}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
parsed_response = response.json()
|
if response.status_code == 429:
|
||||||
|
if not is_retry:
|
||||||
if "error" in parsed_response:
|
|
||||||
# if we get rate limited, try sleeping for 1 second
|
|
||||||
if (
|
|
||||||
not is_retry
|
|
||||||
and "rate limit exceeded" in parsed_response["error"].lower()
|
|
||||||
):
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
time.sleep(self.retry_sleep)
|
time.sleep(self.retry_sleep)
|
||||||
@ -94,16 +88,20 @@ class MosaicMLInstructorEmbeddings(BaseModel, Embeddings):
|
|||||||
return self._embed(input, is_retry=True)
|
return self._embed(input, is_retry=True)
|
||||||
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Error raised by inference API: {parsed_response['error']}"
|
f"Error raised by inference API: rate limit exceeded.\nResponse: "
|
||||||
|
f"{response.text}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parsed_response = response.json()
|
||||||
|
|
||||||
# The inference API has changed a couple of times, so we add some handling
|
# The inference API has changed a couple of times, so we add some handling
|
||||||
# to be robust to multiple response formats.
|
# to be robust to multiple response formats.
|
||||||
if isinstance(parsed_response, dict):
|
if isinstance(parsed_response, dict):
|
||||||
if "data" in parsed_response:
|
output_keys = ["data", "output", "outputs"]
|
||||||
output_item = parsed_response["data"]
|
for key in output_keys:
|
||||||
elif "output" in parsed_response:
|
if key in parsed_response:
|
||||||
output_item = parsed_response["output"]
|
output_item = parsed_response[key]
|
||||||
|
break
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No key data or output in response: {parsed_response}"
|
f"No key data or output in response: {parsed_response}"
|
||||||
@ -113,19 +111,6 @@ class MosaicMLInstructorEmbeddings(BaseModel, Embeddings):
|
|||||||
embeddings = output_item
|
embeddings = output_item
|
||||||
else:
|
else:
|
||||||
embeddings = [output_item]
|
embeddings = [output_item]
|
||||||
elif isinstance(parsed_response, list):
|
|
||||||
first_item = parsed_response[0]
|
|
||||||
if isinstance(first_item, list):
|
|
||||||
embeddings = parsed_response
|
|
||||||
elif isinstance(first_item, dict):
|
|
||||||
if "output" in first_item:
|
|
||||||
embeddings = [item["output"] for item in parsed_response]
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"No key data or output in response: {parsed_response}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unexpected response format: {parsed_response}")
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unexpected response type: {parsed_response}")
|
raise ValueError(f"Unexpected response type: {parsed_response}")
|
||||||
|
|
||||||
|
@ -138,14 +138,8 @@ class MosaicML(LLM):
|
|||||||
raise ValueError(f"Error raised by inference endpoint: {e}")
|
raise ValueError(f"Error raised by inference endpoint: {e}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
parsed_response = response.json()
|
if response.status_code == 429:
|
||||||
|
if not is_retry:
|
||||||
if "error" in parsed_response:
|
|
||||||
# if we get rate limited, try sleeping for 1 second
|
|
||||||
if (
|
|
||||||
not is_retry
|
|
||||||
and "rate limit exceeded" in parsed_response["error"].lower()
|
|
||||||
):
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
time.sleep(self.retry_sleep)
|
time.sleep(self.retry_sleep)
|
||||||
@ -153,9 +147,12 @@ class MosaicML(LLM):
|
|||||||
return self._call(prompt, stop, run_manager, is_retry=True)
|
return self._call(prompt, stop, run_manager, is_retry=True)
|
||||||
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Error raised by inference API: {parsed_response['error']}"
|
f"Error raised by inference API: rate limit exceeded.\nResponse: "
|
||||||
|
f"{response.text}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parsed_response = response.json()
|
||||||
|
|
||||||
# The inference API has changed a couple of times, so we add some handling
|
# The inference API has changed a couple of times, so we add some handling
|
||||||
# to be robust to multiple response formats.
|
# to be robust to multiple response formats.
|
||||||
if isinstance(parsed_response, dict):
|
if isinstance(parsed_response, dict):
|
||||||
@ -173,23 +170,12 @@ class MosaicML(LLM):
|
|||||||
text = output_item[0]
|
text = output_item[0]
|
||||||
else:
|
else:
|
||||||
text = output_item
|
text = output_item
|
||||||
elif isinstance(parsed_response, list):
|
|
||||||
first_item = parsed_response[0]
|
|
||||||
if isinstance(first_item, str):
|
|
||||||
text = first_item
|
|
||||||
elif isinstance(first_item, dict):
|
|
||||||
if "output" in parsed_response:
|
|
||||||
text = first_item["output"]
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"No key data or output in response: {parsed_response}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unexpected response format: {parsed_response}")
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unexpected response type: {parsed_response}")
|
raise ValueError(f"Unexpected response type: {parsed_response}")
|
||||||
|
|
||||||
text = text[len(prompt) :]
|
# Older versions of the API include the input in the output response
|
||||||
|
if text.startswith(prompt):
|
||||||
|
text = text[len(prompt) :]
|
||||||
|
|
||||||
except requests.exceptions.JSONDecodeError as e:
|
except requests.exceptions.JSONDecodeError as e:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -34,7 +34,9 @@ def test_mosaicml_embedding_endpoint() -> None:
|
|||||||
"""Test MosaicML embeddings with a different endpoint"""
|
"""Test MosaicML embeddings with a different endpoint"""
|
||||||
documents = ["foo bar"]
|
documents = ["foo bar"]
|
||||||
embedding = MosaicMLInstructorEmbeddings(
|
embedding = MosaicMLInstructorEmbeddings(
|
||||||
endpoint_url="https://models.hosted-on.mosaicml.hosting/instructor-xl/v1/predict"
|
endpoint_url=(
|
||||||
|
"https://models.hosted-on.mosaicml.hosting/instructor-xl/v1/predict"
|
||||||
|
)
|
||||||
)
|
)
|
||||||
output = embedding.embed_documents(documents)
|
output = embedding.embed_documents(documents)
|
||||||
assert len(output) == 1
|
assert len(output) == 1
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
"""Test MosaicML API wrapper."""
|
"""Test MosaicML API wrapper."""
|
||||||
|
import re
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from langchain.llms.mosaicml import PROMPT_FOR_GENERATION_FORMAT, MosaicML
|
from langchain.llms.mosaicml import PROMPT_FOR_GENERATION_FORMAT, MosaicML
|
||||||
@ -13,7 +15,7 @@ def test_mosaicml_llm_call() -> None:
|
|||||||
|
|
||||||
def test_mosaicml_endpoint_change() -> None:
|
def test_mosaicml_endpoint_change() -> None:
|
||||||
"""Test valid call to MosaicML."""
|
"""Test valid call to MosaicML."""
|
||||||
new_url = "https://models.hosted-on.mosaicml.hosting/dolly-12b/v1/predict"
|
new_url = "https://models.hosted-on.mosaicml.hosting/mpt-30b-instruct/v1/predict"
|
||||||
llm = MosaicML(endpoint_url=new_url)
|
llm = MosaicML(endpoint_url=new_url)
|
||||||
assert llm.endpoint_url == new_url
|
assert llm.endpoint_url == new_url
|
||||||
output = llm("Say foo:")
|
output = llm("Say foo:")
|
||||||
@ -34,7 +36,7 @@ def test_mosaicml_extra_kwargs() -> None:
|
|||||||
|
|
||||||
def test_instruct_prompt() -> None:
|
def test_instruct_prompt() -> None:
|
||||||
"""Test instruct prompt."""
|
"""Test instruct prompt."""
|
||||||
llm = MosaicML(inject_instruction_format=True, model_kwargs={"do_sample": False})
|
llm = MosaicML(inject_instruction_format=True, model_kwargs={"max_new_tokens": 10})
|
||||||
instruction = "Repeat the word foo"
|
instruction = "Repeat the word foo"
|
||||||
prompt = llm._transform_prompt(instruction)
|
prompt = llm._transform_prompt(instruction)
|
||||||
expected_prompt = PROMPT_FOR_GENERATION_FORMAT.format(instruction=instruction)
|
expected_prompt = PROMPT_FOR_GENERATION_FORMAT.format(instruction=instruction)
|
||||||
@ -45,7 +47,7 @@ def test_instruct_prompt() -> None:
|
|||||||
|
|
||||||
def test_retry_logic() -> None:
|
def test_retry_logic() -> None:
|
||||||
"""Tests that two queries (which would usually exceed the rate limit) works"""
|
"""Tests that two queries (which would usually exceed the rate limit) works"""
|
||||||
llm = MosaicML(inject_instruction_format=True, model_kwargs={"do_sample": False})
|
llm = MosaicML(inject_instruction_format=True, model_kwargs={"max_new_tokens": 10})
|
||||||
instruction = "Repeat the word foo"
|
instruction = "Repeat the word foo"
|
||||||
prompt = llm._transform_prompt(instruction)
|
prompt = llm._transform_prompt(instruction)
|
||||||
expected_prompt = PROMPT_FOR_GENERATION_FORMAT.format(instruction=instruction)
|
expected_prompt = PROMPT_FOR_GENERATION_FORMAT.format(instruction=instruction)
|
||||||
@ -70,9 +72,11 @@ def test_short_retry_does_not_loop() -> None:
|
|||||||
|
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
ValueError,
|
ValueError,
|
||||||
match="Error raised by inference API: Rate limit exceeded: 1 per 1 second",
|
match=re.escape(
|
||||||
|
"Error raised by inference API: rate limit exceeded.\nResponse: You have "
|
||||||
|
"reached maximum request limit.\n"
|
||||||
|
),
|
||||||
):
|
):
|
||||||
output = llm(prompt)
|
for _ in range(10):
|
||||||
assert isinstance(output, str)
|
output = llm(prompt)
|
||||||
output = llm(prompt)
|
assert isinstance(output, str)
|
||||||
assert isinstance(output, str)
|
|
||||||
|
Loading…
Reference in New Issue
Block a user