langchain/libs/community/tests/unit_tests/chat_models/test_outlines.py
shroominic dee72c46c1
community: Outlines integration (#27449)
In collaboration with @rlouf I build an
[outlines](https://dottxt-ai.github.io/outlines/latest/) integration for
langchain!

I think this is really useful for doing any type of structured output
locally.
[Dottxt](https://dottxt.co) spend alot of work optimising this process
at a lower level
([outlines-core](https://pypi.org/project/outlines-core/0.1.14/) written
in rust) so I think this is a better alternative over all current
approaches in langchain to do structured output.
It also implements the `.with_structured_output` method so it should be
a drop in replacement for a lot of applications.

The integration includes:
- **Outlines LLM class**
- **ChatOutlines class**
- **Tutorial Cookbooks**
- **Documentation Page**
- **Validation and error messages** 
- **Exposes Outlines Structured output features**
- **Support for multiple backends**
- **Integration and Unit Tests**

Dependencies: `outlines` + additional (depending on backend used)

I am not sure if the unit-tests comply with all requirements, if not I
suggest to just remove them since I don't see a useful way to do it
differently.

### Quick overview:

Chat Models:
<img width="698" alt="image"
src="https://github.com/user-attachments/assets/05a499b9-858c-4397-a9ff-165c2b3e7acc">

Structured Output:
<img width="955" alt="image"
src="https://github.com/user-attachments/assets/b9fcac11-d3e5-4698-b1ae-8c4cb3d54c45">

---------

Co-authored-by: Vadym Barda <vadym@langchain.dev>
2024-11-20 16:31:31 -05:00

92 lines
3.2 KiB
Python

import pytest
from _pytest.monkeypatch import MonkeyPatch
from pydantic import BaseModel, Field
from langchain_community.chat_models.outlines import ChatOutlines
def test_chat_outlines_initialization(monkeypatch: MonkeyPatch) -> None:
monkeypatch.setattr(ChatOutlines, "build_client", lambda self: self)
chat = ChatOutlines(
model="microsoft/Phi-3-mini-4k-instruct",
max_tokens=42,
stop=["\n"],
)
assert chat.model == "microsoft/Phi-3-mini-4k-instruct"
assert chat.max_tokens == 42
assert chat.backend == "transformers"
assert chat.stop == ["\n"]
def test_chat_outlines_backend_llamacpp(monkeypatch: MonkeyPatch) -> None:
monkeypatch.setattr(ChatOutlines, "build_client", lambda self: self)
chat = ChatOutlines(
model="TheBloke/Llama-2-7B-Chat-GGUF/llama-2-7b-chat.Q4_K_M.gguf",
backend="llamacpp",
)
assert chat.backend == "llamacpp"
def test_chat_outlines_backend_vllm(monkeypatch: MonkeyPatch) -> None:
monkeypatch.setattr(ChatOutlines, "build_client", lambda self: self)
chat = ChatOutlines(model="microsoft/Phi-3-mini-4k-instruct", backend="vllm")
assert chat.backend == "vllm"
def test_chat_outlines_backend_mlxlm(monkeypatch: MonkeyPatch) -> None:
monkeypatch.setattr(ChatOutlines, "build_client", lambda self: self)
chat = ChatOutlines(model="microsoft/Phi-3-mini-4k-instruct", backend="mlxlm")
assert chat.backend == "mlxlm"
def test_chat_outlines_with_regex(monkeypatch: MonkeyPatch) -> None:
monkeypatch.setattr(ChatOutlines, "build_client", lambda self: self)
regex = r"\d{3}-\d{3}-\d{4}"
chat = ChatOutlines(model="microsoft/Phi-3-mini-4k-instruct", regex=regex)
assert chat.regex == regex
def test_chat_outlines_with_type_constraints(monkeypatch: MonkeyPatch) -> None:
monkeypatch.setattr(ChatOutlines, "build_client", lambda self: self)
chat = ChatOutlines(model="microsoft/Phi-3-mini-4k-instruct", type_constraints=int)
assert chat.type_constraints == int # noqa
def test_chat_outlines_with_json_schema(monkeypatch: MonkeyPatch) -> None:
monkeypatch.setattr(ChatOutlines, "build_client", lambda self: self)
class TestSchema(BaseModel):
name: str = Field(description="A person's name")
age: int = Field(description="A person's age")
chat = ChatOutlines(
model="microsoft/Phi-3-mini-4k-instruct", json_schema=TestSchema
)
assert chat.json_schema == TestSchema
def test_chat_outlines_with_grammar(monkeypatch: MonkeyPatch) -> None:
monkeypatch.setattr(ChatOutlines, "build_client", lambda self: self)
grammar = """
?start: expression
?expression: term (("+" | "-") term)*
?term: factor (("*" | "/") factor)*
?factor: NUMBER | "-" factor | "(" expression ")"
%import common.NUMBER
"""
chat = ChatOutlines(model="microsoft/Phi-3-mini-4k-instruct", grammar=grammar)
assert chat.grammar == grammar
def test_raise_for_multiple_output_constraints(monkeypatch: MonkeyPatch) -> None:
monkeypatch.setattr(ChatOutlines, "build_client", lambda self: self)
with pytest.raises(ValueError):
ChatOutlines(
model="microsoft/Phi-3-mini-4k-instruct",
type_constraints=int,
regex=r"\d{3}-\d{3}-\d{4}",
)