diff --git a/libs/langchain/langchain/chains/moderation.py b/libs/langchain/langchain/chains/moderation.py index 9467b8894c1..f72cae6cc8d 100644 --- a/libs/langchain/langchain/chains/moderation.py +++ b/libs/langchain/langchain/chains/moderation.py @@ -67,8 +67,8 @@ class OpenAIModerationChain(Chain): if values["openai_pre_1_0"]: values["client"] = openai.Moderation else: - values["client"] = openai.OpenAI() - values["async_client"] = openai.AsyncOpenAI() + values["client"] = openai.OpenAI(api_key=openai_api_key) + values["async_client"] = openai.AsyncOpenAI(api_key=openai_api_key) except ImportError: raise ImportError( diff --git a/libs/langchain/tests/integration_tests/chains/openai_functions/test_openapi.py b/libs/langchain/tests/integration_tests/chains/openai_functions/test_openapi.py index d3c294cbd56..9db8b957698 100644 --- a/libs/langchain/tests/integration_tests/chains/openai_functions/test_openapi.py +++ b/libs/langchain/tests/integration_tests/chains/openai_functions/test_openapi.py @@ -2,6 +2,7 @@ import json import pytest +from langchain.chains import OpenAIModerationChain from langchain.chains.openai_functions.openapi import get_openapi_chain api_spec = { @@ -36,3 +37,13 @@ def test_openai_openapi_chain() -> None: chain = get_openapi_chain(json.dumps(api_spec), llm) output = chain.invoke({"query": "Fetch the top two posts."}) assert len(output["response"]) == 2 + + +@pytest.mark.requires("openai") +def test_openai_moderation_chain_instantiation() -> None: + """Test OpenAIModerationChain.""" + api_key = "foo" + + moderation = OpenAIModerationChain(openai_api_key=api_key) + + assert isinstance(moderation, OpenAIModerationChain)