From 65b404a2d1d28bf1e1a1ff65de9ce8a8c0a6ae7f Mon Sep 17 00:00:00 2001 From: Ashutosh Kumar Date: Tue, 4 Feb 2025 08:14:13 +0530 Subject: [PATCH] [oci_generative_ai] Option to pass auth_file_location (#29481) **PR title**: "community: Option to pass auth_file_location for oci_generative_ai" **Description:** Option to pass auth_file_location, to overwrite config file default location "~/.oci/config" where profile name configs present. This is not fixing any issues. Just added optional parameter called "auth_file_location", which internally supported by any OCI client including GenerativeAiInferenceClient. --- .../integrations/llms/oci_generative_ai.ipynb | 2 ++ .../text_embedding/oci_generative_ai.ipynb | 1 + .../chat_models/oci_generative_ai.py | 2 ++ .../embeddings/oci_generative_ai.py | 24 +++++++++++++------ .../llms/oci_generative_ai.py | 19 +++++++++++---- 5 files changed, 36 insertions(+), 12 deletions(-) diff --git a/docs/docs/integrations/llms/oci_generative_ai.ipynb b/docs/docs/integrations/llms/oci_generative_ai.ipynb index 3da80aef0e4..06d9b10aafc 100644 --- a/docs/docs/integrations/llms/oci_generative_ai.ipynb +++ b/docs/docs/integrations/llms/oci_generative_ai.ipynb @@ -135,6 +135,7 @@ " compartment_id=\"MY_OCID\",\n", " auth_type=\"SECURITY_TOKEN\",\n", " auth_profile=\"MY_PROFILE\", # replace with your profile name\n", + " auth_file_location=\"MY_CONFIG_FILE_LOCATION\", # replace with file location where profile name configs present\n", ")" ] }, @@ -159,6 +160,7 @@ " service_endpoint=\"https://inference.generativeai.us-chicago-1.oci.oraclecloud.com\",\n", " compartment_id=\"DEDICATED_COMPARTMENT_OCID\",\n", " auth_profile=\"MY_PROFILE\", # replace with your profile name,\n", + " auth_file_location=\"MY_CONFIG_FILE_LOCATION\", # replace with file location where profile name configs present\n", " provider=\"MODEL_PROVIDER\", # e.g., \"cohere\" or \"meta\"\n", " context_size=\"MODEL_CONTEXT_SIZE\", # e.g., 128000\n", ")" diff --git a/docs/docs/integrations/text_embedding/oci_generative_ai.ipynb b/docs/docs/integrations/text_embedding/oci_generative_ai.ipynb index c9fd205930d..58bdcc23b9c 100755 --- a/docs/docs/integrations/text_embedding/oci_generative_ai.ipynb +++ b/docs/docs/integrations/text_embedding/oci_generative_ai.ipynb @@ -103,6 +103,7 @@ " compartment_id=\"MY_OCID\",\n", " auth_type=\"SECURITY_TOKEN\",\n", " auth_profile=\"MY_PROFILE\", # replace with your profile name\n", + " auth_file_location=\"MY_CONFIG_FILE_LOCATION\", # replace with file location where profile name configs present\n", ")\n", "\n", "\n", diff --git a/libs/community/langchain_community/chat_models/oci_generative_ai.py b/libs/community/langchain_community/chat_models/oci_generative_ai.py index 0c64592fb94..716766654d7 100644 --- a/libs/community/langchain_community/chat_models/oci_generative_ai.py +++ b/libs/community/langchain_community/chat_models/oci_generative_ai.py @@ -539,6 +539,8 @@ class ChatOCIGenAI(BaseChatModel, OCIGenAIBase): The authentication type to use, e.g., API_KEY (default), SECURITY_TOKEN, INSTANCE_PRINCIPAL, RESOURCE_PRINCIPAL. auth_profile: Optional[str] The name of the profile in ~/.oci/config, if not specified , DEFAULT will be used. + auth_file_location: Optional[str] + Path to the config file, If not specified, ~/.oci/config will be used. provider: str Provider name of the model. Default to None, will try to be derived from the model_id otherwise, requires user input. See full list of supported init args and their descriptions in the params section. diff --git a/libs/community/langchain_community/embeddings/oci_generative_ai.py b/libs/community/langchain_community/embeddings/oci_generative_ai.py index 10bef6d441d..4940f1747cc 100644 --- a/libs/community/langchain_community/embeddings/oci_generative_ai.py +++ b/libs/community/langchain_community/embeddings/oci_generative_ai.py @@ -29,6 +29,9 @@ class OCIGenAIEmbeddings(BaseModel, Embeddings): Make sure you have the required policies (profile/roles) to access the OCI Generative AI service. If a specific config profile is used, you must pass the name of the profile (~/.oci/config) through auth_profile. + If a specific config file location is used, you must pass + the file location where profile name configs present + through auth_file_location To use, you must provide the compartment id along with the endpoint url, and model id @@ -66,6 +69,11 @@ class OCIGenAIEmbeddings(BaseModel, Embeddings): If not specified , DEFAULT will be used """ + auth_file_location: Optional[str] = "~/.oci/config" + """Path to the config file. + If not specified, ~/.oci/config will be used + """ + model_id: Optional[str] = None """Id of the model to call, e.g., cohere.embed-english-light-v2.0""" @@ -108,7 +116,8 @@ class OCIGenAIEmbeddings(BaseModel, Embeddings): if values["auth_type"] == OCIAuthType(1).name: client_kwargs["config"] = oci.config.from_file( - profile_name=values["auth_profile"] + file_location=values["auth_file_location"], + profile_name=values["auth_profile"], ) client_kwargs.pop("signer", None) elif values["auth_type"] == OCIAuthType(2).name: @@ -124,7 +133,8 @@ class OCIGenAIEmbeddings(BaseModel, Embeddings): return oci.auth.signers.SecurityTokenSigner(st_string, pk) client_kwargs["config"] = oci.config.from_file( - profile_name=values["auth_profile"] + file_location=values["auth_file_location"], + profile_name=values["auth_profile"], ) client_kwargs["signer"] = make_security_token_signer( oci_config=client_kwargs["config"] @@ -151,11 +161,11 @@ class OCIGenAIEmbeddings(BaseModel, Embeddings): ) from ex except Exception as e: raise ValueError( - "Could not authenticate with OCI client. " - "Please check if ~/.oci/config exists. " - "If INSTANCE_PRINCIPLE or RESOURCE_PRINCIPLE is used, " - "Please check the specified " - "auth_profile and auth_type are valid." + """Could not authenticate with OCI client. + If INSTANCE_PRINCIPAL or RESOURCE_PRINCIPAL is used, + please check the specified + auth_profile, auth_file_location and auth_type are valid.""", + e, ) from e return values diff --git a/libs/community/langchain_community/llms/oci_generative_ai.py b/libs/community/langchain_community/llms/oci_generative_ai.py index a9f48a97528..8bea0ed3b01 100644 --- a/libs/community/langchain_community/llms/oci_generative_ai.py +++ b/libs/community/langchain_community/llms/oci_generative_ai.py @@ -79,6 +79,11 @@ class OCIGenAIBase(BaseModel, ABC): If not specified , DEFAULT will be used """ + auth_file_location: Optional[str] = "~/.oci/config" + """Path to the config file. + If not specified, ~/.oci/config will be used + """ + model_id: Optional[str] = None """Id of the model to call, e.g., cohere.command""" @@ -125,7 +130,8 @@ class OCIGenAIBase(BaseModel, ABC): if values["auth_type"] == OCIAuthType(1).name: client_kwargs["config"] = oci.config.from_file( - profile_name=values["auth_profile"] + file_location=values["auth_file_location"], + profile_name=values["auth_profile"], ) client_kwargs.pop("signer", None) elif values["auth_type"] == OCIAuthType(2).name: @@ -141,7 +147,8 @@ class OCIGenAIBase(BaseModel, ABC): return oci.auth.signers.SecurityTokenSigner(st_string, pk) client_kwargs["config"] = oci.config.from_file( - profile_name=values["auth_profile"] + file_location=values["auth_file_location"], + profile_name=values["auth_profile"], ) client_kwargs["signer"] = make_security_token_signer( oci_config=client_kwargs["config"] @@ -171,11 +178,10 @@ class OCIGenAIBase(BaseModel, ABC): ) from ex except Exception as e: raise ValueError( - """Could not authenticate with OCI client. - Please check if ~/.oci/config exists. + """Could not authenticate with OCI client. If INSTANCE_PRINCIPAL or RESOURCE_PRINCIPAL is used, please check the specified - auth_profile and auth_type are valid.""", + auth_profile, auth_file_location and auth_type are valid.""", e, ) from e @@ -223,6 +229,9 @@ class OCIGenAI(LLM, OCIGenAIBase): access the OCI Generative AI service. If a specific config profile is used, you must pass the name of the profile (from ~/.oci/config) through auth_profile. + If a specific config file location is used, you must pass + the file location where profile name configs present + through auth_file_location To use, you must provide the compartment id along with the endpoint url, and model id