mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-14 14:05:37 +00:00
LM Format Enforcer Integration + Sample Notebook (#12625)
## Description This PR adds support for [lm-format-enforcer](https://github.com/noamgat/lm-format-enforcer) to LangChain.  The library is similar to jsonformer / RELLM which are supported in Langchain, but has several advantages such as - Batching and Beam search support - More complete JSON Schema support - LLM has control over whitespace, improving quality - Better runtime performance due to only calling the LLM's generate() function once per generate() call. The integration is loosely based on the jsonformer integration in terms of project structure. ## Dependencies No compile-time dependency was added, but if `lm-format-enforcer` is not installed, a runtime error will occur if it is trying to be used. ## Tests Due to the integration modifying the internal parameters of the underlying huggingface transformer LLM, it is not possible to test without building a real LM, which requires internet access. So, similar to the jsonformer and RELLM integrations, the testing is via the notebook. ## Twitter Handle [@noamgat](https://twitter.com/noamgat) Looking forward to hearing feedback! --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
|
||||
from langchain_experimental.llms.jsonformer_decoder import JsonFormer
|
||||
from langchain_experimental.llms.llamaapi import ChatLlamaAPI
|
||||
from langchain_experimental.llms.lmformatenforcer_decoder import LMFormatEnforcer
|
||||
from langchain_experimental.llms.rellm_decoder import RELLM
|
||||
|
||||
__all__ = ["RELLM", "JsonFormer", "ChatLlamaAPI"]
|
||||
__all__ = ["RELLM", "JsonFormer", "ChatLlamaAPI", "LMFormatEnforcer"]
|
||||
|
@@ -0,0 +1,83 @@
|
||||
"""Experimental implementation of lm-format-enforcer wrapped LLM."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, List, Optional
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
from langchain_experimental.pydantic_v1 import Field
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import lmformatenforcer
|
||||
|
||||
|
||||
def import_lmformatenforcer() -> lmformatenforcer:
|
||||
"""Lazily import lmformatenforcer."""
|
||||
try:
|
||||
import lmformatenforcer
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import lmformatenforcer python package. "
|
||||
"Please install it with `pip install lm-format-enforcer`."
|
||||
)
|
||||
return lmformatenforcer
|
||||
|
||||
|
||||
class LMFormatEnforcer(HuggingFacePipeline):
|
||||
"""LMFormatEnforcer wrapped LLM using HuggingFace Pipeline API.
|
||||
|
||||
This pipeline is experimental and not yet stable.
|
||||
"""
|
||||
|
||||
json_schema: Optional[dict] = Field(
|
||||
description="The JSON Schema to complete.", default=None
|
||||
)
|
||||
regex: Optional[str] = Field(
|
||||
description="The regular expression to complete.", default=None
|
||||
)
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
lmformatenforcer = import_lmformatenforcer()
|
||||
import lmformatenforcer.integrations.transformers as hf_integration
|
||||
|
||||
# We integrate lmformatenforcer by adding a prefix_allowed_tokens_fn.
|
||||
# It has to be done on each call, because the prefix function is stateful.
|
||||
if "prefix_allowed_tokens_fn" in self.pipeline._forward_params:
|
||||
raise ValueError(
|
||||
"prefix_allowed_tokens_fn param is forbidden with LMFormatEnforcer."
|
||||
)
|
||||
|
||||
has_json_schema = self.json_schema is not None
|
||||
has_regex = self.regex is not None
|
||||
if has_json_schema == has_regex:
|
||||
raise ValueError(
|
||||
"You must specify exactly one of json_schema or a regex, but not both."
|
||||
)
|
||||
|
||||
if has_json_schema:
|
||||
parser = lmformatenforcer.JsonSchemaParser(self.json_schema)
|
||||
else:
|
||||
parser = lmformatenforcer.RegexParser(self.regex)
|
||||
|
||||
prefix_function = hf_integration.build_transformers_prefix_allowed_tokens_fn(
|
||||
self.pipeline.tokenizer, parser
|
||||
)
|
||||
self.pipeline._forward_params["prefix_allowed_tokens_fn"] = prefix_function
|
||||
|
||||
result = super()._generate(
|
||||
prompts,
|
||||
stop=stop,
|
||||
run_manager=run_manager,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
del self.pipeline._forward_params["prefix_allowed_tokens_fn"]
|
||||
return result
|
Reference in New Issue
Block a user