mirror of
https://github.com/hwchase17/langchain.git
synced 2025-04-28 11:55:21 +00:00
openai[patch]: code cleaning (#17355)
h/t @tdene for finding cleanup op in #17047
This commit is contained in:
parent
a9d6da609a
commit
42648061ad
@ -1,10 +1,10 @@
|
||||
"""OpenAI chat wrapper."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
@ -64,6 +64,7 @@ from langchain_core.utils.function_calling import (
|
||||
convert_to_openai_function,
|
||||
convert_to_openai_tool,
|
||||
)
|
||||
from langchain_core.utils.utils import build_extra_kwargs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -290,25 +291,9 @@ class ChatOpenAI(BaseChatModel):
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
extra = values.get("model_kwargs", {})
|
||||
for field_name in list(values):
|
||||
if field_name in extra:
|
||||
raise ValueError(f"Found {field_name} supplied twice.")
|
||||
if field_name not in all_required_field_names:
|
||||
warnings.warn(
|
||||
f"""WARNING! {field_name} is not default parameter.
|
||||
{field_name} was transferred to model_kwargs.
|
||||
Please confirm that {field_name} is what you intended."""
|
||||
)
|
||||
extra[field_name] = values.pop(field_name)
|
||||
|
||||
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
|
||||
if invalid_model_kwargs:
|
||||
raise ValueError(
|
||||
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
|
||||
f"Instead they were passed in as part of `model_kwargs` parameter."
|
||||
)
|
||||
|
||||
values["model_kwargs"] = extra
|
||||
values["model_kwargs"] = build_extra_kwargs(
|
||||
extra, values, all_required_field_names
|
||||
)
|
||||
return values
|
||||
|
||||
@root_validator()
|
||||
@ -339,9 +324,11 @@ class ChatOpenAI(BaseChatModel):
|
||||
)
|
||||
|
||||
client_params = {
|
||||
"api_key": values["openai_api_key"].get_secret_value()
|
||||
if values["openai_api_key"]
|
||||
else None,
|
||||
"api_key": (
|
||||
values["openai_api_key"].get_secret_value()
|
||||
if values["openai_api_key"]
|
||||
else None
|
||||
),
|
||||
"organization": values["openai_organization"],
|
||||
"base_url": values["openai_api_base"],
|
||||
"timeout": values["request_timeout"],
|
||||
|
@ -198,9 +198,11 @@ class BaseOpenAI(BaseLLM):
|
||||
)
|
||||
|
||||
client_params = {
|
||||
"api_key": values["openai_api_key"].get_secret_value()
|
||||
if values["openai_api_key"]
|
||||
else None,
|
||||
"api_key": (
|
||||
values["openai_api_key"].get_secret_value()
|
||||
if values["openai_api_key"]
|
||||
else None
|
||||
),
|
||||
"organization": values["openai_organization"],
|
||||
"base_url": values["openai_api_base"],
|
||||
"timeout": values["request_timeout"],
|
||||
@ -257,9 +259,11 @@ class BaseOpenAI(BaseLLM):
|
||||
chunk.text,
|
||||
chunk=chunk,
|
||||
verbose=self.verbose,
|
||||
logprobs=chunk.generation_info["logprobs"]
|
||||
if chunk.generation_info
|
||||
else None,
|
||||
logprobs=(
|
||||
chunk.generation_info["logprobs"]
|
||||
if chunk.generation_info
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
async def _astream(
|
||||
@ -283,9 +287,11 @@ class BaseOpenAI(BaseLLM):
|
||||
chunk.text,
|
||||
chunk=chunk,
|
||||
verbose=self.verbose,
|
||||
logprobs=chunk.generation_info["logprobs"]
|
||||
if chunk.generation_info
|
||||
else None,
|
||||
logprobs=(
|
||||
chunk.generation_info["logprobs"]
|
||||
if chunk.generation_info
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
def _generate(
|
||||
@ -334,12 +340,16 @@ class BaseOpenAI(BaseLLM):
|
||||
choices.append(
|
||||
{
|
||||
"text": generation.text,
|
||||
"finish_reason": generation.generation_info.get("finish_reason")
|
||||
if generation.generation_info
|
||||
else None,
|
||||
"logprobs": generation.generation_info.get("logprobs")
|
||||
if generation.generation_info
|
||||
else None,
|
||||
"finish_reason": (
|
||||
generation.generation_info.get("finish_reason")
|
||||
if generation.generation_info
|
||||
else None
|
||||
),
|
||||
"logprobs": (
|
||||
generation.generation_info.get("logprobs")
|
||||
if generation.generation_info
|
||||
else None
|
||||
),
|
||||
}
|
||||
)
|
||||
else:
|
||||
@ -395,12 +405,16 @@ class BaseOpenAI(BaseLLM):
|
||||
choices.append(
|
||||
{
|
||||
"text": generation.text,
|
||||
"finish_reason": generation.generation_info.get("finish_reason")
|
||||
if generation.generation_info
|
||||
else None,
|
||||
"logprobs": generation.generation_info.get("logprobs")
|
||||
if generation.generation_info
|
||||
else None,
|
||||
"finish_reason": (
|
||||
generation.generation_info.get("finish_reason")
|
||||
if generation.generation_info
|
||||
else None
|
||||
),
|
||||
"logprobs": (
|
||||
generation.generation_info.get("logprobs")
|
||||
if generation.generation_info
|
||||
else None
|
||||
),
|
||||
}
|
||||
)
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user