mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-20 09:57:32 +00:00
Add Retry Events (#8053)
 --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
94a693e2ee
commit
ff98fad2d9
@ -242,6 +242,11 @@ class BaseCallbackHandler(
|
|||||||
"""Whether to ignore LLM callbacks."""
|
"""Whether to ignore LLM callbacks."""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ignore_retry(self) -> bool:
|
||||||
|
"""Whether to ignore retry callbacks."""
|
||||||
|
return False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ignore_chain(self) -> bool:
|
def ignore_chain(self) -> bool:
|
||||||
"""Whether to ignore chain callbacks."""
|
"""Whether to ignore chain callbacks."""
|
||||||
|
@ -23,6 +23,8 @@ from typing import (
|
|||||||
)
|
)
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
|
from tenacity import RetryCallState
|
||||||
|
|
||||||
import langchain
|
import langchain
|
||||||
from langchain.callbacks.base import (
|
from langchain.callbacks.base import (
|
||||||
BaseCallbackHandler,
|
BaseCallbackHandler,
|
||||||
@ -572,6 +574,22 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def on_retry(
|
||||||
|
self,
|
||||||
|
retry_state: RetryCallState,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
_handle_event(
|
||||||
|
self.handlers,
|
||||||
|
"on_retry",
|
||||||
|
"ignore_retry",
|
||||||
|
retry_state,
|
||||||
|
run_id=self.run_id,
|
||||||
|
parent_run_id=self.parent_run_id,
|
||||||
|
tags=self.tags,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||||
"""Run when LLM ends running.
|
"""Run when LLM ends running.
|
||||||
|
|
||||||
@ -635,6 +653,22 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def on_retry(
|
||||||
|
self,
|
||||||
|
retry_state: RetryCallState,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
await _ahandle_event(
|
||||||
|
self.handlers,
|
||||||
|
"on_retry",
|
||||||
|
"ignore_retry",
|
||||||
|
retry_state,
|
||||||
|
run_id=self.run_id,
|
||||||
|
parent_run_id=self.parent_run_id,
|
||||||
|
tags=self.tags,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||||
"""Run when LLM ends running.
|
"""Run when LLM ends running.
|
||||||
|
|
||||||
|
@ -7,6 +7,8 @@ from datetime import datetime
|
|||||||
from typing import Any, Dict, List, Optional, Sequence, Union, cast
|
from typing import Any, Dict, List, Optional, Sequence, Union, cast
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
|
from tenacity import RetryCallState
|
||||||
|
|
||||||
from langchain.callbacks.base import BaseCallbackHandler
|
from langchain.callbacks.base import BaseCallbackHandler
|
||||||
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum
|
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum
|
||||||
from langchain.load.dump import dumpd
|
from langchain.load.dump import dumpd
|
||||||
@ -138,6 +140,41 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def on_retry(
|
||||||
|
self,
|
||||||
|
retry_state: RetryCallState,
|
||||||
|
*,
|
||||||
|
run_id: UUID,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
if not run_id:
|
||||||
|
raise TracerException("No run_id provided for on_retry callback.")
|
||||||
|
run_id_ = str(run_id)
|
||||||
|
llm_run = self.run_map.get(run_id_)
|
||||||
|
if llm_run is None or llm_run.run_type != RunTypeEnum.llm:
|
||||||
|
raise TracerException("No LLM Run found to be traced for on_retry")
|
||||||
|
retry_d: Dict[str, Any] = {
|
||||||
|
"slept": retry_state.idle_for,
|
||||||
|
"attempt": retry_state.attempt_number,
|
||||||
|
}
|
||||||
|
if retry_state.outcome is None:
|
||||||
|
retry_d["outcome"] = "N/A"
|
||||||
|
elif retry_state.outcome.failed:
|
||||||
|
retry_d["outcome"] = "failed"
|
||||||
|
exception = retry_state.outcome.exception()
|
||||||
|
retry_d["exception"] = str(exception)
|
||||||
|
retry_d["exception_type"] = exception.__class__.__name__
|
||||||
|
else:
|
||||||
|
retry_d["outcome"] = "success"
|
||||||
|
retry_d["result"] = str(retry_state.outcome.result())
|
||||||
|
llm_run.events.append(
|
||||||
|
{
|
||||||
|
"name": "retry",
|
||||||
|
"time": datetime.utcnow(),
|
||||||
|
"kwargs": retry_d,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> None:
|
def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> None:
|
||||||
"""End a trace for an LLM run."""
|
"""End a trace for an LLM run."""
|
||||||
if not run_id:
|
if not run_id:
|
||||||
|
@ -18,23 +18,14 @@ from typing import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from pydantic import Field, root_validator
|
from pydantic import Field, root_validator
|
||||||
from tenacity import (
|
|
||||||
before_sleep_log,
|
|
||||||
retry,
|
|
||||||
retry_if_exception_type,
|
|
||||||
stop_after_attempt,
|
|
||||||
wait_exponential,
|
|
||||||
)
|
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
AsyncCallbackManagerForLLMRun,
|
AsyncCallbackManagerForLLMRun,
|
||||||
CallbackManagerForLLMRun,
|
CallbackManagerForLLMRun,
|
||||||
)
|
)
|
||||||
from langchain.chat_models.base import BaseChatModel
|
from langchain.chat_models.base import BaseChatModel
|
||||||
from langchain.schema import (
|
from langchain.llms.base import create_base_retry_decorator
|
||||||
ChatGeneration,
|
from langchain.schema import ChatGeneration, ChatResult
|
||||||
ChatResult,
|
|
||||||
)
|
|
||||||
from langchain.schema.messages import (
|
from langchain.schema.messages import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
AIMessageChunk,
|
AIMessageChunk,
|
||||||
@ -70,31 +61,33 @@ def _import_tiktoken() -> Any:
|
|||||||
return tiktoken
|
return tiktoken
|
||||||
|
|
||||||
|
|
||||||
def _create_retry_decorator(llm: ChatOpenAI) -> Callable[[Any], Any]:
|
def _create_retry_decorator(
|
||||||
|
llm: ChatOpenAI,
|
||||||
|
run_manager: Optional[
|
||||||
|
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
|
||||||
|
] = None,
|
||||||
|
) -> Callable[[Any], Any]:
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
min_seconds = 1
|
errors = [
|
||||||
max_seconds = 60
|
openai.error.Timeout,
|
||||||
# Wait 2^x * 1 second between each retry starting with
|
openai.error.APIError,
|
||||||
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
|
openai.error.APIConnectionError,
|
||||||
return retry(
|
openai.error.RateLimitError,
|
||||||
reraise=True,
|
openai.error.ServiceUnavailableError,
|
||||||
stop=stop_after_attempt(llm.max_retries),
|
]
|
||||||
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
|
return create_base_retry_decorator(
|
||||||
retry=(
|
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
|
||||||
retry_if_exception_type(openai.error.Timeout)
|
|
||||||
| retry_if_exception_type(openai.error.APIError)
|
|
||||||
| retry_if_exception_type(openai.error.APIConnectionError)
|
|
||||||
| retry_if_exception_type(openai.error.RateLimitError)
|
|
||||||
| retry_if_exception_type(openai.error.ServiceUnavailableError)
|
|
||||||
),
|
|
||||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def acompletion_with_retry(llm: ChatOpenAI, **kwargs: Any) -> Any:
|
async def acompletion_with_retry(
|
||||||
|
llm: ChatOpenAI,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
"""Use tenacity to retry the async completion call."""
|
"""Use tenacity to retry the async completion call."""
|
||||||
retry_decorator = _create_retry_decorator(llm)
|
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||||
|
|
||||||
@retry_decorator
|
@retry_decorator
|
||||||
async def _completion_with_retry(**kwargs: Any) -> Any:
|
async def _completion_with_retry(**kwargs: Any) -> Any:
|
||||||
@ -322,9 +315,11 @@ class ChatOpenAI(BaseChatModel):
|
|||||||
**self.model_kwargs,
|
**self.model_kwargs,
|
||||||
}
|
}
|
||||||
|
|
||||||
def completion_with_retry(self, **kwargs: Any) -> Any:
|
def completion_with_retry(
|
||||||
|
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
|
||||||
|
) -> Any:
|
||||||
"""Use tenacity to retry the completion call."""
|
"""Use tenacity to retry the completion call."""
|
||||||
retry_decorator = _create_retry_decorator(self)
|
retry_decorator = _create_retry_decorator(self, run_manager=run_manager)
|
||||||
|
|
||||||
@retry_decorator
|
@retry_decorator
|
||||||
def _completion_with_retry(**kwargs: Any) -> Any:
|
def _completion_with_retry(**kwargs: Any) -> Any:
|
||||||
@ -357,7 +352,9 @@ class ChatOpenAI(BaseChatModel):
|
|||||||
params = {**params, **kwargs, "stream": True}
|
params = {**params, **kwargs, "stream": True}
|
||||||
|
|
||||||
default_chunk_class = AIMessageChunk
|
default_chunk_class = AIMessageChunk
|
||||||
for chunk in self.completion_with_retry(messages=message_dicts, **params):
|
for chunk in self.completion_with_retry(
|
||||||
|
messages=message_dicts, run_manager=run_manager, **params
|
||||||
|
):
|
||||||
if len(chunk["choices"]) == 0:
|
if len(chunk["choices"]) == 0:
|
||||||
continue
|
continue
|
||||||
delta = chunk["choices"][0]["delta"]
|
delta = chunk["choices"][0]["delta"]
|
||||||
@ -388,7 +385,9 @@ class ChatOpenAI(BaseChatModel):
|
|||||||
|
|
||||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
params = {**params, **kwargs}
|
params = {**params, **kwargs}
|
||||||
response = self.completion_with_retry(messages=message_dicts, **params)
|
response = self.completion_with_retry(
|
||||||
|
messages=message_dicts, run_manager=run_manager, **params
|
||||||
|
)
|
||||||
return self._create_chat_result(response)
|
return self._create_chat_result(response)
|
||||||
|
|
||||||
def _create_message_dicts(
|
def _create_message_dicts(
|
||||||
@ -427,7 +426,7 @@ class ChatOpenAI(BaseChatModel):
|
|||||||
|
|
||||||
default_chunk_class = AIMessageChunk
|
default_chunk_class = AIMessageChunk
|
||||||
async for chunk in await acompletion_with_retry(
|
async for chunk in await acompletion_with_retry(
|
||||||
self, messages=message_dicts, **params
|
self, messages=message_dicts, run_manager=run_manager, **params
|
||||||
):
|
):
|
||||||
if len(chunk["choices"]) == 0:
|
if len(chunk["choices"]) == 0:
|
||||||
continue
|
continue
|
||||||
@ -459,7 +458,9 @@ class ChatOpenAI(BaseChatModel):
|
|||||||
|
|
||||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
params = {**params, **kwargs}
|
params = {**params, **kwargs}
|
||||||
response = await acompletion_with_retry(self, messages=message_dicts, **params)
|
response = await acompletion_with_retry(
|
||||||
|
self, messages=message_dicts, run_manager=run_manager, **params
|
||||||
|
)
|
||||||
return self._create_chat_result(response)
|
return self._create_chat_result(response)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
@ -28,6 +29,7 @@ from typing import (
|
|||||||
import yaml
|
import yaml
|
||||||
from pydantic import Field, root_validator, validator
|
from pydantic import Field, root_validator, validator
|
||||||
from tenacity import (
|
from tenacity import (
|
||||||
|
RetryCallState,
|
||||||
before_sleep_log,
|
before_sleep_log,
|
||||||
retry,
|
retry,
|
||||||
retry_base,
|
retry_base,
|
||||||
@ -66,11 +68,36 @@ def _get_verbosity() -> bool:
|
|||||||
return langchain.verbose
|
return langchain.verbose
|
||||||
|
|
||||||
|
|
||||||
|
@functools.lru_cache
|
||||||
|
def _log_error_once(msg: str) -> None:
|
||||||
|
"""Log an error once."""
|
||||||
|
logger.error(msg)
|
||||||
|
|
||||||
|
|
||||||
def create_base_retry_decorator(
|
def create_base_retry_decorator(
|
||||||
error_types: List[Type[BaseException]], max_retries: int = 1
|
error_types: List[Type[BaseException]],
|
||||||
|
max_retries: int = 1,
|
||||||
|
run_manager: Optional[
|
||||||
|
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
|
||||||
|
] = None,
|
||||||
) -> Callable[[Any], Any]:
|
) -> Callable[[Any], Any]:
|
||||||
"""Create a retry decorator for a given LLM and provided list of error types."""
|
"""Create a retry decorator for a given LLM and provided list of error types."""
|
||||||
|
|
||||||
|
_logging = before_sleep_log(logger, logging.WARNING)
|
||||||
|
|
||||||
|
def _before_sleep(retry_state: RetryCallState) -> None:
|
||||||
|
_logging(retry_state)
|
||||||
|
if run_manager:
|
||||||
|
if isinstance(run_manager, AsyncCallbackManagerForLLMRun):
|
||||||
|
coro = run_manager.on_retry(retry_state)
|
||||||
|
try:
|
||||||
|
asyncio.run(coro)
|
||||||
|
except Exception as e:
|
||||||
|
_log_error_once(f"Error in on_retry: {e}")
|
||||||
|
else:
|
||||||
|
run_manager.on_retry(retry_state)
|
||||||
|
return None
|
||||||
|
|
||||||
min_seconds = 4
|
min_seconds = 4
|
||||||
max_seconds = 10
|
max_seconds = 10
|
||||||
# Wait 2^x * 1 second between each retry starting with
|
# Wait 2^x * 1 second between each retry starting with
|
||||||
@ -83,7 +110,7 @@ def create_base_retry_decorator(
|
|||||||
stop=stop_after_attempt(max_retries),
|
stop=stop_after_attempt(max_retries),
|
||||||
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
|
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
|
||||||
retry=retry_instance,
|
retry=retry_instance,
|
||||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
before_sleep=_before_sleep,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -80,7 +80,12 @@ def _streaming_response_template() -> Dict[str, Any]:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _create_retry_decorator(llm: Union[BaseOpenAI, OpenAIChat]) -> Callable[[Any], Any]:
|
def _create_retry_decorator(
|
||||||
|
llm: Union[BaseOpenAI, OpenAIChat],
|
||||||
|
run_manager: Optional[
|
||||||
|
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
|
||||||
|
] = None,
|
||||||
|
) -> Callable[[Any], Any]:
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
errors = [
|
errors = [
|
||||||
@ -90,12 +95,18 @@ def _create_retry_decorator(llm: Union[BaseOpenAI, OpenAIChat]) -> Callable[[Any
|
|||||||
openai.error.RateLimitError,
|
openai.error.RateLimitError,
|
||||||
openai.error.ServiceUnavailableError,
|
openai.error.ServiceUnavailableError,
|
||||||
]
|
]
|
||||||
return create_base_retry_decorator(error_types=errors, max_retries=llm.max_retries)
|
return create_base_retry_decorator(
|
||||||
|
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def completion_with_retry(llm: Union[BaseOpenAI, OpenAIChat], **kwargs: Any) -> Any:
|
def completion_with_retry(
|
||||||
|
llm: Union[BaseOpenAI, OpenAIChat],
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
"""Use tenacity to retry the completion call."""
|
"""Use tenacity to retry the completion call."""
|
||||||
retry_decorator = _create_retry_decorator(llm)
|
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||||
|
|
||||||
@retry_decorator
|
@retry_decorator
|
||||||
def _completion_with_retry(**kwargs: Any) -> Any:
|
def _completion_with_retry(**kwargs: Any) -> Any:
|
||||||
@ -105,10 +116,12 @@ def completion_with_retry(llm: Union[BaseOpenAI, OpenAIChat], **kwargs: Any) ->
|
|||||||
|
|
||||||
|
|
||||||
async def acompletion_with_retry(
|
async def acompletion_with_retry(
|
||||||
llm: Union[BaseOpenAI, OpenAIChat], **kwargs: Any
|
llm: Union[BaseOpenAI, OpenAIChat],
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Use tenacity to retry the async completion call."""
|
"""Use tenacity to retry the async completion call."""
|
||||||
retry_decorator = _create_retry_decorator(llm)
|
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||||
|
|
||||||
@retry_decorator
|
@retry_decorator
|
||||||
async def _completion_with_retry(**kwargs: Any) -> Any:
|
async def _completion_with_retry(**kwargs: Any) -> Any:
|
||||||
@ -291,8 +304,10 @@ class BaseOpenAI(BaseLLM):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Iterator[GenerationChunk]:
|
) -> Iterator[GenerationChunk]:
|
||||||
params = {**self._invocation_params, **kwargs, "stream": True}
|
params = {**self._invocation_params, **kwargs, "stream": True}
|
||||||
self.get_sub_prompts(params, [prompt], stop) # this mutate params
|
self.get_sub_prompts(params, [prompt], stop) # this mutates params
|
||||||
for stream_resp in completion_with_retry(self, prompt=prompt, **params):
|
for stream_resp in completion_with_retry(
|
||||||
|
self, prompt=prompt, run_manager=run_manager, **params
|
||||||
|
):
|
||||||
chunk = _stream_response_to_generation_chunk(stream_resp)
|
chunk = _stream_response_to_generation_chunk(stream_resp)
|
||||||
yield chunk
|
yield chunk
|
||||||
if run_manager:
|
if run_manager:
|
||||||
@ -314,7 +329,7 @@ class BaseOpenAI(BaseLLM):
|
|||||||
params = {**self._invocation_params, **kwargs, "stream": True}
|
params = {**self._invocation_params, **kwargs, "stream": True}
|
||||||
self.get_sub_prompts(params, [prompt], stop) # this mutate params
|
self.get_sub_prompts(params, [prompt], stop) # this mutate params
|
||||||
async for stream_resp in await acompletion_with_retry(
|
async for stream_resp in await acompletion_with_retry(
|
||||||
self, prompt=prompt, **params
|
self, prompt=prompt, run_manager=run_manager, **params
|
||||||
):
|
):
|
||||||
chunk = _stream_response_to_generation_chunk(stream_resp)
|
chunk = _stream_response_to_generation_chunk(stream_resp)
|
||||||
yield chunk
|
yield chunk
|
||||||
@ -381,7 +396,9 @@ class BaseOpenAI(BaseLLM):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response = completion_with_retry(self, prompt=_prompts, **params)
|
response = completion_with_retry(
|
||||||
|
self, prompt=_prompts, run_manager=run_manager, **params
|
||||||
|
)
|
||||||
choices.extend(response["choices"])
|
choices.extend(response["choices"])
|
||||||
update_token_usage(_keys, response, token_usage)
|
update_token_usage(_keys, response, token_usage)
|
||||||
return self.create_llm_result(choices, prompts, token_usage)
|
return self.create_llm_result(choices, prompts, token_usage)
|
||||||
@ -428,7 +445,9 @@ class BaseOpenAI(BaseLLM):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response = await acompletion_with_retry(self, prompt=_prompts, **params)
|
response = await acompletion_with_retry(
|
||||||
|
self, prompt=_prompts, run_manager=run_manager, **params
|
||||||
|
)
|
||||||
choices.extend(response["choices"])
|
choices.extend(response["choices"])
|
||||||
update_token_usage(_keys, response, token_usage)
|
update_token_usage(_keys, response, token_usage)
|
||||||
return self.create_llm_result(choices, prompts, token_usage)
|
return self.create_llm_result(choices, prompts, token_usage)
|
||||||
@ -818,7 +837,9 @@ class OpenAIChat(BaseLLM):
|
|||||||
) -> Iterator[GenerationChunk]:
|
) -> Iterator[GenerationChunk]:
|
||||||
messages, params = self._get_chat_params([prompt], stop)
|
messages, params = self._get_chat_params([prompt], stop)
|
||||||
params = {**params, **kwargs, "stream": True}
|
params = {**params, **kwargs, "stream": True}
|
||||||
for stream_resp in completion_with_retry(self, messages=messages, **params):
|
for stream_resp in completion_with_retry(
|
||||||
|
self, messages=messages, run_manager=run_manager, **params
|
||||||
|
):
|
||||||
token = stream_resp["choices"][0]["delta"].get("content", "")
|
token = stream_resp["choices"][0]["delta"].get("content", "")
|
||||||
yield GenerationChunk(text=token)
|
yield GenerationChunk(text=token)
|
||||||
if run_manager:
|
if run_manager:
|
||||||
@ -834,7 +855,7 @@ class OpenAIChat(BaseLLM):
|
|||||||
messages, params = self._get_chat_params([prompt], stop)
|
messages, params = self._get_chat_params([prompt], stop)
|
||||||
params = {**params, **kwargs, "stream": True}
|
params = {**params, **kwargs, "stream": True}
|
||||||
async for stream_resp in await acompletion_with_retry(
|
async for stream_resp in await acompletion_with_retry(
|
||||||
self, messages=messages, **params
|
self, messages=messages, run_manager=run_manager, **params
|
||||||
):
|
):
|
||||||
token = stream_resp["choices"][0]["delta"].get("content", "")
|
token = stream_resp["choices"][0]["delta"].get("content", "")
|
||||||
yield GenerationChunk(text=token)
|
yield GenerationChunk(text=token)
|
||||||
@ -860,7 +881,9 @@ class OpenAIChat(BaseLLM):
|
|||||||
|
|
||||||
messages, params = self._get_chat_params(prompts, stop)
|
messages, params = self._get_chat_params(prompts, stop)
|
||||||
params = {**params, **kwargs}
|
params = {**params, **kwargs}
|
||||||
full_response = completion_with_retry(self, messages=messages, **params)
|
full_response = completion_with_retry(
|
||||||
|
self, messages=messages, run_manager=run_manager, **params
|
||||||
|
)
|
||||||
llm_output = {
|
llm_output = {
|
||||||
"token_usage": full_response["usage"],
|
"token_usage": full_response["usage"],
|
||||||
"model_name": self.model_name,
|
"model_name": self.model_name,
|
||||||
@ -891,7 +914,9 @@ class OpenAIChat(BaseLLM):
|
|||||||
|
|
||||||
messages, params = self._get_chat_params(prompts, stop)
|
messages, params = self._get_chat_params(prompts, stop)
|
||||||
params = {**params, **kwargs}
|
params = {**params, **kwargs}
|
||||||
full_response = await acompletion_with_retry(self, messages=messages, **params)
|
full_response = await acompletion_with_retry(
|
||||||
|
self, messages=messages, run_manager=run_manager, **params
|
||||||
|
)
|
||||||
llm_output = {
|
llm_output = {
|
||||||
"token_usage": full_response["usage"],
|
"token_usage": full_response["usage"],
|
||||||
"model_name": self.model_name,
|
"model_name": self.model_name,
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
"""Test OpenAI API wrapper."""
|
"""Test OpenAI API wrapper."""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Generator
|
from typing import Any, Generator
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -10,7 +10,10 @@ from langchain.chat_models.openai import ChatOpenAI
|
|||||||
from langchain.llms.loading import load_llm
|
from langchain.llms.loading import load_llm
|
||||||
from langchain.llms.openai import OpenAI, OpenAIChat
|
from langchain.llms.openai import OpenAI, OpenAIChat
|
||||||
from langchain.schema import LLMResult
|
from langchain.schema import LLMResult
|
||||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
from tests.unit_tests.callbacks.fake_callback_handler import (
|
||||||
|
FakeAsyncCallbackHandler,
|
||||||
|
FakeCallbackHandler,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_openai_call() -> None:
|
def test_openai_call() -> None:
|
||||||
@ -334,3 +337,77 @@ def test_chat_openai_get_num_tokens(model: str) -> None:
|
|||||||
"""Test get_tokens."""
|
"""Test get_tokens."""
|
||||||
llm = ChatOpenAI(model=model)
|
llm = ChatOpenAI(model=model)
|
||||||
assert llm.get_num_tokens("表情符号是\n🦜🔗") == _EXPECTED_NUM_TOKENS[model]
|
assert llm.get_num_tokens("表情符号是\n🦜🔗") == _EXPECTED_NUM_TOKENS[model]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_completion() -> dict:
|
||||||
|
return {
|
||||||
|
"id": "cmpl-3evkmQda5Hu7fcZavknQda3SQ",
|
||||||
|
"object": "text_completion",
|
||||||
|
"created": 1689989000,
|
||||||
|
"model": "text-davinci-003",
|
||||||
|
"choices": [
|
||||||
|
{"text": "Bar Baz", "index": 0, "logprobs": None, "finish_reason": "length"}
|
||||||
|
],
|
||||||
|
"usage": {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("openai")
|
||||||
|
def test_openai_retries(mock_completion: dict) -> None:
|
||||||
|
llm = OpenAI()
|
||||||
|
mock_client = MagicMock()
|
||||||
|
completed = False
|
||||||
|
raised = False
|
||||||
|
import openai
|
||||||
|
|
||||||
|
def raise_once(*args: Any, **kwargs: Any) -> Any:
|
||||||
|
nonlocal completed, raised
|
||||||
|
if not raised:
|
||||||
|
raised = True
|
||||||
|
raise openai.error.APIError
|
||||||
|
completed = True
|
||||||
|
return mock_completion
|
||||||
|
|
||||||
|
mock_client.create = raise_once
|
||||||
|
callback_handler = FakeCallbackHandler()
|
||||||
|
with patch.object(
|
||||||
|
llm,
|
||||||
|
"client",
|
||||||
|
mock_client,
|
||||||
|
):
|
||||||
|
res = llm.predict("bar", callbacks=[callback_handler])
|
||||||
|
assert res == "Bar Baz"
|
||||||
|
assert completed
|
||||||
|
assert raised
|
||||||
|
assert callback_handler.retries == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("openai")
|
||||||
|
async def test_openai_async_retries(mock_completion: dict) -> None:
|
||||||
|
llm = OpenAI()
|
||||||
|
mock_client = MagicMock()
|
||||||
|
completed = False
|
||||||
|
raised = False
|
||||||
|
import openai
|
||||||
|
|
||||||
|
def raise_once(*args: Any, **kwargs: Any) -> Any:
|
||||||
|
nonlocal completed, raised
|
||||||
|
if not raised:
|
||||||
|
raised = True
|
||||||
|
raise openai.error.APIError
|
||||||
|
completed = True
|
||||||
|
return mock_completion
|
||||||
|
|
||||||
|
mock_client.create = raise_once
|
||||||
|
callback_handler = FakeAsyncCallbackHandler()
|
||||||
|
with patch.object(
|
||||||
|
llm,
|
||||||
|
"client",
|
||||||
|
mock_client,
|
||||||
|
):
|
||||||
|
res = llm.apredict("bar", callbacks=[callback_handler])
|
||||||
|
assert res == "Bar Baz"
|
||||||
|
assert completed
|
||||||
|
assert raised
|
||||||
|
assert callback_handler.retries == 1
|
||||||
|
@ -39,6 +39,7 @@ class BaseFakeCallbackHandler(BaseModel):
|
|||||||
retriever_starts: int = 0
|
retriever_starts: int = 0
|
||||||
retriever_ends: int = 0
|
retriever_ends: int = 0
|
||||||
retriever_errors: int = 0
|
retriever_errors: int = 0
|
||||||
|
retries: int = 0
|
||||||
|
|
||||||
|
|
||||||
class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler):
|
class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler):
|
||||||
@ -58,8 +59,10 @@ class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler):
|
|||||||
def on_llm_new_token_common(self) -> None:
|
def on_llm_new_token_common(self) -> None:
|
||||||
self.llm_streams += 1
|
self.llm_streams += 1
|
||||||
|
|
||||||
|
def on_retry_common(self) -> None:
|
||||||
|
self.retries += 1
|
||||||
|
|
||||||
def on_chain_start_common(self) -> None:
|
def on_chain_start_common(self) -> None:
|
||||||
("CHAIN START")
|
|
||||||
self.chain_starts += 1
|
self.chain_starts += 1
|
||||||
self.starts += 1
|
self.starts += 1
|
||||||
|
|
||||||
@ -82,7 +85,6 @@ class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler):
|
|||||||
self.errors += 1
|
self.errors += 1
|
||||||
|
|
||||||
def on_agent_action_common(self) -> None:
|
def on_agent_action_common(self) -> None:
|
||||||
print("AGENT ACTION")
|
|
||||||
self.agent_actions += 1
|
self.agent_actions += 1
|
||||||
self.starts += 1
|
self.starts += 1
|
||||||
|
|
||||||
@ -91,7 +93,6 @@ class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler):
|
|||||||
self.ends += 1
|
self.ends += 1
|
||||||
|
|
||||||
def on_chat_model_start_common(self) -> None:
|
def on_chat_model_start_common(self) -> None:
|
||||||
print("STARTING CHAT MODEL")
|
|
||||||
self.chat_model_starts += 1
|
self.chat_model_starts += 1
|
||||||
self.starts += 1
|
self.starts += 1
|
||||||
|
|
||||||
@ -162,6 +163,13 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
self.on_llm_error_common()
|
self.on_llm_error_common()
|
||||||
|
|
||||||
|
def on_retry(
|
||||||
|
self,
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
self.on_retry_common()
|
||||||
|
|
||||||
def on_chain_start(
|
def on_chain_start(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
|
@ -1,8 +1,12 @@
|
|||||||
"""Test OpenAI Chat API wrapper."""
|
"""Test OpenAI Chat API wrapper."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from langchain.chat_models.openai import (
|
from langchain.chat_models.openai import (
|
||||||
|
ChatOpenAI,
|
||||||
_convert_dict_to_message,
|
_convert_dict_to_message,
|
||||||
)
|
)
|
||||||
from langchain.schema.messages import FunctionMessage
|
from langchain.schema.messages import FunctionMessage
|
||||||
@ -21,3 +25,67 @@ def test_function_message_dict_to_function_message() -> None:
|
|||||||
assert isinstance(result, FunctionMessage)
|
assert isinstance(result, FunctionMessage)
|
||||||
assert result.name == name
|
assert result.name == name
|
||||||
assert result.content == content
|
assert result.content == content
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_completion() -> dict:
|
||||||
|
return {
|
||||||
|
"id": "chatcmpl-7fcZavknQda3SQ",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": 1689989000,
|
||||||
|
"model": "gpt-3.5-turbo-0613",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Bar Baz",
|
||||||
|
},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("openai")
|
||||||
|
def test_openai_predict(mock_completion: dict) -> None:
|
||||||
|
llm = ChatOpenAI()
|
||||||
|
mock_client = MagicMock()
|
||||||
|
completed = False
|
||||||
|
|
||||||
|
def mock_create(*args: Any, **kwargs: Any) -> Any:
|
||||||
|
nonlocal completed
|
||||||
|
completed = True
|
||||||
|
return mock_completion
|
||||||
|
|
||||||
|
mock_client.create = mock_create
|
||||||
|
with patch.object(
|
||||||
|
llm,
|
||||||
|
"client",
|
||||||
|
mock_client,
|
||||||
|
):
|
||||||
|
res = llm.predict("bar")
|
||||||
|
assert res == "Bar Baz"
|
||||||
|
assert completed
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("openai")
|
||||||
|
async def test_openai_apredict(mock_completion: dict) -> None:
|
||||||
|
llm = ChatOpenAI()
|
||||||
|
mock_client = MagicMock()
|
||||||
|
completed = False
|
||||||
|
|
||||||
|
def mock_create(*args: Any, **kwargs: Any) -> Any:
|
||||||
|
nonlocal completed
|
||||||
|
completed = True
|
||||||
|
return mock_completion
|
||||||
|
|
||||||
|
mock_client.create = mock_create
|
||||||
|
with patch.object(
|
||||||
|
llm,
|
||||||
|
"client",
|
||||||
|
mock_client,
|
||||||
|
):
|
||||||
|
res = llm.predict("bar")
|
||||||
|
assert res == "Bar Baz"
|
||||||
|
assert completed
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -26,3 +28,61 @@ def test_openai_incorrect_field() -> None:
|
|||||||
with pytest.warns(match="not default parameter"):
|
with pytest.warns(match="not default parameter"):
|
||||||
llm = OpenAI(foo="bar")
|
llm = OpenAI(foo="bar")
|
||||||
assert llm.model_kwargs == {"foo": "bar"}
|
assert llm.model_kwargs == {"foo": "bar"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_completion() -> dict:
|
||||||
|
return {
|
||||||
|
"id": "cmpl-3evkmQda5Hu7fcZavknQda3SQ",
|
||||||
|
"object": "text_completion",
|
||||||
|
"created": 1689989000,
|
||||||
|
"model": "text-davinci-003",
|
||||||
|
"choices": [
|
||||||
|
{"text": "Bar Baz", "index": 0, "logprobs": None, "finish_reason": "length"}
|
||||||
|
],
|
||||||
|
"usage": {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("openai")
|
||||||
|
def test_openai_calls(mock_completion: dict) -> None:
|
||||||
|
llm = OpenAI()
|
||||||
|
mock_client = MagicMock()
|
||||||
|
completed = False
|
||||||
|
|
||||||
|
def raise_once(*args: Any, **kwargs: Any) -> Any:
|
||||||
|
nonlocal completed
|
||||||
|
completed = True
|
||||||
|
return mock_completion
|
||||||
|
|
||||||
|
mock_client.create = raise_once
|
||||||
|
with patch.object(
|
||||||
|
llm,
|
||||||
|
"client",
|
||||||
|
mock_client,
|
||||||
|
):
|
||||||
|
res = llm.predict("bar")
|
||||||
|
assert res == "Bar Baz"
|
||||||
|
assert completed
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("openai")
|
||||||
|
async def test_openai_async_retries(mock_completion: dict) -> None:
|
||||||
|
llm = OpenAI()
|
||||||
|
mock_client = MagicMock()
|
||||||
|
completed = False
|
||||||
|
|
||||||
|
def raise_once(*args: Any, **kwargs: Any) -> Any:
|
||||||
|
nonlocal completed
|
||||||
|
completed = True
|
||||||
|
return mock_completion
|
||||||
|
|
||||||
|
mock_client.create = raise_once
|
||||||
|
with patch.object(
|
||||||
|
llm,
|
||||||
|
"client",
|
||||||
|
mock_client,
|
||||||
|
):
|
||||||
|
res = llm.apredict("bar")
|
||||||
|
assert res == "Bar Baz"
|
||||||
|
assert completed
|
||||||
|
Loading…
Reference in New Issue
Block a user