mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 15:43:54 +00:00
langchain: Fix broken OpenAIModerationChain
and implement async (#18537)
Thank you for contributing to LangChain! ## PR title lancghain[patch]: fix `OpenAIModerationChain` and implement async ## PR message Description: fix `OpenAIModerationChain` and implement async Issues: - https://github.com/langchain-ai/langchain/issues/18533 - https://github.com/langchain-ai/langchain/issues/13685 Dependencies: none Twitter handle: mattflo ## Add tests and docs Existing documentation is broken: https://python.langchain.com/docs/guides/safety/moderation - [ x] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ --------- Co-authored-by: Emilia Katari <emilia@outpace.com> Co-authored-by: ccurme <chester.curme@gmail.com> Co-authored-by: Erick Friis <erickfriis@gmail.com>
This commit is contained in:
parent
4170e72a42
commit
d3ca2cc8c3
@ -1,9 +1,13 @@
|
|||||||
"""Pass input through a moderation endpoint."""
|
"""Pass input through a moderation endpoint."""
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
from langchain_core.callbacks import (
|
||||||
from langchain_core.pydantic_v1 import root_validator
|
AsyncCallbackManagerForChainRun,
|
||||||
from langchain_core.utils import get_from_dict_or_env
|
CallbackManagerForChainRun,
|
||||||
|
)
|
||||||
|
from langchain_core.pydantic_v1 import Field, root_validator
|
||||||
|
from langchain_core.utils import check_package_version, get_from_dict_or_env
|
||||||
|
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
|
|
||||||
@ -25,6 +29,7 @@ class OpenAIModerationChain(Chain):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
client: Any #: :meta private:
|
client: Any #: :meta private:
|
||||||
|
async_client: Any #: :meta private:
|
||||||
model_name: Optional[str] = None
|
model_name: Optional[str] = None
|
||||||
"""Moderation model name to use."""
|
"""Moderation model name to use."""
|
||||||
error: bool = False
|
error: bool = False
|
||||||
@ -33,6 +38,7 @@ class OpenAIModerationChain(Chain):
|
|||||||
output_key: str = "output" #: :meta private:
|
output_key: str = "output" #: :meta private:
|
||||||
openai_api_key: Optional[str] = None
|
openai_api_key: Optional[str] = None
|
||||||
openai_organization: Optional[str] = None
|
openai_organization: Optional[str] = None
|
||||||
|
_openai_pre_1_0: bool = Field(default=None)
|
||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
@ -52,7 +58,16 @@ class OpenAIModerationChain(Chain):
|
|||||||
openai.api_key = openai_api_key
|
openai.api_key = openai_api_key
|
||||||
if openai_organization:
|
if openai_organization:
|
||||||
openai.organization = openai_organization
|
openai.organization = openai_organization
|
||||||
values["client"] = openai.Moderation # type: ignore
|
values["_openai_pre_1_0"] = False
|
||||||
|
try:
|
||||||
|
check_package_version("openai", gte_version="1.0")
|
||||||
|
except ValueError:
|
||||||
|
values["_openai_pre_1_0"] = True
|
||||||
|
if values["_openai_pre_1_0"]:
|
||||||
|
values["client"] = openai.Moderation
|
||||||
|
else:
|
||||||
|
values["client"] = openai.OpenAI()
|
||||||
|
values["async_client"] = openai.AsyncOpenAI()
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Could not import openai python package. "
|
"Could not import openai python package. "
|
||||||
@ -76,8 +91,12 @@ class OpenAIModerationChain(Chain):
|
|||||||
"""
|
"""
|
||||||
return [self.output_key]
|
return [self.output_key]
|
||||||
|
|
||||||
def _moderate(self, text: str, results: dict) -> str:
|
def _moderate(self, text: str, results: Any) -> str:
|
||||||
if results["flagged"]:
|
if self._openai_pre_1_0:
|
||||||
|
condition = results["flagged"]
|
||||||
|
else:
|
||||||
|
condition = results.flagged
|
||||||
|
if condition:
|
||||||
error_str = "Text was found that violates OpenAI's content policy."
|
error_str = "Text was found that violates OpenAI's content policy."
|
||||||
if self.error:
|
if self.error:
|
||||||
raise ValueError(error_str)
|
raise ValueError(error_str)
|
||||||
@ -87,10 +106,26 @@ class OpenAIModerationChain(Chain):
|
|||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, str],
|
inputs: Dict[str, Any],
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, str]:
|
) -> Dict[str, Any]:
|
||||||
text = inputs[self.input_key]
|
text = inputs[self.input_key]
|
||||||
|
if self._openai_pre_1_0:
|
||||||
results = self.client.create(text)
|
results = self.client.create(text)
|
||||||
output = self._moderate(text, results["results"][0])
|
output = self._moderate(text, results["results"][0])
|
||||||
|
else:
|
||||||
|
results = self.client.moderations.create(input=text)
|
||||||
|
output = self._moderate(text, results.results[0])
|
||||||
|
return {self.output_key: output}
|
||||||
|
|
||||||
|
async def _acall(
|
||||||
|
self,
|
||||||
|
inputs: Dict[str, Any],
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
if self._openai_pre_1_0:
|
||||||
|
return await super()._acall(inputs, run_manager=run_manager)
|
||||||
|
text = inputs[self.input_key]
|
||||||
|
results = await self.async_client.moderations.create(input=text)
|
||||||
|
output = self._moderate(text, results.results[0])
|
||||||
return {self.output_key: output}
|
return {self.output_key: output}
|
||||||
|
Loading…
Reference in New Issue
Block a user