mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-18 17:11:25 +00:00
Add Retry Events (#8053)
 --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
94a693e2ee
commit
ff98fad2d9
@ -242,6 +242,11 @@ class BaseCallbackHandler(
|
||||
"""Whether to ignore LLM callbacks."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def ignore_retry(self) -> bool:
|
||||
"""Whether to ignore retry callbacks."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def ignore_chain(self) -> bool:
|
||||
"""Whether to ignore chain callbacks."""
|
||||
|
@ -23,6 +23,8 @@ from typing import (
|
||||
)
|
||||
from uuid import UUID
|
||||
|
||||
from tenacity import RetryCallState
|
||||
|
||||
import langchain
|
||||
from langchain.callbacks.base import (
|
||||
BaseCallbackHandler,
|
||||
@ -572,6 +574,22 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def on_retry(
|
||||
self,
|
||||
retry_state: RetryCallState,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_retry",
|
||||
"ignore_retry",
|
||||
retry_state,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Run when LLM ends running.
|
||||
|
||||
@ -635,6 +653,22 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def on_retry(
|
||||
self,
|
||||
retry_state: RetryCallState,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_retry",
|
||||
"ignore_retry",
|
||||
retry_state,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Run when LLM ends running.
|
||||
|
||||
|
@ -7,6 +7,8 @@ from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Sequence, Union, cast
|
||||
from uuid import UUID
|
||||
|
||||
from tenacity import RetryCallState
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum
|
||||
from langchain.load.dump import dumpd
|
||||
@ -138,6 +140,41 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
},
|
||||
)
|
||||
|
||||
def on_retry(
|
||||
self,
|
||||
retry_state: RetryCallState,
|
||||
*,
|
||||
run_id: UUID,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
if not run_id:
|
||||
raise TracerException("No run_id provided for on_retry callback.")
|
||||
run_id_ = str(run_id)
|
||||
llm_run = self.run_map.get(run_id_)
|
||||
if llm_run is None or llm_run.run_type != RunTypeEnum.llm:
|
||||
raise TracerException("No LLM Run found to be traced for on_retry")
|
||||
retry_d: Dict[str, Any] = {
|
||||
"slept": retry_state.idle_for,
|
||||
"attempt": retry_state.attempt_number,
|
||||
}
|
||||
if retry_state.outcome is None:
|
||||
retry_d["outcome"] = "N/A"
|
||||
elif retry_state.outcome.failed:
|
||||
retry_d["outcome"] = "failed"
|
||||
exception = retry_state.outcome.exception()
|
||||
retry_d["exception"] = str(exception)
|
||||
retry_d["exception_type"] = exception.__class__.__name__
|
||||
else:
|
||||
retry_d["outcome"] = "success"
|
||||
retry_d["result"] = str(retry_state.outcome.result())
|
||||
llm_run.events.append(
|
||||
{
|
||||
"name": "retry",
|
||||
"time": datetime.utcnow(),
|
||||
"kwargs": retry_d,
|
||||
},
|
||||
)
|
||||
|
||||
def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
"""End a trace for an LLM run."""
|
||||
if not run_id:
|
||||
|
@ -18,23 +18,14 @@ from typing import (
|
||||
)
|
||||
|
||||
from pydantic import Field, root_validator
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.schema import (
|
||||
ChatGeneration,
|
||||
ChatResult,
|
||||
)
|
||||
from langchain.llms.base import create_base_retry_decorator
|
||||
from langchain.schema import ChatGeneration, ChatResult
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
@ -70,31 +61,33 @@ def _import_tiktoken() -> Any:
|
||||
return tiktoken
|
||||
|
||||
|
||||
def _create_retry_decorator(llm: ChatOpenAI) -> Callable[[Any], Any]:
|
||||
def _create_retry_decorator(
|
||||
llm: ChatOpenAI,
|
||||
run_manager: Optional[
|
||||
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
|
||||
] = None,
|
||||
) -> Callable[[Any], Any]:
|
||||
import openai
|
||||
|
||||
min_seconds = 1
|
||||
max_seconds = 60
|
||||
# Wait 2^x * 1 second between each retry starting with
|
||||
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
|
||||
return retry(
|
||||
reraise=True,
|
||||
stop=stop_after_attempt(llm.max_retries),
|
||||
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
|
||||
retry=(
|
||||
retry_if_exception_type(openai.error.Timeout)
|
||||
| retry_if_exception_type(openai.error.APIError)
|
||||
| retry_if_exception_type(openai.error.APIConnectionError)
|
||||
| retry_if_exception_type(openai.error.RateLimitError)
|
||||
| retry_if_exception_type(openai.error.ServiceUnavailableError)
|
||||
),
|
||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||
errors = [
|
||||
openai.error.Timeout,
|
||||
openai.error.APIError,
|
||||
openai.error.APIConnectionError,
|
||||
openai.error.RateLimitError,
|
||||
openai.error.ServiceUnavailableError,
|
||||
]
|
||||
return create_base_retry_decorator(
|
||||
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
|
||||
)
|
||||
|
||||
|
||||
async def acompletion_with_retry(llm: ChatOpenAI, **kwargs: Any) -> Any:
|
||||
async def acompletion_with_retry(
|
||||
llm: ChatOpenAI,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Use tenacity to retry the async completion call."""
|
||||
retry_decorator = _create_retry_decorator(llm)
|
||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
async def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
@ -322,9 +315,11 @@ class ChatOpenAI(BaseChatModel):
|
||||
**self.model_kwargs,
|
||||
}
|
||||
|
||||
def completion_with_retry(self, **kwargs: Any) -> Any:
|
||||
def completion_with_retry(
|
||||
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
retry_decorator = _create_retry_decorator(self)
|
||||
retry_decorator = _create_retry_decorator(self, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
@ -357,7 +352,9 @@ class ChatOpenAI(BaseChatModel):
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
|
||||
default_chunk_class = AIMessageChunk
|
||||
for chunk in self.completion_with_retry(messages=message_dicts, **params):
|
||||
for chunk in self.completion_with_retry(
|
||||
messages=message_dicts, run_manager=run_manager, **params
|
||||
):
|
||||
if len(chunk["choices"]) == 0:
|
||||
continue
|
||||
delta = chunk["choices"][0]["delta"]
|
||||
@ -388,7 +385,9 @@ class ChatOpenAI(BaseChatModel):
|
||||
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs}
|
||||
response = self.completion_with_retry(messages=message_dicts, **params)
|
||||
response = self.completion_with_retry(
|
||||
messages=message_dicts, run_manager=run_manager, **params
|
||||
)
|
||||
return self._create_chat_result(response)
|
||||
|
||||
def _create_message_dicts(
|
||||
@ -427,7 +426,7 @@ class ChatOpenAI(BaseChatModel):
|
||||
|
||||
default_chunk_class = AIMessageChunk
|
||||
async for chunk in await acompletion_with_retry(
|
||||
self, messages=message_dicts, **params
|
||||
self, messages=message_dicts, run_manager=run_manager, **params
|
||||
):
|
||||
if len(chunk["choices"]) == 0:
|
||||
continue
|
||||
@ -459,7 +458,9 @@ class ChatOpenAI(BaseChatModel):
|
||||
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs}
|
||||
response = await acompletion_with_retry(self, messages=message_dicts, **params)
|
||||
response = await acompletion_with_retry(
|
||||
self, messages=message_dicts, run_manager=run_manager, **params
|
||||
)
|
||||
return self._create_chat_result(response)
|
||||
|
||||
@property
|
||||
|
@ -2,6 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
@ -28,6 +29,7 @@ from typing import (
|
||||
import yaml
|
||||
from pydantic import Field, root_validator, validator
|
||||
from tenacity import (
|
||||
RetryCallState,
|
||||
before_sleep_log,
|
||||
retry,
|
||||
retry_base,
|
||||
@ -66,11 +68,36 @@ def _get_verbosity() -> bool:
|
||||
return langchain.verbose
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def _log_error_once(msg: str) -> None:
|
||||
"""Log an error once."""
|
||||
logger.error(msg)
|
||||
|
||||
|
||||
def create_base_retry_decorator(
|
||||
error_types: List[Type[BaseException]], max_retries: int = 1
|
||||
error_types: List[Type[BaseException]],
|
||||
max_retries: int = 1,
|
||||
run_manager: Optional[
|
||||
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
|
||||
] = None,
|
||||
) -> Callable[[Any], Any]:
|
||||
"""Create a retry decorator for a given LLM and provided list of error types."""
|
||||
|
||||
_logging = before_sleep_log(logger, logging.WARNING)
|
||||
|
||||
def _before_sleep(retry_state: RetryCallState) -> None:
|
||||
_logging(retry_state)
|
||||
if run_manager:
|
||||
if isinstance(run_manager, AsyncCallbackManagerForLLMRun):
|
||||
coro = run_manager.on_retry(retry_state)
|
||||
try:
|
||||
asyncio.run(coro)
|
||||
except Exception as e:
|
||||
_log_error_once(f"Error in on_retry: {e}")
|
||||
else:
|
||||
run_manager.on_retry(retry_state)
|
||||
return None
|
||||
|
||||
min_seconds = 4
|
||||
max_seconds = 10
|
||||
# Wait 2^x * 1 second between each retry starting with
|
||||
@ -83,7 +110,7 @@ def create_base_retry_decorator(
|
||||
stop=stop_after_attempt(max_retries),
|
||||
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
|
||||
retry=retry_instance,
|
||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||
before_sleep=_before_sleep,
|
||||
)
|
||||
|
||||
|
||||
|
@ -80,7 +80,12 @@ def _streaming_response_template() -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
def _create_retry_decorator(llm: Union[BaseOpenAI, OpenAIChat]) -> Callable[[Any], Any]:
|
||||
def _create_retry_decorator(
|
||||
llm: Union[BaseOpenAI, OpenAIChat],
|
||||
run_manager: Optional[
|
||||
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
|
||||
] = None,
|
||||
) -> Callable[[Any], Any]:
|
||||
import openai
|
||||
|
||||
errors = [
|
||||
@ -90,12 +95,18 @@ def _create_retry_decorator(llm: Union[BaseOpenAI, OpenAIChat]) -> Callable[[Any
|
||||
openai.error.RateLimitError,
|
||||
openai.error.ServiceUnavailableError,
|
||||
]
|
||||
return create_base_retry_decorator(error_types=errors, max_retries=llm.max_retries)
|
||||
return create_base_retry_decorator(
|
||||
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
|
||||
)
|
||||
|
||||
|
||||
def completion_with_retry(llm: Union[BaseOpenAI, OpenAIChat], **kwargs: Any) -> Any:
|
||||
def completion_with_retry(
|
||||
llm: Union[BaseOpenAI, OpenAIChat],
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
retry_decorator = _create_retry_decorator(llm)
|
||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
@ -105,10 +116,12 @@ def completion_with_retry(llm: Union[BaseOpenAI, OpenAIChat], **kwargs: Any) ->
|
||||
|
||||
|
||||
async def acompletion_with_retry(
|
||||
llm: Union[BaseOpenAI, OpenAIChat], **kwargs: Any
|
||||
llm: Union[BaseOpenAI, OpenAIChat],
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Use tenacity to retry the async completion call."""
|
||||
retry_decorator = _create_retry_decorator(llm)
|
||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
async def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
@ -291,8 +304,10 @@ class BaseOpenAI(BaseLLM):
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
params = {**self._invocation_params, **kwargs, "stream": True}
|
||||
self.get_sub_prompts(params, [prompt], stop) # this mutate params
|
||||
for stream_resp in completion_with_retry(self, prompt=prompt, **params):
|
||||
self.get_sub_prompts(params, [prompt], stop) # this mutates params
|
||||
for stream_resp in completion_with_retry(
|
||||
self, prompt=prompt, run_manager=run_manager, **params
|
||||
):
|
||||
chunk = _stream_response_to_generation_chunk(stream_resp)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
@ -314,7 +329,7 @@ class BaseOpenAI(BaseLLM):
|
||||
params = {**self._invocation_params, **kwargs, "stream": True}
|
||||
self.get_sub_prompts(params, [prompt], stop) # this mutate params
|
||||
async for stream_resp in await acompletion_with_retry(
|
||||
self, prompt=prompt, **params
|
||||
self, prompt=prompt, run_manager=run_manager, **params
|
||||
):
|
||||
chunk = _stream_response_to_generation_chunk(stream_resp)
|
||||
yield chunk
|
||||
@ -381,7 +396,9 @@ class BaseOpenAI(BaseLLM):
|
||||
}
|
||||
)
|
||||
else:
|
||||
response = completion_with_retry(self, prompt=_prompts, **params)
|
||||
response = completion_with_retry(
|
||||
self, prompt=_prompts, run_manager=run_manager, **params
|
||||
)
|
||||
choices.extend(response["choices"])
|
||||
update_token_usage(_keys, response, token_usage)
|
||||
return self.create_llm_result(choices, prompts, token_usage)
|
||||
@ -428,7 +445,9 @@ class BaseOpenAI(BaseLLM):
|
||||
}
|
||||
)
|
||||
else:
|
||||
response = await acompletion_with_retry(self, prompt=_prompts, **params)
|
||||
response = await acompletion_with_retry(
|
||||
self, prompt=_prompts, run_manager=run_manager, **params
|
||||
)
|
||||
choices.extend(response["choices"])
|
||||
update_token_usage(_keys, response, token_usage)
|
||||
return self.create_llm_result(choices, prompts, token_usage)
|
||||
@ -818,7 +837,9 @@ class OpenAIChat(BaseLLM):
|
||||
) -> Iterator[GenerationChunk]:
|
||||
messages, params = self._get_chat_params([prompt], stop)
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
for stream_resp in completion_with_retry(self, messages=messages, **params):
|
||||
for stream_resp in completion_with_retry(
|
||||
self, messages=messages, run_manager=run_manager, **params
|
||||
):
|
||||
token = stream_resp["choices"][0]["delta"].get("content", "")
|
||||
yield GenerationChunk(text=token)
|
||||
if run_manager:
|
||||
@ -834,7 +855,7 @@ class OpenAIChat(BaseLLM):
|
||||
messages, params = self._get_chat_params([prompt], stop)
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
async for stream_resp in await acompletion_with_retry(
|
||||
self, messages=messages, **params
|
||||
self, messages=messages, run_manager=run_manager, **params
|
||||
):
|
||||
token = stream_resp["choices"][0]["delta"].get("content", "")
|
||||
yield GenerationChunk(text=token)
|
||||
@ -860,7 +881,9 @@ class OpenAIChat(BaseLLM):
|
||||
|
||||
messages, params = self._get_chat_params(prompts, stop)
|
||||
params = {**params, **kwargs}
|
||||
full_response = completion_with_retry(self, messages=messages, **params)
|
||||
full_response = completion_with_retry(
|
||||
self, messages=messages, run_manager=run_manager, **params
|
||||
)
|
||||
llm_output = {
|
||||
"token_usage": full_response["usage"],
|
||||
"model_name": self.model_name,
|
||||
@ -891,7 +914,9 @@ class OpenAIChat(BaseLLM):
|
||||
|
||||
messages, params = self._get_chat_params(prompts, stop)
|
||||
params = {**params, **kwargs}
|
||||
full_response = await acompletion_with_retry(self, messages=messages, **params)
|
||||
full_response = await acompletion_with_retry(
|
||||
self, messages=messages, run_manager=run_manager, **params
|
||||
)
|
||||
llm_output = {
|
||||
"token_usage": full_response["usage"],
|
||||
"model_name": self.model_name,
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""Test OpenAI API wrapper."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Generator
|
||||
from typing import Any, Generator
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@ -10,7 +10,10 @@ from langchain.chat_models.openai import ChatOpenAI
|
||||
from langchain.llms.loading import load_llm
|
||||
from langchain.llms.openai import OpenAI, OpenAIChat
|
||||
from langchain.schema import LLMResult
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import (
|
||||
FakeAsyncCallbackHandler,
|
||||
FakeCallbackHandler,
|
||||
)
|
||||
|
||||
|
||||
def test_openai_call() -> None:
|
||||
@ -334,3 +337,77 @@ def test_chat_openai_get_num_tokens(model: str) -> None:
|
||||
"""Test get_tokens."""
|
||||
llm = ChatOpenAI(model=model)
|
||||
assert llm.get_num_tokens("表情符号是\n🦜🔗") == _EXPECTED_NUM_TOKENS[model]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_completion() -> dict:
|
||||
return {
|
||||
"id": "cmpl-3evkmQda5Hu7fcZavknQda3SQ",
|
||||
"object": "text_completion",
|
||||
"created": 1689989000,
|
||||
"model": "text-davinci-003",
|
||||
"choices": [
|
||||
{"text": "Bar Baz", "index": 0, "logprobs": None, "finish_reason": "length"}
|
||||
],
|
||||
"usage": {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_openai_retries(mock_completion: dict) -> None:
|
||||
llm = OpenAI()
|
||||
mock_client = MagicMock()
|
||||
completed = False
|
||||
raised = False
|
||||
import openai
|
||||
|
||||
def raise_once(*args: Any, **kwargs: Any) -> Any:
|
||||
nonlocal completed, raised
|
||||
if not raised:
|
||||
raised = True
|
||||
raise openai.error.APIError
|
||||
completed = True
|
||||
return mock_completion
|
||||
|
||||
mock_client.create = raise_once
|
||||
callback_handler = FakeCallbackHandler()
|
||||
with patch.object(
|
||||
llm,
|
||||
"client",
|
||||
mock_client,
|
||||
):
|
||||
res = llm.predict("bar", callbacks=[callback_handler])
|
||||
assert res == "Bar Baz"
|
||||
assert completed
|
||||
assert raised
|
||||
assert callback_handler.retries == 1
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
async def test_openai_async_retries(mock_completion: dict) -> None:
|
||||
llm = OpenAI()
|
||||
mock_client = MagicMock()
|
||||
completed = False
|
||||
raised = False
|
||||
import openai
|
||||
|
||||
def raise_once(*args: Any, **kwargs: Any) -> Any:
|
||||
nonlocal completed, raised
|
||||
if not raised:
|
||||
raised = True
|
||||
raise openai.error.APIError
|
||||
completed = True
|
||||
return mock_completion
|
||||
|
||||
mock_client.create = raise_once
|
||||
callback_handler = FakeAsyncCallbackHandler()
|
||||
with patch.object(
|
||||
llm,
|
||||
"client",
|
||||
mock_client,
|
||||
):
|
||||
res = llm.apredict("bar", callbacks=[callback_handler])
|
||||
assert res == "Bar Baz"
|
||||
assert completed
|
||||
assert raised
|
||||
assert callback_handler.retries == 1
|
||||
|
@ -39,6 +39,7 @@ class BaseFakeCallbackHandler(BaseModel):
|
||||
retriever_starts: int = 0
|
||||
retriever_ends: int = 0
|
||||
retriever_errors: int = 0
|
||||
retries: int = 0
|
||||
|
||||
|
||||
class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler):
|
||||
@ -58,8 +59,10 @@ class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler):
|
||||
def on_llm_new_token_common(self) -> None:
|
||||
self.llm_streams += 1
|
||||
|
||||
def on_retry_common(self) -> None:
|
||||
self.retries += 1
|
||||
|
||||
def on_chain_start_common(self) -> None:
|
||||
("CHAIN START")
|
||||
self.chain_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
@ -82,7 +85,6 @@ class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler):
|
||||
self.errors += 1
|
||||
|
||||
def on_agent_action_common(self) -> None:
|
||||
print("AGENT ACTION")
|
||||
self.agent_actions += 1
|
||||
self.starts += 1
|
||||
|
||||
@ -91,7 +93,6 @@ class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler):
|
||||
self.ends += 1
|
||||
|
||||
def on_chat_model_start_common(self) -> None:
|
||||
print("STARTING CHAT MODEL")
|
||||
self.chat_model_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
@ -162,6 +163,13 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
) -> Any:
|
||||
self.on_llm_error_common()
|
||||
|
||||
def on_retry(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.on_retry_common()
|
||||
|
||||
def on_chain_start(
|
||||
self,
|
||||
*args: Any,
|
||||
|
@ -1,8 +1,12 @@
|
||||
"""Test OpenAI Chat API wrapper."""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.chat_models.openai import (
|
||||
ChatOpenAI,
|
||||
_convert_dict_to_message,
|
||||
)
|
||||
from langchain.schema.messages import FunctionMessage
|
||||
@ -21,3 +25,67 @@ def test_function_message_dict_to_function_message() -> None:
|
||||
assert isinstance(result, FunctionMessage)
|
||||
assert result.name == name
|
||||
assert result.content == content
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_completion() -> dict:
|
||||
return {
|
||||
"id": "chatcmpl-7fcZavknQda3SQ",
|
||||
"object": "chat.completion",
|
||||
"created": 1689989000,
|
||||
"model": "gpt-3.5-turbo-0613",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Bar Baz",
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_openai_predict(mock_completion: dict) -> None:
|
||||
llm = ChatOpenAI()
|
||||
mock_client = MagicMock()
|
||||
completed = False
|
||||
|
||||
def mock_create(*args: Any, **kwargs: Any) -> Any:
|
||||
nonlocal completed
|
||||
completed = True
|
||||
return mock_completion
|
||||
|
||||
mock_client.create = mock_create
|
||||
with patch.object(
|
||||
llm,
|
||||
"client",
|
||||
mock_client,
|
||||
):
|
||||
res = llm.predict("bar")
|
||||
assert res == "Bar Baz"
|
||||
assert completed
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
async def test_openai_apredict(mock_completion: dict) -> None:
|
||||
llm = ChatOpenAI()
|
||||
mock_client = MagicMock()
|
||||
completed = False
|
||||
|
||||
def mock_create(*args: Any, **kwargs: Any) -> Any:
|
||||
nonlocal completed
|
||||
completed = True
|
||||
return mock_completion
|
||||
|
||||
mock_client.create = mock_create
|
||||
with patch.object(
|
||||
llm,
|
||||
"client",
|
||||
mock_client,
|
||||
):
|
||||
res = llm.predict("bar")
|
||||
assert res == "Bar Baz"
|
||||
assert completed
|
||||
|
@ -1,4 +1,6 @@
|
||||
import os
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@ -26,3 +28,61 @@ def test_openai_incorrect_field() -> None:
|
||||
with pytest.warns(match="not default parameter"):
|
||||
llm = OpenAI(foo="bar")
|
||||
assert llm.model_kwargs == {"foo": "bar"}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_completion() -> dict:
|
||||
return {
|
||||
"id": "cmpl-3evkmQda5Hu7fcZavknQda3SQ",
|
||||
"object": "text_completion",
|
||||
"created": 1689989000,
|
||||
"model": "text-davinci-003",
|
||||
"choices": [
|
||||
{"text": "Bar Baz", "index": 0, "logprobs": None, "finish_reason": "length"}
|
||||
],
|
||||
"usage": {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_openai_calls(mock_completion: dict) -> None:
|
||||
llm = OpenAI()
|
||||
mock_client = MagicMock()
|
||||
completed = False
|
||||
|
||||
def raise_once(*args: Any, **kwargs: Any) -> Any:
|
||||
nonlocal completed
|
||||
completed = True
|
||||
return mock_completion
|
||||
|
||||
mock_client.create = raise_once
|
||||
with patch.object(
|
||||
llm,
|
||||
"client",
|
||||
mock_client,
|
||||
):
|
||||
res = llm.predict("bar")
|
||||
assert res == "Bar Baz"
|
||||
assert completed
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
async def test_openai_async_retries(mock_completion: dict) -> None:
|
||||
llm = OpenAI()
|
||||
mock_client = MagicMock()
|
||||
completed = False
|
||||
|
||||
def raise_once(*args: Any, **kwargs: Any) -> Any:
|
||||
nonlocal completed
|
||||
completed = True
|
||||
return mock_completion
|
||||
|
||||
mock_client.create = raise_once
|
||||
with patch.object(
|
||||
llm,
|
||||
"client",
|
||||
mock_client,
|
||||
):
|
||||
res = llm.apredict("bar")
|
||||
assert res == "Bar Baz"
|
||||
assert completed
|
||||
|
Loading…
Reference in New Issue
Block a user