diff --git a/libs/standard-tests/langchain_tests/base.py b/libs/standard-tests/langchain_tests/base.py index 57bf64b00cc..311c713d58d 100644 --- a/libs/standard-tests/langchain_tests/base.py +++ b/libs/standard-tests/langchain_tests/base.py @@ -46,10 +46,18 @@ class BaseStandardTests: m = getattr(self.__class__, method) if not hasattr(m, "pytestmark"): return False - marks = m.pytestmark - return any( - mark.name == "xfail" and mark.kwargs.get("reason") for mark in marks - ) + for mark in m.pytestmark: + if mark.name == "xfail" and mark.kwargs.get("reason"): + return True + # Also accept xfail marks on individual `pytest.param` entries + # within a `parametrize` - supports xfailing only a subset of + # parametrized cases. + if mark.name == "parametrize" and len(mark.args) >= 2: + for param in mark.args[1]: + for inner in getattr(param, "marks", ()): + if inner.name == "xfail" and inner.kwargs.get("reason"): + return True + return False overridden_not_xfail = [ method for method in overridden_tests if not is_xfail(method)