diff --git a/libs/langchain/langchain/chains/moderation.py b/libs/langchain/langchain/chains/moderation.py index 70e43990ffa..a4b3551491c 100644 --- a/libs/langchain/langchain/chains/moderation.py +++ b/libs/langchain/langchain/chains/moderation.py @@ -38,7 +38,7 @@ class OpenAIModerationChain(Chain): output_key: str = "output" #: :meta private: openai_api_key: Optional[str] = None openai_organization: Optional[str] = None - _openai_pre_1_0: bool = Field(default=None) + openai_pre_1_0: bool = Field(default=None) @root_validator(pre=True) def validate_environment(cls, values: Dict) -> Dict: @@ -58,16 +58,17 @@ class OpenAIModerationChain(Chain): openai.api_key = openai_api_key if openai_organization: openai.organization = openai_organization - values["_openai_pre_1_0"] = False + values["openai_pre_1_0"] = False try: check_package_version("openai", gte_version="1.0") except ValueError: - values["_openai_pre_1_0"] = True - if values["_openai_pre_1_0"]: + values["openai_pre_1_0"] = True + if values["openai_pre_1_0"]: values["client"] = openai.Moderation else: values["client"] = openai.OpenAI() values["async_client"] = openai.AsyncOpenAI() + except ImportError: raise ImportError( "Could not import openai python package. " @@ -92,7 +93,7 @@ class OpenAIModerationChain(Chain): return [self.output_key] def _moderate(self, text: str, results: Any) -> str: - if self._openai_pre_1_0: + if self.openai_pre_1_0: condition = results["flagged"] else: condition = results.flagged @@ -110,7 +111,7 @@ class OpenAIModerationChain(Chain): run_manager: Optional[CallbackManagerForChainRun] = None, ) -> Dict[str, Any]: text = inputs[self.input_key] - if self._openai_pre_1_0: + if self.openai_pre_1_0: results = self.client.create(text) output = self._moderate(text, results["results"][0]) else: @@ -123,7 +124,7 @@ class OpenAIModerationChain(Chain): inputs: Dict[str, Any], run_manager: Optional[AsyncCallbackManagerForChainRun] = None, ) -> Dict[str, Any]: - if self._openai_pre_1_0: + if self.openai_pre_1_0: return await super()._acall(inputs, run_manager=run_manager) text = inputs[self.input_key] results = await self.async_client.moderations.create(input=text)