mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 23:00:00 +00:00
make trajectory eval chain stricter and add unit tests (#8909)
- update trajectory eval logic to be stricter - add tests to trajectory eval chain
This commit is contained in:
parent
b8df15cd64
commit
3adb1e12ca
@ -5,6 +5,7 @@ the sequence of actions taken and their outcomes. It uses a language model
|
|||||||
chain (LLMChain) to generate the reasoning and scores.
|
chain (LLMChain) to generate the reasoning and scores.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Dict,
|
Dict,
|
||||||
@ -74,15 +75,24 @@ class TrajectoryOutputParser(BaseOutputParser):
|
|||||||
|
|
||||||
reasoning, score_str = reasoning.strip(), score_str.strip()
|
reasoning, score_str = reasoning.strip(), score_str.strip()
|
||||||
|
|
||||||
score_str = next(
|
# Use regex to extract the score.
|
||||||
(char for char in score_str if char.isdigit()), "0"
|
# This will get the number in the string, even if it is a float or more than 10.
|
||||||
) # Scan for first digit
|
# E.g. "Score: 1" will return 1, "Score: 3.5" will return 3.5, and
|
||||||
|
# "Score: 10" will return 10.
|
||||||
if not 1 <= int(score_str) <= 5:
|
# The score should be an integer digit in the range 1-5.
|
||||||
|
_score = re.search(r"(\d+(\.\d+)?)", score_str)
|
||||||
|
# If the score is not found or is a float, raise an exception.
|
||||||
|
if _score is None or "." in _score.group(1):
|
||||||
|
raise OutputParserException(
|
||||||
|
f"Score is not an integer digit in the range 1-5: {text}"
|
||||||
|
)
|
||||||
|
score = int(_score.group(1))
|
||||||
|
# If the score is not in the range 1-5, raise an exception.
|
||||||
|
if not 1 <= score <= 5:
|
||||||
raise OutputParserException(
|
raise OutputParserException(
|
||||||
f"Score is not a digit in the range 1-5: {text}"
|
f"Score is not a digit in the range 1-5: {text}"
|
||||||
)
|
)
|
||||||
normalized_score = (int(score_str) - 1) / 4
|
normalized_score = (score - 1) / 4
|
||||||
return TrajectoryEval(score=normalized_score, reasoning=reasoning)
|
return TrajectoryEval(score=normalized_score, reasoning=reasoning)
|
||||||
|
|
||||||
|
|
||||||
|
@ -6,8 +6,12 @@ import pytest
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||||
from langchain.evaluation.agents.trajectory_eval_chain import TrajectoryEvalChain
|
from langchain.evaluation.agents.trajectory_eval_chain import (
|
||||||
from langchain.schema import AgentAction, BaseMessage
|
TrajectoryEval,
|
||||||
|
TrajectoryEvalChain,
|
||||||
|
TrajectoryOutputParser,
|
||||||
|
)
|
||||||
|
from langchain.schema import AgentAction, BaseMessage, OutputParserException
|
||||||
from langchain.tools.base import tool
|
from langchain.tools.base import tool
|
||||||
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
|
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
|
||||||
|
|
||||||
@ -53,6 +57,61 @@ class _FakeTrajectoryChatModel(FakeChatModel):
|
|||||||
return self.queries[prompt]
|
return self.queries[prompt]
|
||||||
|
|
||||||
|
|
||||||
|
def test_trajectory_output_parser_parse() -> None:
|
||||||
|
trajectory_output_parser = TrajectoryOutputParser()
|
||||||
|
text = """Judgment: Given the good reasoning in the final answer
|
||||||
|
but otherwise poor performance, we give the model a score of 2.
|
||||||
|
|
||||||
|
Score: 2"""
|
||||||
|
got = trajectory_output_parser.parse(text)
|
||||||
|
want = TrajectoryEval(
|
||||||
|
score=0.25,
|
||||||
|
reasoning="""Judgment: Given the good reasoning in the final answer
|
||||||
|
but otherwise poor performance, we give the model a score of 2.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert got["score"] == want["score"]
|
||||||
|
assert got["reasoning"] == want["reasoning"]
|
||||||
|
|
||||||
|
with pytest.raises(OutputParserException):
|
||||||
|
trajectory_output_parser.parse(
|
||||||
|
"""Judgment: Given the good reasoning in the final answer
|
||||||
|
but otherwise poor performance, we give the model a score of 2."""
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(OutputParserException):
|
||||||
|
trajectory_output_parser.parse(
|
||||||
|
"""Judgment: Given the good reasoning in the final answer
|
||||||
|
but otherwise poor performance, we give the model a score of 2.
|
||||||
|
|
||||||
|
Score: 9"""
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(OutputParserException):
|
||||||
|
trajectory_output_parser.parse(
|
||||||
|
"""Judgment: Given the good reasoning in the final answer
|
||||||
|
but otherwise poor performance, we give the model a score of 2.
|
||||||
|
|
||||||
|
Score: 10"""
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(OutputParserException):
|
||||||
|
trajectory_output_parser.parse(
|
||||||
|
"""Judgment: Given the good reasoning in the final answer
|
||||||
|
but otherwise poor performance, we give the model a score of 2.
|
||||||
|
|
||||||
|
Score: 0.1"""
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(OutputParserException):
|
||||||
|
trajectory_output_parser.parse(
|
||||||
|
"""Judgment: Given the good reasoning in the final answer
|
||||||
|
but otherwise poor performance, we give the model a score of 2.
|
||||||
|
|
||||||
|
Score: One"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_trajectory_eval_chain(
|
def test_trajectory_eval_chain(
|
||||||
intermediate_steps: List[Tuple[AgentAction, str]]
|
intermediate_steps: List[Tuple[AgentAction, str]]
|
||||||
) -> None:
|
) -> None:
|
||||||
|
Loading…
Reference in New Issue
Block a user