experimental[patch]: Pass enum only to openai in llm graph transformer (#21860)

Some models like Groq return bad request if you pass in `enum` parameter
in tool definition
This commit is contained in:
Tomaz Bratanic 2024-05-21 00:02:48 +02:00 committed by GitHub
parent aab9cb666f
commit a43515ca65
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -150,16 +150,24 @@ def optional_enum_field(
enum_values: Optional[List[str]] = None, enum_values: Optional[List[str]] = None,
description: str = "", description: str = "",
input_type: str = "node", input_type: str = "node",
llm_type: Optional[str] = None,
**field_kwargs: Any, **field_kwargs: Any,
) -> Any: ) -> Any:
"""Utility function to conditionally create a field with an enum constraint.""" """Utility function to conditionally create a field with an enum constraint."""
if enum_values: # Only openai supports enum param
if enum_values and llm_type == "openai-chat":
return Field( return Field(
..., ...,
enum=enum_values, enum=enum_values,
description=f"{description}. Available options are {enum_values}", description=f"{description}. Available options are {enum_values}",
**field_kwargs, **field_kwargs,
) )
elif enum_values:
return Field(
...,
description=f"{description}. Available options are {enum_values}",
**field_kwargs,
)
else: else:
additional_info = _get_additional_info(input_type) additional_info = _get_additional_info(input_type)
return Field(..., description=description + additional_info, **field_kwargs) return Field(..., description=description + additional_info, **field_kwargs)
@ -271,6 +279,7 @@ def create_simple_model(
node_labels: Optional[List[str]] = None, node_labels: Optional[List[str]] = None,
rel_types: Optional[List[str]] = None, rel_types: Optional[List[str]] = None,
node_properties: Union[bool, List[str]] = False, node_properties: Union[bool, List[str]] = False,
llm_type: Optional[str] = None,
) -> Type[_Graph]: ) -> Type[_Graph]:
""" """
Simple model allows to limit node and/or relationship types. Simple model allows to limit node and/or relationship types.
@ -288,6 +297,7 @@ def create_simple_model(
node_labels, node_labels,
description="The type or label of the node.", description="The type or label of the node.",
input_type="node", input_type="node",
llm_type=llm_type,
), ),
), ),
} }
@ -325,6 +335,7 @@ def create_simple_model(
node_labels, node_labels,
description="The type or label of the source node.", description="The type or label of the source node.",
input_type="node", input_type="node",
llm_type=llm_type,
) )
target_node_id: str = Field( target_node_id: str = Field(
description="Name or human-readable unique identifier of target node" description="Name or human-readable unique identifier of target node"
@ -333,11 +344,13 @@ def create_simple_model(
node_labels, node_labels,
description="The type or label of the target node.", description="The type or label of the target node.",
input_type="node", input_type="node",
llm_type=llm_type,
) )
type: str = optional_enum_field( type: str = optional_enum_field(
rel_types, rel_types,
description="The type of the relationship.", description="The type of the relationship.",
input_type="relationship", input_type="relationship",
llm_type=llm_type,
) )
class DynamicGraph(_Graph): class DynamicGraph(_Graph):
@ -572,8 +585,12 @@ class LLMGraphTransformer:
self.chain = prompt | llm self.chain = prompt | llm
else: else:
# Define chain # Define chain
try:
llm_type = llm._llm_type # type: ignore
except AttributeError:
llm_type = None
schema = create_simple_model( schema = create_simple_model(
allowed_nodes, allowed_relationships, node_properties allowed_nodes, allowed_relationships, node_properties, llm_type
) )
structured_llm = llm.with_structured_output(schema, include_raw=True) structured_llm = llm.with_structured_output(schema, include_raw=True)
prompt = prompt or default_prompt prompt = prompt or default_prompt