mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
feat(openai): add callable support for openai_api_key parameter (#33532)
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
from collections.abc import AsyncIterator
|
||||
from pathlib import Path
|
||||
from textwrap import dedent
|
||||
@@ -64,6 +65,57 @@ def test_chat_openai_model() -> None:
|
||||
assert chat.model_name == "bar"
|
||||
|
||||
|
||||
def test_callable_api_key(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
original_key = os.environ["OPENAI_API_KEY"]
|
||||
|
||||
calls = {"sync": 0}
|
||||
|
||||
def get_openai_api_key() -> str:
|
||||
calls["sync"] += 1
|
||||
return original_key
|
||||
|
||||
monkeypatch.delenv("OPENAI_API_KEY")
|
||||
|
||||
model = ChatOpenAI(model="gpt-4.1-mini", api_key=get_openai_api_key)
|
||||
response = model.invoke("hello")
|
||||
assert isinstance(response, AIMessage)
|
||||
assert calls["sync"] == 1
|
||||
|
||||
|
||||
async def test_callable_api_key_async(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
original_key = os.environ["OPENAI_API_KEY"]
|
||||
|
||||
calls = {"sync": 0, "async": 0}
|
||||
|
||||
def get_openai_api_key() -> str:
|
||||
calls["sync"] += 1
|
||||
return original_key
|
||||
|
||||
async def get_openai_api_key_async() -> str:
|
||||
calls["async"] += 1
|
||||
return original_key
|
||||
|
||||
monkeypatch.delenv("OPENAI_API_KEY")
|
||||
|
||||
model = ChatOpenAI(model="gpt-4.1-mini", api_key=get_openai_api_key)
|
||||
response = model.invoke("hello")
|
||||
assert isinstance(response, AIMessage)
|
||||
assert calls["sync"] == 1
|
||||
|
||||
response = await model.ainvoke("hello")
|
||||
assert isinstance(response, AIMessage)
|
||||
assert calls["sync"] == 2
|
||||
|
||||
model = ChatOpenAI(model="gpt-4.1-mini", api_key=get_openai_api_key_async)
|
||||
async_response = await model.ainvoke("hello")
|
||||
assert isinstance(async_response, AIMessage)
|
||||
assert calls["async"] == 1
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
# We do not create a sync callable from an async one
|
||||
_ = model.invoke("hello")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_responses_api", [False, True])
|
||||
def test_chat_openai_system_message(use_responses_api: bool) -> None:
|
||||
"""Test ChatOpenAI wrapper with system message."""
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
"""Test OpenAI embeddings."""
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import openai
|
||||
import pytest
|
||||
|
||||
from langchain_openai.embeddings.base import OpenAIEmbeddings
|
||||
|
||||
@@ -67,3 +70,56 @@ def test_langchain_openai_embeddings_dimensions_large_num() -> None:
|
||||
output = embedding.embed_documents(documents)
|
||||
assert len(output) == 2000
|
||||
assert len(output[0]) == 128
|
||||
|
||||
|
||||
def test_callable_api_key(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
original_key = os.environ["OPENAI_API_KEY"]
|
||||
|
||||
calls = {"sync": 0}
|
||||
|
||||
def get_openai_api_key() -> str:
|
||||
calls["sync"] += 1
|
||||
return original_key
|
||||
|
||||
monkeypatch.delenv("OPENAI_API_KEY")
|
||||
|
||||
model = OpenAIEmbeddings(
|
||||
model="text-embedding-3-small", dimensions=128, api_key=get_openai_api_key
|
||||
)
|
||||
_ = model.embed_query("hello")
|
||||
assert calls["sync"] == 1
|
||||
|
||||
|
||||
async def test_callable_api_key_async(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
original_key = os.environ["OPENAI_API_KEY"]
|
||||
|
||||
calls = {"sync": 0, "async": 0}
|
||||
|
||||
def get_openai_api_key() -> str:
|
||||
calls["sync"] += 1
|
||||
return original_key
|
||||
|
||||
async def get_openai_api_key_async() -> str:
|
||||
calls["async"] += 1
|
||||
return original_key
|
||||
|
||||
monkeypatch.delenv("OPENAI_API_KEY")
|
||||
|
||||
model = OpenAIEmbeddings(
|
||||
model="text-embedding-3-small", dimensions=128, api_key=get_openai_api_key
|
||||
)
|
||||
_ = model.embed_query("hello")
|
||||
assert calls["sync"] == 1
|
||||
|
||||
_ = await model.aembed_query("hello")
|
||||
assert calls["sync"] == 2
|
||||
|
||||
model = OpenAIEmbeddings(
|
||||
model="text-embedding-3-small", dimensions=128, api_key=get_openai_api_key_async
|
||||
)
|
||||
_ = await model.aembed_query("hello")
|
||||
assert calls["async"] == 1
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
# We do not create a sync callable from an async one
|
||||
_ = model.embed_query("hello")
|
||||
|
||||
@@ -187,6 +187,18 @@ def test_openai_uses_actual_secret_value_from_secretstr(model_class: type) -> No
|
||||
assert cast(SecretStr, model.openai_api_key).get_secret_value() == "secret-api-key"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [ChatOpenAI, OpenAI, OpenAIEmbeddings])
|
||||
def test_openai_api_key_accepts_callable(model_class: type) -> None:
|
||||
"""Test that the API key can be passed as a callable."""
|
||||
|
||||
def get_api_key() -> str:
|
||||
return "secret-api-key-from-callable"
|
||||
|
||||
model = model_class(openai_api_key=get_api_key)
|
||||
assert callable(model.openai_api_key)
|
||||
assert model.openai_api_key() == "secret-api-key-from-callable"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [AzureChatOpenAI, AzureOpenAI])
|
||||
def test_azure_serialized_secrets(model_class: type) -> None:
|
||||
"""Test that the actual secret value is correctly retrieved."""
|
||||
|
||||
Reference in New Issue
Block a user