community: Outlines integration (#27449)

In collaboration with @rlouf I build an
[outlines](https://dottxt-ai.github.io/outlines/latest/) integration for
langchain!

I think this is really useful for doing any type of structured output
locally.
[Dottxt](https://dottxt.co) spend alot of work optimising this process
at a lower level
([outlines-core](https://pypi.org/project/outlines-core/0.1.14/) written
in rust) so I think this is a better alternative over all current
approaches in langchain to do structured output.
It also implements the `.with_structured_output` method so it should be
a drop in replacement for a lot of applications.

The integration includes:
- **Outlines LLM class**
- **ChatOutlines class**
- **Tutorial Cookbooks**
- **Documentation Page**
- **Validation and error messages** 
- **Exposes Outlines Structured output features**
- **Support for multiple backends**
- **Integration and Unit Tests**

Dependencies: `outlines` + additional (depending on backend used)

I am not sure if the unit-tests comply with all requirements, if not I
suggest to just remove them since I don't see a useful way to do it
differently.

### Quick overview:

Chat Models:
<img width="698" alt="image"
src="https://github.com/user-attachments/assets/05a499b9-858c-4397-a9ff-165c2b3e7acc">

Structured Output:
<img width="955" alt="image"
src="https://github.com/user-attachments/assets/b9fcac11-d3e5-4698-b1ae-8c4cb3d54c45">

---------

Co-authored-by: Vadym Barda <vadym@langchain.dev>
This commit is contained in:
shroominic
2024-11-21 05:31:31 +08:00
committed by GitHub
parent 2901fa20cc
commit dee72c46c1
14 changed files with 2162 additions and 0 deletions

View File

@@ -0,0 +1,177 @@
# flake8: noqa
"""Test ChatOutlines wrapper."""
from typing import Generator
import re
import platform
import pytest
from langchain_community.chat_models.outlines import ChatOutlines
from langchain_core.messages import AIMessage, HumanMessage, BaseMessage
from langchain_core.messages import BaseMessageChunk
from pydantic import BaseModel
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
MODEL = "microsoft/Phi-3-mini-4k-instruct"
LLAMACPP_MODEL = "bartowski/qwen2.5-7b-ins-v3-GGUF/qwen2.5-7b-ins-v3-Q4_K_M.gguf"
BACKENDS = ["transformers", "llamacpp"]
if platform.system() != "Darwin":
BACKENDS.append("vllm")
if platform.system() == "Darwin":
BACKENDS.append("mlxlm")
@pytest.fixture(params=BACKENDS)
def chat_model(request: pytest.FixtureRequest) -> ChatOutlines:
if request.param == "llamacpp":
return ChatOutlines(model=LLAMACPP_MODEL, backend=request.param)
else:
return ChatOutlines(model=MODEL, backend=request.param)
def test_chat_outlines_inference(chat_model: ChatOutlines) -> None:
"""Test valid ChatOutlines inference."""
messages = [HumanMessage(content="Say foo:")]
output = chat_model.invoke(messages)
assert isinstance(output, AIMessage)
assert len(output.content) > 1
def test_chat_outlines_streaming(chat_model: ChatOutlines) -> None:
"""Test streaming tokens from ChatOutlines."""
messages = [HumanMessage(content="How do you say 'hello' in Spanish?")]
generator = chat_model.stream(messages)
stream_results_string = ""
assert isinstance(generator, Generator)
for chunk in generator:
assert isinstance(chunk, BaseMessageChunk)
if isinstance(chunk.content, str):
stream_results_string += chunk.content
else:
raise ValueError(
f"Invalid content type, only str is supported, "
f"got {type(chunk.content)}"
)
assert len(stream_results_string.strip()) > 1
def test_chat_outlines_streaming_callback(chat_model: ChatOutlines) -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
MIN_CHUNKS = 5
callback_handler = FakeCallbackHandler()
chat_model.callbacks = [callback_handler]
chat_model.verbose = True
messages = [HumanMessage(content="Can you count to 10?")]
chat_model.invoke(messages)
assert callback_handler.llm_streams >= MIN_CHUNKS
def test_chat_outlines_regex(chat_model: ChatOutlines) -> None:
"""Test regex for generating a valid IP address"""
ip_regex = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
chat_model.regex = ip_regex
assert chat_model.regex == ip_regex
messages = [HumanMessage(content="What is the IP address of Google's DNS server?")]
output = chat_model.invoke(messages)
assert isinstance(output, AIMessage)
assert re.match(
ip_regex, str(output.content)
), f"Generated output '{output.content}' is not a valid IP address"
def test_chat_outlines_type_constraints(chat_model: ChatOutlines) -> None:
"""Test type constraints for generating an integer"""
chat_model.type_constraints = int
messages = [
HumanMessage(
content="What is the answer to life, the universe, and everything?"
)
]
output = chat_model.invoke(messages)
assert isinstance(int(str(output.content)), int)
def test_chat_outlines_json(chat_model: ChatOutlines) -> None:
"""Test json for generating a valid JSON object"""
class Person(BaseModel):
name: str
chat_model.json_schema = Person
messages = [HumanMessage(content="Who are the main contributors to LangChain?")]
output = chat_model.invoke(messages)
person = Person.model_validate_json(str(output.content))
assert isinstance(person, Person)
def test_chat_outlines_grammar(chat_model: ChatOutlines) -> None:
"""Test grammar for generating a valid arithmetic expression"""
if chat_model.backend == "mlxlm":
pytest.skip("MLX grammars not yet supported.")
chat_model.grammar = """
?start: expression
?expression: term (("+" | "-") term)*
?term: factor (("*" | "/") factor)*
?factor: NUMBER | "-" factor | "(" expression ")"
%import common.NUMBER
%import common.WS
%ignore WS
"""
messages = [HumanMessage(content="Give me a complex arithmetic expression:")]
output = chat_model.invoke(messages)
# Validate the output is a non-empty string
assert (
isinstance(output.content, str) and output.content.strip()
), "Output should be a non-empty string"
# Use a simple regex to check if the output contains basic arithmetic operations and numbers
assert re.search(
r"[\d\+\-\*/\(\)]+", output.content
), f"Generated output '{output.content}' does not appear to be a valid arithmetic expression"
def test_chat_outlines_with_structured_output(chat_model: ChatOutlines) -> None:
"""Test that ChatOutlines can generate structured outputs"""
class AnswerWithJustification(BaseModel):
"""An answer to the user question along with justification for the answer."""
answer: str
justification: str
structured_chat_model = chat_model.with_structured_output(AnswerWithJustification)
result = structured_chat_model.invoke(
"What weighs more, a pound of bricks or a pound of feathers?"
)
assert isinstance(result, AnswerWithJustification)
assert isinstance(result.answer, str)
assert isinstance(result.justification, str)
assert len(result.answer) > 0
assert len(result.justification) > 0
structured_chat_model_with_raw = chat_model.with_structured_output(
AnswerWithJustification, include_raw=True
)
result_with_raw = structured_chat_model_with_raw.invoke(
"What weighs more, a pound of bricks or a pound of feathers?"
)
assert isinstance(result_with_raw, dict)
assert "raw" in result_with_raw
assert "parsed" in result_with_raw
assert "parsing_error" in result_with_raw
assert isinstance(result_with_raw["raw"], BaseMessage)
assert isinstance(result_with_raw["parsed"], AnswerWithJustification)
assert result_with_raw["parsing_error"] is None

View File

@@ -0,0 +1,123 @@
# flake8: noqa
"""Test Outlines wrapper."""
from typing import Generator
import re
import platform
import pytest
from langchain_community.llms.outlines import Outlines
from pydantic import BaseModel
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
MODEL = "microsoft/Phi-3-mini-4k-instruct"
LLAMACPP_MODEL = "microsoft/Phi-3-mini-4k-instruct-gguf/Phi-3-mini-4k-instruct-q4.gguf"
BACKENDS = ["transformers", "llamacpp"]
if platform.system() != "Darwin":
BACKENDS.append("vllm")
if platform.system() == "Darwin":
BACKENDS.append("mlxlm")
@pytest.fixture(params=BACKENDS)
def llm(request: pytest.FixtureRequest) -> Outlines:
if request.param == "llamacpp":
return Outlines(model=LLAMACPP_MODEL, backend=request.param, max_tokens=100)
else:
return Outlines(model=MODEL, backend=request.param, max_tokens=100)
def test_outlines_inference(llm: Outlines) -> None:
"""Test valid outlines inference."""
output = llm.invoke("Say foo:")
assert isinstance(output, str)
assert len(output) > 1
def test_outlines_streaming(llm: Outlines) -> None:
"""Test streaming tokens from Outlines."""
generator = llm.stream("Q: How do you say 'hello' in Spanish?\n\nA:")
stream_results_string = ""
assert isinstance(generator, Generator)
for chunk in generator:
print(chunk)
assert isinstance(chunk, str)
stream_results_string += chunk
print(stream_results_string)
assert len(stream_results_string.strip()) > 1
def test_outlines_streaming_callback(llm: Outlines) -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
MIN_CHUNKS = 5
callback_handler = FakeCallbackHandler()
llm.callbacks = [callback_handler]
llm.verbose = True
llm.invoke("Q: Can you count to 10? A:'1, ")
assert callback_handler.llm_streams >= MIN_CHUNKS
def test_outlines_regex(llm: Outlines) -> None:
"""Test regex for generating a valid IP address"""
ip_regex = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
llm.regex = ip_regex
assert llm.regex == ip_regex
output = llm.invoke("Q: What is the IP address of googles dns server?\n\nA: ")
assert isinstance(output, str)
assert re.match(
ip_regex, output
), f"Generated output '{output}' is not a valid IP address"
def test_outlines_type_constraints(llm: Outlines) -> None:
"""Test type constraints for generating an integer"""
llm.type_constraints = int
output = llm.invoke(
"Q: What is the answer to life, the universe, and everything?\n\nA: "
)
assert int(output)
def test_outlines_json(llm: Outlines) -> None:
"""Test json for generating a valid JSON object"""
class Person(BaseModel):
name: str
llm.json_schema = Person
output = llm.invoke("Q: Who is the author of LangChain?\n\nA: ")
person = Person.model_validate_json(output)
assert isinstance(person, Person)
def test_outlines_grammar(llm: Outlines) -> None:
"""Test grammar for generating a valid arithmetic expression"""
llm.grammar = """
?start: expression
?expression: term (("+" | "-") term)*
?term: factor (("*" | "/") factor)*
?factor: NUMBER | "-" factor | "(" expression ")"
%import common.NUMBER
%import common.WS
%ignore WS
"""
output = llm.invoke("Here is a complex arithmetic expression: ")
# Validate the output is a non-empty string
assert (
isinstance(output, str) and output.strip()
), "Output should be a non-empty string"
# Use a simple regex to check if the output contains basic arithmetic operations and numbers
assert re.search(
r"[\d\+\-\*/\(\)]+", output
), f"Generated output '{output}' does not appear to be a valid arithmetic expression"

View File

@@ -36,6 +36,7 @@ EXPECTED_ALL = [
"ChatOCIModelDeploymentTGI",
"ChatOllama",
"ChatOpenAI",
"ChatOutlines",
"ChatPerplexity",
"ChatPremAI",
"ChatSambaNovaCloud",

View File

@@ -0,0 +1,91 @@
import pytest
from _pytest.monkeypatch import MonkeyPatch
from pydantic import BaseModel, Field
from langchain_community.chat_models.outlines import ChatOutlines
def test_chat_outlines_initialization(monkeypatch: MonkeyPatch) -> None:
monkeypatch.setattr(ChatOutlines, "build_client", lambda self: self)
chat = ChatOutlines(
model="microsoft/Phi-3-mini-4k-instruct",
max_tokens=42,
stop=["\n"],
)
assert chat.model == "microsoft/Phi-3-mini-4k-instruct"
assert chat.max_tokens == 42
assert chat.backend == "transformers"
assert chat.stop == ["\n"]
def test_chat_outlines_backend_llamacpp(monkeypatch: MonkeyPatch) -> None:
monkeypatch.setattr(ChatOutlines, "build_client", lambda self: self)
chat = ChatOutlines(
model="TheBloke/Llama-2-7B-Chat-GGUF/llama-2-7b-chat.Q4_K_M.gguf",
backend="llamacpp",
)
assert chat.backend == "llamacpp"
def test_chat_outlines_backend_vllm(monkeypatch: MonkeyPatch) -> None:
monkeypatch.setattr(ChatOutlines, "build_client", lambda self: self)
chat = ChatOutlines(model="microsoft/Phi-3-mini-4k-instruct", backend="vllm")
assert chat.backend == "vllm"
def test_chat_outlines_backend_mlxlm(monkeypatch: MonkeyPatch) -> None:
monkeypatch.setattr(ChatOutlines, "build_client", lambda self: self)
chat = ChatOutlines(model="microsoft/Phi-3-mini-4k-instruct", backend="mlxlm")
assert chat.backend == "mlxlm"
def test_chat_outlines_with_regex(monkeypatch: MonkeyPatch) -> None:
monkeypatch.setattr(ChatOutlines, "build_client", lambda self: self)
regex = r"\d{3}-\d{3}-\d{4}"
chat = ChatOutlines(model="microsoft/Phi-3-mini-4k-instruct", regex=regex)
assert chat.regex == regex
def test_chat_outlines_with_type_constraints(monkeypatch: MonkeyPatch) -> None:
monkeypatch.setattr(ChatOutlines, "build_client", lambda self: self)
chat = ChatOutlines(model="microsoft/Phi-3-mini-4k-instruct", type_constraints=int)
assert chat.type_constraints == int # noqa
def test_chat_outlines_with_json_schema(monkeypatch: MonkeyPatch) -> None:
monkeypatch.setattr(ChatOutlines, "build_client", lambda self: self)
class TestSchema(BaseModel):
name: str = Field(description="A person's name")
age: int = Field(description="A person's age")
chat = ChatOutlines(
model="microsoft/Phi-3-mini-4k-instruct", json_schema=TestSchema
)
assert chat.json_schema == TestSchema
def test_chat_outlines_with_grammar(monkeypatch: MonkeyPatch) -> None:
monkeypatch.setattr(ChatOutlines, "build_client", lambda self: self)
grammar = """
?start: expression
?expression: term (("+" | "-") term)*
?term: factor (("*" | "/") factor)*
?factor: NUMBER | "-" factor | "(" expression ")"
%import common.NUMBER
"""
chat = ChatOutlines(model="microsoft/Phi-3-mini-4k-instruct", grammar=grammar)
assert chat.grammar == grammar
def test_raise_for_multiple_output_constraints(monkeypatch: MonkeyPatch) -> None:
monkeypatch.setattr(ChatOutlines, "build_client", lambda self: self)
with pytest.raises(ValueError):
ChatOutlines(
model="microsoft/Phi-3-mini-4k-instruct",
type_constraints=int,
regex=r"\d{3}-\d{3}-\d{4}",
)

View File

@@ -67,6 +67,7 @@ EXPECT_ALL = [
"OpenAIChat",
"OpenLLM",
"OpenLM",
"Outlines",
"PaiEasEndpoint",
"Petals",
"PipelineAI",

View File

@@ -0,0 +1,92 @@
import pytest
from _pytest.monkeypatch import MonkeyPatch
from langchain_community.llms.outlines import Outlines
def test_outlines_initialization(monkeypatch: MonkeyPatch) -> None:
monkeypatch.setattr(Outlines, "build_client", lambda self: self)
llm = Outlines(
model="microsoft/Phi-3-mini-4k-instruct",
max_tokens=42,
stop=["\n"],
)
assert llm.model == "microsoft/Phi-3-mini-4k-instruct"
assert llm.max_tokens == 42
assert llm.backend == "transformers"
assert llm.stop == ["\n"]
def test_outlines_backend_llamacpp(monkeypatch: MonkeyPatch) -> None:
monkeypatch.setattr(Outlines, "build_client", lambda self: self)
llm = Outlines(
model="TheBloke/Llama-2-7B-Chat-GGUF/llama-2-7b-chat.Q4_K_M.gguf",
backend="llamacpp",
)
assert llm.backend == "llamacpp"
def test_outlines_backend_vllm(monkeypatch: MonkeyPatch) -> None:
monkeypatch.setattr(Outlines, "build_client", lambda self: self)
llm = Outlines(model="microsoft/Phi-3-mini-4k-instruct", backend="vllm")
assert llm.backend == "vllm"
def test_outlines_backend_mlxlm(monkeypatch: MonkeyPatch) -> None:
monkeypatch.setattr(Outlines, "build_client", lambda self: self)
llm = Outlines(model="microsoft/Phi-3-mini-4k-instruct", backend="mlxlm")
assert llm.backend == "mlxlm"
def test_outlines_with_regex(monkeypatch: MonkeyPatch) -> None:
monkeypatch.setattr(Outlines, "build_client", lambda self: self)
regex = r"\d{3}-\d{3}-\d{4}"
llm = Outlines(model="microsoft/Phi-3-mini-4k-instruct", regex=regex)
assert llm.regex == regex
def test_outlines_with_type_constraints(monkeypatch: MonkeyPatch) -> None:
monkeypatch.setattr(Outlines, "build_client", lambda self: self)
llm = Outlines(model="microsoft/Phi-3-mini-4k-instruct", type_constraints=int)
assert llm.type_constraints == int # noqa
def test_outlines_with_json_schema(monkeypatch: MonkeyPatch) -> None:
monkeypatch.setattr(Outlines, "build_client", lambda self: self)
from pydantic import BaseModel, Field
class TestSchema(BaseModel):
name: str = Field(description="A person's name")
age: int = Field(description="A person's age")
llm = Outlines(model="microsoft/Phi-3-mini-4k-instruct", json_schema=TestSchema)
assert llm.json_schema == TestSchema
def test_outlines_with_grammar(monkeypatch: MonkeyPatch) -> None:
monkeypatch.setattr(Outlines, "build_client", lambda self: self)
grammar = """
?start: expression
?expression: term (("+" | "-") term)*
?term: factor (("*" | "/") factor)*
?factor: NUMBER | "-" factor | "(" expression ")"
%import common.NUMBER
"""
llm = Outlines(model="microsoft/Phi-3-mini-4k-instruct", grammar=grammar)
assert llm.grammar == grammar
def test_raise_for_multiple_output_constraints(monkeypatch: MonkeyPatch) -> None:
monkeypatch.setattr(Outlines, "build_client", lambda self: self)
with pytest.raises(ValueError):
Outlines(
model="microsoft/Phi-3-mini-4k-instruct",
type_constraints=int,
regex=r"\d{3}-\d{3}-\d{4}",
)
Outlines(
model="microsoft/Phi-3-mini-4k-instruct",
type_constraints=int,
regex=r"\d{3}-\d{3}-\d{4}",
)