community: fixed critical bugs at Writer provider (#27879)

This commit is contained in:
Yan 2024-11-25 20:03:37 +03:00 committed by GitHub
parent 6ed2d387bb
commit c60695a1c7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 1205 additions and 542 deletions

View File

@ -17,7 +17,7 @@
"source": [ "source": [
"# ChatWriter\n", "# ChatWriter\n",
"\n", "\n",
"This notebook provides a quick overview for getting started with Writer [chat models](/docs/concepts/chat_models).\n", "This notebook provides a quick overview for getting started with Writer [chat models](/docs/concepts/#chat-models).\n",
"\n", "\n",
"Writer has several chat models. You can find information about their latest models and their costs, context windows, and supported input types in the [Writer docs](https://dev.writer.com/home/models).\n", "Writer has several chat models. You can find information about their latest models and their costs, context windows, and supported input types in the [Writer docs](https://dev.writer.com/home/models).\n",
"\n", "\n",
@ -25,21 +25,20 @@
] ]
}, },
{ {
"cell_type": "markdown",
"id": "e49f1e0d",
"metadata": {}, "metadata": {},
"cell_type": "markdown",
"source": [ "source": [
"## Overview\n", "## Overview\n",
"\n", "\n",
"### Integration details\n", "### Integration details\n",
"| Class | Package | Local | Serializable | [JS support](https://js.langchain.com/docs/integrations/chat/openai) | Package downloads | Package latest |\n", "| Class | Package | Local | Serializable | JS support | Package downloads | Package latest |\n",
"| :--- | :--- | :---: | :---: | :---: | :---: | :---: |\n", "| :--- | :--- | :---: | :---: |:----------:| :---: | :---: |\n",
"| ChatWriter | langchain-community | ❌ | ❌ | ❌ | ❌ | ❌ |\n", "| ChatWriter | langchain-community | ❌ | ❌ | ❌ | ❌ | ❌ |\n",
"\n", "\n",
"### Model features\n", "### Model features\n",
"| [Tool calling](/docs/how_to/tool_calling) | [Structured output](/docs/how_to/structured_output/) | JSON mode | Image input | Audio input | Video input | [Token-level streaming](/docs/how_to/chat_streaming/) | Native async | [Token usage](/docs/how_to/chat_token_usage_tracking/) | [Logprobs](/docs/how_to/logprobs/) |\n", "| [Tool calling](/docs/how_to/tool_calling) | Structured output | JSON mode | Image input | Audio input | Video input | [Token-level streaming](/docs/how_to/chat_streaming/) | Native async | [Token usage](/docs/how_to/chat_token_usage_tracking/) | Logprobs |\n",
"| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |\n", "| :---: |:-----------------:| :---: | :---: | :---: | :---: | :---: | :---: |:--------------------------------:|:--------:|\n",
"| ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | \n", "| ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ |\n",
"\n", "\n",
"## Setup\n", "## Setup\n",
"\n", "\n",
@ -48,15 +47,16 @@
"### Credentials\n", "### Credentials\n",
"\n", "\n",
"Head to [Writer AI Studio](https://app.writer.com/aistudio/signup?utm_campaign=devrel) to sign up to OpenAI and generate an API key. Once you've done this set the WRITER_API_KEY environment variable:" "Head to [Writer AI Studio](https://app.writer.com/aistudio/signup?utm_campaign=devrel) to sign up to OpenAI and generate an API key. Once you've done this set the WRITER_API_KEY environment variable:"
] ],
"id": "617a6e98205ab7c8"
}, },
{ {
"cell_type": "code", "cell_type": "code",
"id": "e817fe2e-4f1d-4533-b19e-2400b1cf6ce8", "id": "e817fe2e-4f1d-4533-b19e-2400b1cf6ce8",
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
"end_time": "2024-10-24T13:51:54.323678Z", "end_time": "2024-11-14T09:46:26.800627Z",
"start_time": "2024-10-24T13:51:42.127404Z" "start_time": "2024-11-14T09:27:59.652281Z"
} }
}, },
"source": [ "source": [
@ -64,7 +64,7 @@
"import os\n", "import os\n",
"\n", "\n",
"if not os.environ.get(\"WRITER_API_KEY\"):\n", "if not os.environ.get(\"WRITER_API_KEY\"):\n",
" os.environ[\"WRITER_API_KEY\"] = getpass.getpass(\"Enter your Writer API key: \")" " os.environ[\"WRITER_API_KEY\"] = getpass.getpass(\"Enter your Writer API key:\")"
], ],
"outputs": [], "outputs": [],
"execution_count": 1 "execution_count": 1
@ -84,23 +84,24 @@
"id": "2113471c-75d7-45df-b784-d78da4ef7aba", "id": "2113471c-75d7-45df-b784-d78da4ef7aba",
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
"end_time": "2024-10-24T13:52:49.262240Z", "end_time": "2024-11-14T09:46:32.415354Z",
"start_time": "2024-10-24T13:52:47.564879Z" "start_time": "2024-11-14T09:46:26.826112Z"
} }
}, },
"source": [ "source": "%pip install -qU langchain-community writer-sdk",
"%pip install -qU langchain-community writer-sdk"
],
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"\r\n",
"\u001B[1m[\u001B[0m\u001B[34;49mnotice\u001B[0m\u001B[1;39;49m]\u001B[0m\u001B[39;49m A new release of pip is available: \u001B[0m\u001B[31;49m24.2\u001B[0m\u001B[39;49m -> \u001B[0m\u001B[32;49m24.3.1\u001B[0m\r\n",
"\u001B[1m[\u001B[0m\u001B[34;49mnotice\u001B[0m\u001B[1;39;49m]\u001B[0m\u001B[39;49m To update, run: \u001B[0m\u001B[32;49mpip install --upgrade pip\u001B[0m\r\n",
"Note: you may need to restart the kernel to use updated packages.\n" "Note: you may need to restart the kernel to use updated packages.\n"
] ]
} }
], ],
"execution_count": 4 "execution_count": 2
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
@ -118,8 +119,8 @@
"metadata": { "metadata": {
"tags": [], "tags": [],
"ExecuteTime": { "ExecuteTime": {
"end_time": "2024-10-24T13:52:38.822950Z", "end_time": "2024-11-14T09:46:33.504711Z",
"start_time": "2024-10-24T13:52:38.674441Z" "start_time": "2024-11-14T09:46:32.574505Z"
} }
}, },
"source": [ "source": [
@ -129,24 +130,10 @@
" model=\"palmyra-x-004\",\n", " model=\"palmyra-x-004\",\n",
" temperature=0.7,\n", " temperature=0.7,\n",
" max_tokens=1000,\n", " max_tokens=1000,\n",
" # api_key=\"...\", # if you prefer to pass api key in directly instaed of using env vars\n",
" # base_url=\"...\",\n",
" # other params...\n", " # other params...\n",
")" ")"
], ],
"outputs": [ "outputs": [],
{
"ename": "ImportError",
"evalue": "cannot import name 'ChatWriter' from 'langchain_community.chat_models' (/home/yanomaly/PycharmProjects/whitesnake/writer/langсhain/libs/community/langchain_community/chat_models/__init__.py)",
"output_type": "error",
"traceback": [
"\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
"\u001B[0;31mImportError\u001B[0m Traceback (most recent call last)",
"Cell \u001B[0;32mIn[3], line 1\u001B[0m\n\u001B[0;32m----> 1\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mlangchain_community\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mchat_models\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m ChatWriter\n\u001B[1;32m 3\u001B[0m llm \u001B[38;5;241m=\u001B[39m ChatWriter(\n\u001B[1;32m 4\u001B[0m model\u001B[38;5;241m=\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mpalmyra-x-004\u001B[39m\u001B[38;5;124m\"\u001B[39m,\n\u001B[1;32m 5\u001B[0m temperature\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m0.7\u001B[39m,\n\u001B[0;32m (...)\u001B[0m\n\u001B[1;32m 9\u001B[0m \u001B[38;5;66;03m# other params...\u001B[39;00m\n\u001B[1;32m 10\u001B[0m )\n",
"\u001B[0;31mImportError\u001B[0m: cannot import name 'ChatWriter' from 'langchain_community.chat_models' (/home/yanomaly/PycharmProjects/whitesnake/writer/langсhain/libs/community/langchain_community/chat_models/__init__.py)"
]
}
],
"execution_count": 3 "execution_count": 3
}, },
{ {
@ -159,12 +146,14 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"id": "ce16ad78-8e6f-48cd-954e-98be75eb5836", "id": "ce16ad78-8e6f-48cd-954e-98be75eb5836",
"metadata": { "metadata": {
"tags": [] "tags": [],
"ExecuteTime": {
"end_time": "2024-11-14T09:46:38.856174Z",
"start_time": "2024-11-14T09:46:33.520062Z"
}
}, },
"outputs": [],
"source": [ "source": [
"messages = [\n", "messages = [\n",
" (\n", " (\n",
@ -173,19 +162,127 @@
" ),\n", " ),\n",
" (\"human\", \"Write a poem about Python.\"),\n", " (\"human\", \"Write a poem about Python.\"),\n",
"]\n", "]\n",
"ai_msg = llm.invoke(messages)\n", "ai_msg = llm.invoke(messages)"
"ai_msg" ],
] "outputs": [],
"execution_count": 4
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"id": "2cd224b8-4499-41fb-a604-d53a7ff17b2e", "id": "2cd224b8-4499-41fb-a604-d53a7ff17b2e",
"metadata": {}, "metadata": {
"outputs": [], "ExecuteTime": {
"end_time": "2024-11-14T09:46:38.866651Z",
"start_time": "2024-11-14T09:46:38.863817Z"
}
},
"source": [ "source": [
"print(ai_msg.content)" "print(ai_msg.content)"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"In realms of code, where logic weaves and flows,\n",
"A language rises, Python by its name,\n",
"With syntax clear, where elegance it shows,\n",
"A serpent, wise, that time and space can tame.\n",
"\n",
"Born from the mind of Guido, pure and bright,\n",
"Its beauty lies in simplicity and grace,\n",
"A tool of power, yet gentle in its might,\n",
"In every programmer's heart, a cherished place.\n",
"\n",
"It dances through the data, vast and deep,\n",
"With libraries that span the digital realm,\n",
"From machine learning's secrets to keep,\n",
"To web development, it wields the helm.\n",
"\n",
"In the hands of the novice and the sage,\n",
"Python spins the threads of digital dreams,\n",
"A language that can turn the age,\n",
"With a gentle learning curve, its appeal gleams.\n",
"\n",
"It's more than code, a community it builds,\n",
"Where knowledge freely flows, and all are heard,\n",
"In Python's world, the future unfolds,\n",
"A language of the people, for the world.\n",
"\n",
"So here's to Python, in its gentle might,\n",
"A master of the modern coding art,\n",
"May it continue to light our path each night,\n",
"In the vast, evolving world of code, its heart.\n"
] ]
}
],
"execution_count": 5
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## Streaming",
"id": "35b3a5b3dabef65"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-14T09:46:38.914883Z",
"start_time": "2024-11-14T09:46:38.912564Z"
}
},
"cell_type": "code",
"source": "ai_stream = llm.stream(messages)",
"id": "2725770182bf96dc",
"outputs": [],
"execution_count": 6
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-14T09:46:43.226449Z",
"start_time": "2024-11-14T09:46:38.955512Z"
}
},
"cell_type": "code",
"source": [
"for chunk in ai_stream:\n",
" print(chunk.content, end=\"\")"
],
"id": "a48410d9488162e3",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"In realms of code where logic weaves,\n",
"A language rises, Python, it breezes,\n",
"With syntax clear and simple to read,\n",
"Through its elegance, our spirits are fed.\n",
"\n",
"Like rivers flowing, smooth and serene,\n",
"Its structure harmonious, a coder's dream,\n",
"Indentations guide the flow of control,\n",
"In Python's world, confusion takes no toll.\n",
"\n",
"A vast library, a treasure trove so bright,\n",
"For web and data, it offers its might,\n",
"With modules and packages, a rich array,\n",
"Python empowers us to code in play.\n",
"\n",
"From AI to scripts, in flexibility it thrives,\n",
"A language of the future, as many now derive,\n",
"Its community, a beacon of support and cheer,\n",
"With Python, the possibilities are vast, far and near.\n",
"\n",
"So here's to Python, in its gentle grace,\n",
"A tool that enhances, a language that embraces,\n",
"The art of coding, with a fluent, flowing pen,\n",
"In the Python world, we code, and we begin."
]
}
],
"execution_count": 7
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
@ -199,12 +296,14 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"id": "fbb043e6", "id": "fbb043e6",
"metadata": { "metadata": {
"tags": [] "tags": [],
"ExecuteTime": {
"end_time": "2024-11-14T09:46:50.721645Z",
"start_time": "2024-11-14T09:46:43.234590Z"
}
}, },
"outputs": [],
"source": [ "source": [
"from langchain_core.prompts import ChatPromptTemplate\n", "from langchain_core.prompts import ChatPromptTemplate\n",
"\n", "\n",
@ -225,8 +324,21 @@
" \"input\": \"Write a poem about Java.\",\n", " \"input\": \"Write a poem about Java.\",\n",
" }\n", " }\n",
")" ")"
],
"outputs": [
{
"data": {
"text/plain": [
"AIMessageChunk(content='In the realm of code, where logic weaves and flows, \\nA language rises, like a phoenix from the code\\'s throes. \\nJava, the name, a cup of coffee\\'s steam, \\nBrewed in the minds, where digital dreams gleam.\\n\\nWith syntax clear, like morning\\'s misty hue, \\nIn classes and objects, it spins a tale so true. \\nA platform agnostic, with a byte to spare, \\nAcross the devices, it journeys everywhere.\\n\\nInheritance and polymorphism, its power\\'s core, \\nLike ancient runes, in every line they bore. \\nEncapsulation, a shield, with data it does hide, \\nIn the vast jungle of code, it stands as a guide.\\n\\nFrom applets small, to vast, server-side apps, \\nIts threads run swift, through the computing traps. \\nA language of the people, by the people, for the peoples use, \\nBuilt on the principle, \"write once, run anywhere, with no excuse.\"\\n\\nIn the heart of Android, it beats, a steady drum, \\nCrafting experiences, in every smartphone\\'s hum. \\nIn the cloud, in the enterprise, its presence is vast, \\nA cornerstone of computing, built to last.\\n\\nOh Java, thy elegance, thy robust design, \\nA language that stands, in any computing line. \\nWith every update, with every new release, \\nThy community grows, with a vibrant, diverse peace.\\n\\nSo here\\'s to Java, the versatile, the grand, \\nA language that shapes the digital land. \\nMay it continue to evolve, to grow, to inspire, \\nIn the endless quest of turning thoughts into digital fire.', additional_kwargs={}, response_metadata={'token_usage': {'completion_tokens': 345, 'prompt_tokens': 33, 'total_tokens': 378, 'completion_tokens_details': None, 'prompt_token_details': None}, 'model_name': 'palmyra-x-004', 'system_fingerprint': 'v1', 'finish_reason': 'stop'}, id='run-a5b4be59-0eb0-41bd-80f7-72477861b0bd-0')"
] ]
}, },
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 8
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "0b1b52a5-b58d-40c9-bcdd-88eb8fb351e2", "id": "0b1b52a5-b58d-40c9-bcdd-88eb8fb351e2",
@ -251,10 +363,13 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6,
"id": "b7ea7690-ec7a-4337-b392-e87d1f39a6ec", "id": "b7ea7690-ec7a-4337-b392-e87d1f39a6ec",
"metadata": {}, "metadata": {
"outputs": [], "ExecuteTime": {
"end_time": "2024-11-14T09:46:50.891937Z",
"start_time": "2024-11-14T09:46:50.733463Z"
}
},
"source": [ "source": [
"from pydantic import BaseModel, Field\n", "from pydantic import BaseModel, Field\n",
"\n", "\n",
@ -266,20 +381,26 @@
"\n", "\n",
"\n", "\n",
"llm_with_tools = llm.bind_tools([GetWeather])" "llm_with_tools = llm.bind_tools([GetWeather])"
] ],
"outputs": [],
"execution_count": 9
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"id": "1d1ab955-6a68-42f8-bb5d-86eb1111478a", "id": "1d1ab955-6a68-42f8-bb5d-86eb1111478a",
"metadata": {}, "metadata": {
"outputs": [], "ExecuteTime": {
"end_time": "2024-11-14T09:46:51.725422Z",
"start_time": "2024-11-14T09:46:50.904699Z"
}
},
"source": [ "source": [
"ai_msg = llm_with_tools.invoke(\n", "ai_msg = llm_with_tools.invoke(\n",
" \"what is the weather like in New York City\",\n", " \"what is the weather like in New York City\",\n",
")\n", ")"
"ai_msg" ],
] "outputs": [],
"execution_count": 10
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
@ -292,14 +413,31 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"id": "166cb7ce-831d-4a7c-9721-abc107f11084", "id": "166cb7ce-831d-4a7c-9721-abc107f11084",
"metadata": {}, "metadata": {
"outputs": [], "ExecuteTime": {
"source": [ "end_time": "2024-11-14T09:46:51.744202Z",
"ai_msg.tool_calls" "start_time": "2024-11-14T09:46:51.738431Z"
}
},
"source": "print(ai_msg.tool_calls)",
"outputs": [
{
"data": {
"text/plain": [
"[{'name': 'GetWeather',\n",
" 'args': {'location': 'New York City, NY'},\n",
" 'id': 'chatcmpl-tool-fe70912c800d40fc8700d604d4823001',\n",
" 'type': 'tool_call'}]"
] ]
}, },
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 11
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "e082c9ac-c7c7-4aff-a8ec-8e220262a59c", "id": "e082c9ac-c7c7-4aff-a8ec-8e220262a59c",

View File

@ -4,120 +4,161 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"# Writer\n", "# Writer LLM\n",
"\n", "\n",
"[Writer](https://writer.com/) is a platform to generate different language content.\n", "[Writer](https://writer.com/) is a platform to generate different language content.\n",
"\n", "\n",
"This example goes over how to use LangChain to interact with `Writer` [models](https://dev.writer.com/docs/models).\n", "This example goes over how to use LangChain to interact with `Writer` [models](https://dev.writer.com/docs/models).\n",
"\n", "\n",
"You have to get the WRITER_API_KEY [here](https://dev.writer.com/docs)." "## Setup\n",
"\n",
"To access Writer models you'll need to create a Writer account, get an API key, and install the `writer-sdk` and `langchain-community` packages.\n",
"\n",
"### Credentials\n",
"\n",
"Head to [Writer AI Studio](https://app.writer.com/aistudio/signup?utm_campaign=devrel) to sign up to OpenAI and generate an API key. Once you've done this set the WRITER_API_KEY environment variable:"
] ]
}, },
{ {
"cell_type": "code",
"execution_count": 4,
"metadata": { "metadata": {
"tags": [] "ExecuteTime": {
"end_time": "2024-11-14T11:10:46.824961Z",
"start_time": "2024-11-14T11:10:44.864137Z"
}
}, },
"cell_type": "code",
"source": [
"import getpass\n",
"import os\n",
"\n",
"if not os.environ.get(\"WRITER_API_KEY\"):\n",
" os.environ[\"WRITER_API_KEY\"] = getpass.getpass(\"Enter your Writer API key:\")"
],
"outputs": [],
"execution_count": 1
},
{
"metadata": {},
"cell_type": "markdown",
"source": [
"## Installation\n",
"\n",
"The LangChain Writer integration lives in the `langchain-community` package:"
]
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-14T11:10:48.297429Z",
"start_time": "2024-11-14T11:10:46.843983Z"
}
},
"cell_type": "code",
"source": "%pip install -qU langchain-community writer-sdk",
"outputs": [ "outputs": [
{ {
"name": "stdin", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
" ········\n" "\r\n",
"\u001B[1m[\u001B[0m\u001B[34;49mnotice\u001B[0m\u001B[1;39;49m]\u001B[0m\u001B[39;49m A new release of pip is available: \u001B[0m\u001B[31;49m24.2\u001B[0m\u001B[39;49m -> \u001B[0m\u001B[32;49m24.3.1\u001B[0m\r\n",
"\u001B[1m[\u001B[0m\u001B[34;49mnotice\u001B[0m\u001B[1;39;49m]\u001B[0m\u001B[39;49m To update, run: \u001B[0m\u001B[32;49mpip install --upgrade pip\u001B[0m\r\n",
"Note: you may need to restart the kernel to use updated packages.\n"
] ]
} }
], ],
"source": [ "execution_count": 2
"from getpass import getpass\n",
"\n",
"WRITER_API_KEY = getpass()"
]
}, },
{ {
"cell_type": "code",
"execution_count": 5,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import os\n",
"\n",
"os.environ[\"WRITER_API_KEY\"] = WRITER_API_KEY"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from langchain.chains import LLMChain\n",
"from langchain_community.llms import Writer\n",
"from langchain_core.prompts import PromptTemplate"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"template = \"\"\"Question: {question}\n",
"\n",
"Answer: Let's think step by step.\"\"\"\n",
"\n",
"prompt = PromptTemplate.from_template(template)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# If you get an error, probably, you need to set up the \"base_url\" parameter that can be taken from the error log.\n",
"\n",
"llm = Writer()"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"llm_chain = LLMChain(prompt=prompt, llm=llm)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"question = \"What NFL team won the Super Bowl in the year Justin Beiber was born?\"\n",
"\n",
"llm_chain.run(question)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {}, "metadata": {},
"cell_type": "markdown",
"source": "Now we can initialize our model object to interact with writer LLMs"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-14T11:10:49.818902Z",
"start_time": "2024-11-14T11:10:48.580516Z"
}
},
"cell_type": "code",
"source": [
"from langchain_community.llms import Writer as WriterLLM\n",
"\n",
"llm = WriterLLM(\n",
" temperature=0.7,\n",
" max_tokens=1000,\n",
" # other params...\n",
")"
],
"outputs": [], "outputs": [],
"source": [] "execution_count": 3
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## Invocation"
},
{
"metadata": {
"jupyter": {
"is_executing": true
},
"ExecuteTime": {
"start_time": "2024-11-14T11:10:49.832822Z"
}
},
"cell_type": "code",
"source": "response_text = llm.invoke(input=\"Write a poem\")",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": "print(response_text)",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## Streaming"
},
{
"metadata": {},
"cell_type": "code",
"source": "stream_response = llm.stream(input=\"Tell me a fairytale\")",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": [
"for chunk in stream_response:\n",
" print(chunk, end=\"\")"
],
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "markdown",
"source": [
"## Async\n",
"\n",
"Writer support asynchronous calls via **ainvoke()** and **astream()** methods"
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": [
"## API reference\n",
"\n",
"For detailed documentation of all Writer features, head to our [API reference](https://dev.writer.com/api-guides/api-reference/completion-api/text-generation#text-generation)."
]
} }
], ],
"metadata": { "metadata": {

View File

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import json
import logging import logging
from typing import ( from typing import (
Any, Any,
@ -11,7 +12,6 @@ from typing import (
Iterator, Iterator,
List, List,
Literal, Literal,
Mapping,
Optional, Optional,
Sequence, Sequence,
Tuple, Tuple,
@ -26,8 +26,6 @@ from langchain_core.callbacks import (
from langchain_core.language_models import LanguageModelInput from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import ( from langchain_core.language_models.chat_models import (
BaseChatModel, BaseChatModel,
agenerate_from_stream,
generate_from_stream,
) )
from langchain_core.messages import ( from langchain_core.messages import (
AIMessage, AIMessage,
@ -40,13 +38,148 @@ from langchain_core.messages import (
) )
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import Runnable from langchain_core.runnables import Runnable
from langchain_core.utils import get_from_dict_or_env
from langchain_core.utils.function_calling import convert_to_openai_tool from langchain_core.utils.function_calling import convert_to_openai_tool
from pydantic import BaseModel, ConfigDict, Field, SecretStr from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _convert_message_to_dict(message: BaseMessage) -> dict: class ChatWriter(BaseChatModel):
"""Writer chat model.
To use, you should have the ``writer-sdk`` Python package installed, and the
environment variable ``WRITER_API_KEY`` set with your API key or pass 'api_key'
init param.
Example:
.. code-block:: python
from langchain_community.chat_models import ChatWriter
chat = ChatWriter(
api_key="your key"
model="palmyra-x-004"
)
"""
client: Any = Field(default=None, exclude=True) #: :meta private:
async_client: Any = Field(default=None, exclude=True) #: :meta private:
api_key: Optional[SecretStr] = Field(default=None)
"""Writer API key."""
model_name: str = Field(default="palmyra-x-004", alias="model")
"""Model name to use."""
temperature: float = 0.7
"""What sampling temperature to use."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""
n: int = 1
"""Number of chat completions to generate for each prompt."""
max_tokens: Optional[int] = None
"""Maximum number of tokens to generate."""
model_config = ConfigDict(populate_by_name=True)
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "writer-chat"
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {
"model_name": self.model_name,
"temperature": self.temperature,
**self.model_kwargs,
}
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling Writer API."""
return {
"model": self.model_name,
"temperature": self.temperature,
"n": self.n,
"max_tokens": self.max_tokens,
**self.model_kwargs,
}
@model_validator(mode="before")
@classmethod
def validate_environment(cls, values: Dict) -> Any:
"""Validates that api key is passed and creates Writer clients."""
try:
from writerai import AsyncClient, Client
except ImportError as e:
raise ImportError(
"Could not import writerai python package. "
"Please install it with `pip install writerai`."
) from e
if not values.get("client"):
values.update(
{
"client": Client(
api_key=get_from_dict_or_env(
values, "api_key", "WRITER_API_KEY"
)
)
}
)
if not values.get("async_client"):
values.update(
{
"async_client": AsyncClient(
api_key=get_from_dict_or_env(
values, "api_key", "WRITER_API_KEY"
)
)
}
)
if not (
type(values.get("client")) is Client
and type(values.get("async_client")) is AsyncClient
):
raise ValueError(
"'client' attribute must be with type 'Client' and "
"'async_client' must be with type 'AsyncClient' from 'writerai' package"
)
return values
def _create_chat_result(self, response: Any) -> ChatResult:
generations = []
for choice in response.choices:
message = self._convert_writer_to_langchain(choice.message)
gen = ChatGeneration(
message=message,
generation_info=dict(finish_reason=choice.finish_reason),
)
generations.append(gen)
token_usage = {}
if response.usage:
token_usage = response.usage.__dict__
llm_output = {
"token_usage": token_usage,
"model_name": self.model_name,
"system_fingerprint": response.system_fingerprint,
}
return ChatResult(generations=generations, llm_output=llm_output)
@staticmethod
def _convert_langchain_to_writer(message: BaseMessage) -> dict:
"""Convert a LangChain message to a Writer message dict.""" """Convert a LangChain message to a Writer message dict."""
message_dict = {"role": "", "content": message.content} message_dict = {"role": "", "content": message.content}
@ -78,17 +211,24 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
return message_dict return message_dict
@staticmethod
def _convert_writer_to_langchain(response_message: Any) -> BaseMessage:
"""Convert a Writer message to a LangChain message."""
if not isinstance(response_message, dict):
response_message = json.loads(
json.dumps(response_message, default=lambda o: o.__dict__)
)
def _convert_dict_to_message(response_dict: Dict[str, Any]) -> BaseMessage: role = response_message.get("role", "")
"""Convert a Writer message dict to a LangChain message.""" content = response_message.get("content")
role = response_dict["role"] if not content:
content = response_dict.get("content", "") content = ""
if role == "user": if role == "user":
return HumanMessage(content=content) return HumanMessage(content=content)
elif role == "assistant": elif role == "assistant":
additional_kwargs = {} additional_kwargs = {}
if tool_calls := response_dict.get("tool_calls"): if tool_calls := response_message.get("tool_calls", []):
additional_kwargs["tool_calls"] = tool_calls additional_kwargs["tool_calls"] = tool_calls
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs) return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
elif role == "system": elif role == "system":
@ -96,90 +236,20 @@ def _convert_dict_to_message(response_dict: Dict[str, Any]) -> BaseMessage:
elif role == "tool": elif role == "tool":
return ToolMessage( return ToolMessage(
content=content, content=content,
tool_call_id=response_dict["tool_call_id"], tool_call_id=response_message.get("tool_call_id", ""),
name=response_dict.get("name"), name=response_message.get("name", ""),
) )
else: else:
return ChatMessage(content=content, role=role) return ChatMessage(content=content, role=role)
def _convert_messages_to_writer(
class ChatWriter(BaseChatModel):
"""Writer chat model.
To use, you should have the ``writer-sdk`` Python package installed, and the
environment variable ``WRITER_API_KEY`` set with your API key.
Example:
.. code-block:: python
from langchain_community.chat_models import ChatWriter
chat = ChatWriter(model="palmyra-x-004")
"""
client: Any = Field(default=None, exclude=True) #: :meta private:
async_client: Any = Field(default=None, exclude=True) #: :meta private:
model_name: str = Field(default="palmyra-x-004", alias="model")
"""Model name to use."""
temperature: float = 0.7
"""What sampling temperature to use."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""
writer_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
"""Writer API key."""
writer_api_base: Optional[str] = Field(default=None, alias="base_url")
"""Base URL for API requests."""
streaming: bool = False
"""Whether to stream the results or not."""
n: int = 1
"""Number of chat completions to generate for each prompt."""
max_tokens: Optional[int] = None
"""Maximum number of tokens to generate."""
model_config = ConfigDict(populate_by_name=True)
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "writer-chat"
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {
"model_name": self.model_name,
"temperature": self.temperature,
"streaming": self.streaming,
**self.model_kwargs,
}
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
generations = []
for choice in response["choices"]:
message = _convert_dict_to_message(choice["message"])
gen = ChatGeneration(
message=message,
generation_info=dict(finish_reason=choice.get("finish_reason")),
)
generations.append(gen)
token_usage = response.get("usage", {})
llm_output = {
"token_usage": token_usage,
"model_name": self.model_name,
"system_fingerprint": response.get("system_fingerprint", ""),
}
return ChatResult(generations=generations, llm_output=llm_output)
def _convert_messages_to_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None self, messages: List[BaseMessage], stop: Optional[List[str]] = None
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
"""Convert a list of LangChain messages to List of Writer dicts."""
params = { params = {
"model": self.model_name, "model": self.model_name,
"temperature": self.temperature, "temperature": self.temperature,
"n": self.n, "n": self.n,
"stream": self.streaming,
**self.model_kwargs, **self.model_kwargs,
} }
if stop: if stop:
@ -187,7 +257,7 @@ class ChatWriter(BaseChatModel):
if self.max_tokens is not None: if self.max_tokens is not None:
params["max_tokens"] = self.max_tokens params["max_tokens"] = self.max_tokens
message_dicts = [_convert_message_to_dict(m) for m in messages] message_dicts = [self._convert_langchain_to_writer(m) for m in messages]
return message_dicts, params return message_dicts, params
def _stream( def _stream(
@ -197,17 +267,17 @@ class ChatWriter(BaseChatModel):
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[ChatGenerationChunk]: ) -> Iterator[ChatGenerationChunk]:
message_dicts, params = self._convert_messages_to_dicts(messages, stop) message_dicts, params = self._convert_messages_to_writer(messages, stop)
params = {**params, **kwargs, "stream": True} params = {**params, **kwargs, "stream": True}
response = self.client.chat.chat(messages=message_dicts, **params) response = self.client.chat.chat(messages=message_dicts, **params)
for chunk in response: for chunk in response:
delta = chunk["choices"][0].get("delta") delta = chunk.choices[0].delta
if not delta or not delta.get("content"): if not delta or not delta.content:
continue continue
chunk = _convert_dict_to_message( chunk = self._convert_writer_to_langchain(
{"role": "assistant", "content": delta["content"]} {"role": "assistant", "content": delta.content}
) )
chunk = ChatGenerationChunk(message=chunk) chunk = ChatGenerationChunk(message=chunk)
@ -223,17 +293,17 @@ class ChatWriter(BaseChatModel):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]: ) -> AsyncIterator[ChatGenerationChunk]:
message_dicts, params = self._convert_messages_to_dicts(messages, stop) message_dicts, params = self._convert_messages_to_writer(messages, stop)
params = {**params, **kwargs, "stream": True} params = {**params, **kwargs, "stream": True}
response = await self.async_client.chat.chat(messages=message_dicts, **params) response = await self.async_client.chat.chat(messages=message_dicts, **params)
async for chunk in response: async for chunk in response:
delta = chunk["choices"][0].get("delta") delta = chunk.choices[0].delta
if not delta or not delta.get("content"): if not delta or not delta.content:
continue continue
chunk = _convert_dict_to_message( chunk = self._convert_writer_to_langchain(
{"role": "assistant", "content": delta["content"]} {"role": "assistant", "content": delta.content}
) )
chunk = ChatGenerationChunk(message=chunk) chunk = ChatGenerationChunk(message=chunk)
@ -249,12 +319,7 @@ class ChatWriter(BaseChatModel):
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
if self.streaming: message_dicts, params = self._convert_messages_to_writer(messages, stop)
return generate_from_stream(
self._stream(messages, stop, run_manager, **kwargs)
)
message_dicts, params = self._convert_messages_to_dicts(messages, stop)
params = {**params, **kwargs} params = {**params, **kwargs}
response = self.client.chat.chat(messages=message_dicts, **params) response = self.client.chat.chat(messages=message_dicts, **params)
return self._create_chat_result(response) return self._create_chat_result(response)
@ -266,28 +331,11 @@ class ChatWriter(BaseChatModel):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
if self.streaming: message_dicts, params = self._convert_messages_to_writer(messages, stop)
return await agenerate_from_stream(
self._astream(messages, stop, run_manager, **kwargs)
)
message_dicts, params = self._convert_messages_to_dicts(messages, stop)
params = {**params, **kwargs} params = {**params, **kwargs}
response = await self.async_client.chat.chat(messages=message_dicts, **params) response = await self.async_client.chat.chat(messages=message_dicts, **params)
return self._create_chat_result(response) return self._create_chat_result(response)
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling Writer API."""
return {
"model": self.model_name,
"temperature": self.temperature,
"stream": self.streaming,
"n": self.n,
"max_tokens": self.max_tokens,
**self.model_kwargs,
}
def bind_tools( def bind_tools(
self, self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]], tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]],

View File

@ -1,108 +1,89 @@
from typing import Any, Dict, List, Mapping, Optional from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional
import requests from langchain_core.callbacks import (
from langchain_core.callbacks import CallbackManagerForLLMRun AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import LLM from langchain_core.language_models.llms import LLM
from langchain_core.utils import get_from_dict_or_env, pre_init from langchain_core.outputs import GenerationChunk
from pydantic import ConfigDict from langchain_core.utils import get_from_dict_or_env
from pydantic import ConfigDict, Field, SecretStr, model_validator
from langchain_community.llms.utils import enforce_stop_tokens
class Writer(LLM): class Writer(LLM):
"""Writer large language models. """Writer large language models.
To use, you should have the environment variable ``WRITER_API_KEY`` and To use, you should have the ``writer-sdk`` Python package installed, and the
``WRITER_ORG_ID`` set with your API key and organization ID respectively. environment variable ``WRITER_API_KEY`` set with your API key.
Example: Example:
.. code-block:: python .. code-block:: python
from langchain_community.llms import Writer from langchain_community.llms import Writer as WriterLLM
writer = Writer(model_id="palmyra-base") from writerai import Writer, AsyncWriter
client = Writer()
async_client = AsyncWriter()
chat = WriterLLM(
client=client,
async_client=async_client
)
""" """
writer_org_id: Optional[str] = None client: Any = Field(default=None, exclude=True) #: :meta private:
"""Writer organization ID.""" async_client: Any = Field(default=None, exclude=True) #: :meta private:
model_id: str = "palmyra-instruct" api_key: Optional[SecretStr] = Field(default=None)
"""Model name to use."""
min_tokens: Optional[int] = None
"""Minimum number of tokens to generate."""
max_tokens: Optional[int] = None
"""Maximum number of tokens to generate."""
temperature: Optional[float] = None
"""What sampling temperature to use."""
top_p: Optional[float] = None
"""Total probability mass of tokens to consider at each step."""
stop: Optional[List[str]] = None
"""Sequences when completion generation will stop."""
presence_penalty: Optional[float] = None
"""Penalizes repeated tokens regardless of frequency."""
repetition_penalty: Optional[float] = None
"""Penalizes repeated tokens according to frequency."""
best_of: Optional[int] = None
"""Generates this many completions server-side and returns the "best"."""
logprobs: bool = False
"""Whether to return log probabilities."""
n: Optional[int] = None
"""How many completions to generate."""
writer_api_key: Optional[str] = None
"""Writer API key.""" """Writer API key."""
base_url: Optional[str] = None model_name: str = Field(default="palmyra-x-003-instruct", alias="model")
"""Base url to use, if None decides based on model name.""" """Model name to use."""
model_config = ConfigDict( max_tokens: Optional[int] = None
extra="forbid", """The maximum number of tokens that the model can generate in the response."""
)
@pre_init temperature: Optional[float] = 0.7
def validate_environment(cls, values: Dict) -> Dict: """Controls the randomness of the model's outputs. Higher values lead to more
"""Validate that api key and organization id exist in environment.""" random outputs, while lower values make the model more deterministic."""
writer_api_key = get_from_dict_or_env( top_p: Optional[float] = None
values, "writer_api_key", "WRITER_API_KEY" """Used to control the nucleus sampling, where only the most probable tokens
) with a cumulative probability of top_p are considered for sampling, providing
values["writer_api_key"] = writer_api_key a way to fine-tune the randomness of predictions."""
writer_org_id = get_from_dict_or_env(values, "writer_org_id", "WRITER_ORG_ID") stop: Optional[List[str]] = None
values["writer_org_id"] = writer_org_id """Specifies stopping conditions for the model's output generation. This can
be an array of strings or a single string that the model will look for as a
signal to stop generating further tokens."""
return values best_of: Optional[int] = None
"""Specifies the number of completions to generate and return the best one.
Useful for generating multiple outputs and choosing the best based on some
criteria."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""
model_config = ConfigDict(populate_by_name=True)
@property @property
def _default_params(self) -> Mapping[str, Any]: def _default_params(self) -> Mapping[str, Any]:
"""Get the default parameters for calling Writer API.""" """Get the default parameters for calling Writer API."""
return { return {
"minTokens": self.min_tokens, "max_tokens": self.max_tokens,
"maxTokens": self.max_tokens,
"temperature": self.temperature, "temperature": self.temperature,
"topP": self.top_p, "top_p": self.top_p,
"stop": self.stop, "stop": self.stop,
"presencePenalty": self.presence_penalty, "best_of": self.best_of,
"repetitionPenalty": self.repetition_penalty, **self.model_kwargs,
"bestOf": self.best_of,
"logprobs": self.logprobs,
"n": self.n,
} }
@property @property
def _identifying_params(self) -> Mapping[str, Any]: def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters.""" """Get the identifying parameters."""
return { return {
**{"model_id": self.model_id, "writer_org_id": self.writer_org_id}, "model": self.model_name,
**self._default_params, **self._default_params,
} }
@ -111,6 +92,51 @@ class Writer(LLM):
"""Return type of llm.""" """Return type of llm."""
return "writer" return "writer"
@model_validator(mode="before")
@classmethod
def validate_environment(cls, values: Dict) -> Any:
"""Validates that api key is passed and creates Writer clients."""
try:
from writerai import AsyncClient, Client
except ImportError as e:
raise ImportError(
"Could not import writerai python package. "
"Please install it with `pip install writerai`."
) from e
if not values.get("client"):
values.update(
{
"client": Client(
api_key=get_from_dict_or_env(
values, "api_key", "WRITER_API_KEY"
)
)
}
)
if not values.get("async_client"):
values.update(
{
"async_client": AsyncClient(
api_key=get_from_dict_or_env(
values, "api_key", "WRITER_API_KEY"
)
)
}
)
if not (
type(values.get("client")) is Client
and type(values.get("async_client")) is AsyncClient
):
raise ValueError(
"'client' attribute must be with type 'Client' and "
"'async_client' must be with type 'AsyncClient' from 'writerai' package"
)
return values
def _call( def _call(
self, self,
prompt: str, prompt: str,
@ -118,41 +144,54 @@ class Writer(LLM):
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
"""Call out to Writer's completions endpoint. params = {**self._identifying_params, **kwargs}
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
The string generated by the model.
Example:
.. code-block:: python
response = Writer("Tell me a joke.")
"""
if self.base_url is not None:
base_url = self.base_url
else:
base_url = (
"https://enterprise-api.writer.com/llm"
f"/organization/{self.writer_org_id}"
f"/model/{self.model_id}/completions"
)
params = {**self._default_params, **kwargs}
response = requests.post(
url=base_url,
headers={
"Authorization": f"{self.writer_api_key}",
"Content-Type": "application/json",
"Accept": "application/json",
},
json={"prompt": prompt, **params},
)
text = response.text
if stop is not None: if stop is not None:
# I believe this is required since the stop tokens params.update({"stop": stop})
# are not enforced by the model parameters text = self.client.completions.create(prompt=prompt, **params).choices[0].text
text = enforce_stop_tokens(text, stop)
return text return text
async def _acall(
self,
prompt: str,
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
params = {**self._identifying_params, **kwargs}
if stop is not None:
params.update({"stop": stop})
response = await self.async_client.completions.create(prompt=prompt, **params)
text = response.choices[0].text
return text
def _stream(
self,
prompt: str,
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
params = {**self._identifying_params, **kwargs, "stream": True}
if stop is not None:
params.update({"stop": stop})
response = self.client.completions.create(prompt=prompt, **params)
for chunk in response:
if run_manager:
run_manager.on_llm_new_token(chunk.value)
yield GenerationChunk(text=chunk.value)
async def _astream(
self,
prompt: str,
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]:
params = {**self._identifying_params, **kwargs, "stream": True}
if stop is not None:
params.update({"stop": stop})
response = await self.async_client.completions.create(prompt=prompt, **params)
async for chunk in response:
if run_manager:
await run_manager.on_llm_new_token(chunk.value)
yield GenerationChunk(text=chunk.value)

View File

@ -20,7 +20,7 @@ count=$(git grep -E '(@root_validator)|(@validator)|(@field_validator)|(@pre_ini
# PRs that increase the current count will not be accepted. # PRs that increase the current count will not be accepted.
# PRs that decrease update the code in the repository # PRs that decrease update the code in the repository
# and allow decreasing the count of are welcome! # and allow decreasing the count of are welcome!
current_count=126 current_count=125
if [ "$count" -gt "$current_count" ]; then if [ "$count" -gt "$current_count" ]; then
echo "The PR seems to be introducing new usage of @root_validator and/or @field_validator." echo "The PR seems to be introducing new usage of @root_validator and/or @field_validator."

View File

@ -1,10 +0,0 @@
"""Test Writer API wrapper."""
from langchain_community.llms.writer import Writer
def test_writer_call() -> None:
"""Test valid call to Writer."""
llm = Writer()
output = llm.invoke("Say foo:")
assert isinstance(output, str)

View File

@ -1,61 +1,251 @@
"""Unit tests for Writer chat model integration."""
import json import json
from typing import Any, Dict, List from typing import Any, Dict, List, Literal, Optional, Tuple, Type
from unittest.mock import AsyncMock, MagicMock, patch from unittest import mock
from unittest.mock import AsyncMock, MagicMock
import pytest import pytest
from langchain_core.callbacks.manager import CallbackManager from langchain_core.callbacks.manager import CallbackManager
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_standard_tests.unit_tests import ChatModelUnitTests
from pydantic import SecretStr from pydantic import SecretStr
from langchain_community.chat_models.writer import ChatWriter, _convert_dict_to_message from langchain_community.chat_models.writer import ChatWriter
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
"""Classes for mocking Writer responses."""
class ChoiceDelta:
def __init__(self, content: str):
self.content = content
class ChunkChoice:
def __init__(self, index: int, finish_reason: str, delta: ChoiceDelta):
self.index = index
self.finish_reason = finish_reason
self.delta = delta
class ChatCompletionChunk:
def __init__(
self,
id: str,
object: str,
created: int,
model: str,
choices: List[ChunkChoice],
):
self.id = id
self.object = object
self.created = created
self.model = model
self.choices = choices
class ToolCallFunction:
def __init__(self, name: str, arguments: str):
self.name = name
self.arguments = arguments
class ChoiceMessageToolCall:
def __init__(self, id: str, type: str, function: ToolCallFunction):
self.id = id
self.type = type
self.function = function
class Usage:
def __init__(
self,
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
):
self.prompt_tokens = prompt_tokens
self.completion_tokens = completion_tokens
self.total_tokens = total_tokens
class ChoiceMessage:
def __init__(
self,
role: str,
content: str,
tool_calls: Optional[List[ChoiceMessageToolCall]] = None,
):
self.role = role
self.content = content
self.tool_calls = tool_calls
class Choice:
def __init__(self, index: int, finish_reason: str, message: ChoiceMessage):
self.index = index
self.finish_reason = finish_reason
self.message = message
class Chat:
def __init__(
self,
id: str,
object: str,
created: int,
system_fingerprint: str,
model: str,
usage: Usage,
choices: List[Choice],
):
self.id = id
self.object = object
self.created = created
self.system_fingerprint = system_fingerprint
self.model = model
self.usage = usage
self.choices = choices
@pytest.mark.requires("writerai")
class TestChatWriterCustom:
"""Test case for ChatWriter"""
@pytest.fixture(autouse=True)
def mock_unstreaming_completion(self) -> Chat:
"""Fixture providing a mock API response."""
return Chat(
id="chat-12345",
object="chat.completion",
created=1699000000,
model="palmyra-x-004",
system_fingerprint="v1",
usage=Usage(prompt_tokens=10, completion_tokens=8, total_tokens=18),
choices=[
Choice(
index=0,
finish_reason="stop",
message=ChoiceMessage(
role="assistant",
content="Hello! How can I help you?",
),
)
],
)
@pytest.fixture(autouse=True)
def mock_tool_call_choice_response(self) -> Chat:
return Chat(
id="chat-12345",
object="chat.completion",
created=1699000000,
model="palmyra-x-004",
system_fingerprint="v1",
usage=Usage(prompt_tokens=29, completion_tokens=32, total_tokens=61),
choices=[
Choice(
index=0,
finish_reason="tool_calls",
message=ChoiceMessage(
role="assistant",
content="",
tool_calls=[
ChoiceMessageToolCall(
id="call_abc123",
type="function",
function=ToolCallFunction(
name="GetWeather",
arguments='{"location": "London"}',
),
)
],
),
)
],
)
@pytest.fixture(autouse=True)
def mock_streaming_chunks(self) -> List[ChatCompletionChunk]:
"""Fixture providing mock streaming response chunks."""
return [
ChatCompletionChunk(
id="chat-12345",
object="chat.completion",
created=1699000000,
model="palmyra-x-004",
choices=[
ChunkChoice(
index=0,
finish_reason="stop",
delta=ChoiceDelta(content="Hello! "),
)
],
),
ChatCompletionChunk(
id="chat-12345",
object="chat.completion",
created=1699000000,
model="palmyra-x-004",
choices=[
ChunkChoice(
index=0,
finish_reason="stop",
delta=ChoiceDelta(content="How can I help you?"),
)
],
),
]
class TestChatWriter:
def test_writer_model_param(self) -> None: def test_writer_model_param(self) -> None:
"""Test different ways to initialize the chat model.""" """Test different ways to initialize the chat model."""
test_cases: List[dict] = [ test_cases: List[dict] = [
{"model_name": "palmyra-x-004", "writer_api_key": "test-key"}, {
{"model": "palmyra-x-004", "writer_api_key": "test-key"}, "model_name": "palmyra-x-004",
{"model_name": "palmyra-x-004", "writer_api_key": "test-key"}, "api_key": "key",
},
{
"model": "palmyra-x-004",
"api_key": "key",
},
{
"model_name": "palmyra-x-004",
"api_key": "key",
},
{ {
"model": "palmyra-x-004", "model": "palmyra-x-004",
"writer_api_key": "test-key",
"temperature": 0.5, "temperature": 0.5,
"api_key": "key",
}, },
] ]
for case in test_cases: for case in test_cases:
chat = ChatWriter(**case) chat = ChatWriter(**case)
assert chat.model_name == "palmyra-x-004" assert chat.model_name == "palmyra-x-004"
assert chat.writer_api_key
assert chat.writer_api_key.get_secret_value() == "test-key"
assert chat.temperature == (0.5 if "temperature" in case else 0.7) assert chat.temperature == (0.5 if "temperature" in case else 0.7)
def test_convert_dict_to_message_human(self) -> None: def test_convert_writer_to_langchain_human(self) -> None:
"""Test converting a human message dict to a LangChain message.""" """Test converting a human message dict to a LangChain message."""
message = {"role": "user", "content": "Hello"} message = {"role": "user", "content": "Hello"}
result = _convert_dict_to_message(message) result = ChatWriter._convert_writer_to_langchain(message)
assert isinstance(result, HumanMessage) assert isinstance(result, HumanMessage)
assert result.content == "Hello" assert result.content == "Hello"
def test_convert_dict_to_message_ai(self) -> None: def test_convert_writer_to_langchain_ai(self) -> None:
"""Test converting an AI message dict to a LangChain message.""" """Test converting an AI message dict to a LangChain message."""
message = {"role": "assistant", "content": "Hello"} message = {"role": "assistant", "content": "Hello"}
result = _convert_dict_to_message(message) result = ChatWriter._convert_writer_to_langchain(message)
assert isinstance(result, AIMessage) assert isinstance(result, AIMessage)
assert result.content == "Hello" assert result.content == "Hello"
def test_convert_dict_to_message_system(self) -> None: def test_convert_writer_to_langchain_system(self) -> None:
"""Test converting a system message dict to a LangChain message.""" """Test converting a system message dict to a LangChain message."""
message = {"role": "system", "content": "You are a helpful assistant"} message = {"role": "system", "content": "You are a helpful assistant"}
result = _convert_dict_to_message(message) result = ChatWriter._convert_writer_to_langchain(message)
assert isinstance(result, SystemMessage) assert isinstance(result, SystemMessage)
assert result.content == "You are a helpful assistant" assert result.content == "You are a helpful assistant"
def test_convert_dict_to_message_tool_call(self) -> None: def test_convert_writer_to_langchain_tool_call(self) -> None:
"""Test converting a tool call message dict to a LangChain message.""" """Test converting a tool call message dict to a LangChain message."""
content = json.dumps({"result": 42}) content = json.dumps({"result": 42})
message = { message = {
@ -64,12 +254,12 @@ class TestChatWriter:
"content": content, "content": content,
"tool_call_id": "call_abc123", "tool_call_id": "call_abc123",
} }
result = _convert_dict_to_message(message) result = ChatWriter._convert_writer_to_langchain(message)
assert isinstance(result, ToolMessage) assert isinstance(result, ToolMessage)
assert result.name == "get_number" assert result.name == "get_number"
assert result.content == content assert result.content == content
def test_convert_dict_to_message_with_tool_calls(self) -> None: def test_convert_writer_to_langchain_with_tool_calls(self) -> None:
"""Test converting an AIMessage with tool calls.""" """Test converting an AIMessage with tool calls."""
message = { message = {
"role": "assistant", "role": "assistant",
@ -85,131 +275,55 @@ class TestChatWriter:
} }
], ],
} }
result = _convert_dict_to_message(message) result = ChatWriter._convert_writer_to_langchain(message)
assert isinstance(result, AIMessage) assert isinstance(result, AIMessage)
assert result.tool_calls assert result.tool_calls
assert len(result.tool_calls) == 1 assert len(result.tool_calls) == 1
assert result.tool_calls[0]["name"] == "get_weather" assert result.tool_calls[0]["name"] == "get_weather"
assert result.tool_calls[0]["args"]["location"] == "London" assert result.tool_calls[0]["args"]["location"] == "London"
@pytest.fixture(autouse=True) def test_sync_completion(
def mock_completion(self) -> Dict[str, Any]: self, mock_unstreaming_completion: List[ChatCompletionChunk]
"""Fixture providing a mock API response.""" ) -> None:
return {
"id": "chat-12345",
"object": "chat.completion",
"created": 1699000000,
"model": "palmyra-x-004",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Hello! How can I help you?",
},
"finish_reason": "stop",
}
],
"usage": {"prompt_tokens": 10, "completion_tokens": 8, "total_tokens": 18},
}
@pytest.fixture(autouse=True)
def mock_response(self) -> Dict[str, Any]:
response = {
"id": "chat-12345",
"choices": [
{
"message": {
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "call_abc123",
"type": "function",
"function": {
"name": "GetWeather",
"arguments": '{"location": "London"}',
},
}
],
},
"finish_reason": "tool_calls",
}
],
}
return response
@pytest.fixture(autouse=True)
def mock_streaming_chunks(self) -> List[Dict[str, Any]]:
"""Fixture providing mock streaming response chunks."""
return [
{
"id": "chat-12345",
"object": "chat.completion.chunk",
"created": 1699000000,
"model": "palmyra-x-004",
"choices": [
{
"index": 0,
"delta": {
"role": "assistant",
"content": "Hello",
},
"finish_reason": None,
}
],
},
{
"id": "chat-12345",
"object": "chat.completion.chunk",
"created": 1699000000,
"model": "palmyra-x-004",
"choices": [
{
"index": 0,
"delta": {
"content": "!",
},
"finish_reason": "stop",
}
],
},
]
def test_sync_completion(self, mock_completion: Dict[str, Any]) -> None:
"""Test basic chat completion with mocked response.""" """Test basic chat completion with mocked response."""
chat = ChatWriter(api_key=SecretStr("test-key")) chat = ChatWriter(api_key=SecretStr("key"))
mock_client = MagicMock()
mock_client.chat.chat.return_value = mock_completion
with patch.object(chat, "client", mock_client): mock_client = MagicMock()
mock_client.chat.chat.return_value = mock_unstreaming_completion
with mock.patch.object(chat, "client", mock_client):
message = HumanMessage(content="Hi there!") message = HumanMessage(content="Hi there!")
response = chat.invoke([message]) response = chat.invoke([message])
assert isinstance(response, AIMessage) assert isinstance(response, AIMessage)
assert response.content == "Hello! How can I help you?" assert response.content == "Hello! How can I help you?"
async def test_async_completion(self, mock_completion: Dict[str, Any]) -> None: @pytest.mark.asyncio
async def test_async_completion(
self, mock_unstreaming_completion: List[ChatCompletionChunk]
) -> None:
"""Test async chat completion with mocked response.""" """Test async chat completion with mocked response."""
chat = ChatWriter(api_key=SecretStr("test-key")) chat = ChatWriter(api_key=SecretStr("key"))
mock_client = AsyncMock()
mock_client.chat.chat.return_value = mock_completion
with patch.object(chat, "async_client", mock_client): mock_async_client = AsyncMock()
mock_async_client.chat.chat.return_value = mock_unstreaming_completion
with mock.patch.object(chat, "async_client", mock_async_client):
message = HumanMessage(content="Hi there!") message = HumanMessage(content="Hi there!")
response = await chat.ainvoke([message]) response = await chat.ainvoke([message])
assert isinstance(response, AIMessage) assert isinstance(response, AIMessage)
assert response.content == "Hello! How can I help you?" assert response.content == "Hello! How can I help you?"
def test_sync_streaming(self, mock_streaming_chunks: List[Dict[str, Any]]) -> None: def test_sync_streaming(
self, mock_streaming_chunks: List[ChatCompletionChunk]
) -> None:
"""Test sync streaming with callback handler.""" """Test sync streaming with callback handler."""
callback_handler = FakeCallbackHandler() callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler]) callback_manager = CallbackManager([callback_handler])
chat = ChatWriter( chat = ChatWriter(
streaming=True, api_key=SecretStr("key"),
callback_manager=callback_manager, callback_manager=callback_manager,
max_tokens=10, max_tokens=10,
api_key=SecretStr("test-key"),
) )
mock_client = MagicMock() mock_client = MagicMock()
@ -217,42 +331,46 @@ class TestChatWriter:
mock_response.__iter__.return_value = mock_streaming_chunks mock_response.__iter__.return_value = mock_streaming_chunks
mock_client.chat.chat.return_value = mock_response mock_client.chat.chat.return_value = mock_response
with patch.object(chat, "client", mock_client): with mock.patch.object(chat, "client", mock_client):
message = HumanMessage(content="Hi") message = HumanMessage(content="Hi")
response = chat.invoke([message]) response = chat.stream([message])
response_message = ""
assert isinstance(response, AIMessage) for chunk in response:
response_message += str(chunk.content)
assert callback_handler.llm_streams > 0 assert callback_handler.llm_streams > 0
assert response.content == "Hello!" assert response_message == "Hello! How can I help you?"
@pytest.mark.asyncio
async def test_async_streaming( async def test_async_streaming(
self, mock_streaming_chunks: List[Dict[str, Any]] self, mock_streaming_chunks: List[ChatCompletionChunk]
) -> None: ) -> None:
"""Test async streaming with callback handler.""" """Test async streaming with callback handler."""
callback_handler = FakeCallbackHandler() callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler]) callback_manager = CallbackManager([callback_handler])
chat = ChatWriter( chat = ChatWriter(
streaming=True, api_key=SecretStr("key"),
callback_manager=callback_manager, callback_manager=callback_manager,
max_tokens=10, max_tokens=10,
api_key=SecretStr("test-key"),
) )
mock_client = AsyncMock() mock_async_client = AsyncMock()
mock_response = AsyncMock() mock_response = AsyncMock()
mock_response.__aiter__.return_value = mock_streaming_chunks mock_response.__aiter__.return_value = mock_streaming_chunks
mock_client.chat.chat.return_value = mock_response mock_async_client.chat.chat.return_value = mock_response
with patch.object(chat, "async_client", mock_client): with mock.patch.object(chat, "async_client", mock_async_client):
message = HumanMessage(content="Hi") message = HumanMessage(content="Hi")
response = await chat.ainvoke([message]) response = chat.astream([message])
response_message = ""
assert isinstance(response, AIMessage) async for chunk in response:
response_message += str(chunk.content)
assert callback_handler.llm_streams > 0 assert callback_handler.llm_streams > 0
assert response.content == "Hello!" assert response_message == "Hello! How can I help you?"
def test_sync_tool_calling(self, mock_response: Dict[str, Any]) -> None: def test_sync_tool_calling(
self, mock_tool_call_choice_response: Dict[str, Any]
) -> None:
"""Test synchronous tool calling functionality.""" """Test synchronous tool calling functionality."""
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -261,23 +379,27 @@ class TestChatWriter:
location: str = Field(..., description="The location to get weather for") location: str = Field(..., description="The location to get weather for")
mock_client = MagicMock() chat = ChatWriter(api_key=SecretStr("key"))
mock_client.chat.chat.return_value = mock_response
chat = ChatWriter(api_key=SecretStr("test-key"), client=mock_client) mock_client = MagicMock()
mock_client.chat.chat.return_value = mock_tool_call_choice_response
chat_with_tools = chat.bind_tools( chat_with_tools = chat.bind_tools(
tools=[GetWeather], tools=[GetWeather],
tool_choice="GetWeather", tool_choice="GetWeather",
) )
with mock.patch.object(chat, "client", mock_client):
response = chat_with_tools.invoke("What's the weather in London?") response = chat_with_tools.invoke("What's the weather in London?")
assert isinstance(response, AIMessage) assert isinstance(response, AIMessage)
assert response.tool_calls assert response.tool_calls
assert response.tool_calls[0]["name"] == "GetWeather" assert response.tool_calls[0]["name"] == "GetWeather"
assert response.tool_calls[0]["args"]["location"] == "London" assert response.tool_calls[0]["args"]["location"] == "London"
async def test_async_tool_calling(self, mock_response: Dict[str, Any]) -> None: @pytest.mark.asyncio
async def test_async_tool_calling(
self, mock_tool_call_choice_response: Dict[str, Any]
) -> None:
"""Test asynchronous tool calling functionality.""" """Test asynchronous tool calling functionality."""
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -286,18 +408,101 @@ class TestChatWriter:
location: str = Field(..., description="The location to get weather for") location: str = Field(..., description="The location to get weather for")
mock_client = AsyncMock() mock_async_client = AsyncMock()
mock_client.chat.chat.return_value = mock_response mock_async_client.chat.chat.return_value = mock_tool_call_choice_response
chat = ChatWriter(api_key=SecretStr("test-key"), async_client=mock_client) chat = ChatWriter(api_key=SecretStr("key"))
chat_with_tools = chat.bind_tools( chat_with_tools = chat.bind_tools(
tools=[GetWeather], tools=[GetWeather],
tool_choice="GetWeather", tool_choice="GetWeather",
) )
with mock.patch.object(chat, "async_client", mock_async_client):
response = await chat_with_tools.ainvoke("What's the weather in London?") response = await chat_with_tools.ainvoke("What's the weather in London?")
assert isinstance(response, AIMessage) assert isinstance(response, AIMessage)
assert response.tool_calls assert response.tool_calls
assert response.tool_calls[0]["name"] == "GetWeather" assert response.tool_calls[0]["name"] == "GetWeather"
assert response.tool_calls[0]["args"]["location"] == "London" assert response.tool_calls[0]["args"]["location"] == "London"
@pytest.mark.requires("writerai")
class TestChatWriterStandart(ChatModelUnitTests):
"""Test case for ChatWriter that inherits from standard LangChain tests."""
@property
def chat_model_class(self) -> Type[BaseChatModel]:
"""Return ChatWriter model class."""
return ChatWriter
@property
def chat_model_params(self) -> Dict:
"""Return any additional parameters needed."""
return {
"api_key": "fake-api-key",
"model_name": "palmyra-x-004",
}
@property
def has_tool_calling(self) -> bool:
"""Writer supports tool/function calling."""
return True
@property
def tool_choice_value(self) -> Optional[str]:
"""Value to use for tool choice in tests."""
return "auto"
@property
def has_structured_output(self) -> bool:
"""Writer does not yet support structured output."""
return False
@property
def supports_image_inputs(self) -> bool:
"""Writer does not support image inputs."""
return False
@property
def supports_video_inputs(self) -> bool:
"""Writer does not support video inputs."""
return False
@property
def returns_usage_metadata(self) -> bool:
"""Writer returns token usage information."""
return True
@property
def supports_anthropic_inputs(self) -> bool:
"""Writer does not support anthropic inputs."""
return False
@property
def supports_image_tool_message(self) -> bool:
"""Writer does not support image tool message."""
return False
@property
def supported_usage_metadata_details(
self,
) -> Dict[
Literal["invoke", "stream"],
List[
Literal[
"audio_input",
"audio_output",
"reasoning_output",
"cache_read_input",
"cache_creation_input",
]
],
]:
"""Return which types of usage metadata your model supports."""
return {"invoke": ["cache_creation_input"], "stream": ["reasoning_output"]}
@property
def init_from_env_params(self) -> Tuple[dict, dict, dict]:
"""Return env vars, init args, and expected instance attrs for initializing
from env vars."""
return {"WRITER_API_KEY": "key"}, {"api_key": "key"}, {"api_key": "key"}

View File

@ -0,0 +1,202 @@
from typing import List
from unittest import mock
from unittest.mock import AsyncMock, MagicMock
import pytest
from langchain_core.callbacks import CallbackManager
from pydantic import SecretStr
from langchain_community.llms.writer import Writer
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
"""Classes for mocking Writer responses."""
class Choice:
def __init__(self, text: str):
self.text = text
class Completion:
def __init__(self, choices: List[Choice]):
self.choices = choices
class StreamingData:
def __init__(self, value: str):
self.value = value
@pytest.mark.requires("writerai")
class TestWriterLLM:
"""Unit tests for Writer LLM integration."""
@pytest.fixture(autouse=True)
def mock_unstreaming_completion(self) -> Completion:
"""Fixture providing a mock API response."""
return Completion(choices=[Choice(text="Hello! How can I help you?")])
@pytest.fixture(autouse=True)
def mock_streaming_completion(self) -> List[StreamingData]:
"""Fixture providing mock streaming response chunks."""
return [
StreamingData(value="Hello! "),
StreamingData(value="How can I"),
StreamingData(value=" help you?"),
]
def test_sync_unstream_completion(
self, mock_unstreaming_completion: Completion
) -> None:
"""Test basic llm call with mocked response."""
mock_client = MagicMock()
mock_client.completions.create.return_value = mock_unstreaming_completion
llm = Writer(api_key=SecretStr("key"))
with mock.patch.object(llm, "client", mock_client):
response_text = llm.invoke(input="Hello")
assert response_text == "Hello! How can I help you?"
def test_sync_unstream_completion_with_params(
self, mock_unstreaming_completion: Completion
) -> None:
"""Test llm call with passed params with mocked response."""
mock_client = MagicMock()
mock_client.completions.create.return_value = mock_unstreaming_completion
llm = Writer(api_key=SecretStr("key"), temperature=1)
with mock.patch.object(llm, "client", mock_client):
response_text = llm.invoke(input="Hello")
assert response_text == "Hello! How can I help you?"
@pytest.mark.asyncio
async def test_async_unstream_completion(
self, mock_unstreaming_completion: Completion
) -> None:
"""Test async chat completion with mocked response."""
mock_async_client = AsyncMock()
mock_async_client.completions.create.return_value = mock_unstreaming_completion
llm = Writer(api_key=SecretStr("key"))
with mock.patch.object(llm, "async_client", mock_async_client):
response_text = await llm.ainvoke(input="Hello")
assert response_text == "Hello! How can I help you?"
@pytest.mark.asyncio
async def test_async_unstream_completion_with_params(
self, mock_unstreaming_completion: Completion
) -> None:
"""Test async llm call with passed params with mocked response."""
mock_async_client = AsyncMock()
mock_async_client.completions.create.return_value = mock_unstreaming_completion
llm = Writer(api_key=SecretStr("key"), temperature=1)
with mock.patch.object(llm, "async_client", mock_async_client):
response_text = await llm.ainvoke(input="Hello")
assert response_text == "Hello! How can I help you?"
def test_sync_streaming_completion(
self, mock_streaming_completion: List[StreamingData]
) -> None:
"""Test sync streaming."""
mock_client = MagicMock()
mock_response = MagicMock()
mock_response.__iter__.return_value = mock_streaming_completion
mock_client.completions.create.return_value = mock_response
llm = Writer(api_key=SecretStr("key"))
with mock.patch.object(llm, "client", mock_client):
response = llm.stream(input="Hello")
response_message = ""
for chunk in response:
response_message += chunk
assert response_message == "Hello! How can I help you?"
def test_sync_streaming_completion_with_callback_handler(
self, mock_streaming_completion: List[StreamingData]
) -> None:
"""Test sync streaming with callback handler."""
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
mock_client = MagicMock()
mock_response = MagicMock()
mock_response.__iter__.return_value = mock_streaming_completion
mock_client.completions.create.return_value = mock_response
llm = Writer(
api_key=SecretStr("key"),
callback_manager=callback_manager,
)
with mock.patch.object(llm, "client", mock_client):
response = llm.stream(input="Hello")
response_message = ""
for chunk in response:
response_message += chunk
assert callback_handler.llm_streams == 3
assert response_message == "Hello! How can I help you?"
@pytest.mark.asyncio
async def test_async_streaming_completion(
self, mock_streaming_completion: Completion
) -> None:
"""Test async streaming with callback handler."""
mock_async_client = AsyncMock()
mock_response = AsyncMock()
mock_response.__aiter__.return_value = mock_streaming_completion
mock_async_client.completions.create.return_value = mock_response
llm = Writer(api_key=SecretStr("key"))
with mock.patch.object(llm, "async_client", mock_async_client):
response = llm.astream(input="Hello")
response_message = ""
async for chunk in response:
response_message += str(chunk)
assert response_message == "Hello! How can I help you?"
@pytest.mark.asyncio
async def test_async_streaming_completion_with_callback_handler(
self, mock_streaming_completion: Completion
) -> None:
"""Test async streaming with callback handler."""
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
mock_async_client = AsyncMock()
mock_response = AsyncMock()
mock_response.__aiter__.return_value = mock_streaming_completion
mock_async_client.completions.create.return_value = mock_response
llm = Writer(
api_key=SecretStr("key"),
callback_manager=callback_manager,
)
with mock.patch.object(llm, "async_client", mock_async_client):
response = llm.astream(input="Hello")
response_message = ""
async for chunk in response:
response_message += str(chunk)
assert callback_handler.llm_streams == 3
assert response_message == "Hello! How can I help you?"