From a43515ca65307bfef917339db343a32e0ba0c91d Mon Sep 17 00:00:00 2001 From: Tomaz Bratanic Date: Tue, 21 May 2024 00:02:48 +0200 Subject: [PATCH] experimental[patch]: Pass enum only to openai in llm graph transformer (#21860) Some models like Groq return bad request if you pass in `enum` parameter in tool definition --- .../graph_transformers/llm.py | 21 +++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/libs/experimental/langchain_experimental/graph_transformers/llm.py b/libs/experimental/langchain_experimental/graph_transformers/llm.py index 6cd97941a9c..29af044237f 100644 --- a/libs/experimental/langchain_experimental/graph_transformers/llm.py +++ b/libs/experimental/langchain_experimental/graph_transformers/llm.py @@ -150,16 +150,24 @@ def optional_enum_field( enum_values: Optional[List[str]] = None, description: str = "", input_type: str = "node", + llm_type: Optional[str] = None, **field_kwargs: Any, ) -> Any: """Utility function to conditionally create a field with an enum constraint.""" - if enum_values: + # Only openai supports enum param + if enum_values and llm_type == "openai-chat": return Field( ..., enum=enum_values, description=f"{description}. Available options are {enum_values}", **field_kwargs, ) + elif enum_values: + return Field( + ..., + description=f"{description}. Available options are {enum_values}", + **field_kwargs, + ) else: additional_info = _get_additional_info(input_type) return Field(..., description=description + additional_info, **field_kwargs) @@ -271,6 +279,7 @@ def create_simple_model( node_labels: Optional[List[str]] = None, rel_types: Optional[List[str]] = None, node_properties: Union[bool, List[str]] = False, + llm_type: Optional[str] = None, ) -> Type[_Graph]: """ Simple model allows to limit node and/or relationship types. @@ -288,6 +297,7 @@ def create_simple_model( node_labels, description="The type or label of the node.", input_type="node", + llm_type=llm_type, ), ), } @@ -325,6 +335,7 @@ def create_simple_model( node_labels, description="The type or label of the source node.", input_type="node", + llm_type=llm_type, ) target_node_id: str = Field( description="Name or human-readable unique identifier of target node" @@ -333,11 +344,13 @@ def create_simple_model( node_labels, description="The type or label of the target node.", input_type="node", + llm_type=llm_type, ) type: str = optional_enum_field( rel_types, description="The type of the relationship.", input_type="relationship", + llm_type=llm_type, ) class DynamicGraph(_Graph): @@ -572,8 +585,12 @@ class LLMGraphTransformer: self.chain = prompt | llm else: # Define chain + try: + llm_type = llm._llm_type # type: ignore + except AttributeError: + llm_type = None schema = create_simple_model( - allowed_nodes, allowed_relationships, node_properties + allowed_nodes, allowed_relationships, node_properties, llm_type ) structured_llm = llm.with_structured_output(schema, include_raw=True) prompt = prompt or default_prompt