mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-10 05:20:39 +00:00
experimental[minor]: Add bind_tools and with_structured_output functions to OllamaFunctions (#20881)
Implemented bind_tools for OllamaFunctions. Made OllamaFunctions sub class of ChatOllama. Implemented with_structured_output for OllamaFunctions. integration unit test has been updated. notebook has been updated. --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
d781560722
commit
2ddac9a7c3
@ -17,7 +17,7 @@
|
||||
"\n",
|
||||
"This notebook shows how to use an experimental wrapper around Ollama that gives it the same API as OpenAI Functions.\n",
|
||||
"\n",
|
||||
"Note that more powerful and capable models will perform better with complex schema and/or multiple functions. The examples below use Mistral.\n",
|
||||
"Note that more powerful and capable models will perform better with complex schema and/or multiple functions. The examples below use llama3 and phi3 models.\n",
|
||||
"For a complete list of supported models and model variants, see the [Ollama model library](https://ollama.ai/library).\n",
|
||||
"\n",
|
||||
"## Setup\n",
|
||||
@ -32,12 +32,18 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-04-28T00:53:25.276543Z",
|
||||
"start_time": "2024-04-28T00:53:24.881202Z"
|
||||
},
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_experimental.llms.ollama_functions import OllamaFunctions\n",
|
||||
"\n",
|
||||
"model = OllamaFunctions(model=\"mistral\")"
|
||||
"model = OllamaFunctions(model=\"llama3\", format=\"json\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -50,11 +56,16 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-04-26T04:59:17.270931Z",
|
||||
"start_time": "2024-04-26T04:59:17.263347Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = model.bind(\n",
|
||||
" functions=[\n",
|
||||
"model = model.bind_tools(\n",
|
||||
" tools=[\n",
|
||||
" {\n",
|
||||
" \"name\": \"get_current_weather\",\n",
|
||||
" \"description\": \"Get the current weather in a given location\",\n",
|
||||
@ -88,12 +99,17 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-04-26T04:59:26.092428Z",
|
||||
"start_time": "2024-04-26T04:59:17.272627Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content='', additional_kwargs={'function_call': {'name': 'get_current_weather', 'arguments': '{\"location\": \"Boston, MA\", \"unit\": \"celsius\"}'}})"
|
||||
"AIMessage(content='', additional_kwargs={'function_call': {'name': 'get_current_weather', 'arguments': '{\"location\": \"Boston, MA\"}'}}, id='run-1791f9fe-95ad-4ca4-bdf7-9f73eab31e6f-0')"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
@ -111,54 +127,119 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Using for extraction\n",
|
||||
"## Structured Output\n",
|
||||
"\n",
|
||||
"One useful thing you can do with function calling here is extracting properties from a given input in a structured format:"
|
||||
"One useful thing you can do with function calling using `with_structured_output()` function is extracting properties from a given input in a structured format:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-04-26T04:59:26.098828Z",
|
||||
"start_time": "2024-04-26T04:59:26.094021Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_core.prompts import PromptTemplate\n",
|
||||
"from langchain_core.pydantic_v1 import BaseModel, Field\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Schema for structured response\n",
|
||||
"class Person(BaseModel):\n",
|
||||
" name: str = Field(description=\"The person's name\", required=True)\n",
|
||||
" height: float = Field(description=\"The person's height\", required=True)\n",
|
||||
" hair_color: str = Field(description=\"The person's hair color\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Prompt template\n",
|
||||
"prompt = PromptTemplate.from_template(\n",
|
||||
" \"\"\"Alex is 5 feet tall. \n",
|
||||
"Claudia is 1 feet taller than Alex and jumps higher than him. \n",
|
||||
"Claudia is a brunette and Alex is blonde.\n",
|
||||
"\n",
|
||||
"Human: {question}\n",
|
||||
"AI: \"\"\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Chain\n",
|
||||
"llm = OllamaFunctions(model=\"phi3\", format=\"json\", temperature=0)\n",
|
||||
"structured_llm = llm.with_structured_output(Person)\n",
|
||||
"chain = prompt | structured_llm"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Extracting data about Alex"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-04-26T04:59:30.164955Z",
|
||||
"start_time": "2024-04-26T04:59:26.099790Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[{'name': 'Alex', 'height': 5, 'hair_color': 'blonde'},\n",
|
||||
" {'name': 'Claudia', 'height': 6, 'hair_color': 'brunette'}]"
|
||||
"Person(name='Alex', height=5.0, hair_color='blonde')"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.chains import create_extraction_chain\n",
|
||||
"\n",
|
||||
"# Schema\n",
|
||||
"schema = {\n",
|
||||
" \"properties\": {\n",
|
||||
" \"name\": {\"type\": \"string\"},\n",
|
||||
" \"height\": {\"type\": \"integer\"},\n",
|
||||
" \"hair_color\": {\"type\": \"string\"},\n",
|
||||
" },\n",
|
||||
" \"required\": [\"name\", \"height\"],\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"# Input\n",
|
||||
"input = \"\"\"Alex is 5 feet tall. Claudia is 1 feet taller than Alex and jumps higher than him. Claudia is a brunette and Alex is blonde.\"\"\"\n",
|
||||
"\n",
|
||||
"# Run chain\n",
|
||||
"llm = OllamaFunctions(model=\"mistral\", temperature=0)\n",
|
||||
"chain = create_extraction_chain(schema, llm)\n",
|
||||
"chain.run(input)"
|
||||
"alex = chain.invoke(\"Describe Alex\")\n",
|
||||
"alex"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Extracting data about Claudia"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-04-26T04:59:31.509846Z",
|
||||
"start_time": "2024-04-26T04:59:30.165662Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Person(name='Claudia', height=6.0, hair_color='brunette')"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"claudia = chain.invoke(\"Describe Claudia\")\n",
|
||||
"claudia"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@ -172,9 +253,9 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.5"
|
||||
"version": "3.9.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
|
@ -1,14 +1,34 @@
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
from operator import itemgetter
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
|
||||
from langchain_community.chat_models.ollama import ChatOllama
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.messages import AIMessage, BaseMessage
|
||||
from langchain_core.output_parsers.base import OutputParserLike
|
||||
from langchain_core.output_parsers.json import JsonOutputParser
|
||||
from langchain_core.output_parsers.pydantic import PydanticOutputParser
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
from langchain_core.prompts import SystemMessagePromptTemplate
|
||||
|
||||
from langchain_experimental.pydantic_v1 import root_validator
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.runnables import Runnable, RunnableLambda
|
||||
from langchain_core.runnables.base import RunnableMap
|
||||
from langchain_core.runnables.passthrough import RunnablePassthrough
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
DEFAULT_SYSTEM_TEMPLATE = """You have access to the following tools:
|
||||
|
||||
@ -22,7 +42,6 @@ You must always select one of the above tools and respond with only a JSON objec
|
||||
}}
|
||||
""" # noqa: E501
|
||||
|
||||
|
||||
DEFAULT_RESPONSE_FUNCTION = {
|
||||
"name": "__conversational_response",
|
||||
"description": (
|
||||
@ -40,26 +59,219 @@ DEFAULT_RESPONSE_FUNCTION = {
|
||||
},
|
||||
}
|
||||
|
||||
_BM = TypeVar("_BM", bound=BaseModel)
|
||||
_DictOrPydanticClass = Union[Dict[str, Any], Type[_BM]]
|
||||
_DictOrPydantic = Union[Dict, _BM]
|
||||
|
||||
class OllamaFunctions(BaseChatModel):
|
||||
|
||||
def _is_pydantic_class(obj: Any) -> bool:
|
||||
return isinstance(obj, type) and (
|
||||
issubclass(obj, BaseModel) or BaseModel in obj.__bases__
|
||||
)
|
||||
|
||||
|
||||
def convert_to_ollama_tool(tool: Any) -> Dict:
|
||||
"""Convert a tool to an Ollama tool."""
|
||||
if _is_pydantic_class(tool):
|
||||
schema = tool.construct().schema()
|
||||
definition = {"name": schema["title"], "properties": schema["properties"]}
|
||||
if "required" in schema:
|
||||
definition["required"] = schema["required"]
|
||||
|
||||
return definition
|
||||
raise ValueError(
|
||||
f"Cannot convert {tool} to an Ollama tool. {tool} needs to be a Pydantic model."
|
||||
)
|
||||
|
||||
|
||||
class _AllReturnType(TypedDict):
|
||||
raw: BaseMessage
|
||||
parsed: Optional[_DictOrPydantic]
|
||||
parsing_error: Optional[BaseException]
|
||||
|
||||
|
||||
def parse_response(message: BaseMessage) -> str:
|
||||
"""Extract `function_call` from `AIMessage`."""
|
||||
if isinstance(message, AIMessage):
|
||||
kwargs = message.additional_kwargs
|
||||
if "function_call" in kwargs:
|
||||
if "arguments" in kwargs["function_call"]:
|
||||
return kwargs["function_call"]["arguments"]
|
||||
raise ValueError(
|
||||
f"`arguments` missing from `function_call` within AIMessage: {message}"
|
||||
)
|
||||
raise ValueError(
|
||||
"`function_call` missing from `additional_kwargs` "
|
||||
f"within AIMessage: {message}"
|
||||
)
|
||||
raise ValueError(f"`message` is not an instance of `AIMessage`: {message}")
|
||||
|
||||
|
||||
class OllamaFunctions(ChatOllama):
|
||||
"""Function chat model that uses Ollama API."""
|
||||
|
||||
llm: ChatOllama
|
||||
tool_system_prompt_template: str = DEFAULT_SYSTEM_TEMPLATE
|
||||
|
||||
tool_system_prompt_template: str
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
values["llm"] = values.get("llm") or ChatOllama(**values, format="json")
|
||||
values["tool_system_prompt_template"] = (
|
||||
values.get("tool_system_prompt_template") or DEFAULT_SYSTEM_TEMPLATE
|
||||
)
|
||||
return values
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
return self.bind(functions=tools, **kwargs)
|
||||
|
||||
@property
|
||||
def model(self) -> BaseChatModel:
|
||||
"""For backwards compatibility."""
|
||||
return self.llm
|
||||
@overload
|
||||
def with_structured_output(
|
||||
self,
|
||||
schema: Optional[_DictOrPydanticClass] = None,
|
||||
*,
|
||||
include_raw: Literal[True] = True,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, _AllReturnType]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def with_structured_output(
|
||||
self,
|
||||
schema: Optional[_DictOrPydanticClass] = None,
|
||||
*,
|
||||
include_raw: Literal[False] = False,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, _DictOrPydantic]:
|
||||
...
|
||||
|
||||
def with_structured_output(
|
||||
self,
|
||||
schema: Optional[_DictOrPydanticClass] = None,
|
||||
*,
|
||||
include_raw: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, _DictOrPydantic]:
|
||||
"""Model wrapper that returns outputs formatted to match the given schema.
|
||||
|
||||
Args:
|
||||
schema: The output schema as a dict or a Pydantic class. If a Pydantic class
|
||||
then the model output will be an object of that class. If a dict then
|
||||
the model output will be a dict. With a Pydantic class the returned
|
||||
attributes will be validated, whereas with a dict they will not be.
|
||||
include_raw: If False then only the parsed structured output is returned. If
|
||||
an error occurs during model output parsing it will be raised. If True
|
||||
then both the raw model response (a BaseMessage) and the parsed model
|
||||
response will be returned. If an error occurs during output parsing it
|
||||
will be caught and returned as well. The final output is always a dict
|
||||
with keys "raw", "parsed", and "parsing_error".
|
||||
|
||||
Returns:
|
||||
A Runnable that takes any ChatModel input and returns as output:
|
||||
|
||||
If include_raw is True then a dict with keys:
|
||||
raw: BaseMessage
|
||||
parsed: Optional[_DictOrPydantic]
|
||||
parsing_error: Optional[BaseException]
|
||||
|
||||
If include_raw is False then just _DictOrPydantic is returned,
|
||||
where _DictOrPydantic depends on the schema:
|
||||
|
||||
If schema is a Pydantic class then _DictOrPydantic is the Pydantic
|
||||
class.
|
||||
|
||||
If schema is a dict then _DictOrPydantic is a dict.
|
||||
|
||||
Example: Pydantic schema (include_raw=False):
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_experimental.llms import OllamaFunctions
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
|
||||
class AnswerWithJustification(BaseModel):
|
||||
'''An answer to the user question along with justification for the answer.'''
|
||||
answer: str
|
||||
justification: str
|
||||
|
||||
llm = OllamaFunctions(model="phi3", format="json", temperature=0)
|
||||
structured_llm = llm.with_structured_output(AnswerWithJustification)
|
||||
|
||||
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
|
||||
|
||||
# -> AnswerWithJustification(
|
||||
# answer='They weigh the same',
|
||||
# justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'
|
||||
# )
|
||||
|
||||
Example: Pydantic schema (include_raw=True):
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_experimental.llms import OllamaFunctions
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
|
||||
class AnswerWithJustification(BaseModel):
|
||||
'''An answer to the user question along with justification for the answer.'''
|
||||
answer: str
|
||||
justification: str
|
||||
|
||||
llm = OllamaFunctions(model="phi3", format="json", temperature=0)
|
||||
structured_llm = llm.with_structured_output(AnswerWithJustification, include_raw=True)
|
||||
|
||||
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
|
||||
# -> {
|
||||
# 'raw': AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_Ao02pnFYXD6GN1yzc0uXPsvF', 'function': {'arguments': '{"answer":"They weigh the same.","justification":"Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ."}', 'name': 'AnswerWithJustification'}, 'type': 'function'}]}),
|
||||
# 'parsed': AnswerWithJustification(answer='They weigh the same.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'),
|
||||
# 'parsing_error': None
|
||||
# }
|
||||
|
||||
Example: dict schema (method="include_raw=False):
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_experimental.llms import OllamaFunctions, convert_to_ollama_tool
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
|
||||
class AnswerWithJustification(BaseModel):
|
||||
'''An answer to the user question along with justification for the answer.'''
|
||||
answer: str
|
||||
justification: str
|
||||
|
||||
dict_schema = convert_to_ollama_tool(AnswerWithJustification)
|
||||
llm = OllamaFunctions(model="phi3", format="json", temperature=0)
|
||||
structured_llm = llm.with_structured_output(dict_schema)
|
||||
|
||||
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
|
||||
# -> {
|
||||
# 'answer': 'They weigh the same',
|
||||
# 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume and density of the two substances differ.'
|
||||
# }
|
||||
|
||||
|
||||
""" # noqa: E501
|
||||
if kwargs:
|
||||
raise ValueError(f"Received unsupported arguments {kwargs}")
|
||||
is_pydantic_schema = _is_pydantic_class(schema)
|
||||
if schema is None:
|
||||
raise ValueError(
|
||||
"schema must be specified when method is 'function_calling'. "
|
||||
"Received None."
|
||||
)
|
||||
llm = self.bind_tools(tools=[schema], format="json")
|
||||
if is_pydantic_schema:
|
||||
output_parser: OutputParserLike = PydanticOutputParser(
|
||||
pydantic_object=schema
|
||||
)
|
||||
else:
|
||||
output_parser = JsonOutputParser()
|
||||
|
||||
parser_chain = RunnableLambda(parse_response) | output_parser
|
||||
if include_raw:
|
||||
parser_assign = RunnablePassthrough.assign(
|
||||
parsed=itemgetter("raw") | parser_chain, parsing_error=lambda _: None
|
||||
)
|
||||
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
|
||||
parser_with_fallback = parser_assign.with_fallbacks(
|
||||
[parser_none], exception_key="parsing_error"
|
||||
)
|
||||
return RunnableMap(raw=llm) | parser_with_fallback
|
||||
else:
|
||||
return llm | parser_chain
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
@ -69,37 +281,41 @@ class OllamaFunctions(BaseChatModel):
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
functions = kwargs.get("functions", [])
|
||||
if "functions" in kwargs:
|
||||
del kwargs["functions"]
|
||||
if "function_call" in kwargs:
|
||||
functions = [
|
||||
fn for fn in functions if fn["name"] == kwargs["function_call"]["name"]
|
||||
]
|
||||
if not functions:
|
||||
raise ValueError(
|
||||
'If "function_call" is specified, you must also pass a matching \
|
||||
function in "functions".'
|
||||
"If `function_call` is specified, you must also pass a "
|
||||
"matching function in `functions`."
|
||||
)
|
||||
del kwargs["function_call"]
|
||||
elif not functions:
|
||||
functions.append(DEFAULT_RESPONSE_FUNCTION)
|
||||
if _is_pydantic_class(functions[0]):
|
||||
functions = [convert_to_ollama_tool(fn) for fn in functions]
|
||||
system_message_prompt_template = SystemMessagePromptTemplate.from_template(
|
||||
self.tool_system_prompt_template
|
||||
)
|
||||
system_message = system_message_prompt_template.format(
|
||||
tools=json.dumps(functions, indent=2)
|
||||
)
|
||||
if "functions" in kwargs:
|
||||
del kwargs["functions"]
|
||||
response_message = self.llm.invoke(
|
||||
[system_message] + messages, stop=stop, callbacks=run_manager, **kwargs
|
||||
response_message = super()._generate(
|
||||
[system_message] + messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
chat_generation_content = response_message.content
|
||||
chat_generation_content = response_message.generations[0].text
|
||||
if not isinstance(chat_generation_content, str):
|
||||
raise ValueError("OllamaFunctions does not support non-string output.")
|
||||
try:
|
||||
parsed_chat_result = json.loads(chat_generation_content)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(
|
||||
f'"{self.llm.model}" did not respond with valid JSON. Please try again.'
|
||||
f"""'{self.model}' did not respond with valid JSON.
|
||||
Please try again.
|
||||
Response: {chat_generation_content}"""
|
||||
)
|
||||
called_tool_name = parsed_chat_result["tool"]
|
||||
called_tool_arguments = parsed_chat_result["tool_input"]
|
||||
@ -108,8 +324,8 @@ function in "functions".'
|
||||
)
|
||||
if called_tool is None:
|
||||
raise ValueError(
|
||||
f"Failed to parse a function call from {self.llm.model} \
|
||||
output: {chat_generation_content}"
|
||||
f"Failed to parse a function call from {self.model} output: "
|
||||
f"{chat_generation_content}"
|
||||
)
|
||||
if called_tool["name"] == DEFAULT_RESPONSE_FUNCTION["name"]:
|
||||
return ChatResult(
|
||||
|
@ -2,9 +2,18 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from langchain_community.chat_models.ollama import ChatOllama
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
|
||||
from langchain_experimental.llms.ollama_functions import OllamaFunctions
|
||||
from langchain_experimental.llms.ollama_functions import (
|
||||
OllamaFunctions,
|
||||
convert_to_ollama_tool,
|
||||
)
|
||||
|
||||
|
||||
class Joke(BaseModel):
|
||||
setup: str = Field(description="The setup of the joke")
|
||||
punchline: str = Field(description="The punchline to the joke")
|
||||
|
||||
|
||||
class TestOllamaFunctions(unittest.TestCase):
|
||||
@ -13,12 +22,11 @@ class TestOllamaFunctions(unittest.TestCase):
|
||||
"""
|
||||
|
||||
def test_default_ollama_functions(self) -> None:
|
||||
base_model = OllamaFunctions(model="mistral")
|
||||
self.assertIsInstance(base_model.model, ChatOllama)
|
||||
base_model = OllamaFunctions(model="llama3", format="json")
|
||||
|
||||
# bind functions
|
||||
model = base_model.bind(
|
||||
functions=[
|
||||
model = base_model.bind_tools(
|
||||
tools=[
|
||||
{
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
@ -47,3 +55,29 @@ class TestOllamaFunctions(unittest.TestCase):
|
||||
function_call = res.additional_kwargs.get("function_call")
|
||||
assert function_call
|
||||
self.assertEqual(function_call.get("name"), "get_current_weather")
|
||||
|
||||
def test_ollama_structured_output(self) -> None:
|
||||
model = OllamaFunctions(model="phi3")
|
||||
structured_llm = model.with_structured_output(Joke, include_raw=False)
|
||||
|
||||
res = structured_llm.invoke("Tell me a joke about cats")
|
||||
assert isinstance(res, Joke)
|
||||
|
||||
def test_ollama_structured_output_with_json(self) -> None:
|
||||
model = OllamaFunctions(model="phi3")
|
||||
joke_schema = convert_to_ollama_tool(Joke)
|
||||
structured_llm = model.with_structured_output(joke_schema, include_raw=False)
|
||||
|
||||
res = structured_llm.invoke("Tell me a joke about cats")
|
||||
assert "setup" in res
|
||||
assert "punchline" in res
|
||||
|
||||
def test_ollama_structured_output_raw(self) -> None:
|
||||
model = OllamaFunctions(model="phi3")
|
||||
structured_llm = model.with_structured_output(Joke, include_raw=True)
|
||||
|
||||
res = structured_llm.invoke("Tell me a joke about cars")
|
||||
assert "raw" in res
|
||||
assert "parsed" in res
|
||||
assert isinstance(res["raw"], AIMessage)
|
||||
assert isinstance(res["parsed"], Joke)
|
||||
|
Loading…
Reference in New Issue
Block a user