mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-28 17:38:36 +00:00
Speed up openai tests (#9943)
Saves ~8-10 seconds from total unit tests times --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
bcc3463ff4
commit
5cce6529a4
@ -4,7 +4,9 @@ from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from tenacity import wait_none
|
||||
|
||||
from langchain.llms import base
|
||||
from langchain.llms.openai import OpenAI
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import (
|
||||
FakeAsyncCallbackHandler,
|
||||
@ -55,6 +57,16 @@ def mock_completion() -> dict:
|
||||
}
|
||||
|
||||
|
||||
def _patched_retry(*args: Any, **kwargs: Any) -> Any:
|
||||
"""Patched retry for unit tests that does not wait."""
|
||||
from tenacity import retry
|
||||
|
||||
assert "wait" in kwargs
|
||||
kwargs["wait"] = wait_none()
|
||||
r = retry(*args, **kwargs)
|
||||
return r
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_openai_retries(mock_completion: dict) -> None:
|
||||
llm = OpenAI()
|
||||
@ -73,13 +85,16 @@ def test_openai_retries(mock_completion: dict) -> None:
|
||||
|
||||
mock_client.create = raise_once
|
||||
callback_handler = FakeCallbackHandler()
|
||||
with patch.object(
|
||||
llm,
|
||||
"client",
|
||||
mock_client,
|
||||
):
|
||||
res = llm.predict("bar", callbacks=[callback_handler])
|
||||
assert res == "Bar Baz"
|
||||
|
||||
# Patch the retry to avoid waiting during a unit test
|
||||
with patch.object(base, "retry", _patched_retry):
|
||||
with patch.object(
|
||||
llm,
|
||||
"client",
|
||||
mock_client,
|
||||
):
|
||||
res = llm.predict("bar", callbacks=[callback_handler])
|
||||
assert res == "Bar Baz"
|
||||
assert completed
|
||||
assert raised
|
||||
assert callback_handler.retries == 1
|
||||
@ -105,13 +120,15 @@ async def test_openai_async_retries(mock_completion: dict) -> None:
|
||||
|
||||
mock_client.acreate = araise_once
|
||||
callback_handler = FakeAsyncCallbackHandler()
|
||||
with patch.object(
|
||||
llm,
|
||||
"client",
|
||||
mock_client,
|
||||
):
|
||||
res = await llm.apredict("bar", callbacks=[callback_handler])
|
||||
assert res == "Bar Baz"
|
||||
# Patch the retry to avoid waiting during a unit test
|
||||
with patch.object(base, "retry", _patched_retry):
|
||||
with patch.object(
|
||||
llm,
|
||||
"client",
|
||||
mock_client,
|
||||
):
|
||||
res = await llm.apredict("bar", callbacks=[callback_handler])
|
||||
assert res == "Bar Baz"
|
||||
assert completed
|
||||
assert raised
|
||||
assert callback_handler.retries == 1
|
||||
|
Loading…
Reference in New Issue
Block a user