upstage: Upstage Groundedness Check parameter update (#20914)

* Groundedness Check takes `str` or `list[Document]` as input.

* Deprecate `GroundednessCheck` due to its naming.
* Added `UpstageGroundednessCheck`. 

* Hotfix for Groundedness Check parameter. 
  The name `query` was misleading and it should be `answer` instead.

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Sean
2024-04-26 10:34:05 -07:00
committed by GitHub
parent 84b8e67c9c
commit e1c2e2fdfa
8 changed files with 101 additions and 69 deletions

View File

@@ -2,12 +2,16 @@ from langchain_upstage.chat_models import ChatUpstage
from langchain_upstage.embeddings import UpstageEmbeddings
from langchain_upstage.layout_analysis import UpstageLayoutAnalysisLoader
from langchain_upstage.layout_analysis_parsers import UpstageLayoutAnalysisParser
from langchain_upstage.tools.groundedness_check import GroundednessCheck
from langchain_upstage.tools.groundedness_check import (
GroundednessCheck,
UpstageGroundednessCheck,
)
__all__ = [
"ChatUpstage",
"UpstageEmbeddings",
"UpstageLayoutAnalysisLoader",
"UpstageLayoutAnalysisParser",
"UpstageGroundednessCheck",
"GroundednessCheck",
]

View File

@@ -1,10 +1,12 @@
import os
from typing import Any, Literal, Optional, Type, Union
from typing import Any, List, Literal, Optional, Type, Union
from langchain_core._api.deprecation import deprecated
from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain_core.documents import Document
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr
from langchain_core.tools import BaseTool
@@ -13,16 +15,18 @@ from langchain_core.utils import convert_to_secret_str
from langchain_upstage import ChatUpstage
class GroundednessCheckInput(BaseModel):
class UpstageGroundednessCheckInput(BaseModel):
"""Input for the Groundedness Check tool."""
context: str = Field(description="context in which the answer should be verified")
query: str = Field(
context: Union[str, List[Document]] = Field(
description="context in which the answer should be verified"
)
answer: str = Field(
description="assistant's reply or a text that is subject to groundedness check"
)
class GroundednessCheck(BaseTool):
class UpstageGroundednessCheck(BaseTool):
"""Tool that checks the groundedness of a context and an assistant message.
To use, you should have the environment variable `UPSTAGE_API_KEY`
@@ -31,15 +35,15 @@ class GroundednessCheck(BaseTool):
Example:
.. code-block:: python
from langchain_upstage import GroundednessCheck
from langchain_upstage import UpstageGroundednessCheck
tool = GroundednessCheck()
tool = UpstageGroundednessCheck()
"""
name: str = "groundedness_check"
description: str = (
"A tool that checks the groundedness of an assistant response "
"to user-provided context. GroundednessCheck ensures that "
"to user-provided context. UpstageGroundednessCheck ensures that "
"the assistants response is not only relevant but also "
"precisely aligned with the user's initial context, "
"promoting a more reliable and context-aware interaction. "
@@ -50,7 +54,7 @@ class GroundednessCheck(BaseTool):
upstage_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
api_wrapper: ChatUpstage
args_schema: Type[BaseModel] = GroundednessCheckInput
args_schema: Type[BaseModel] = UpstageGroundednessCheckInput
def __init__(self, **kwargs: Any) -> None:
upstage_api_key = kwargs.get("upstage_api_key", None)
@@ -73,25 +77,41 @@ class GroundednessCheck(BaseTool):
)
super().__init__(upstage_api_key=upstage_api_key, api_wrapper=api_wrapper)
def formatDocumentsAsString(self, docs: List[Document]) -> str:
return "\n".join([doc.page_content for doc in docs])
def _run(
self,
context: str,
query: str,
context: Union[str, List[Document]],
answer: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> Union[str, Literal["grounded", "notGrounded", "notSure"]]:
"""Use the tool."""
if isinstance(context, List):
context = self.formatDocumentsAsString(context)
response = self.api_wrapper.invoke(
[HumanMessage(context), AIMessage(query)], stream=False
[HumanMessage(context), AIMessage(answer)], stream=False
)
return str(response.content)
async def _arun(
self,
context: str,
query: str,
context: Union[str, List[Document]],
answer: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> Union[str, Literal["grounded", "notGrounded", "notSure"]]:
if isinstance(context, List):
context = self.formatDocumentsAsString(context)
response = await self.api_wrapper.ainvoke(
[HumanMessage(context), AIMessage(query)], stream=False
[HumanMessage(context), AIMessage(answer)], stream=False
)
return str(response.content)
@deprecated(
since="0.1.3",
removal="0.2.0",
alternative_import="langchain_upstage.UpstageGroundednessCheck",
)
class GroundednessCheck(UpstageGroundednessCheck):
pass

View File

@@ -2,34 +2,62 @@ import os
import openai
import pytest
from langchain_core.documents import Document
from langchain_upstage import GroundednessCheck
from langchain_upstage import GroundednessCheck, UpstageGroundednessCheck
def test_langchain_upstage_groundedness_check() -> None:
def test_langchain_upstage_groundedness_check_deprecated() -> None:
"""Test Upstage Groundedness Check."""
tool = GroundednessCheck()
output = tool.run({"context": "foo bar", "query": "bar foo"})
output = tool.invoke({"context": "foo bar", "answer": "bar foo"})
assert output in ["grounded", "notGrounded", "notSure"]
api_key = os.environ.get("UPSTAGE_API_KEY", None)
tool = GroundednessCheck(upstage_api_key=api_key)
output = tool.run({"context": "foo bar", "query": "bar foo"})
output = tool.invoke({"context": "foo bar", "answer": "bar foo"})
assert output in ["grounded", "notGrounded", "notSure"]
def test_langchain_upstage_groundedness_check() -> None:
"""Test Upstage Groundedness Check."""
tool = UpstageGroundednessCheck()
output = tool.invoke({"context": "foo bar", "answer": "bar foo"})
assert output in ["grounded", "notGrounded", "notSure"]
api_key = os.environ.get("UPSTAGE_API_KEY", None)
tool = UpstageGroundednessCheck(upstage_api_key=api_key)
output = tool.invoke({"context": "foo bar", "answer": "bar foo"})
assert output in ["grounded", "notGrounded", "notSure"]
def test_langchain_upstage_groundedness_check_with_documents_input() -> None:
"""Test Upstage Groundedness Check."""
tool = UpstageGroundednessCheck()
docs = [
Document(page_content="foo bar"),
Document(page_content="bar foo"),
]
output = tool.invoke({"context": docs, "answer": "bar foo"})
assert output in ["grounded", "notGrounded", "notSure"]
def test_langchain_upstage_groundedness_check_fail_with_wrong_api_key() -> None:
tool = GroundednessCheck(api_key="wrong-key")
tool = UpstageGroundednessCheck(api_key="wrong-key")
with pytest.raises(openai.AuthenticationError):
tool.run({"context": "foo bar", "query": "bar foo"})
tool.invoke({"context": "foo bar", "answer": "bar foo"})
async def test_langchain_upstage_groundedness_check_async() -> None:
"""Test Upstage Groundedness Check asynchronous."""
tool = GroundednessCheck()
output = await tool.arun({"context": "foo bar", "query": "bar foo"})
tool = UpstageGroundednessCheck()
output = await tool.ainvoke({"context": "foo bar", "answer": "bar foo"})
assert output in ["grounded", "notGrounded", "notSure"]

View File

@@ -1,12 +1,12 @@
import os
from langchain_upstage import GroundednessCheck
from langchain_upstage import UpstageGroundednessCheck
os.environ["UPSTAGE_API_KEY"] = "foo"
def test_initialization() -> None:
"""Test embedding model initialization."""
GroundednessCheck()
GroundednessCheck(upstage_api_key="key")
GroundednessCheck(api_key="key")
UpstageGroundednessCheck()
UpstageGroundednessCheck(upstage_api_key="key")
UpstageGroundednessCheck(api_key="key")

View File

@@ -2,10 +2,11 @@ from langchain_upstage import __all__
EXPECTED_ALL = [
"ChatUpstage",
"GroundednessCheck",
"UpstageEmbeddings",
"UpstageLayoutAnalysisLoader",
"UpstageLayoutAnalysisParser",
"GroundednessCheck",
"UpstageGroundednessCheck",
]