mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 06:39:52 +00:00
experimental[patch]: Pass enum only to openai in llm graph transformer (#21860)
Some models like Groq return bad request if you pass in `enum` parameter in tool definition
This commit is contained in:
parent
aab9cb666f
commit
a43515ca65
@ -150,16 +150,24 @@ def optional_enum_field(
|
|||||||
enum_values: Optional[List[str]] = None,
|
enum_values: Optional[List[str]] = None,
|
||||||
description: str = "",
|
description: str = "",
|
||||||
input_type: str = "node",
|
input_type: str = "node",
|
||||||
|
llm_type: Optional[str] = None,
|
||||||
**field_kwargs: Any,
|
**field_kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Utility function to conditionally create a field with an enum constraint."""
|
"""Utility function to conditionally create a field with an enum constraint."""
|
||||||
if enum_values:
|
# Only openai supports enum param
|
||||||
|
if enum_values and llm_type == "openai-chat":
|
||||||
return Field(
|
return Field(
|
||||||
...,
|
...,
|
||||||
enum=enum_values,
|
enum=enum_values,
|
||||||
description=f"{description}. Available options are {enum_values}",
|
description=f"{description}. Available options are {enum_values}",
|
||||||
**field_kwargs,
|
**field_kwargs,
|
||||||
)
|
)
|
||||||
|
elif enum_values:
|
||||||
|
return Field(
|
||||||
|
...,
|
||||||
|
description=f"{description}. Available options are {enum_values}",
|
||||||
|
**field_kwargs,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
additional_info = _get_additional_info(input_type)
|
additional_info = _get_additional_info(input_type)
|
||||||
return Field(..., description=description + additional_info, **field_kwargs)
|
return Field(..., description=description + additional_info, **field_kwargs)
|
||||||
@ -271,6 +279,7 @@ def create_simple_model(
|
|||||||
node_labels: Optional[List[str]] = None,
|
node_labels: Optional[List[str]] = None,
|
||||||
rel_types: Optional[List[str]] = None,
|
rel_types: Optional[List[str]] = None,
|
||||||
node_properties: Union[bool, List[str]] = False,
|
node_properties: Union[bool, List[str]] = False,
|
||||||
|
llm_type: Optional[str] = None,
|
||||||
) -> Type[_Graph]:
|
) -> Type[_Graph]:
|
||||||
"""
|
"""
|
||||||
Simple model allows to limit node and/or relationship types.
|
Simple model allows to limit node and/or relationship types.
|
||||||
@ -288,6 +297,7 @@ def create_simple_model(
|
|||||||
node_labels,
|
node_labels,
|
||||||
description="The type or label of the node.",
|
description="The type or label of the node.",
|
||||||
input_type="node",
|
input_type="node",
|
||||||
|
llm_type=llm_type,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
@ -325,6 +335,7 @@ def create_simple_model(
|
|||||||
node_labels,
|
node_labels,
|
||||||
description="The type or label of the source node.",
|
description="The type or label of the source node.",
|
||||||
input_type="node",
|
input_type="node",
|
||||||
|
llm_type=llm_type,
|
||||||
)
|
)
|
||||||
target_node_id: str = Field(
|
target_node_id: str = Field(
|
||||||
description="Name or human-readable unique identifier of target node"
|
description="Name or human-readable unique identifier of target node"
|
||||||
@ -333,11 +344,13 @@ def create_simple_model(
|
|||||||
node_labels,
|
node_labels,
|
||||||
description="The type or label of the target node.",
|
description="The type or label of the target node.",
|
||||||
input_type="node",
|
input_type="node",
|
||||||
|
llm_type=llm_type,
|
||||||
)
|
)
|
||||||
type: str = optional_enum_field(
|
type: str = optional_enum_field(
|
||||||
rel_types,
|
rel_types,
|
||||||
description="The type of the relationship.",
|
description="The type of the relationship.",
|
||||||
input_type="relationship",
|
input_type="relationship",
|
||||||
|
llm_type=llm_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
class DynamicGraph(_Graph):
|
class DynamicGraph(_Graph):
|
||||||
@ -572,8 +585,12 @@ class LLMGraphTransformer:
|
|||||||
self.chain = prompt | llm
|
self.chain = prompt | llm
|
||||||
else:
|
else:
|
||||||
# Define chain
|
# Define chain
|
||||||
|
try:
|
||||||
|
llm_type = llm._llm_type # type: ignore
|
||||||
|
except AttributeError:
|
||||||
|
llm_type = None
|
||||||
schema = create_simple_model(
|
schema = create_simple_model(
|
||||||
allowed_nodes, allowed_relationships, node_properties
|
allowed_nodes, allowed_relationships, node_properties, llm_type
|
||||||
)
|
)
|
||||||
structured_llm = llm.with_structured_output(schema, include_raw=True)
|
structured_llm = llm.with_structured_output(schema, include_raw=True)
|
||||||
prompt = prompt or default_prompt
|
prompt = prompt or default_prompt
|
||||||
|
Loading…
Reference in New Issue
Block a user