Compare commits

...

2 Commits

Author SHA1 Message Date
Bagatur
40ffb6253b undo 2023-07-31 13:11:17 -07:00
Bagatur
fb1d36a6fc rfc 2023-07-31 13:10:16 -07:00

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import logging
import sys
import warnings
from copy import deepcopy
from typing import (
AbstractSet,
Any,
@@ -15,13 +16,16 @@ from typing import (
Literal,
Mapping,
Optional,
Sequence,
Set,
Tuple,
Union,
)
import openai.error
from pydantic import Field, root_validator
from langchain.callbacks.base import Callbacks
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
@@ -192,6 +196,56 @@ class BaseOpenAI(BaseLLM):
when using one of the many model providers that expose an OpenAI-like
API but with different models. In those cases, in order to avoid erroring
when tiktoken is called, you can specify a model name to use here."""
fallbacks: Optional[Sequence[Union[str, Mapping, BaseLLM]]] = None
""""""
def _get_fallbacks(self) -> Optional[Iterator[BaseLLM]]:
if not self.fallbacks:
return None
for fb in self.fallbacks:
if isinstance(fb, str):
copy = deepcopy(self.__dict__)
copy["model_name"] = fb
copy.pop("fallbacks")
yield self.__class__(**copy)
elif isinstance(fb, Mapping):
copy = deepcopy(self.__dict__)
copy = {**copy, **fb}
copy.pop("fallbacks")
yield self.__class__(**copy)
else:
yield fb
def generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None,
*,
tags: Optional[Union[List[str], List[List[str]]]] = None,
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
**kwargs: Any,
) -> LLMResult:
try:
return super().generate(
prompts, stop=stop, callbacks=callbacks, tags=tags, metadata=metadata
)
except Exception as e:
fallbacks = self._get_fallbacks()
if not fallbacks:
raise e
for fb in fallbacks:
try:
return fb.generate(
prompts,
stop=stop,
callbacks=callbacks,
tags=tags,
metadata=metadata,
)
except Exception:
pass
raise e
def __new__(cls, **data: Any) -> Union[OpenAIChat, BaseOpenAI]: # type: ignore
"""Initialize the OpenAI object."""