mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-12 12:59:07 +00:00
community: Outlines integration (#27449)
In collaboration with @rlouf I build an [outlines](https://dottxt-ai.github.io/outlines/latest/) integration for langchain! I think this is really useful for doing any type of structured output locally. [Dottxt](https://dottxt.co) spend alot of work optimising this process at a lower level ([outlines-core](https://pypi.org/project/outlines-core/0.1.14/) written in rust) so I think this is a better alternative over all current approaches in langchain to do structured output. It also implements the `.with_structured_output` method so it should be a drop in replacement for a lot of applications. The integration includes: - **Outlines LLM class** - **ChatOutlines class** - **Tutorial Cookbooks** - **Documentation Page** - **Validation and error messages** - **Exposes Outlines Structured output features** - **Support for multiple backends** - **Integration and Unit Tests** Dependencies: `outlines` + additional (depending on backend used) I am not sure if the unit-tests comply with all requirements, if not I suggest to just remove them since I don't see a useful way to do it differently. ### Quick overview: Chat Models: <img width="698" alt="image" src="https://github.com/user-attachments/assets/05a499b9-858c-4397-a9ff-165c2b3e7acc"> Structured Output: <img width="955" alt="image" src="https://github.com/user-attachments/assets/b9fcac11-d3e5-4698-b1ae-8c4cb3d54c45"> --------- Co-authored-by: Vadym Barda <vadym@langchain.dev>
This commit is contained in:
@@ -0,0 +1,177 @@
|
||||
# flake8: noqa
|
||||
"""Test ChatOutlines wrapper."""
|
||||
|
||||
from typing import Generator
|
||||
import re
|
||||
import platform
|
||||
import pytest
|
||||
|
||||
from langchain_community.chat_models.outlines import ChatOutlines
|
||||
from langchain_core.messages import AIMessage, HumanMessage, BaseMessage
|
||||
from langchain_core.messages import BaseMessageChunk
|
||||
from pydantic import BaseModel
|
||||
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
|
||||
|
||||
MODEL = "microsoft/Phi-3-mini-4k-instruct"
|
||||
LLAMACPP_MODEL = "bartowski/qwen2.5-7b-ins-v3-GGUF/qwen2.5-7b-ins-v3-Q4_K_M.gguf"
|
||||
|
||||
BACKENDS = ["transformers", "llamacpp"]
|
||||
if platform.system() != "Darwin":
|
||||
BACKENDS.append("vllm")
|
||||
if platform.system() == "Darwin":
|
||||
BACKENDS.append("mlxlm")
|
||||
|
||||
|
||||
@pytest.fixture(params=BACKENDS)
|
||||
def chat_model(request: pytest.FixtureRequest) -> ChatOutlines:
|
||||
if request.param == "llamacpp":
|
||||
return ChatOutlines(model=LLAMACPP_MODEL, backend=request.param)
|
||||
else:
|
||||
return ChatOutlines(model=MODEL, backend=request.param)
|
||||
|
||||
|
||||
def test_chat_outlines_inference(chat_model: ChatOutlines) -> None:
|
||||
"""Test valid ChatOutlines inference."""
|
||||
messages = [HumanMessage(content="Say foo:")]
|
||||
output = chat_model.invoke(messages)
|
||||
assert isinstance(output, AIMessage)
|
||||
assert len(output.content) > 1
|
||||
|
||||
|
||||
def test_chat_outlines_streaming(chat_model: ChatOutlines) -> None:
|
||||
"""Test streaming tokens from ChatOutlines."""
|
||||
messages = [HumanMessage(content="How do you say 'hello' in Spanish?")]
|
||||
generator = chat_model.stream(messages)
|
||||
stream_results_string = ""
|
||||
assert isinstance(generator, Generator)
|
||||
|
||||
for chunk in generator:
|
||||
assert isinstance(chunk, BaseMessageChunk)
|
||||
if isinstance(chunk.content, str):
|
||||
stream_results_string += chunk.content
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid content type, only str is supported, "
|
||||
f"got {type(chunk.content)}"
|
||||
)
|
||||
assert len(stream_results_string.strip()) > 1
|
||||
|
||||
|
||||
def test_chat_outlines_streaming_callback(chat_model: ChatOutlines) -> None:
|
||||
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
||||
MIN_CHUNKS = 5
|
||||
callback_handler = FakeCallbackHandler()
|
||||
chat_model.callbacks = [callback_handler]
|
||||
chat_model.verbose = True
|
||||
messages = [HumanMessage(content="Can you count to 10?")]
|
||||
chat_model.invoke(messages)
|
||||
assert callback_handler.llm_streams >= MIN_CHUNKS
|
||||
|
||||
|
||||
def test_chat_outlines_regex(chat_model: ChatOutlines) -> None:
|
||||
"""Test regex for generating a valid IP address"""
|
||||
ip_regex = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
|
||||
chat_model.regex = ip_regex
|
||||
assert chat_model.regex == ip_regex
|
||||
|
||||
messages = [HumanMessage(content="What is the IP address of Google's DNS server?")]
|
||||
output = chat_model.invoke(messages)
|
||||
|
||||
assert isinstance(output, AIMessage)
|
||||
assert re.match(
|
||||
ip_regex, str(output.content)
|
||||
), f"Generated output '{output.content}' is not a valid IP address"
|
||||
|
||||
|
||||
def test_chat_outlines_type_constraints(chat_model: ChatOutlines) -> None:
|
||||
"""Test type constraints for generating an integer"""
|
||||
chat_model.type_constraints = int
|
||||
messages = [
|
||||
HumanMessage(
|
||||
content="What is the answer to life, the universe, and everything?"
|
||||
)
|
||||
]
|
||||
output = chat_model.invoke(messages)
|
||||
assert isinstance(int(str(output.content)), int)
|
||||
|
||||
|
||||
def test_chat_outlines_json(chat_model: ChatOutlines) -> None:
|
||||
"""Test json for generating a valid JSON object"""
|
||||
|
||||
class Person(BaseModel):
|
||||
name: str
|
||||
|
||||
chat_model.json_schema = Person
|
||||
messages = [HumanMessage(content="Who are the main contributors to LangChain?")]
|
||||
output = chat_model.invoke(messages)
|
||||
person = Person.model_validate_json(str(output.content))
|
||||
assert isinstance(person, Person)
|
||||
|
||||
|
||||
def test_chat_outlines_grammar(chat_model: ChatOutlines) -> None:
|
||||
"""Test grammar for generating a valid arithmetic expression"""
|
||||
if chat_model.backend == "mlxlm":
|
||||
pytest.skip("MLX grammars not yet supported.")
|
||||
|
||||
chat_model.grammar = """
|
||||
?start: expression
|
||||
?expression: term (("+" | "-") term)*
|
||||
?term: factor (("*" | "/") factor)*
|
||||
?factor: NUMBER | "-" factor | "(" expression ")"
|
||||
%import common.NUMBER
|
||||
%import common.WS
|
||||
%ignore WS
|
||||
"""
|
||||
|
||||
messages = [HumanMessage(content="Give me a complex arithmetic expression:")]
|
||||
output = chat_model.invoke(messages)
|
||||
|
||||
# Validate the output is a non-empty string
|
||||
assert (
|
||||
isinstance(output.content, str) and output.content.strip()
|
||||
), "Output should be a non-empty string"
|
||||
|
||||
# Use a simple regex to check if the output contains basic arithmetic operations and numbers
|
||||
assert re.search(
|
||||
r"[\d\+\-\*/\(\)]+", output.content
|
||||
), f"Generated output '{output.content}' does not appear to be a valid arithmetic expression"
|
||||
|
||||
|
||||
def test_chat_outlines_with_structured_output(chat_model: ChatOutlines) -> None:
|
||||
"""Test that ChatOutlines can generate structured outputs"""
|
||||
|
||||
class AnswerWithJustification(BaseModel):
|
||||
"""An answer to the user question along with justification for the answer."""
|
||||
|
||||
answer: str
|
||||
justification: str
|
||||
|
||||
structured_chat_model = chat_model.with_structured_output(AnswerWithJustification)
|
||||
|
||||
result = structured_chat_model.invoke(
|
||||
"What weighs more, a pound of bricks or a pound of feathers?"
|
||||
)
|
||||
|
||||
assert isinstance(result, AnswerWithJustification)
|
||||
assert isinstance(result.answer, str)
|
||||
assert isinstance(result.justification, str)
|
||||
assert len(result.answer) > 0
|
||||
assert len(result.justification) > 0
|
||||
|
||||
structured_chat_model_with_raw = chat_model.with_structured_output(
|
||||
AnswerWithJustification, include_raw=True
|
||||
)
|
||||
|
||||
result_with_raw = structured_chat_model_with_raw.invoke(
|
||||
"What weighs more, a pound of bricks or a pound of feathers?"
|
||||
)
|
||||
|
||||
assert isinstance(result_with_raw, dict)
|
||||
assert "raw" in result_with_raw
|
||||
assert "parsed" in result_with_raw
|
||||
assert "parsing_error" in result_with_raw
|
||||
assert isinstance(result_with_raw["raw"], BaseMessage)
|
||||
assert isinstance(result_with_raw["parsed"], AnswerWithJustification)
|
||||
assert result_with_raw["parsing_error"] is None
|
123
libs/community/tests/integration_tests/llms/test_outlines.py
Normal file
123
libs/community/tests/integration_tests/llms/test_outlines.py
Normal file
@@ -0,0 +1,123 @@
|
||||
# flake8: noqa
|
||||
"""Test Outlines wrapper."""
|
||||
|
||||
from typing import Generator
|
||||
import re
|
||||
import platform
|
||||
import pytest
|
||||
|
||||
from langchain_community.llms.outlines import Outlines
|
||||
from pydantic import BaseModel
|
||||
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
|
||||
|
||||
MODEL = "microsoft/Phi-3-mini-4k-instruct"
|
||||
LLAMACPP_MODEL = "microsoft/Phi-3-mini-4k-instruct-gguf/Phi-3-mini-4k-instruct-q4.gguf"
|
||||
|
||||
BACKENDS = ["transformers", "llamacpp"]
|
||||
if platform.system() != "Darwin":
|
||||
BACKENDS.append("vllm")
|
||||
if platform.system() == "Darwin":
|
||||
BACKENDS.append("mlxlm")
|
||||
|
||||
|
||||
@pytest.fixture(params=BACKENDS)
|
||||
def llm(request: pytest.FixtureRequest) -> Outlines:
|
||||
if request.param == "llamacpp":
|
||||
return Outlines(model=LLAMACPP_MODEL, backend=request.param, max_tokens=100)
|
||||
else:
|
||||
return Outlines(model=MODEL, backend=request.param, max_tokens=100)
|
||||
|
||||
|
||||
def test_outlines_inference(llm: Outlines) -> None:
|
||||
"""Test valid outlines inference."""
|
||||
output = llm.invoke("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
assert len(output) > 1
|
||||
|
||||
|
||||
def test_outlines_streaming(llm: Outlines) -> None:
|
||||
"""Test streaming tokens from Outlines."""
|
||||
generator = llm.stream("Q: How do you say 'hello' in Spanish?\n\nA:")
|
||||
stream_results_string = ""
|
||||
assert isinstance(generator, Generator)
|
||||
|
||||
for chunk in generator:
|
||||
print(chunk)
|
||||
assert isinstance(chunk, str)
|
||||
stream_results_string += chunk
|
||||
print(stream_results_string)
|
||||
assert len(stream_results_string.strip()) > 1
|
||||
|
||||
|
||||
def test_outlines_streaming_callback(llm: Outlines) -> None:
|
||||
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
||||
MIN_CHUNKS = 5
|
||||
|
||||
callback_handler = FakeCallbackHandler()
|
||||
llm.callbacks = [callback_handler]
|
||||
llm.verbose = True
|
||||
llm.invoke("Q: Can you count to 10? A:'1, ")
|
||||
assert callback_handler.llm_streams >= MIN_CHUNKS
|
||||
|
||||
|
||||
def test_outlines_regex(llm: Outlines) -> None:
|
||||
"""Test regex for generating a valid IP address"""
|
||||
ip_regex = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
|
||||
llm.regex = ip_regex
|
||||
assert llm.regex == ip_regex
|
||||
|
||||
output = llm.invoke("Q: What is the IP address of googles dns server?\n\nA: ")
|
||||
|
||||
assert isinstance(output, str)
|
||||
|
||||
assert re.match(
|
||||
ip_regex, output
|
||||
), f"Generated output '{output}' is not a valid IP address"
|
||||
|
||||
|
||||
def test_outlines_type_constraints(llm: Outlines) -> None:
|
||||
"""Test type constraints for generating an integer"""
|
||||
llm.type_constraints = int
|
||||
output = llm.invoke(
|
||||
"Q: What is the answer to life, the universe, and everything?\n\nA: "
|
||||
)
|
||||
assert int(output)
|
||||
|
||||
|
||||
def test_outlines_json(llm: Outlines) -> None:
|
||||
"""Test json for generating a valid JSON object"""
|
||||
|
||||
class Person(BaseModel):
|
||||
name: str
|
||||
|
||||
llm.json_schema = Person
|
||||
output = llm.invoke("Q: Who is the author of LangChain?\n\nA: ")
|
||||
person = Person.model_validate_json(output)
|
||||
assert isinstance(person, Person)
|
||||
|
||||
|
||||
def test_outlines_grammar(llm: Outlines) -> None:
|
||||
"""Test grammar for generating a valid arithmetic expression"""
|
||||
llm.grammar = """
|
||||
?start: expression
|
||||
?expression: term (("+" | "-") term)*
|
||||
?term: factor (("*" | "/") factor)*
|
||||
?factor: NUMBER | "-" factor | "(" expression ")"
|
||||
%import common.NUMBER
|
||||
%import common.WS
|
||||
%ignore WS
|
||||
"""
|
||||
|
||||
output = llm.invoke("Here is a complex arithmetic expression: ")
|
||||
|
||||
# Validate the output is a non-empty string
|
||||
assert (
|
||||
isinstance(output, str) and output.strip()
|
||||
), "Output should be a non-empty string"
|
||||
|
||||
# Use a simple regex to check if the output contains basic arithmetic operations and numbers
|
||||
assert re.search(
|
||||
r"[\d\+\-\*/\(\)]+", output
|
||||
), f"Generated output '{output}' does not appear to be a valid arithmetic expression"
|
Reference in New Issue
Block a user