experimental[patch]: Pass enum only to openai in llm graph transformer (#21860)

Some models like Groq return bad request if you pass in `enum` parameter
in tool definition
This commit is contained in:
Tomaz Bratanic 2024-05-21 00:02:48 +02:00 committed by GitHub
parent aab9cb666f
commit a43515ca65
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -150,16 +150,24 @@ def optional_enum_field(
enum_values: Optional[List[str]] = None,
description: str = "",
input_type: str = "node",
llm_type: Optional[str] = None,
**field_kwargs: Any,
) -> Any:
"""Utility function to conditionally create a field with an enum constraint."""
if enum_values:
# Only openai supports enum param
if enum_values and llm_type == "openai-chat":
return Field(
...,
enum=enum_values,
description=f"{description}. Available options are {enum_values}",
**field_kwargs,
)
elif enum_values:
return Field(
...,
description=f"{description}. Available options are {enum_values}",
**field_kwargs,
)
else:
additional_info = _get_additional_info(input_type)
return Field(..., description=description + additional_info, **field_kwargs)
@ -271,6 +279,7 @@ def create_simple_model(
node_labels: Optional[List[str]] = None,
rel_types: Optional[List[str]] = None,
node_properties: Union[bool, List[str]] = False,
llm_type: Optional[str] = None,
) -> Type[_Graph]:
"""
Simple model allows to limit node and/or relationship types.
@ -288,6 +297,7 @@ def create_simple_model(
node_labels,
description="The type or label of the node.",
input_type="node",
llm_type=llm_type,
),
),
}
@ -325,6 +335,7 @@ def create_simple_model(
node_labels,
description="The type or label of the source node.",
input_type="node",
llm_type=llm_type,
)
target_node_id: str = Field(
description="Name or human-readable unique identifier of target node"
@ -333,11 +344,13 @@ def create_simple_model(
node_labels,
description="The type or label of the target node.",
input_type="node",
llm_type=llm_type,
)
type: str = optional_enum_field(
rel_types,
description="The type of the relationship.",
input_type="relationship",
llm_type=llm_type,
)
class DynamicGraph(_Graph):
@ -572,8 +585,12 @@ class LLMGraphTransformer:
self.chain = prompt | llm
else:
# Define chain
try:
llm_type = llm._llm_type # type: ignore
except AttributeError:
llm_type = None
schema = create_simple_model(
allowed_nodes, allowed_relationships, node_properties
allowed_nodes, allowed_relationships, node_properties, llm_type
)
structured_llm = llm.with_structured_output(schema, include_raw=True)
prompt = prompt or default_prompt