openai[patch]: add checking codes for calling AI model get error (#17909)

**Description:**: adding checking codes for calling AI model get error
in chat_models/base.py and llms/base.py
**Issue**: Sometimes the AI Model calling will get error, we should
raise it.
Otherwise, the next code 'choices.extend(response["choices"])' will
throw a "TypeError: 'NoneType' object is not iterable" error to mask the
true error.
       Because 'response["choices"]' is None.
**Dependencies**: None

---------

Co-authored-by: yangkx <yangkx@asiainfo-int.com>
Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
This commit is contained in:
Kaixin Yang 2024-03-29 08:17:32 +08:00 committed by GitHub
parent 833d61adb3
commit a8104ea8e9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 15 additions and 0 deletions

View File

@ -522,6 +522,14 @@ class ChatOpenAI(BaseChatModel):
generations = []
if not isinstance(response, dict):
response = response.model_dump()
# Sometimes the AI Model calling will get error, we should raise it.
# Otherwise, the next code 'choices.extend(response["choices"])'
# will throw a "TypeError: 'NoneType' object is not iterable" error
# to mask the true error. Because 'response["choices"]' is None.
if response.get("error"):
raise ValueError(response.get("error"))
for res in response["choices"]:
message = _convert_dict_to_message(res["message"])
generation_info = dict(finish_reason=res.get("finish_reason"))

View File

@ -371,6 +371,13 @@ class BaseOpenAI(BaseLLM):
# dict. For the transition period, we deep convert it to dict.
response = response.model_dump()
# Sometimes the AI Model calling will get error, we should raise it.
# Otherwise, the next code 'choices.extend(response["choices"])'
# will throw a "TypeError: 'NoneType' object is not iterable" error
# to mask the true error. Because 'response["choices"]' is None.
if response.get("error"):
raise ValueError(response.get("error"))
choices.extend(response["choices"])
_update_token_usage(_keys, response, token_usage)
if not system_fingerprint: