mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-12 06:13:36 +00:00
experimental[minor]: Add bind_tools and with_structured_output functions to OllamaFunctions (#20881)
Implemented bind_tools for OllamaFunctions. Made OllamaFunctions sub class of ChatOllama. Implemented with_structured_output for OllamaFunctions. integration unit test has been updated. notebook has been updated. --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
d781560722
commit
2ddac9a7c3
@ -17,7 +17,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"This notebook shows how to use an experimental wrapper around Ollama that gives it the same API as OpenAI Functions.\n",
|
"This notebook shows how to use an experimental wrapper around Ollama that gives it the same API as OpenAI Functions.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"Note that more powerful and capable models will perform better with complex schema and/or multiple functions. The examples below use Mistral.\n",
|
"Note that more powerful and capable models will perform better with complex schema and/or multiple functions. The examples below use llama3 and phi3 models.\n",
|
||||||
"For a complete list of supported models and model variants, see the [Ollama model library](https://ollama.ai/library).\n",
|
"For a complete list of supported models and model variants, see the [Ollama model library](https://ollama.ai/library).\n",
|
||||||
"\n",
|
"\n",
|
||||||
"## Setup\n",
|
"## Setup\n",
|
||||||
@ -32,12 +32,18 @@
|
|||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 1,
|
"execution_count": 1,
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-04-28T00:53:25.276543Z",
|
||||||
|
"start_time": "2024-04-28T00:53:24.881202Z"
|
||||||
|
},
|
||||||
|
"scrolled": true
|
||||||
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from langchain_experimental.llms.ollama_functions import OllamaFunctions\n",
|
"from langchain_experimental.llms.ollama_functions import OllamaFunctions\n",
|
||||||
"\n",
|
"\n",
|
||||||
"model = OllamaFunctions(model=\"mistral\")"
|
"model = OllamaFunctions(model=\"llama3\", format=\"json\")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -50,11 +56,16 @@
|
|||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 2,
|
"execution_count": 2,
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-04-26T04:59:17.270931Z",
|
||||||
|
"start_time": "2024-04-26T04:59:17.263347Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"model = model.bind(\n",
|
"model = model.bind_tools(\n",
|
||||||
" functions=[\n",
|
" tools=[\n",
|
||||||
" {\n",
|
" {\n",
|
||||||
" \"name\": \"get_current_weather\",\n",
|
" \"name\": \"get_current_weather\",\n",
|
||||||
" \"description\": \"Get the current weather in a given location\",\n",
|
" \"description\": \"Get the current weather in a given location\",\n",
|
||||||
@ -88,12 +99,17 @@
|
|||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": 3,
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-04-26T04:59:26.092428Z",
|
||||||
|
"start_time": "2024-04-26T04:59:17.272627Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"AIMessage(content='', additional_kwargs={'function_call': {'name': 'get_current_weather', 'arguments': '{\"location\": \"Boston, MA\", \"unit\": \"celsius\"}'}})"
|
"AIMessage(content='', additional_kwargs={'function_call': {'name': 'get_current_weather', 'arguments': '{\"location\": \"Boston, MA\"}'}}, id='run-1791f9fe-95ad-4ca4-bdf7-9f73eab31e6f-0')"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 3,
|
"execution_count": 3,
|
||||||
@ -111,54 +127,119 @@
|
|||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"## Using for extraction\n",
|
"## Structured Output\n",
|
||||||
"\n",
|
"\n",
|
||||||
"One useful thing you can do with function calling here is extracting properties from a given input in a structured format:"
|
"One useful thing you can do with function calling using `with_structured_output()` function is extracting properties from a given input in a structured format:"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 4,
|
"execution_count": 4,
|
||||||
|
"metadata": {
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-04-26T04:59:26.098828Z",
|
||||||
|
"start_time": "2024-04-26T04:59:26.094021Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain_core.prompts import PromptTemplate\n",
|
||||||
|
"from langchain_core.pydantic_v1 import BaseModel, Field\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"# Schema for structured response\n",
|
||||||
|
"class Person(BaseModel):\n",
|
||||||
|
" name: str = Field(description=\"The person's name\", required=True)\n",
|
||||||
|
" height: float = Field(description=\"The person's height\", required=True)\n",
|
||||||
|
" hair_color: str = Field(description=\"The person's hair color\")\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"# Prompt template\n",
|
||||||
|
"prompt = PromptTemplate.from_template(\n",
|
||||||
|
" \"\"\"Alex is 5 feet tall. \n",
|
||||||
|
"Claudia is 1 feet taller than Alex and jumps higher than him. \n",
|
||||||
|
"Claudia is a brunette and Alex is blonde.\n",
|
||||||
|
"\n",
|
||||||
|
"Human: {question}\n",
|
||||||
|
"AI: \"\"\"\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"# Chain\n",
|
||||||
|
"llm = OllamaFunctions(model=\"phi3\", format=\"json\", temperature=0)\n",
|
||||||
|
"structured_llm = llm.with_structured_output(Person)\n",
|
||||||
|
"chain = prompt | structured_llm"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Extracting data about Alex"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-04-26T04:59:30.164955Z",
|
||||||
|
"start_time": "2024-04-26T04:59:26.099790Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"[{'name': 'Alex', 'height': 5, 'hair_color': 'blonde'},\n",
|
"Person(name='Alex', height=5.0, hair_color='blonde')"
|
||||||
" {'name': 'Claudia', 'height': 6, 'hair_color': 'brunette'}]"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 4,
|
"execution_count": 5,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"from langchain.chains import create_extraction_chain\n",
|
"alex = chain.invoke(\"Describe Alex\")\n",
|
||||||
"\n",
|
"alex"
|
||||||
"# Schema\n",
|
]
|
||||||
"schema = {\n",
|
},
|
||||||
" \"properties\": {\n",
|
{
|
||||||
" \"name\": {\"type\": \"string\"},\n",
|
"cell_type": "markdown",
|
||||||
" \"height\": {\"type\": \"integer\"},\n",
|
"metadata": {},
|
||||||
" \"hair_color\": {\"type\": \"string\"},\n",
|
"source": [
|
||||||
" },\n",
|
"### Extracting data about Claudia"
|
||||||
" \"required\": [\"name\", \"height\"],\n",
|
]
|
||||||
"}\n",
|
},
|
||||||
"\n",
|
{
|
||||||
"# Input\n",
|
"cell_type": "code",
|
||||||
"input = \"\"\"Alex is 5 feet tall. Claudia is 1 feet taller than Alex and jumps higher than him. Claudia is a brunette and Alex is blonde.\"\"\"\n",
|
"execution_count": 6,
|
||||||
"\n",
|
"metadata": {
|
||||||
"# Run chain\n",
|
"ExecuteTime": {
|
||||||
"llm = OllamaFunctions(model=\"mistral\", temperature=0)\n",
|
"end_time": "2024-04-26T04:59:31.509846Z",
|
||||||
"chain = create_extraction_chain(schema, llm)\n",
|
"start_time": "2024-04-26T04:59:30.165662Z"
|
||||||
"chain.run(input)"
|
}
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"Person(name='Claudia', height=6.0, hair_color='brunette')"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 6,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"claudia = chain.invoke(\"Describe Claudia\")\n",
|
||||||
|
"claudia"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": ".venv",
|
"display_name": "Python 3 (ipykernel)",
|
||||||
"language": "python",
|
"language": "python",
|
||||||
"name": "python3"
|
"name": "python3"
|
||||||
},
|
},
|
||||||
@ -172,9 +253,9 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.5"
|
"version": "3.9.1"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
"nbformat_minor": 2
|
"nbformat_minor": 4
|
||||||
}
|
}
|
||||||
|
@ -1,14 +1,34 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Any, Dict, List, Optional
|
from operator import itemgetter
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Type,
|
||||||
|
TypedDict,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
overload,
|
||||||
|
)
|
||||||
|
|
||||||
from langchain_community.chat_models.ollama import ChatOllama
|
from langchain_community.chat_models.ollama import ChatOllama
|
||||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||||
from langchain_core.language_models import BaseChatModel
|
from langchain_core.language_models import LanguageModelInput
|
||||||
from langchain_core.messages import AIMessage, BaseMessage
|
from langchain_core.messages import AIMessage, BaseMessage
|
||||||
|
from langchain_core.output_parsers.base import OutputParserLike
|
||||||
|
from langchain_core.output_parsers.json import JsonOutputParser
|
||||||
|
from langchain_core.output_parsers.pydantic import PydanticOutputParser
|
||||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||||
from langchain_core.prompts import SystemMessagePromptTemplate
|
from langchain_core.prompts import SystemMessagePromptTemplate
|
||||||
|
from langchain_core.pydantic_v1 import BaseModel
|
||||||
from langchain_experimental.pydantic_v1 import root_validator
|
from langchain_core.runnables import Runnable, RunnableLambda
|
||||||
|
from langchain_core.runnables.base import RunnableMap
|
||||||
|
from langchain_core.runnables.passthrough import RunnablePassthrough
|
||||||
|
from langchain_core.tools import BaseTool
|
||||||
|
|
||||||
DEFAULT_SYSTEM_TEMPLATE = """You have access to the following tools:
|
DEFAULT_SYSTEM_TEMPLATE = """You have access to the following tools:
|
||||||
|
|
||||||
@ -22,7 +42,6 @@ You must always select one of the above tools and respond with only a JSON objec
|
|||||||
}}
|
}}
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_RESPONSE_FUNCTION = {
|
DEFAULT_RESPONSE_FUNCTION = {
|
||||||
"name": "__conversational_response",
|
"name": "__conversational_response",
|
||||||
"description": (
|
"description": (
|
||||||
@ -40,26 +59,219 @@ DEFAULT_RESPONSE_FUNCTION = {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_BM = TypeVar("_BM", bound=BaseModel)
|
||||||
|
_DictOrPydanticClass = Union[Dict[str, Any], Type[_BM]]
|
||||||
|
_DictOrPydantic = Union[Dict, _BM]
|
||||||
|
|
||||||
class OllamaFunctions(BaseChatModel):
|
|
||||||
|
def _is_pydantic_class(obj: Any) -> bool:
|
||||||
|
return isinstance(obj, type) and (
|
||||||
|
issubclass(obj, BaseModel) or BaseModel in obj.__bases__
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_ollama_tool(tool: Any) -> Dict:
|
||||||
|
"""Convert a tool to an Ollama tool."""
|
||||||
|
if _is_pydantic_class(tool):
|
||||||
|
schema = tool.construct().schema()
|
||||||
|
definition = {"name": schema["title"], "properties": schema["properties"]}
|
||||||
|
if "required" in schema:
|
||||||
|
definition["required"] = schema["required"]
|
||||||
|
|
||||||
|
return definition
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot convert {tool} to an Ollama tool. {tool} needs to be a Pydantic model."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _AllReturnType(TypedDict):
|
||||||
|
raw: BaseMessage
|
||||||
|
parsed: Optional[_DictOrPydantic]
|
||||||
|
parsing_error: Optional[BaseException]
|
||||||
|
|
||||||
|
|
||||||
|
def parse_response(message: BaseMessage) -> str:
|
||||||
|
"""Extract `function_call` from `AIMessage`."""
|
||||||
|
if isinstance(message, AIMessage):
|
||||||
|
kwargs = message.additional_kwargs
|
||||||
|
if "function_call" in kwargs:
|
||||||
|
if "arguments" in kwargs["function_call"]:
|
||||||
|
return kwargs["function_call"]["arguments"]
|
||||||
|
raise ValueError(
|
||||||
|
f"`arguments` missing from `function_call` within AIMessage: {message}"
|
||||||
|
)
|
||||||
|
raise ValueError(
|
||||||
|
"`function_call` missing from `additional_kwargs` "
|
||||||
|
f"within AIMessage: {message}"
|
||||||
|
)
|
||||||
|
raise ValueError(f"`message` is not an instance of `AIMessage`: {message}")
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaFunctions(ChatOllama):
|
||||||
"""Function chat model that uses Ollama API."""
|
"""Function chat model that uses Ollama API."""
|
||||||
|
|
||||||
llm: ChatOllama
|
tool_system_prompt_template: str = DEFAULT_SYSTEM_TEMPLATE
|
||||||
|
|
||||||
tool_system_prompt_template: str
|
def __init__(self, **kwargs: Any) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
@root_validator(pre=True)
|
def bind_tools(
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
self,
|
||||||
values["llm"] = values.get("llm") or ChatOllama(**values, format="json")
|
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
||||||
values["tool_system_prompt_template"] = (
|
**kwargs: Any,
|
||||||
values.get("tool_system_prompt_template") or DEFAULT_SYSTEM_TEMPLATE
|
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||||
|
return self.bind(functions=tools, **kwargs)
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def with_structured_output(
|
||||||
|
self,
|
||||||
|
schema: Optional[_DictOrPydanticClass] = None,
|
||||||
|
*,
|
||||||
|
include_raw: Literal[True] = True,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Runnable[LanguageModelInput, _AllReturnType]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def with_structured_output(
|
||||||
|
self,
|
||||||
|
schema: Optional[_DictOrPydanticClass] = None,
|
||||||
|
*,
|
||||||
|
include_raw: Literal[False] = False,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Runnable[LanguageModelInput, _DictOrPydantic]:
|
||||||
|
...
|
||||||
|
|
||||||
|
def with_structured_output(
|
||||||
|
self,
|
||||||
|
schema: Optional[_DictOrPydanticClass] = None,
|
||||||
|
*,
|
||||||
|
include_raw: bool = False,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Runnable[LanguageModelInput, _DictOrPydantic]:
|
||||||
|
"""Model wrapper that returns outputs formatted to match the given schema.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
schema: The output schema as a dict or a Pydantic class. If a Pydantic class
|
||||||
|
then the model output will be an object of that class. If a dict then
|
||||||
|
the model output will be a dict. With a Pydantic class the returned
|
||||||
|
attributes will be validated, whereas with a dict they will not be.
|
||||||
|
include_raw: If False then only the parsed structured output is returned. If
|
||||||
|
an error occurs during model output parsing it will be raised. If True
|
||||||
|
then both the raw model response (a BaseMessage) and the parsed model
|
||||||
|
response will be returned. If an error occurs during output parsing it
|
||||||
|
will be caught and returned as well. The final output is always a dict
|
||||||
|
with keys "raw", "parsed", and "parsing_error".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Runnable that takes any ChatModel input and returns as output:
|
||||||
|
|
||||||
|
If include_raw is True then a dict with keys:
|
||||||
|
raw: BaseMessage
|
||||||
|
parsed: Optional[_DictOrPydantic]
|
||||||
|
parsing_error: Optional[BaseException]
|
||||||
|
|
||||||
|
If include_raw is False then just _DictOrPydantic is returned,
|
||||||
|
where _DictOrPydantic depends on the schema:
|
||||||
|
|
||||||
|
If schema is a Pydantic class then _DictOrPydantic is the Pydantic
|
||||||
|
class.
|
||||||
|
|
||||||
|
If schema is a dict then _DictOrPydantic is a dict.
|
||||||
|
|
||||||
|
Example: Pydantic schema (include_raw=False):
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain_experimental.llms import OllamaFunctions
|
||||||
|
from langchain_core.pydantic_v1 import BaseModel
|
||||||
|
|
||||||
|
class AnswerWithJustification(BaseModel):
|
||||||
|
'''An answer to the user question along with justification for the answer.'''
|
||||||
|
answer: str
|
||||||
|
justification: str
|
||||||
|
|
||||||
|
llm = OllamaFunctions(model="phi3", format="json", temperature=0)
|
||||||
|
structured_llm = llm.with_structured_output(AnswerWithJustification)
|
||||||
|
|
||||||
|
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
|
||||||
|
|
||||||
|
# -> AnswerWithJustification(
|
||||||
|
# answer='They weigh the same',
|
||||||
|
# justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'
|
||||||
|
# )
|
||||||
|
|
||||||
|
Example: Pydantic schema (include_raw=True):
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain_experimental.llms import OllamaFunctions
|
||||||
|
from langchain_core.pydantic_v1 import BaseModel
|
||||||
|
|
||||||
|
class AnswerWithJustification(BaseModel):
|
||||||
|
'''An answer to the user question along with justification for the answer.'''
|
||||||
|
answer: str
|
||||||
|
justification: str
|
||||||
|
|
||||||
|
llm = OllamaFunctions(model="phi3", format="json", temperature=0)
|
||||||
|
structured_llm = llm.with_structured_output(AnswerWithJustification, include_raw=True)
|
||||||
|
|
||||||
|
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
|
||||||
|
# -> {
|
||||||
|
# 'raw': AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_Ao02pnFYXD6GN1yzc0uXPsvF', 'function': {'arguments': '{"answer":"They weigh the same.","justification":"Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ."}', 'name': 'AnswerWithJustification'}, 'type': 'function'}]}),
|
||||||
|
# 'parsed': AnswerWithJustification(answer='They weigh the same.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'),
|
||||||
|
# 'parsing_error': None
|
||||||
|
# }
|
||||||
|
|
||||||
|
Example: dict schema (method="include_raw=False):
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain_experimental.llms import OllamaFunctions, convert_to_ollama_tool
|
||||||
|
from langchain_core.pydantic_v1 import BaseModel
|
||||||
|
|
||||||
|
class AnswerWithJustification(BaseModel):
|
||||||
|
'''An answer to the user question along with justification for the answer.'''
|
||||||
|
answer: str
|
||||||
|
justification: str
|
||||||
|
|
||||||
|
dict_schema = convert_to_ollama_tool(AnswerWithJustification)
|
||||||
|
llm = OllamaFunctions(model="phi3", format="json", temperature=0)
|
||||||
|
structured_llm = llm.with_structured_output(dict_schema)
|
||||||
|
|
||||||
|
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
|
||||||
|
# -> {
|
||||||
|
# 'answer': 'They weigh the same',
|
||||||
|
# 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume and density of the two substances differ.'
|
||||||
|
# }
|
||||||
|
|
||||||
|
|
||||||
|
""" # noqa: E501
|
||||||
|
if kwargs:
|
||||||
|
raise ValueError(f"Received unsupported arguments {kwargs}")
|
||||||
|
is_pydantic_schema = _is_pydantic_class(schema)
|
||||||
|
if schema is None:
|
||||||
|
raise ValueError(
|
||||||
|
"schema must be specified when method is 'function_calling'. "
|
||||||
|
"Received None."
|
||||||
)
|
)
|
||||||
return values
|
llm = self.bind_tools(tools=[schema], format="json")
|
||||||
|
if is_pydantic_schema:
|
||||||
|
output_parser: OutputParserLike = PydanticOutputParser(
|
||||||
|
pydantic_object=schema
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output_parser = JsonOutputParser()
|
||||||
|
|
||||||
@property
|
parser_chain = RunnableLambda(parse_response) | output_parser
|
||||||
def model(self) -> BaseChatModel:
|
if include_raw:
|
||||||
"""For backwards compatibility."""
|
parser_assign = RunnablePassthrough.assign(
|
||||||
return self.llm
|
parsed=itemgetter("raw") | parser_chain, parsing_error=lambda _: None
|
||||||
|
)
|
||||||
|
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
|
||||||
|
parser_with_fallback = parser_assign.with_fallbacks(
|
||||||
|
[parser_none], exception_key="parsing_error"
|
||||||
|
)
|
||||||
|
return RunnableMap(raw=llm) | parser_with_fallback
|
||||||
|
else:
|
||||||
|
return llm | parser_chain
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
@ -69,37 +281,41 @@ class OllamaFunctions(BaseChatModel):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
functions = kwargs.get("functions", [])
|
functions = kwargs.get("functions", [])
|
||||||
|
if "functions" in kwargs:
|
||||||
|
del kwargs["functions"]
|
||||||
if "function_call" in kwargs:
|
if "function_call" in kwargs:
|
||||||
functions = [
|
functions = [
|
||||||
fn for fn in functions if fn["name"] == kwargs["function_call"]["name"]
|
fn for fn in functions if fn["name"] == kwargs["function_call"]["name"]
|
||||||
]
|
]
|
||||||
if not functions:
|
if not functions:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'If "function_call" is specified, you must also pass a matching \
|
"If `function_call` is specified, you must also pass a "
|
||||||
function in "functions".'
|
"matching function in `functions`."
|
||||||
)
|
)
|
||||||
del kwargs["function_call"]
|
del kwargs["function_call"]
|
||||||
elif not functions:
|
elif not functions:
|
||||||
functions.append(DEFAULT_RESPONSE_FUNCTION)
|
functions.append(DEFAULT_RESPONSE_FUNCTION)
|
||||||
|
if _is_pydantic_class(functions[0]):
|
||||||
|
functions = [convert_to_ollama_tool(fn) for fn in functions]
|
||||||
system_message_prompt_template = SystemMessagePromptTemplate.from_template(
|
system_message_prompt_template = SystemMessagePromptTemplate.from_template(
|
||||||
self.tool_system_prompt_template
|
self.tool_system_prompt_template
|
||||||
)
|
)
|
||||||
system_message = system_message_prompt_template.format(
|
system_message = system_message_prompt_template.format(
|
||||||
tools=json.dumps(functions, indent=2)
|
tools=json.dumps(functions, indent=2)
|
||||||
)
|
)
|
||||||
if "functions" in kwargs:
|
response_message = super()._generate(
|
||||||
del kwargs["functions"]
|
[system_message] + messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
response_message = self.llm.invoke(
|
|
||||||
[system_message] + messages, stop=stop, callbacks=run_manager, **kwargs
|
|
||||||
)
|
)
|
||||||
chat_generation_content = response_message.content
|
chat_generation_content = response_message.generations[0].text
|
||||||
if not isinstance(chat_generation_content, str):
|
if not isinstance(chat_generation_content, str):
|
||||||
raise ValueError("OllamaFunctions does not support non-string output.")
|
raise ValueError("OllamaFunctions does not support non-string output.")
|
||||||
try:
|
try:
|
||||||
parsed_chat_result = json.loads(chat_generation_content)
|
parsed_chat_result = json.loads(chat_generation_content)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f'"{self.llm.model}" did not respond with valid JSON. Please try again.'
|
f"""'{self.model}' did not respond with valid JSON.
|
||||||
|
Please try again.
|
||||||
|
Response: {chat_generation_content}"""
|
||||||
)
|
)
|
||||||
called_tool_name = parsed_chat_result["tool"]
|
called_tool_name = parsed_chat_result["tool"]
|
||||||
called_tool_arguments = parsed_chat_result["tool_input"]
|
called_tool_arguments = parsed_chat_result["tool_input"]
|
||||||
@ -108,8 +324,8 @@ function in "functions".'
|
|||||||
)
|
)
|
||||||
if called_tool is None:
|
if called_tool is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Failed to parse a function call from {self.llm.model} \
|
f"Failed to parse a function call from {self.model} output: "
|
||||||
output: {chat_generation_content}"
|
f"{chat_generation_content}"
|
||||||
)
|
)
|
||||||
if called_tool["name"] == DEFAULT_RESPONSE_FUNCTION["name"]:
|
if called_tool["name"] == DEFAULT_RESPONSE_FUNCTION["name"]:
|
||||||
return ChatResult(
|
return ChatResult(
|
||||||
|
@ -2,9 +2,18 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from langchain_community.chat_models.ollama import ChatOllama
|
from langchain_core.messages import AIMessage
|
||||||
|
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||||
|
|
||||||
from langchain_experimental.llms.ollama_functions import OllamaFunctions
|
from langchain_experimental.llms.ollama_functions import (
|
||||||
|
OllamaFunctions,
|
||||||
|
convert_to_ollama_tool,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Joke(BaseModel):
|
||||||
|
setup: str = Field(description="The setup of the joke")
|
||||||
|
punchline: str = Field(description="The punchline to the joke")
|
||||||
|
|
||||||
|
|
||||||
class TestOllamaFunctions(unittest.TestCase):
|
class TestOllamaFunctions(unittest.TestCase):
|
||||||
@ -13,12 +22,11 @@ class TestOllamaFunctions(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def test_default_ollama_functions(self) -> None:
|
def test_default_ollama_functions(self) -> None:
|
||||||
base_model = OllamaFunctions(model="mistral")
|
base_model = OllamaFunctions(model="llama3", format="json")
|
||||||
self.assertIsInstance(base_model.model, ChatOllama)
|
|
||||||
|
|
||||||
# bind functions
|
# bind functions
|
||||||
model = base_model.bind(
|
model = base_model.bind_tools(
|
||||||
functions=[
|
tools=[
|
||||||
{
|
{
|
||||||
"name": "get_current_weather",
|
"name": "get_current_weather",
|
||||||
"description": "Get the current weather in a given location",
|
"description": "Get the current weather in a given location",
|
||||||
@ -47,3 +55,29 @@ class TestOllamaFunctions(unittest.TestCase):
|
|||||||
function_call = res.additional_kwargs.get("function_call")
|
function_call = res.additional_kwargs.get("function_call")
|
||||||
assert function_call
|
assert function_call
|
||||||
self.assertEqual(function_call.get("name"), "get_current_weather")
|
self.assertEqual(function_call.get("name"), "get_current_weather")
|
||||||
|
|
||||||
|
def test_ollama_structured_output(self) -> None:
|
||||||
|
model = OllamaFunctions(model="phi3")
|
||||||
|
structured_llm = model.with_structured_output(Joke, include_raw=False)
|
||||||
|
|
||||||
|
res = structured_llm.invoke("Tell me a joke about cats")
|
||||||
|
assert isinstance(res, Joke)
|
||||||
|
|
||||||
|
def test_ollama_structured_output_with_json(self) -> None:
|
||||||
|
model = OllamaFunctions(model="phi3")
|
||||||
|
joke_schema = convert_to_ollama_tool(Joke)
|
||||||
|
structured_llm = model.with_structured_output(joke_schema, include_raw=False)
|
||||||
|
|
||||||
|
res = structured_llm.invoke("Tell me a joke about cats")
|
||||||
|
assert "setup" in res
|
||||||
|
assert "punchline" in res
|
||||||
|
|
||||||
|
def test_ollama_structured_output_raw(self) -> None:
|
||||||
|
model = OllamaFunctions(model="phi3")
|
||||||
|
structured_llm = model.with_structured_output(Joke, include_raw=True)
|
||||||
|
|
||||||
|
res = structured_llm.invoke("Tell me a joke about cars")
|
||||||
|
assert "raw" in res
|
||||||
|
assert "parsed" in res
|
||||||
|
assert isinstance(res["raw"], AIMessage)
|
||||||
|
assert isinstance(res["parsed"], Joke)
|
||||||
|
Loading…
Reference in New Issue
Block a user