[Evaluation] Pass in seed directly (#24403)

adding test rn
This commit is contained in:
William FH 2024-07-18 19:12:28 -07:00 committed by GitHub
parent 62b6965d2a
commit 0ee6ed76ca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 16 additions and 1 deletions

View File

@ -146,7 +146,7 @@ def load_evaluator(
) )
llm = llm or ChatOpenAI( # type: ignore[call-arg] llm = llm or ChatOpenAI( # type: ignore[call-arg]
model="gpt-4", model_kwargs={"seed": 42}, temperature=0 model="gpt-4", seed=42, temperature=0
) )
except Exception as e: except Exception as e:
raise ValueError( raise ValueError(

View File

@ -1,11 +1,14 @@
"""Test LLM Bash functionality.""" """Test LLM Bash functionality."""
import os
import sys import sys
from typing import Type from typing import Type
from unittest.mock import patch
import pytest import pytest
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.evaluation.loading import load_evaluator
from langchain.evaluation.qa.eval_chain import ( from langchain.evaluation.qa.eval_chain import (
ContextQAEvalChain, ContextQAEvalChain,
CotQAEvalChain, CotQAEvalChain,
@ -50,6 +53,18 @@ def test_context_eval_chain(chain_cls: Type[ContextQAEvalChain]) -> None:
assert outputs[0]["text"] == "foo" assert outputs[0]["text"] == "foo"
def test_load_criteria_evaluator() -> None:
"""Test loading a criteria evaluator."""
try:
from langchain_openai import ChatOpenAI # noqa: F401
except ImportError:
pytest.skip("langchain-openai not installed")
# Patch the env with an openai-api-key
with patch.dict(os.environ, {"OPENAI_API_KEY": "foo"}):
# Check it can load using a string arg (even if that's not how it's typed)
load_evaluator("criteria") # type: ignore
@pytest.mark.parametrize("chain_cls", [QAEvalChain, ContextQAEvalChain, CotQAEvalChain]) @pytest.mark.parametrize("chain_cls", [QAEvalChain, ContextQAEvalChain, CotQAEvalChain])
def test_implements_string_evaluator_protocol( def test_implements_string_evaluator_protocol(
chain_cls: Type[LLMChain], chain_cls: Type[LLMChain],