mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-05 11:12:47 +00:00
WIP: openai settings (#5792)
[] need to test more [] make sure they arent saved when serializing [] do for embeddings
This commit is contained in:
parent
b7999a9bc1
commit
3954bcf396
@ -53,33 +53,33 @@ class AzureChatOpenAI(ChatOpenAI):
|
|||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
"""Validate that api key and python package exists in environment."""
|
"""Validate that api key and python package exists in environment."""
|
||||||
openai_api_key = get_from_dict_or_env(
|
values["openai_api_key"] = get_from_dict_or_env(
|
||||||
values,
|
values,
|
||||||
"openai_api_key",
|
"openai_api_key",
|
||||||
"OPENAI_API_KEY",
|
"OPENAI_API_KEY",
|
||||||
)
|
)
|
||||||
openai_api_base = get_from_dict_or_env(
|
values["openai_api_base"] = get_from_dict_or_env(
|
||||||
values,
|
values,
|
||||||
"openai_api_base",
|
"openai_api_base",
|
||||||
"OPENAI_API_BASE",
|
"OPENAI_API_BASE",
|
||||||
)
|
)
|
||||||
openai_api_version = get_from_dict_or_env(
|
values["openai_api_version"] = get_from_dict_or_env(
|
||||||
values,
|
values,
|
||||||
"openai_api_version",
|
"openai_api_version",
|
||||||
"OPENAI_API_VERSION",
|
"OPENAI_API_VERSION",
|
||||||
)
|
)
|
||||||
openai_api_type = get_from_dict_or_env(
|
values["openai_api_type"] = get_from_dict_or_env(
|
||||||
values,
|
values,
|
||||||
"openai_api_type",
|
"openai_api_type",
|
||||||
"OPENAI_API_TYPE",
|
"OPENAI_API_TYPE",
|
||||||
)
|
)
|
||||||
openai_organization = get_from_dict_or_env(
|
values["openai_organization"] = get_from_dict_or_env(
|
||||||
values,
|
values,
|
||||||
"openai_organization",
|
"openai_organization",
|
||||||
"OPENAI_ORGANIZATION",
|
"OPENAI_ORGANIZATION",
|
||||||
default="",
|
default="",
|
||||||
)
|
)
|
||||||
openai_proxy = get_from_dict_or_env(
|
values["openai_proxy"] = get_from_dict_or_env(
|
||||||
values,
|
values,
|
||||||
"openai_proxy",
|
"openai_proxy",
|
||||||
"OPENAI_PROXY",
|
"OPENAI_PROXY",
|
||||||
@ -88,14 +88,6 @@ class AzureChatOpenAI(ChatOpenAI):
|
|||||||
try:
|
try:
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
openai.api_type = openai_api_type
|
|
||||||
openai.api_base = openai_api_base
|
|
||||||
openai.api_version = openai_api_version
|
|
||||||
openai.api_key = openai_api_key
|
|
||||||
if openai_organization:
|
|
||||||
openai.organization = openai_organization
|
|
||||||
if openai_proxy:
|
|
||||||
openai.proxy = {"http": openai_proxy, "https": openai_proxy} # type: ignore[assignment] # noqa: E501
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Could not import openai python package. "
|
"Could not import openai python package. "
|
||||||
@ -128,6 +120,14 @@ class AzureChatOpenAI(ChatOpenAI):
|
|||||||
"""Get the identifying parameters."""
|
"""Get the identifying parameters."""
|
||||||
return {**self._default_params}
|
return {**self._default_params}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _invocation_params(self) -> Mapping[str, Any]:
|
||||||
|
openai_creds = {
|
||||||
|
"api_type": self.openai_api_type,
|
||||||
|
"api_version": self.openai_api_version,
|
||||||
|
}
|
||||||
|
return {**openai_creds, **super()._invocation_params}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
return "azure-openai-chat"
|
return "azure-openai-chat"
|
||||||
|
@ -196,22 +196,22 @@ class ChatOpenAI(BaseChatModel):
|
|||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
"""Validate that api key and python package exists in environment."""
|
"""Validate that api key and python package exists in environment."""
|
||||||
openai_api_key = get_from_dict_or_env(
|
values["openai_api_key"] = get_from_dict_or_env(
|
||||||
values, "openai_api_key", "OPENAI_API_KEY"
|
values, "openai_api_key", "OPENAI_API_KEY"
|
||||||
)
|
)
|
||||||
openai_organization = get_from_dict_or_env(
|
values["openai_organization"] = get_from_dict_or_env(
|
||||||
values,
|
values,
|
||||||
"openai_organization",
|
"openai_organization",
|
||||||
"OPENAI_ORGANIZATION",
|
"OPENAI_ORGANIZATION",
|
||||||
default="",
|
default="",
|
||||||
)
|
)
|
||||||
openai_api_base = get_from_dict_or_env(
|
values["openai_api_base"] = get_from_dict_or_env(
|
||||||
values,
|
values,
|
||||||
"openai_api_base",
|
"openai_api_base",
|
||||||
"OPENAI_API_BASE",
|
"OPENAI_API_BASE",
|
||||||
default="",
|
default="",
|
||||||
)
|
)
|
||||||
openai_proxy = get_from_dict_or_env(
|
values["openai_proxy"] = get_from_dict_or_env(
|
||||||
values,
|
values,
|
||||||
"openai_proxy",
|
"openai_proxy",
|
||||||
"OPENAI_PROXY",
|
"OPENAI_PROXY",
|
||||||
@ -225,13 +225,6 @@ class ChatOpenAI(BaseChatModel):
|
|||||||
"Could not import openai python package. "
|
"Could not import openai python package. "
|
||||||
"Please install it with `pip install openai`."
|
"Please install it with `pip install openai`."
|
||||||
)
|
)
|
||||||
openai.api_key = openai_api_key
|
|
||||||
if openai_organization:
|
|
||||||
openai.organization = openai_organization
|
|
||||||
if openai_api_base:
|
|
||||||
openai.api_base = openai_api_base
|
|
||||||
if openai_proxy:
|
|
||||||
openai.proxy = {"http": openai_proxy, "https": openai_proxy} # type: ignore[assignment] # noqa: E501
|
|
||||||
try:
|
try:
|
||||||
values["client"] = openai.ChatCompletion
|
values["client"] = openai.ChatCompletion
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
@ -333,7 +326,7 @@ class ChatOpenAI(BaseChatModel):
|
|||||||
def _create_message_dicts(
|
def _create_message_dicts(
|
||||||
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
||||||
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
||||||
params: Dict[str, Any] = {**{"model": self.model_name}, **self._default_params}
|
params = dict(self._invocation_params)
|
||||||
if stop is not None:
|
if stop is not None:
|
||||||
if "stop" in params:
|
if "stop" in params:
|
||||||
raise ValueError("`stop` found in both the input and default params.")
|
raise ValueError("`stop` found in both the input and default params.")
|
||||||
@ -384,6 +377,21 @@ class ChatOpenAI(BaseChatModel):
|
|||||||
"""Get the identifying parameters."""
|
"""Get the identifying parameters."""
|
||||||
return {**{"model_name": self.model_name}, **self._default_params}
|
return {**{"model_name": self.model_name}, **self._default_params}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _invocation_params(self) -> Mapping[str, Any]:
|
||||||
|
"""Get the parameters used to invoke the model."""
|
||||||
|
openai_creds: Dict[str, Any] = {
|
||||||
|
"api_key": self.openai_api_key,
|
||||||
|
"api_base": self.openai_api_base,
|
||||||
|
"organization": self.openai_organization,
|
||||||
|
"model": self.model_name,
|
||||||
|
}
|
||||||
|
if self.openai_proxy:
|
||||||
|
openai_creds["proxy"] = (
|
||||||
|
{"http": self.openai_proxy, "https": self.openai_proxy},
|
||||||
|
)
|
||||||
|
return {**openai_creds, **self._default_params}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
"""Return type of chat model."""
|
"""Return type of chat model."""
|
||||||
|
@ -136,38 +136,38 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
"""Validate that api key and python package exists in environment."""
|
"""Validate that api key and python package exists in environment."""
|
||||||
openai_api_key = get_from_dict_or_env(
|
values["openai_api_key"] = get_from_dict_or_env(
|
||||||
values, "openai_api_key", "OPENAI_API_KEY"
|
values, "openai_api_key", "OPENAI_API_KEY"
|
||||||
)
|
)
|
||||||
openai_api_base = get_from_dict_or_env(
|
values["openai_api_base"] = get_from_dict_or_env(
|
||||||
values,
|
values,
|
||||||
"openai_api_base",
|
"openai_api_base",
|
||||||
"OPENAI_API_BASE",
|
"OPENAI_API_BASE",
|
||||||
default="",
|
default="",
|
||||||
)
|
)
|
||||||
openai_api_type = get_from_dict_or_env(
|
values["openai_api_type"] = get_from_dict_or_env(
|
||||||
values,
|
values,
|
||||||
"openai_api_type",
|
"openai_api_type",
|
||||||
"OPENAI_API_TYPE",
|
"OPENAI_API_TYPE",
|
||||||
default="",
|
default="",
|
||||||
)
|
)
|
||||||
openai_proxy = get_from_dict_or_env(
|
values["openai_proxy"] = get_from_dict_or_env(
|
||||||
values,
|
values,
|
||||||
"openai_proxy",
|
"openai_proxy",
|
||||||
"OPENAI_PROXY",
|
"OPENAI_PROXY",
|
||||||
default="",
|
default="",
|
||||||
)
|
)
|
||||||
if openai_api_type in ("azure", "azure_ad", "azuread"):
|
if values["openai_api_type"] in ("azure", "azure_ad", "azuread"):
|
||||||
default_api_version = "2022-12-01"
|
default_api_version = "2022-12-01"
|
||||||
else:
|
else:
|
||||||
default_api_version = ""
|
default_api_version = ""
|
||||||
openai_api_version = get_from_dict_or_env(
|
values["openai_api_version"] = get_from_dict_or_env(
|
||||||
values,
|
values,
|
||||||
"openai_api_version",
|
"openai_api_version",
|
||||||
"OPENAI_API_VERSION",
|
"OPENAI_API_VERSION",
|
||||||
default=default_api_version,
|
default=default_api_version,
|
||||||
)
|
)
|
||||||
openai_organization = get_from_dict_or_env(
|
values["openai_organization"] = get_from_dict_or_env(
|
||||||
values,
|
values,
|
||||||
"openai_organization",
|
"openai_organization",
|
||||||
"OPENAI_ORGANIZATION",
|
"OPENAI_ORGANIZATION",
|
||||||
@ -176,17 +176,6 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|||||||
try:
|
try:
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
openai.api_key = openai_api_key
|
|
||||||
if openai_organization:
|
|
||||||
openai.organization = openai_organization
|
|
||||||
if openai_api_base:
|
|
||||||
openai.api_base = openai_api_base
|
|
||||||
if openai_api_type:
|
|
||||||
openai.api_version = openai_api_version
|
|
||||||
if openai_api_type:
|
|
||||||
openai.api_type = openai_api_type
|
|
||||||
if openai_proxy:
|
|
||||||
openai.proxy = {"http": openai_proxy, "https": openai_proxy} # type: ignore[assignment] # noqa: E501
|
|
||||||
values["client"] = openai.Embedding
|
values["client"] = openai.Embedding
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
@ -195,6 +184,25 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|||||||
)
|
)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _invocation_params(self) -> Dict:
|
||||||
|
openai_args = {
|
||||||
|
"engine": self.deployment,
|
||||||
|
"request_timeout": self.request_timeout,
|
||||||
|
"headers": self.headers,
|
||||||
|
"api_key": self.openai_api_key,
|
||||||
|
"organization": self.openai_organization,
|
||||||
|
"api_base": self.openai_api_base,
|
||||||
|
"api_type": self.openai_api_type,
|
||||||
|
"api_version": self.openai_api_version,
|
||||||
|
}
|
||||||
|
if self.openai_proxy:
|
||||||
|
openai_args["proxy"] = {
|
||||||
|
"http": self.openai_proxy,
|
||||||
|
"https": self.openai_proxy,
|
||||||
|
}
|
||||||
|
return openai_args
|
||||||
|
|
||||||
# please refer to
|
# please refer to
|
||||||
# https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb
|
# https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb
|
||||||
def _get_len_safe_embeddings(
|
def _get_len_safe_embeddings(
|
||||||
@ -233,9 +241,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|||||||
response = embed_with_retry(
|
response = embed_with_retry(
|
||||||
self,
|
self,
|
||||||
input=tokens[i : i + _chunk_size],
|
input=tokens[i : i + _chunk_size],
|
||||||
engine=self.deployment,
|
**self._invocation_params,
|
||||||
request_timeout=self.request_timeout,
|
|
||||||
headers=self.headers,
|
|
||||||
)
|
)
|
||||||
batched_embeddings += [r["embedding"] for r in response["data"]]
|
batched_embeddings += [r["embedding"] for r in response["data"]]
|
||||||
|
|
||||||
|
@ -211,22 +211,22 @@ class BaseOpenAI(BaseLLM):
|
|||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
"""Validate that api key and python package exists in environment."""
|
"""Validate that api key and python package exists in environment."""
|
||||||
openai_api_key = get_from_dict_or_env(
|
values["openai_api_key"] = get_from_dict_or_env(
|
||||||
values, "openai_api_key", "OPENAI_API_KEY"
|
values, "openai_api_key", "OPENAI_API_KEY"
|
||||||
)
|
)
|
||||||
openai_api_base = get_from_dict_or_env(
|
values["openai_api_base"] = get_from_dict_or_env(
|
||||||
values,
|
values,
|
||||||
"openai_api_base",
|
"openai_api_base",
|
||||||
"OPENAI_API_BASE",
|
"OPENAI_API_BASE",
|
||||||
default="",
|
default="",
|
||||||
)
|
)
|
||||||
openai_proxy = get_from_dict_or_env(
|
values["openai_proxy"] = get_from_dict_or_env(
|
||||||
values,
|
values,
|
||||||
"openai_proxy",
|
"openai_proxy",
|
||||||
"OPENAI_PROXY",
|
"OPENAI_PROXY",
|
||||||
default="",
|
default="",
|
||||||
)
|
)
|
||||||
openai_organization = get_from_dict_or_env(
|
values["openai_organization"] = get_from_dict_or_env(
|
||||||
values,
|
values,
|
||||||
"openai_organization",
|
"openai_organization",
|
||||||
"OPENAI_ORGANIZATION",
|
"OPENAI_ORGANIZATION",
|
||||||
@ -235,13 +235,6 @@ class BaseOpenAI(BaseLLM):
|
|||||||
try:
|
try:
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
openai.api_key = openai_api_key
|
|
||||||
if openai_api_base:
|
|
||||||
openai.api_base = openai_api_base
|
|
||||||
if openai_organization:
|
|
||||||
openai.organization = openai_organization
|
|
||||||
if openai_proxy:
|
|
||||||
openai.proxy = {"http": openai_proxy, "https": openai_proxy} # type: ignore[assignment] # noqa: E501
|
|
||||||
values["client"] = openai.Completion
|
values["client"] = openai.Completion
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
@ -452,7 +445,17 @@ class BaseOpenAI(BaseLLM):
|
|||||||
@property
|
@property
|
||||||
def _invocation_params(self) -> Dict[str, Any]:
|
def _invocation_params(self) -> Dict[str, Any]:
|
||||||
"""Get the parameters used to invoke the model."""
|
"""Get the parameters used to invoke the model."""
|
||||||
return self._default_params
|
openai_creds: Dict[str, Any] = {
|
||||||
|
"api_key": self.openai_api_key,
|
||||||
|
"api_base": self.openai_api_base,
|
||||||
|
"organization": self.openai_organization,
|
||||||
|
}
|
||||||
|
if self.openai_proxy:
|
||||||
|
openai_creds["proxy"] = {
|
||||||
|
"http": self.openai_proxy,
|
||||||
|
"https": self.openai_proxy,
|
||||||
|
}
|
||||||
|
return {**openai_creds, **self._default_params}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _identifying_params(self) -> Mapping[str, Any]:
|
def _identifying_params(self) -> Mapping[str, Any]:
|
||||||
@ -596,6 +599,22 @@ class AzureOpenAI(BaseOpenAI):
|
|||||||
|
|
||||||
deployment_name: str = ""
|
deployment_name: str = ""
|
||||||
"""Deployment name to use."""
|
"""Deployment name to use."""
|
||||||
|
openai_api_type: str = "azure"
|
||||||
|
openai_api_version: str = ""
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def validate_azure_settings(cls, values: Dict) -> Dict:
|
||||||
|
values["openai_api_version"] = get_from_dict_or_env(
|
||||||
|
values,
|
||||||
|
"openai_api_version",
|
||||||
|
"OPENAI_API_VERSION",
|
||||||
|
)
|
||||||
|
values["openai_api_type"] = get_from_dict_or_env(
|
||||||
|
values,
|
||||||
|
"openai_api_type",
|
||||||
|
"OPENAI_API_TYPE",
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _identifying_params(self) -> Mapping[str, Any]:
|
def _identifying_params(self) -> Mapping[str, Any]:
|
||||||
@ -606,7 +625,12 @@ class AzureOpenAI(BaseOpenAI):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def _invocation_params(self) -> Dict[str, Any]:
|
def _invocation_params(self) -> Dict[str, Any]:
|
||||||
return {**{"engine": self.deployment_name}, **super()._invocation_params}
|
openai_params = {
|
||||||
|
"engine": self.deployment_name,
|
||||||
|
"api_type": self.openai_api_type,
|
||||||
|
"api_version": self.openai_api_version,
|
||||||
|
}
|
||||||
|
return {**openai_params, **super()._invocation_params}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
|
Loading…
Reference in New Issue
Block a user