mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-31 00:29:57 +00:00
WIP: openai settings (#5792)
[] need to test more [] make sure they arent saved when serializing [] do for embeddings
This commit is contained in:
parent
b7999a9bc1
commit
3954bcf396
@ -53,33 +53,33 @@ class AzureChatOpenAI(ChatOpenAI):
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
openai_api_key = get_from_dict_or_env(
|
||||
values["openai_api_key"] = get_from_dict_or_env(
|
||||
values,
|
||||
"openai_api_key",
|
||||
"OPENAI_API_KEY",
|
||||
)
|
||||
openai_api_base = get_from_dict_or_env(
|
||||
values["openai_api_base"] = get_from_dict_or_env(
|
||||
values,
|
||||
"openai_api_base",
|
||||
"OPENAI_API_BASE",
|
||||
)
|
||||
openai_api_version = get_from_dict_or_env(
|
||||
values["openai_api_version"] = get_from_dict_or_env(
|
||||
values,
|
||||
"openai_api_version",
|
||||
"OPENAI_API_VERSION",
|
||||
)
|
||||
openai_api_type = get_from_dict_or_env(
|
||||
values["openai_api_type"] = get_from_dict_or_env(
|
||||
values,
|
||||
"openai_api_type",
|
||||
"OPENAI_API_TYPE",
|
||||
)
|
||||
openai_organization = get_from_dict_or_env(
|
||||
values["openai_organization"] = get_from_dict_or_env(
|
||||
values,
|
||||
"openai_organization",
|
||||
"OPENAI_ORGANIZATION",
|
||||
default="",
|
||||
)
|
||||
openai_proxy = get_from_dict_or_env(
|
||||
values["openai_proxy"] = get_from_dict_or_env(
|
||||
values,
|
||||
"openai_proxy",
|
||||
"OPENAI_PROXY",
|
||||
@ -88,14 +88,6 @@ class AzureChatOpenAI(ChatOpenAI):
|
||||
try:
|
||||
import openai
|
||||
|
||||
openai.api_type = openai_api_type
|
||||
openai.api_base = openai_api_base
|
||||
openai.api_version = openai_api_version
|
||||
openai.api_key = openai_api_key
|
||||
if openai_organization:
|
||||
openai.organization = openai_organization
|
||||
if openai_proxy:
|
||||
openai.proxy = {"http": openai_proxy, "https": openai_proxy} # type: ignore[assignment] # noqa: E501
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import openai python package. "
|
||||
@ -128,6 +120,14 @@ class AzureChatOpenAI(ChatOpenAI):
|
||||
"""Get the identifying parameters."""
|
||||
return {**self._default_params}
|
||||
|
||||
@property
|
||||
def _invocation_params(self) -> Mapping[str, Any]:
|
||||
openai_creds = {
|
||||
"api_type": self.openai_api_type,
|
||||
"api_version": self.openai_api_version,
|
||||
}
|
||||
return {**openai_creds, **super()._invocation_params}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "azure-openai-chat"
|
||||
|
@ -196,22 +196,22 @@ class ChatOpenAI(BaseChatModel):
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
openai_api_key = get_from_dict_or_env(
|
||||
values["openai_api_key"] = get_from_dict_or_env(
|
||||
values, "openai_api_key", "OPENAI_API_KEY"
|
||||
)
|
||||
openai_organization = get_from_dict_or_env(
|
||||
values["openai_organization"] = get_from_dict_or_env(
|
||||
values,
|
||||
"openai_organization",
|
||||
"OPENAI_ORGANIZATION",
|
||||
default="",
|
||||
)
|
||||
openai_api_base = get_from_dict_or_env(
|
||||
values["openai_api_base"] = get_from_dict_or_env(
|
||||
values,
|
||||
"openai_api_base",
|
||||
"OPENAI_API_BASE",
|
||||
default="",
|
||||
)
|
||||
openai_proxy = get_from_dict_or_env(
|
||||
values["openai_proxy"] = get_from_dict_or_env(
|
||||
values,
|
||||
"openai_proxy",
|
||||
"OPENAI_PROXY",
|
||||
@ -225,13 +225,6 @@ class ChatOpenAI(BaseChatModel):
|
||||
"Could not import openai python package. "
|
||||
"Please install it with `pip install openai`."
|
||||
)
|
||||
openai.api_key = openai_api_key
|
||||
if openai_organization:
|
||||
openai.organization = openai_organization
|
||||
if openai_api_base:
|
||||
openai.api_base = openai_api_base
|
||||
if openai_proxy:
|
||||
openai.proxy = {"http": openai_proxy, "https": openai_proxy} # type: ignore[assignment] # noqa: E501
|
||||
try:
|
||||
values["client"] = openai.ChatCompletion
|
||||
except AttributeError:
|
||||
@ -333,7 +326,7 @@ class ChatOpenAI(BaseChatModel):
|
||||
def _create_message_dicts(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
||||
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
||||
params: Dict[str, Any] = {**{"model": self.model_name}, **self._default_params}
|
||||
params = dict(self._invocation_params)
|
||||
if stop is not None:
|
||||
if "stop" in params:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
@ -384,6 +377,21 @@ class ChatOpenAI(BaseChatModel):
|
||||
"""Get the identifying parameters."""
|
||||
return {**{"model_name": self.model_name}, **self._default_params}
|
||||
|
||||
@property
|
||||
def _invocation_params(self) -> Mapping[str, Any]:
|
||||
"""Get the parameters used to invoke the model."""
|
||||
openai_creds: Dict[str, Any] = {
|
||||
"api_key": self.openai_api_key,
|
||||
"api_base": self.openai_api_base,
|
||||
"organization": self.openai_organization,
|
||||
"model": self.model_name,
|
||||
}
|
||||
if self.openai_proxy:
|
||||
openai_creds["proxy"] = (
|
||||
{"http": self.openai_proxy, "https": self.openai_proxy},
|
||||
)
|
||||
return {**openai_creds, **self._default_params}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
|
@ -136,38 +136,38 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
openai_api_key = get_from_dict_or_env(
|
||||
values["openai_api_key"] = get_from_dict_or_env(
|
||||
values, "openai_api_key", "OPENAI_API_KEY"
|
||||
)
|
||||
openai_api_base = get_from_dict_or_env(
|
||||
values["openai_api_base"] = get_from_dict_or_env(
|
||||
values,
|
||||
"openai_api_base",
|
||||
"OPENAI_API_BASE",
|
||||
default="",
|
||||
)
|
||||
openai_api_type = get_from_dict_or_env(
|
||||
values["openai_api_type"] = get_from_dict_or_env(
|
||||
values,
|
||||
"openai_api_type",
|
||||
"OPENAI_API_TYPE",
|
||||
default="",
|
||||
)
|
||||
openai_proxy = get_from_dict_or_env(
|
||||
values["openai_proxy"] = get_from_dict_or_env(
|
||||
values,
|
||||
"openai_proxy",
|
||||
"OPENAI_PROXY",
|
||||
default="",
|
||||
)
|
||||
if openai_api_type in ("azure", "azure_ad", "azuread"):
|
||||
if values["openai_api_type"] in ("azure", "azure_ad", "azuread"):
|
||||
default_api_version = "2022-12-01"
|
||||
else:
|
||||
default_api_version = ""
|
||||
openai_api_version = get_from_dict_or_env(
|
||||
values["openai_api_version"] = get_from_dict_or_env(
|
||||
values,
|
||||
"openai_api_version",
|
||||
"OPENAI_API_VERSION",
|
||||
default=default_api_version,
|
||||
)
|
||||
openai_organization = get_from_dict_or_env(
|
||||
values["openai_organization"] = get_from_dict_or_env(
|
||||
values,
|
||||
"openai_organization",
|
||||
"OPENAI_ORGANIZATION",
|
||||
@ -176,17 +176,6 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
try:
|
||||
import openai
|
||||
|
||||
openai.api_key = openai_api_key
|
||||
if openai_organization:
|
||||
openai.organization = openai_organization
|
||||
if openai_api_base:
|
||||
openai.api_base = openai_api_base
|
||||
if openai_api_type:
|
||||
openai.api_version = openai_api_version
|
||||
if openai_api_type:
|
||||
openai.api_type = openai_api_type
|
||||
if openai_proxy:
|
||||
openai.proxy = {"http": openai_proxy, "https": openai_proxy} # type: ignore[assignment] # noqa: E501
|
||||
values["client"] = openai.Embedding
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
@ -195,6 +184,25 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def _invocation_params(self) -> Dict:
|
||||
openai_args = {
|
||||
"engine": self.deployment,
|
||||
"request_timeout": self.request_timeout,
|
||||
"headers": self.headers,
|
||||
"api_key": self.openai_api_key,
|
||||
"organization": self.openai_organization,
|
||||
"api_base": self.openai_api_base,
|
||||
"api_type": self.openai_api_type,
|
||||
"api_version": self.openai_api_version,
|
||||
}
|
||||
if self.openai_proxy:
|
||||
openai_args["proxy"] = {
|
||||
"http": self.openai_proxy,
|
||||
"https": self.openai_proxy,
|
||||
}
|
||||
return openai_args
|
||||
|
||||
# please refer to
|
||||
# https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb
|
||||
def _get_len_safe_embeddings(
|
||||
@ -233,9 +241,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
response = embed_with_retry(
|
||||
self,
|
||||
input=tokens[i : i + _chunk_size],
|
||||
engine=self.deployment,
|
||||
request_timeout=self.request_timeout,
|
||||
headers=self.headers,
|
||||
**self._invocation_params,
|
||||
)
|
||||
batched_embeddings += [r["embedding"] for r in response["data"]]
|
||||
|
||||
|
@ -211,22 +211,22 @@ class BaseOpenAI(BaseLLM):
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
openai_api_key = get_from_dict_or_env(
|
||||
values["openai_api_key"] = get_from_dict_or_env(
|
||||
values, "openai_api_key", "OPENAI_API_KEY"
|
||||
)
|
||||
openai_api_base = get_from_dict_or_env(
|
||||
values["openai_api_base"] = get_from_dict_or_env(
|
||||
values,
|
||||
"openai_api_base",
|
||||
"OPENAI_API_BASE",
|
||||
default="",
|
||||
)
|
||||
openai_proxy = get_from_dict_or_env(
|
||||
values["openai_proxy"] = get_from_dict_or_env(
|
||||
values,
|
||||
"openai_proxy",
|
||||
"OPENAI_PROXY",
|
||||
default="",
|
||||
)
|
||||
openai_organization = get_from_dict_or_env(
|
||||
values["openai_organization"] = get_from_dict_or_env(
|
||||
values,
|
||||
"openai_organization",
|
||||
"OPENAI_ORGANIZATION",
|
||||
@ -235,13 +235,6 @@ class BaseOpenAI(BaseLLM):
|
||||
try:
|
||||
import openai
|
||||
|
||||
openai.api_key = openai_api_key
|
||||
if openai_api_base:
|
||||
openai.api_base = openai_api_base
|
||||
if openai_organization:
|
||||
openai.organization = openai_organization
|
||||
if openai_proxy:
|
||||
openai.proxy = {"http": openai_proxy, "https": openai_proxy} # type: ignore[assignment] # noqa: E501
|
||||
values["client"] = openai.Completion
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
@ -452,7 +445,17 @@ class BaseOpenAI(BaseLLM):
|
||||
@property
|
||||
def _invocation_params(self) -> Dict[str, Any]:
|
||||
"""Get the parameters used to invoke the model."""
|
||||
return self._default_params
|
||||
openai_creds: Dict[str, Any] = {
|
||||
"api_key": self.openai_api_key,
|
||||
"api_base": self.openai_api_base,
|
||||
"organization": self.openai_organization,
|
||||
}
|
||||
if self.openai_proxy:
|
||||
openai_creds["proxy"] = {
|
||||
"http": self.openai_proxy,
|
||||
"https": self.openai_proxy,
|
||||
}
|
||||
return {**openai_creds, **self._default_params}
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
@ -596,6 +599,22 @@ class AzureOpenAI(BaseOpenAI):
|
||||
|
||||
deployment_name: str = ""
|
||||
"""Deployment name to use."""
|
||||
openai_api_type: str = "azure"
|
||||
openai_api_version: str = ""
|
||||
|
||||
@root_validator()
|
||||
def validate_azure_settings(cls, values: Dict) -> Dict:
|
||||
values["openai_api_version"] = get_from_dict_or_env(
|
||||
values,
|
||||
"openai_api_version",
|
||||
"OPENAI_API_VERSION",
|
||||
)
|
||||
values["openai_api_type"] = get_from_dict_or_env(
|
||||
values,
|
||||
"openai_api_type",
|
||||
"OPENAI_API_TYPE",
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
@ -606,7 +625,12 @@ class AzureOpenAI(BaseOpenAI):
|
||||
|
||||
@property
|
||||
def _invocation_params(self) -> Dict[str, Any]:
|
||||
return {**{"engine": self.deployment_name}, **super()._invocation_params}
|
||||
openai_params = {
|
||||
"engine": self.deployment_name,
|
||||
"api_type": self.openai_api_type,
|
||||
"api_version": self.openai_api_version,
|
||||
}
|
||||
return {**openai_params, **super()._invocation_params}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
|
Loading…
Reference in New Issue
Block a user