Speed up openai tests (#9943)

Saves ~8-10 seconds from total unit tests times

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Eugene Yurtsev 2023-08-29 17:30:41 -04:00 committed by GitHub
parent bcc3463ff4
commit 5cce6529a4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -4,7 +4,9 @@ from typing import Any
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
from tenacity import wait_none
from langchain.llms import base
from langchain.llms.openai import OpenAI from langchain.llms.openai import OpenAI
from tests.unit_tests.callbacks.fake_callback_handler import ( from tests.unit_tests.callbacks.fake_callback_handler import (
FakeAsyncCallbackHandler, FakeAsyncCallbackHandler,
@ -55,6 +57,16 @@ def mock_completion() -> dict:
} }
def _patched_retry(*args: Any, **kwargs: Any) -> Any:
"""Patched retry for unit tests that does not wait."""
from tenacity import retry
assert "wait" in kwargs
kwargs["wait"] = wait_none()
r = retry(*args, **kwargs)
return r
@pytest.mark.requires("openai") @pytest.mark.requires("openai")
def test_openai_retries(mock_completion: dict) -> None: def test_openai_retries(mock_completion: dict) -> None:
llm = OpenAI() llm = OpenAI()
@ -73,13 +85,16 @@ def test_openai_retries(mock_completion: dict) -> None:
mock_client.create = raise_once mock_client.create = raise_once
callback_handler = FakeCallbackHandler() callback_handler = FakeCallbackHandler()
with patch.object(
llm, # Patch the retry to avoid waiting during a unit test
"client", with patch.object(base, "retry", _patched_retry):
mock_client, with patch.object(
): llm,
res = llm.predict("bar", callbacks=[callback_handler]) "client",
assert res == "Bar Baz" mock_client,
):
res = llm.predict("bar", callbacks=[callback_handler])
assert res == "Bar Baz"
assert completed assert completed
assert raised assert raised
assert callback_handler.retries == 1 assert callback_handler.retries == 1
@ -105,13 +120,15 @@ async def test_openai_async_retries(mock_completion: dict) -> None:
mock_client.acreate = araise_once mock_client.acreate = araise_once
callback_handler = FakeAsyncCallbackHandler() callback_handler = FakeAsyncCallbackHandler()
with patch.object( # Patch the retry to avoid waiting during a unit test
llm, with patch.object(base, "retry", _patched_retry):
"client", with patch.object(
mock_client, llm,
): "client",
res = await llm.apredict("bar", callbacks=[callback_handler]) mock_client,
assert res == "Bar Baz" ):
res = await llm.apredict("bar", callbacks=[callback_handler])
assert res == "Bar Baz"
assert completed assert completed
assert raised assert raised
assert callback_handler.retries == 1 assert callback_handler.retries == 1