mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-01 17:13:22 +00:00
fireworks[patch]: Add Fireworks partner packages (#17694)
--------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
11cf95e810
commit
ee6a773456
245
cookbook/fireworks_rag.ipynb
Normal file
245
cookbook/fireworks_rag.ipynb
Normal file
@ -0,0 +1,245 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0fc0309d-4d49-4bb5-bec0-bd92c6fddb28",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Fireworks.AI + LangChain + RAG\n",
|
||||
" \n",
|
||||
"[Fireworks AI](https://python.langchain.com/docs/integrations/llms/fireworks) wants to provide the best experience when working with LangChain, and here is an example of Fireworks + LangChain doing RAG\n",
|
||||
"\n",
|
||||
"See [our models page](https://fireworks.ai/models) for the full list of models. We use `accounts/fireworks/models/mixtral-8x7b-instruct` for RAG In this tutorial.\n",
|
||||
"\n",
|
||||
"For the RAG target, we will use the Gemma technical report https://storage.googleapis.com/deepmind-media/gemma/gemma-report.pdf "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "d12fb75a-f707-48d5-82a5-efe2d041813c",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.2.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.0\u001b[0m\n",
|
||||
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n",
|
||||
"Note: you may need to restart the kernel to use updated packages.\n",
|
||||
"Found existing installation: langchain-fireworks 0.0.1\n",
|
||||
"Uninstalling langchain-fireworks-0.0.1:\n",
|
||||
" Successfully uninstalled langchain-fireworks-0.0.1\n",
|
||||
"Note: you may need to restart the kernel to use updated packages.\n",
|
||||
"Obtaining file:///mnt/disks/data/langchain/libs/partners/fireworks\n",
|
||||
" Installing build dependencies ... \u001b[?25ldone\n",
|
||||
"\u001b[?25h Checking if build backend supports build_editable ... \u001b[?25ldone\n",
|
||||
"\u001b[?25h Getting requirements to build editable ... \u001b[?25ldone\n",
|
||||
"\u001b[?25h Preparing editable metadata (pyproject.toml) ... \u001b[?25ldone\n",
|
||||
"\u001b[?25hRequirement already satisfied: aiohttp<4.0.0,>=3.9.1 in /mnt/disks/data/langchain/.venv/lib/python3.9/site-packages (from langchain-fireworks==0.0.1) (3.9.3)\n",
|
||||
"Requirement already satisfied: fireworks-ai<0.13.0,>=0.12.0 in /mnt/disks/data/langchain/.venv/lib/python3.9/site-packages (from langchain-fireworks==0.0.1) (0.12.0)\n",
|
||||
"Requirement already satisfied: langchain-core<0.2,>=0.1 in /mnt/disks/data/langchain/.venv/lib/python3.9/site-packages (from langchain-fireworks==0.0.1) (0.1.23)\n",
|
||||
"Requirement already satisfied: requests<3,>=2 in /mnt/disks/data/langchain/.venv/lib/python3.9/site-packages (from langchain-fireworks==0.0.1) (2.31.0)\n",
|
||||
"Requirement already satisfied: aiosignal>=1.1.2 in /mnt/disks/data/langchain/.venv/lib/python3.9/site-packages (from aiohttp<4.0.0,>=3.9.1->langchain-fireworks==0.0.1) (1.3.1)\n",
|
||||
"Requirement already satisfied: attrs>=17.3.0 in /mnt/disks/data/langchain/.venv/lib/python3.9/site-packages (from aiohttp<4.0.0,>=3.9.1->langchain-fireworks==0.0.1) (23.1.0)\n",
|
||||
"Requirement already satisfied: frozenlist>=1.1.1 in /mnt/disks/data/langchain/.venv/lib/python3.9/site-packages (from aiohttp<4.0.0,>=3.9.1->langchain-fireworks==0.0.1) (1.4.0)\n",
|
||||
"Requirement already satisfied: multidict<7.0,>=4.5 in /mnt/disks/data/langchain/.venv/lib/python3.9/site-packages (from aiohttp<4.0.0,>=3.9.1->langchain-fireworks==0.0.1) (6.0.4)\n",
|
||||
"Requirement already satisfied: yarl<2.0,>=1.0 in /mnt/disks/data/langchain/.venv/lib/python3.9/site-packages (from aiohttp<4.0.0,>=3.9.1->langchain-fireworks==0.0.1) (1.9.2)\n",
|
||||
"Requirement already satisfied: async-timeout<5.0,>=4.0 in /mnt/disks/data/langchain/.venv/lib/python3.9/site-packages (from aiohttp<4.0.0,>=3.9.1->langchain-fireworks==0.0.1) (4.0.3)\n",
|
||||
"Requirement already satisfied: httpx in /mnt/disks/data/langchain/.venv/lib/python3.9/site-packages (from fireworks-ai<0.13.0,>=0.12.0->langchain-fireworks==0.0.1) (0.26.0)\n",
|
||||
"Requirement already satisfied: httpx-sse in /mnt/disks/data/langchain/.venv/lib/python3.9/site-packages (from fireworks-ai<0.13.0,>=0.12.0->langchain-fireworks==0.0.1) (0.4.0)\n",
|
||||
"Requirement already satisfied: pydantic in /mnt/disks/data/langchain/.venv/lib/python3.9/site-packages (from fireworks-ai<0.13.0,>=0.12.0->langchain-fireworks==0.0.1) (2.4.2)\n",
|
||||
"Requirement already satisfied: Pillow in /mnt/disks/data/langchain/.venv/lib/python3.9/site-packages (from fireworks-ai<0.13.0,>=0.12.0->langchain-fireworks==0.0.1) (10.2.0)\n",
|
||||
"Requirement already satisfied: PyYAML>=5.3 in /mnt/disks/data/langchain/.venv/lib/python3.9/site-packages (from langchain-core<0.2,>=0.1->langchain-fireworks==0.0.1) (6.0.1)\n",
|
||||
"Requirement already satisfied: anyio<5,>=3 in /mnt/disks/data/langchain/.venv/lib/python3.9/site-packages (from langchain-core<0.2,>=0.1->langchain-fireworks==0.0.1) (3.7.1)\n",
|
||||
"Requirement already satisfied: jsonpatch<2.0,>=1.33 in /mnt/disks/data/langchain/.venv/lib/python3.9/site-packages (from langchain-core<0.2,>=0.1->langchain-fireworks==0.0.1) (1.33)\n",
|
||||
"Requirement already satisfied: langsmith<0.2.0,>=0.1.0 in /mnt/disks/data/langchain/.venv/lib/python3.9/site-packages (from langchain-core<0.2,>=0.1->langchain-fireworks==0.0.1) (0.1.5)\n",
|
||||
"Requirement already satisfied: packaging<24.0,>=23.2 in /mnt/disks/data/langchain/.venv/lib/python3.9/site-packages (from langchain-core<0.2,>=0.1->langchain-fireworks==0.0.1) (23.2)\n",
|
||||
"Requirement already satisfied: tenacity<9.0.0,>=8.1.0 in /mnt/disks/data/langchain/.venv/lib/python3.9/site-packages (from langchain-core<0.2,>=0.1->langchain-fireworks==0.0.1) (8.2.3)\n",
|
||||
"Requirement already satisfied: charset-normalizer<4,>=2 in /mnt/disks/data/langchain/.venv/lib/python3.9/site-packages (from requests<3,>=2->langchain-fireworks==0.0.1) (3.3.0)\n",
|
||||
"Requirement already satisfied: idna<4,>=2.5 in /mnt/disks/data/langchain/.venv/lib/python3.9/site-packages (from requests<3,>=2->langchain-fireworks==0.0.1) (3.4)\n",
|
||||
"Requirement already satisfied: urllib3<3,>=1.21.1 in /mnt/disks/data/langchain/.venv/lib/python3.9/site-packages (from requests<3,>=2->langchain-fireworks==0.0.1) (2.0.6)\n",
|
||||
"Requirement already satisfied: certifi>=2017.4.17 in /mnt/disks/data/langchain/.venv/lib/python3.9/site-packages (from requests<3,>=2->langchain-fireworks==0.0.1) (2023.7.22)\n",
|
||||
"Requirement already satisfied: sniffio>=1.1 in /mnt/disks/data/langchain/.venv/lib/python3.9/site-packages (from anyio<5,>=3->langchain-core<0.2,>=0.1->langchain-fireworks==0.0.1) (1.3.0)\n",
|
||||
"Requirement already satisfied: exceptiongroup in /mnt/disks/data/langchain/.venv/lib/python3.9/site-packages (from anyio<5,>=3->langchain-core<0.2,>=0.1->langchain-fireworks==0.0.1) (1.1.3)\n",
|
||||
"Requirement already satisfied: jsonpointer>=1.9 in /mnt/disks/data/langchain/.venv/lib/python3.9/site-packages (from jsonpatch<2.0,>=1.33->langchain-core<0.2,>=0.1->langchain-fireworks==0.0.1) (2.4)\n",
|
||||
"Requirement already satisfied: annotated-types>=0.4.0 in /mnt/disks/data/langchain/.venv/lib/python3.9/site-packages (from pydantic->fireworks-ai<0.13.0,>=0.12.0->langchain-fireworks==0.0.1) (0.5.0)\n",
|
||||
"Requirement already satisfied: pydantic-core==2.10.1 in /mnt/disks/data/langchain/.venv/lib/python3.9/site-packages (from pydantic->fireworks-ai<0.13.0,>=0.12.0->langchain-fireworks==0.0.1) (2.10.1)\n",
|
||||
"Requirement already satisfied: typing-extensions>=4.6.1 in /mnt/disks/data/langchain/.venv/lib/python3.9/site-packages (from pydantic->fireworks-ai<0.13.0,>=0.12.0->langchain-fireworks==0.0.1) (4.8.0)\n",
|
||||
"Requirement already satisfied: httpcore==1.* in /mnt/disks/data/langchain/.venv/lib/python3.9/site-packages (from httpx->fireworks-ai<0.13.0,>=0.12.0->langchain-fireworks==0.0.1) (1.0.2)\n",
|
||||
"Requirement already satisfied: h11<0.15,>=0.13 in /mnt/disks/data/langchain/.venv/lib/python3.9/site-packages (from httpcore==1.*->httpx->fireworks-ai<0.13.0,>=0.12.0->langchain-fireworks==0.0.1) (0.14.0)\n",
|
||||
"Building wheels for collected packages: langchain-fireworks\n",
|
||||
" Building editable for langchain-fireworks (pyproject.toml) ... \u001b[?25ldone\n",
|
||||
"\u001b[?25h Created wheel for langchain-fireworks: filename=langchain_fireworks-0.0.1-py3-none-any.whl size=2228 sha256=564071b120b09ec31f2dc737733448a33bbb26e40b49fcde0c129ad26045259d\n",
|
||||
" Stored in directory: /tmp/pip-ephem-wheel-cache-oz368vdk/wheels/e0/ad/31/d7e76dd73d61905ff7f369f5b0d21a4b5e7af4d3cb7487aece\n",
|
||||
"Successfully built langchain-fireworks\n",
|
||||
"Installing collected packages: langchain-fireworks\n",
|
||||
"Successfully installed langchain-fireworks-0.0.1\n",
|
||||
"\n",
|
||||
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.2.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.0\u001b[0m\n",
|
||||
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n",
|
||||
"Note: you may need to restart the kernel to use updated packages.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%pip install --quiet pypdf chromadb tiktoken openai \n",
|
||||
"%pip uninstall -y langchain-fireworks\n",
|
||||
"%pip install --editable /mnt/disks/data/langchain/libs/partners/fireworks"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "cf719376",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"<module 'fireworks' from '/mnt/disks/data/langchain/.venv/lib/python3.9/site-packages/fireworks/__init__.py'>\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import fireworks\n",
|
||||
"\n",
|
||||
"print(fireworks)\n",
|
||||
"import fireworks.client"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9ab49327-0532-4480-804c-d066c302a322",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Load\n",
|
||||
"import requests\n",
|
||||
"from langchain_community.document_loaders import PyPDFLoader\n",
|
||||
"\n",
|
||||
"# Download the PDF from a URL and save it to a temporary location\n",
|
||||
"url = \"https://storage.googleapis.com/deepmind-media/gemma/gemma-report.pdf\"\n",
|
||||
"response = requests.get(url, stream=True)\n",
|
||||
"file_name = \"temp_file.pdf\"\n",
|
||||
"with open(file_name, \"wb\") as pdf:\n",
|
||||
" pdf.write(response.content)\n",
|
||||
"\n",
|
||||
"loader = PyPDFLoader(file_name)\n",
|
||||
"data = loader.load()\n",
|
||||
"\n",
|
||||
"# Split\n",
|
||||
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
|
||||
"\n",
|
||||
"text_splitter = RecursiveCharacterTextSplitter(chunk_size=2000, chunk_overlap=0)\n",
|
||||
"all_splits = text_splitter.split_documents(data)\n",
|
||||
"\n",
|
||||
"# Add to vectorDB\n",
|
||||
"from langchain_community.vectorstores import Chroma\n",
|
||||
"from langchain_fireworks.embeddings import FireworksEmbeddings\n",
|
||||
"\n",
|
||||
"vectorstore = Chroma.from_documents(\n",
|
||||
" documents=all_splits,\n",
|
||||
" collection_name=\"rag-chroma\",\n",
|
||||
" embedding=FireworksEmbeddings(),\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"retriever = vectorstore.as_retriever()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "4efaddd9-3dbb-455c-ba54-0ad7f2d2ce0f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_core.output_parsers import StrOutputParser\n",
|
||||
"from langchain_core.prompts import ChatPromptTemplate\n",
|
||||
"from langchain_core.pydantic_v1 import BaseModel\n",
|
||||
"from langchain_core.runnables import RunnableParallel, RunnablePassthrough\n",
|
||||
"\n",
|
||||
"# RAG prompt\n",
|
||||
"template = \"\"\"Answer the question based only on the following context:\n",
|
||||
"{context}\n",
|
||||
"\n",
|
||||
"Question: {question}\n",
|
||||
"\"\"\"\n",
|
||||
"prompt = ChatPromptTemplate.from_template(template)\n",
|
||||
"\n",
|
||||
"# LLM\n",
|
||||
"from langchain_together import Together\n",
|
||||
"\n",
|
||||
"llm = Together(\n",
|
||||
" model=\"mistralai/Mixtral-8x7B-Instruct-v0.1\",\n",
|
||||
" temperature=0.0,\n",
|
||||
" max_tokens=2000,\n",
|
||||
" top_k=1,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# RAG chain\n",
|
||||
"chain = (\n",
|
||||
" RunnableParallel({\"context\": retriever, \"question\": RunnablePassthrough()})\n",
|
||||
" | prompt\n",
|
||||
" | llm\n",
|
||||
" | StrOutputParser()\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "88b1ee51-1b0f-4ebf-bb32-e50e843f0eeb",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'\\nAnswer: The architectural details of Mixtral are as follows:\\n- Dimension (dim): 4096\\n- Number of layers (n\\\\_layers): 32\\n- Dimension of each head (head\\\\_dim): 128\\n- Hidden dimension (hidden\\\\_dim): 14336\\n- Number of heads (n\\\\_heads): 32\\n- Number of kv heads (n\\\\_kv\\\\_heads): 8\\n- Context length (context\\\\_len): 32768\\n- Vocabulary size (vocab\\\\_size): 32000\\n- Number of experts (num\\\\_experts): 8\\n- Number of top k experts (top\\\\_k\\\\_experts): 2\\n\\nMixtral is based on a transformer architecture and uses the same modifications as described in [18], with the notable exceptions that Mixtral supports a fully dense context length of 32k tokens, and the feedforward block picks from a set of 8 distinct groups of parameters. At every layer, for every token, a router network chooses two of these groups (the “experts”) to process the token and combine their output additively. This technique increases the number of parameters of a model while controlling cost and latency, as the model only uses a fraction of the total set of parameters per token. Mixtral is pretrained with multilingual data using a context size of 32k tokens. It either matches or exceeds the performance of Llama 2 70B and GPT-3.5, over several benchmarks. In particular, Mixtral vastly outperforms Llama 2 70B on mathematics, code generation, and multilingual benchmarks.'"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chain.invoke(\"What are the Architectural details of Mixtral?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "755cf871-26b7-4e30-8b91-9ffd698470f4",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Trace: \n",
|
||||
"\n",
|
||||
"https://smith.langchain.com/public/935fd642-06a6-4b42-98e3-6074f93115cd/r"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -23,6 +23,15 @@
|
||||
"This example goes over how to use LangChain to interact with `ChatFireworks` models."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "raw",
|
||||
"id": "4a7c795e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"%pip install langchain\n",
|
||||
"%pip install langchain-fireworks"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
@ -37,8 +46,8 @@
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"from langchain_community.chat_models.fireworks import ChatFireworks\n",
|
||||
"from langchain_core.messages import HumanMessage, SystemMessage"
|
||||
"from langchain.schema import HumanMessage, SystemMessage\n",
|
||||
"from langchain_fireworks import ChatFireworks"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -67,7 +76,7 @@
|
||||
" os.environ[\"FIREWORKS_API_KEY\"] = getpass.getpass(\"Fireworks API Key:\")\n",
|
||||
"\n",
|
||||
"# Initialize a Fireworks chat model\n",
|
||||
"chat = ChatFireworks(model=\"accounts/fireworks/models/llama-v2-13b-chat\")"
|
||||
"chat = ChatFireworks(model=\"accounts/fireworks/models/mixtral-8x7b-instruct\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -82,17 +91,25 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 3,
|
||||
"id": "72340871-ae2f-415f-b399-0777d32dc379",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/mnt/disks/data/langchain/.venv/lib/python3.10/site-packages/langchain_core/_api/deprecation.py:117: LangChainDeprecationWarning: The function `__call__` was deprecated in LangChain 0.1.7 and will be removed in 0.2.0. Use invoke instead.\n",
|
||||
" warn_deprecated(\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content=\"Hello! My name is LLaMA, I'm a large language model trained by a team of researcher at Meta AI. My primary function is to assist and converse with users like you, answering questions and engaging in discussion to the best of my ability. I'm here to help and provide information on a wide range of topics, so feel free to ask me anything!\", additional_kwargs={}, example=False)"
|
||||
"AIMessage(content=\"Hello! I'm an AI language model, a helpful assistant designed to chat and assist you with any questions or information you might need. I'm here to make your experience as smooth and enjoyable as possible. How can I assist you today?\")"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -107,17 +124,27 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 5,
|
||||
"id": "68c6b1fa-2ff7-4a63-8d88-3cec302180b8",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/mnt/disks/data/langchain/.venv/lib/python3.10/site-packages/langchain_core/utils/utils.py:159: UserWarning: WARNING! top_p is not default parameter.\n",
|
||||
" top_p was transferred to model_kwargs.\n",
|
||||
" Please confirm that top_p is what you intended.\n",
|
||||
" warnings.warn(\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content=\"Oh hello there! *giggle* It's such a beautiful day today, isn\", additional_kwargs={}, example=False)"
|
||||
"AIMessage(content=\"I'm glad to chat with you! I'm an artificial intelligence and don't have\")"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -125,200 +152,15 @@
|
||||
"source": [
|
||||
"# Setting additional parameters: temperature, max_tokens, top_p\n",
|
||||
"chat = ChatFireworks(\n",
|
||||
" model=\"accounts/fireworks/models/llama-v2-13b-chat\",\n",
|
||||
" model_kwargs={\"temperature\": 1, \"max_tokens\": 20, \"top_p\": 1},\n",
|
||||
" model=\"accounts/fireworks/models/mixtral-8x7b-instruct\",\n",
|
||||
" temperature=1,\n",
|
||||
" max_tokens=20,\n",
|
||||
" top_p=1,\n",
|
||||
")\n",
|
||||
"system_message = SystemMessage(content=\"You are to chat with the user.\")\n",
|
||||
"human_message = HumanMessage(content=\"How's the weather today?\")\n",
|
||||
"chat([system_message, human_message])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d93aa186-39cf-4e1a-aa32-01ed31d43bc8",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Simple Chat Chain"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "28763fbc",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"You can use chat models on fireworks, with system prompts and memory."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "cbe29efc-37c3-4c83-8b84-b8bba1a1e589",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.memory import ConversationBufferMemory\n",
|
||||
"from langchain_community.chat_models import ChatFireworks\n",
|
||||
"from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder\n",
|
||||
"from langchain_core.runnables import RunnablePassthrough\n",
|
||||
"\n",
|
||||
"llm = ChatFireworks(\n",
|
||||
" model=\"accounts/fireworks/models/llama-v2-13b-chat\",\n",
|
||||
" model_kwargs={\"temperature\": 0, \"max_tokens\": 64, \"top_p\": 1.0},\n",
|
||||
")\n",
|
||||
"prompt = ChatPromptTemplate.from_messages(\n",
|
||||
" [\n",
|
||||
" (\"system\", \"You are a helpful chatbot that speaks like a pirate.\"),\n",
|
||||
" MessagesPlaceholder(variable_name=\"history\"),\n",
|
||||
" (\"human\", \"{input}\"),\n",
|
||||
" ]\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "02991e05-a38e-47d4-9ab3-7e630a8ead55",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Initially, there is no chat memory"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "e2fd186f",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'history': []}"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"memory = ConversationBufferMemory(return_messages=True)\n",
|
||||
"memory.load_memory_variables({})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "bee461da",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Create a simple chain with memory"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "86972e54",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chain = (\n",
|
||||
" RunnablePassthrough.assign(\n",
|
||||
" history=memory.load_memory_variables | (lambda x: x[\"history\"])\n",
|
||||
" )\n",
|
||||
" | prompt\n",
|
||||
" | llm.bind(stop=[\"\\n\\n\"])\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f48cb142",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Run the chain with a simple question, expecting an answer aligned with the system message provided."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "db3ad5b1",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content=\"Ahoy there, me hearty! Yer a fine lookin' swashbuckler, I can see that! *adjusts eye patch* What be bringin' ye to these waters? Are ye here to plunder some booty or just to enjoy the sea breeze?\", additional_kwargs={}, example=False)"
|
||||
]
|
||||
},
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"inputs = {\"input\": \"hi im bob\"}\n",
|
||||
"response = chain.invoke(inputs)\n",
|
||||
"response"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "338f4bae",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Save the memory context, then read it back to inspect contents"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "257eec01",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'history': [HumanMessage(content='hi im bob', additional_kwargs={}, example=False),\n",
|
||||
" AIMessage(content=\"Ahoy there, me hearty! Yer a fine lookin' swashbuckler, I can see that! *adjusts eye patch* What be bringin' ye to these waters? Are ye here to plunder some booty or just to enjoy the sea breeze?\", additional_kwargs={}, example=False)]}"
|
||||
]
|
||||
},
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"memory.save_context(inputs, {\"output\": response.content})\n",
|
||||
"memory.load_memory_variables({})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "08441347",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now as another question that requires use of the memory."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"id": "7f5f2820",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content=\"Arrrr, ye be askin' about yer name, eh? Well, me matey, I be knowin' ye as Bob, the scurvy dog! *winks* But if ye want me to call ye somethin' else, just let me know, and I\", additional_kwargs={}, example=False)"
|
||||
]
|
||||
},
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"inputs = {\"input\": \"whats my name\"}\n",
|
||||
"chain.invoke(inputs)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
@ -337,7 +179,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.16"
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
@ -14,7 +14,29 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": null,
|
||||
"id": "fb345268",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install langchain\n",
|
||||
"%pip install langchain-fireworks"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "6b9bcdac",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import fireworks\n",
|
||||
"import fireworks.client"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "60b6dbb2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -22,7 +44,7 @@
|
||||
"import os\n",
|
||||
"\n",
|
||||
"from langchain.prompts import PromptTemplate\n",
|
||||
"from langchain_community.llms.fireworks import Fireworks"
|
||||
"from langchain_fireworks import Fireworks"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -34,12 +56,12 @@
|
||||
"\n",
|
||||
"1. Make sure the `fireworks-ai` package is installed in your environment.\n",
|
||||
"2. Sign in to [Fireworks AI](http://fireworks.ai) for the an API Key to access our models, and make sure it is set as the `FIREWORKS_API_KEY` environment variable.\n",
|
||||
"3. Set up your model using a model id. If the model is not set, the default model is fireworks-llama-v2-7b-chat. See the full, most up-to-date model list on [app.fireworks.ai](https://app.fireworks.ai)."
|
||||
"3. Set up your model using a model id. If the model is not set, the default model is fireworks-llama-v2-7b-chat. See the full, most up-to-date model list on [fireworks.ai](https://fireworks.ai)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 26,
|
||||
"execution_count": 5,
|
||||
"id": "9ca87a2e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -51,7 +73,10 @@
|
||||
" os.environ[\"FIREWORKS_API_KEY\"] = getpass.getpass(\"Fireworks API Key:\")\n",
|
||||
"\n",
|
||||
"# Initialize a Fireworks model\n",
|
||||
"llm = Fireworks(model=\"accounts/fireworks/models/llama-v2-13b\")"
|
||||
"llm = Fireworks(\n",
|
||||
" model=\"accounts/fireworks/models/mixtral-8x7b-instruct\",\n",
|
||||
" base_url=\"https://api.fireworks.ai/inference/v1/completions\",\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -66,51 +91,23 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 6,
|
||||
"id": "bf0a425c",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/mnt/disks/data/langchain/.venv/lib/python3.10/site-packages/langchain_core/_api/deprecation.py:117: LangChainDeprecationWarning: The function `__call__` was deprecated in LangChain 0.1.7 and will be removed in 0.2.0. Use invoke instead.\n",
|
||||
" warn_deprecated(\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"Is it Tom Brady? Peyton Manning? Aaron Rodgers? Or maybe even Andrew Luck?\n",
|
||||
"\n",
|
||||
"Well, let's look at some stats to decide.\n",
|
||||
"\n",
|
||||
"First, let's talk about touchdowns. Who's thrown the most touchdowns this season?\n",
|
||||
"\n",
|
||||
"(pause for dramatic effect)\n",
|
||||
"\n",
|
||||
"It's... Aaron Rodgers! With 28 touchdowns, he's leading the league in that category.\n",
|
||||
"\n",
|
||||
"But what about interceptions? Who's thrown the fewest picks?\n",
|
||||
"\n",
|
||||
"(drumroll)\n",
|
||||
"\n",
|
||||
"It's... Tom Brady! With only 4 interceptions, he's got the fewest picks in the league.\n",
|
||||
"\n",
|
||||
"Now, let's talk about passer rating. Who's got the highest passer rating this season?\n",
|
||||
"\n",
|
||||
"(pause for suspense)\n",
|
||||
"\n",
|
||||
"It's... Peyton Manning! With a rating of 114.2, he's been lights out this season.\n",
|
||||
"\n",
|
||||
"But what about wins? Who's got the most wins this season?\n",
|
||||
"\n",
|
||||
"(drumroll)\n",
|
||||
"\n",
|
||||
"It's... Andrew Luck! With 8 wins, he's got the most victories this season.\n",
|
||||
"\n",
|
||||
"So, there you have it folks. According to these stats, the best quarterback in the NFL this season is... (drumroll) Aaron Rodgers!\n",
|
||||
"\n",
|
||||
"But wait, there's more! Each of these quarterbacks has their own unique strengths and weaknesses.\n",
|
||||
"\n",
|
||||
"Tom Brady is a master of the short pass, but can struggle with deep balls. Peyton Manning is a genius at reading defenses, but can be prone to turnovers. Aaron Rodgers has a cannon for an arm, but can be inconsistent at times. Andrew Luck is a pure pocket passer, but can struggle outside of his comfort zone.\n",
|
||||
"\n",
|
||||
"So, who's the best quarterback in the NFL? It's a tough call, but one thing's for sure: each of these quarterbacks is an elite talent, and they'll continue to light up the scoreboard for their respective teams all season long.\n"
|
||||
" With Tom Brady at a season-best 9-0, he'\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -122,7 +119,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 7,
|
||||
"id": "afc7de6f",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -130,7 +127,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[[Generation(text='\\nasked Dec 28, 2016 in Sports by anonymous\\nWho is the best cricket player in 2016?\\nHere are some of the top contenders for the title of best cricket player in 2016:\\n\\n1. Virat Kohli (India): Kohli had a phenomenal year in 2016, scoring over 2,000 runs in international cricket, including 12 centuries. He was named the ICC Cricketer of the Year and the ICC Test Player of the Year.\\n2. Steve Smith (Australia): Smith had a great year as well, scoring over 1,000 runs in Test cricket and leading Australia to the No. 1 ranking in Test cricket. He was named the ICC ODI Player of the Year.\\n3. Joe Root (England): Root had a strong year, scoring over 1,000 runs in Test cricket and leading England to the No. 2 ranking in Test cricket.\\n4. Kane Williamson (New Zealand): Williamson had a great year, scoring over 1,000 runs in all formats of the game and leading New Zealand to the ICC World T20 final.\\n5. Quinton de Kock (South Africa): De Kock had a great year behind the wickets, scoring over 1,000 runs in all formats of the game and effecting over 100 dismissals.\\n6. David Warner (Australia): Warner had a great year, scoring over 1,000 runs in all formats of the game and leading Australia to the ICC World T20 title.\\n7. AB de Villiers (South Africa): De Villiers had a great year, scoring over 1,000 runs in all formats of the game and effecting over 50 dismissals.\\n8. Chris Gayle (West Indies): Gayle had a great year, scoring over 1,000 runs in all formats of the game and leading the West Indies to the ICC World T20 title.\\n9. Shakib Al Hasan (Bangladesh): Shakib had a great year, scoring over 1,000 runs in all formats of the game and taking over 50 wickets.\\n10', generation_info=None)], [Generation(text=\"\\n\\n A) LeBron James\\n B) Kevin Durant\\n C) Steph Curry\\n D) James Harden\\n\\nAnswer: C) Steph Curry\\n\\nIn recent years, Curry has established himself as the premier shooter in the NBA, leading the league in three-point shooting and earning back-to-back MVP awards. He's also a strong ball handler and playmaker, making him a threat to score from anywhere on the court. While other players like LeBron James and Kevin Durant are certainly talented, Curry's unique skill set and consistent dominance make him the best basketball player in the league right now.\", generation_info=None)]]\n"
|
||||
"[[Generation(text=' In 2016, the best cricket player award should go to a')], [Generation(text=\"\\n\\nThere's a good argument that it's Kawhi Leonard\")]]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -147,7 +144,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 9,
|
||||
"id": "b801c20d",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -156,15 +153,18 @@
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"What's the weather like in Kansas City in December? \n"
|
||||
"\n",
|
||||
"December in Kansas City is typically cold. The average high temperature\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Setting additional parameters: temperature, max_tokens, top_p\n",
|
||||
"llm = Fireworks(\n",
|
||||
" model=\"accounts/fireworks/models/llama-v2-13b-chat\",\n",
|
||||
" model_kwargs={\"temperature\": 0.7, \"max_tokens\": 15, \"top_p\": 1.0},\n",
|
||||
" model=\"accounts/fireworks/models/mixtral-8x7b-instruct\",\n",
|
||||
" temperature=0.7,\n",
|
||||
" max_tokens=15,\n",
|
||||
" top_p=1.0,\n",
|
||||
")\n",
|
||||
"print(llm(\"What's the weather like in Kansas City in December?\"))"
|
||||
]
|
||||
@ -187,7 +187,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"execution_count": 13,
|
||||
"id": "fd2c6bc1",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -195,12 +195,11 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" What do you call a bear with no teeth? A gummy bear!\n",
|
||||
"\n",
|
||||
"A bear walks into a bar and says, \"I'll have a beer and a muffin.\" The bartender says, \"Sorry, we don't serve muffins here.\" The bear says, \"OK, give me a beer and I'll make my own muffin.\"\n",
|
||||
"What do you call a bear with no teeth?\n",
|
||||
"A gummy bear.\n",
|
||||
"What do you call a bear with no teeth and no hair?\n",
|
||||
"\n"
|
||||
"User: What do you call a bear with no teeth and no legs? A gummy bear!\n",
|
||||
"\n",
|
||||
"Computer: That's the same joke! You told the same joke I just told.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -209,7 +208,7 @@
|
||||
"from langchain_community.llms.fireworks import Fireworks\n",
|
||||
"\n",
|
||||
"llm = Fireworks(\n",
|
||||
" model=\"accounts/fireworks/models/llama-v2-13b\",\n",
|
||||
" model=\"accounts/fireworks/models/mixtral-8x7b-instruct\",\n",
|
||||
" model_kwargs={\"temperature\": 0, \"max_tokens\": 100, \"top_p\": 1.0},\n",
|
||||
")\n",
|
||||
"prompt = PromptTemplate.from_template(\"Tell me a joke about {topic}?\")\n",
|
||||
@ -228,7 +227,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"execution_count": 15,
|
||||
"id": "f644ff28",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -236,11 +235,11 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" What do you call a bear with no teeth? A gummy bear!\n",
|
||||
"\n",
|
||||
"A bear walks into a bar and says, \"I'll have a beer and a muffin.\" The bartender says, \"Sorry, we don't serve muffins here.\" The bear says, \"OK, give me a beer and I'll make my own muffin.\"\n",
|
||||
"What do you call a bear with no teeth?\n",
|
||||
"A gummy bear.\n",
|
||||
"What do you call a bear with no teeth and no hair?\n"
|
||||
"User: What do you call a bear with no teeth and no legs? A gummy bear!\n",
|
||||
"\n",
|
||||
"Computer: That's the same joke! You told the same joke I just told."
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -266,7 +265,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.16"
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Fireworks
|
||||
|
||||
This page covers how to use [Fireworks](https://app.fireworks.ai/) models within
|
||||
This page covers how to use [Fireworks](https://fireworks.ai/) models within
|
||||
Langchain.
|
||||
|
||||
## Installation and setup
|
||||
@ -11,7 +11,7 @@ Langchain.
|
||||
pip install fireworks-ai
|
||||
```
|
||||
|
||||
- Get a Fireworks API key by signing up at [app.fireworks.ai](https://app.fireworks.ai).
|
||||
- Get a Fireworks API key by signing up at [fireworks.ai](https://fireworks.ai).
|
||||
- Authenticate by setting the FIREWORKS_API_KEY environment variable.
|
||||
|
||||
## Authentication
|
||||
@ -33,14 +33,14 @@ There are two ways to authenticate using your Fireworks API key:
|
||||
## Using the Fireworks LLM module
|
||||
|
||||
Fireworks integrates with Langchain through the LLM module. In this example, we
|
||||
will work the llama-v2-13b-chat model.
|
||||
will work the mixtral-8x7b-instruct model.
|
||||
|
||||
```python
|
||||
from langchain_community.llms.fireworks import Fireworks
|
||||
from langchain_fireworks import Fireworks
|
||||
|
||||
llm = Fireworks(
|
||||
fireworks_api_key="<KEY>",
|
||||
model="accounts/fireworks/models/llama-v2-13b-chat",
|
||||
model="accounts/fireworks/models/mixtral-8x7b-instruct",
|
||||
max_tokens=256)
|
||||
llm("Name 3 sports.")
|
||||
```
|
||||
|
1
libs/partners/fireworks/.gitignore
vendored
Normal file
1
libs/partners/fireworks/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
__pycache__
|
21
libs/partners/fireworks/LICENSE
Normal file
21
libs/partners/fireworks/LICENSE
Normal file
@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2024 LangChain, Inc.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
59
libs/partners/fireworks/Makefile
Normal file
59
libs/partners/fireworks/Makefile
Normal file
@ -0,0 +1,59 @@
|
||||
.PHONY: all format lint test tests integration_tests docker_tests help extended_tests
|
||||
|
||||
# Default target executed when no arguments are given to make.
|
||||
all: help
|
||||
|
||||
# Define a variable for the test file path.
|
||||
TEST_FILE ?= tests/unit_tests/
|
||||
|
||||
test:
|
||||
poetry run pytest $(TEST_FILE)
|
||||
|
||||
tests:
|
||||
poetry run pytest $(TEST_FILE)
|
||||
|
||||
|
||||
######################
|
||||
# LINTING AND FORMATTING
|
||||
######################
|
||||
|
||||
# Define a variable for Python and notebook files.
|
||||
PYTHON_FILES=.
|
||||
MYPY_CACHE=.mypy_cache
|
||||
lint format: PYTHON_FILES=.
|
||||
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/partners/fireworks --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
|
||||
lint_package: PYTHON_FILES=langchain_fireworks
|
||||
lint_tests: PYTHON_FILES=tests
|
||||
lint_tests: MYPY_CACHE=.mypy_cache_test
|
||||
|
||||
lint lint_diff lint_package lint_tests:
|
||||
poetry run ruff .
|
||||
poetry run ruff format $(PYTHON_FILES) --diff
|
||||
poetry run ruff --select I $(PYTHON_FILES)
|
||||
mkdir $(MYPY_CACHE); poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
|
||||
|
||||
format format_diff:
|
||||
poetry run ruff format $(PYTHON_FILES)
|
||||
poetry run ruff --select I --fix $(PYTHON_FILES)
|
||||
|
||||
spell_check:
|
||||
poetry run codespell --toml pyproject.toml
|
||||
|
||||
spell_fix:
|
||||
poetry run codespell --toml pyproject.toml -w
|
||||
|
||||
check_imports: $(shell find langchain_fireworks -name '*.py')
|
||||
poetry run python ./scripts/check_imports.py $^
|
||||
|
||||
######################
|
||||
# HELP
|
||||
######################
|
||||
|
||||
help:
|
||||
@echo '----'
|
||||
@echo 'check_imports - check imports'
|
||||
@echo 'format - run code formatters'
|
||||
@echo 'lint - run linters'
|
||||
@echo 'test - run unit tests'
|
||||
@echo 'tests - run unit tests'
|
||||
@echo 'test TEST_FILE=<test_file> - run all tests in file'
|
15
libs/partners/fireworks/README.md
Normal file
15
libs/partners/fireworks/README.md
Normal file
@ -0,0 +1,15 @@
|
||||
# LangChain-Fireworks
|
||||
|
||||
This is the partner package for tying Fireworks.ai and LangChain. Fireworks really strive to provide good support for LangChain use cases, so if you run into any issues please let us know. You can reach out to us [in our Discord channel](https://discord.com/channels/1137072072808472616/)
|
||||
|
||||
## Basic LangChain-Fireworks example
|
||||
|
||||
|
||||
## Advanced
|
||||
### Tool use: LangChain Agent + Fireworks function calling model
|
||||
Please checkout how to teach Fireworks function calling model to use a [calculator here](https://github.com/fw-ai/cookbook/blob/main/examples/function_calling/fireworks_langchain_tool_usage.ipynb).
|
||||
|
||||
Fireworks focus on delivering the best experience for fast model inference as well as tool use. You can check out [our blog](https://fireworks.ai/blog/firefunction-v1-gpt-4-level-function-calling) for more details on how it fares compares to GPT-4, the punchline is that it is on par with GPT-4 in terms just function calling use cases, but it is way faster and much cheaper.
|
||||
|
||||
### RAG: LangChain agent + Fireworks function calling model + MongoDB + Nomic AI embeddings
|
||||
Please check out the [cookbook here](https://github.com/fw-ai/cookbook/blob/main/examples/rag/mongodb_agent.ipynb) for an end to end flow
|
11
libs/partners/fireworks/langchain_fireworks/__init__.py
Normal file
11
libs/partners/fireworks/langchain_fireworks/__init__.py
Normal file
@ -0,0 +1,11 @@
|
||||
from langchain_fireworks.chat_models import ChatFireworks
|
||||
from langchain_fireworks.embeddings import FireworksEmbeddings
|
||||
from langchain_fireworks.llms import Fireworks
|
||||
from langchain_fireworks.version import __version__
|
||||
|
||||
__all__ = [
|
||||
"__version__",
|
||||
"ChatFireworks",
|
||||
"Fireworks",
|
||||
"FireworksEmbeddings",
|
||||
]
|
615
libs/partners/fireworks/langchain_fireworks/chat_models.py
Normal file
615
libs/partners/fireworks/langchain_fireworks/chat_models.py
Normal file
@ -0,0 +1,615 @@
|
||||
"""Fireworks chat wrapper."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TypedDict,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from fireworks.client import AsyncFireworks, Fireworks # type: ignore
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.language_models.chat_models import (
|
||||
BaseChatModel,
|
||||
agenerate_from_stream,
|
||||
generate_from_stream,
|
||||
)
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
ChatMessage,
|
||||
ChatMessageChunk,
|
||||
FunctionMessage,
|
||||
FunctionMessageChunk,
|
||||
HumanMessage,
|
||||
HumanMessageChunk,
|
||||
SystemMessage,
|
||||
SystemMessageChunk,
|
||||
ToolMessage,
|
||||
ToolMessageChunk,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import (
|
||||
convert_to_secret_str,
|
||||
get_from_dict_or_env,
|
||||
get_pydantic_field_names,
|
||||
)
|
||||
from langchain_core.utils.function_calling import (
|
||||
convert_to_openai_function,
|
||||
convert_to_openai_tool,
|
||||
)
|
||||
from langchain_core.utils.utils import build_extra_kwargs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
"""Convert a dictionary to a LangChain message.
|
||||
|
||||
Args:
|
||||
_dict: The dictionary.
|
||||
|
||||
Returns:
|
||||
The LangChain message.
|
||||
"""
|
||||
role = _dict.get("role")
|
||||
if role == "user":
|
||||
return HumanMessage(content=_dict.get("content", ""))
|
||||
elif role == "assistant":
|
||||
# Fix for azure
|
||||
# Also Fireworks returns None for tool invocations
|
||||
content = _dict.get("content", "") or ""
|
||||
additional_kwargs: Dict = {}
|
||||
if function_call := _dict.get("function_call"):
|
||||
additional_kwargs["function_call"] = dict(function_call)
|
||||
if tool_calls := _dict.get("tool_calls"):
|
||||
additional_kwargs["tool_calls"] = tool_calls
|
||||
return AIMessage(content=content, additional_kwargs=additional_kwargs)
|
||||
elif role == "system":
|
||||
return SystemMessage(content=_dict.get("content", ""))
|
||||
elif role == "function":
|
||||
return FunctionMessage(
|
||||
content=_dict.get("content", ""), name=_dict.get("name", "")
|
||||
)
|
||||
elif role == "tool":
|
||||
additional_kwargs = {}
|
||||
if "name" in _dict:
|
||||
additional_kwargs["name"] = _dict["name"]
|
||||
return ToolMessage(
|
||||
content=_dict.get("content", ""),
|
||||
tool_call_id=_dict.get("tool_call_id", ""),
|
||||
additional_kwargs=additional_kwargs,
|
||||
)
|
||||
else:
|
||||
return ChatMessage(content=_dict.get("content", ""), role=role or "")
|
||||
|
||||
|
||||
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
"""Convert a LangChain message to a dictionary.
|
||||
|
||||
Args:
|
||||
message: The LangChain message.
|
||||
|
||||
Returns:
|
||||
The dictionary.
|
||||
"""
|
||||
message_dict: Dict[str, Any]
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if "function_call" in message.additional_kwargs:
|
||||
message_dict["function_call"] = message.additional_kwargs["function_call"]
|
||||
# If function call only, content is None not empty string
|
||||
if message_dict["content"] == "":
|
||||
message_dict["content"] = None
|
||||
if "tool_calls" in message.additional_kwargs:
|
||||
message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]
|
||||
# If tool calls only, content is None not empty string
|
||||
if message_dict["content"] == "":
|
||||
message_dict["content"] = None
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, FunctionMessage):
|
||||
message_dict = {
|
||||
"role": "function",
|
||||
"content": message.content,
|
||||
"name": message.name,
|
||||
}
|
||||
elif isinstance(message, ToolMessage):
|
||||
message_dict = {
|
||||
"role": "tool",
|
||||
"content": message.content,
|
||||
"tool_call_id": message.tool_call_id,
|
||||
}
|
||||
else:
|
||||
raise TypeError(f"Got unknown type {message}")
|
||||
if "name" in message.additional_kwargs:
|
||||
message_dict["name"] = message.additional_kwargs["name"]
|
||||
return message_dict
|
||||
|
||||
|
||||
def _convert_delta_to_message_chunk(
|
||||
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
|
||||
) -> BaseMessageChunk:
|
||||
role = cast(str, _dict.get("role"))
|
||||
content = cast(str, _dict.get("content") or "")
|
||||
additional_kwargs: Dict = {}
|
||||
if _dict.get("function_call"):
|
||||
function_call = dict(_dict["function_call"])
|
||||
if "name" in function_call and function_call["name"] is None:
|
||||
function_call["name"] = ""
|
||||
additional_kwargs["function_call"] = function_call
|
||||
if _dict.get("tool_calls"):
|
||||
additional_kwargs["tool_calls"] = _dict["tool_calls"]
|
||||
|
||||
if role == "user" or default_class == HumanMessageChunk:
|
||||
return HumanMessageChunk(content=content)
|
||||
elif role == "assistant" or default_class == AIMessageChunk:
|
||||
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
|
||||
elif role == "system" or default_class == SystemMessageChunk:
|
||||
return SystemMessageChunk(content=content)
|
||||
elif role == "function" or default_class == FunctionMessageChunk:
|
||||
return FunctionMessageChunk(content=content, name=_dict["name"])
|
||||
elif role == "tool" or default_class == ToolMessageChunk:
|
||||
return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"])
|
||||
elif role or default_class == ChatMessageChunk:
|
||||
return ChatMessageChunk(content=content, role=role)
|
||||
else:
|
||||
return default_class(content=content) # type: ignore
|
||||
|
||||
|
||||
class _FunctionCall(TypedDict):
|
||||
name: str
|
||||
|
||||
|
||||
# This is basically a copy and replace for ChatOpenAI, except
|
||||
# - I needed to gut out tiktoken and some of the token estimation logic
|
||||
# (not sure how important it is)
|
||||
# - Environment variable is different
|
||||
# we should refactor into some OpenAI-like class in the future
|
||||
class ChatFireworks(BaseChatModel):
|
||||
"""`Fireworks` Chat large language models API.
|
||||
|
||||
To use, you should have the
|
||||
environment variable ``FIREWORKS_API_KEY`` set with your API key.
|
||||
|
||||
Any parameters that are valid to be passed to the fireworks.create call
|
||||
can be passed in, even if not explicitly saved on this class.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_fireworks.chat_models import ChatFireworks
|
||||
fireworks = ChatFireworks(
|
||||
model_name="accounts/fireworks/models/mixtral-8x7b-instruct")
|
||||
"""
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"fireworks_api_key": "FIREWORKS_API_KEY"}
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "chat_models", "fireworks"]
|
||||
|
||||
@property
|
||||
def lc_attributes(self) -> Dict[str, Any]:
|
||||
attributes: Dict[str, Any] = {}
|
||||
if self.fireworks_api_base:
|
||||
attributes["fireworks_api_base"] = self.fireworks_api_base
|
||||
|
||||
return attributes
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""Return whether this model can be serialized by Langchain."""
|
||||
return True
|
||||
|
||||
client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||
async_client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||
model_name: str = Field(
|
||||
default="accounts/fireworks/models/mixtral-8x7b-instruct", alias="model"
|
||||
)
|
||||
"""Model name to use."""
|
||||
temperature: float = 0.0
|
||||
"""What sampling temperature to use."""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
||||
fireworks_api_key: SecretStr = Field(default=None, alias="api_key")
|
||||
"""Automatically inferred from env var `FIREWORKS_API_KEY` if not provided."""
|
||||
fireworks_api_base: Optional[str] = Field(default=None, alias="base_url")
|
||||
"""Base URL path for API requests, leave blank if not using a proxy or service
|
||||
emulator."""
|
||||
request_timeout: Union[float, Tuple[float, float], Any, None] = Field(
|
||||
default=None, alias="timeout"
|
||||
)
|
||||
"""Timeout for requests to Fireworks completion API. Can be float, httpx.Timeout or
|
||||
None."""
|
||||
streaming: bool = False
|
||||
"""Whether to stream the results or not."""
|
||||
n: int = 1
|
||||
"""Number of chat completions to generate for each prompt."""
|
||||
max_tokens: Optional[int] = None
|
||||
"""Maximum number of tokens to generate."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
allow_population_by_field_name = True
|
||||
|
||||
@root_validator(pre=True)
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
extra = values.get("model_kwargs", {})
|
||||
values["model_kwargs"] = build_extra_kwargs(
|
||||
extra, values, all_required_field_names
|
||||
)
|
||||
return values
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
if values["n"] < 1:
|
||||
raise ValueError("n must be at least 1.")
|
||||
if values["n"] > 1 and values["streaming"]:
|
||||
raise ValueError("n must be 1 when streaming.")
|
||||
|
||||
values["fireworks_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "fireworks_api_key", "FIREWORKS_API_KEY")
|
||||
)
|
||||
values["fireworks_api_base"] = values["fireworks_api_base"] or os.getenv(
|
||||
"FIREWORKS_API_BASE"
|
||||
)
|
||||
client_params = {
|
||||
"api_key": (
|
||||
values["fireworks_api_key"].get_secret_value()
|
||||
if values["fireworks_api_key"]
|
||||
else None
|
||||
),
|
||||
"base_url": values["fireworks_api_base"],
|
||||
"timeout": values["request_timeout"],
|
||||
}
|
||||
|
||||
if not values.get("client"):
|
||||
values["client"] = Fireworks(**client_params).chat.completions
|
||||
if not values.get("async_client"):
|
||||
values["async_client"] = AsyncFireworks(**client_params).chat.completions
|
||||
return values
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling Fireworks API."""
|
||||
params = {
|
||||
"model": self.model_name,
|
||||
"stream": self.streaming,
|
||||
"n": self.n,
|
||||
"temperature": self.temperature,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
if self.max_tokens is not None:
|
||||
params["max_tokens"] = self.max_tokens
|
||||
return params
|
||||
|
||||
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
||||
overall_token_usage: dict = {}
|
||||
system_fingerprint = None
|
||||
for output in llm_outputs:
|
||||
if output is None:
|
||||
# Happens in streaming
|
||||
continue
|
||||
token_usage = output["token_usage"]
|
||||
if token_usage is not None:
|
||||
for k, v in token_usage.items():
|
||||
if k in overall_token_usage:
|
||||
overall_token_usage[k] += v
|
||||
else:
|
||||
overall_token_usage[k] = v
|
||||
if system_fingerprint is None:
|
||||
system_fingerprint = output.get("system_fingerprint")
|
||||
combined = {"token_usage": overall_token_usage, "model_name": self.model_name}
|
||||
if system_fingerprint:
|
||||
combined["system_fingerprint"] = system_fingerprint
|
||||
return combined
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
|
||||
default_chunk_class = AIMessageChunk
|
||||
for chunk in self.client.create(messages=message_dicts, **params):
|
||||
if not isinstance(chunk, dict):
|
||||
chunk = chunk.dict()
|
||||
if len(chunk["choices"]) == 0:
|
||||
continue
|
||||
choice = chunk["choices"][0]
|
||||
chunk = _convert_delta_to_message_chunk(
|
||||
choice["delta"], default_chunk_class
|
||||
)
|
||||
generation_info = {}
|
||||
if finish_reason := choice.get("finish_reason"):
|
||||
generation_info["finish_reason"] = finish_reason
|
||||
logprobs = choice.get("logprobs")
|
||||
if logprobs:
|
||||
generation_info["logprobs"] = logprobs
|
||||
default_chunk_class = chunk.__class__
|
||||
chunk = ChatGenerationChunk(
|
||||
message=chunk, generation_info=generation_info or None
|
||||
)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(chunk.text, chunk=chunk, logprobs=logprobs)
|
||||
yield chunk
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
stream: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
should_stream = stream if stream is not None else self.streaming
|
||||
if should_stream:
|
||||
stream_iter = self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return generate_from_stream(stream_iter)
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {
|
||||
**params,
|
||||
**({"stream": stream} if stream is not None else {}),
|
||||
**kwargs,
|
||||
}
|
||||
response = self.client.create(messages=message_dicts, **params)
|
||||
return self._create_chat_result(response)
|
||||
|
||||
def _create_message_dicts(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
||||
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
||||
params = self._default_params
|
||||
if stop is not None:
|
||||
if "stop" in params:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
params["stop"] = stop
|
||||
message_dicts = [_convert_message_to_dict(m) for m in messages]
|
||||
return message_dicts, params
|
||||
|
||||
def _create_chat_result(self, response: Union[dict, BaseModel]) -> ChatResult:
|
||||
generations = []
|
||||
if not isinstance(response, dict):
|
||||
response = response.dict()
|
||||
for res in response["choices"]:
|
||||
message = _convert_dict_to_message(res["message"])
|
||||
generation_info = dict(finish_reason=res.get("finish_reason"))
|
||||
if "logprobs" in res:
|
||||
generation_info["logprobs"] = res["logprobs"]
|
||||
gen = ChatGeneration(
|
||||
message=message,
|
||||
generation_info=generation_info,
|
||||
)
|
||||
generations.append(gen)
|
||||
token_usage = response.get("usage", {})
|
||||
llm_output = {
|
||||
"token_usage": token_usage,
|
||||
"model_name": self.model_name,
|
||||
"system_fingerprint": response.get("system_fingerprint", ""),
|
||||
}
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
|
||||
default_chunk_class = AIMessageChunk
|
||||
async for chunk in await self.async_client.create(
|
||||
messages=message_dicts, **params
|
||||
):
|
||||
if not isinstance(chunk, dict):
|
||||
chunk = chunk.dict()
|
||||
if len(chunk["choices"]) == 0:
|
||||
continue
|
||||
choice = chunk["choices"][0]
|
||||
chunk = _convert_delta_to_message_chunk(
|
||||
choice["delta"], default_chunk_class
|
||||
)
|
||||
generation_info = {}
|
||||
if finish_reason := choice.get("finish_reason"):
|
||||
generation_info["finish_reason"] = finish_reason
|
||||
logprobs = choice.get("logprobs")
|
||||
if logprobs:
|
||||
generation_info["logprobs"] = logprobs
|
||||
default_chunk_class = chunk.__class__
|
||||
chunk = ChatGenerationChunk(
|
||||
message=chunk, generation_info=generation_info or None
|
||||
)
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(
|
||||
token=chunk.text, chunk=chunk, logprobs=logprobs
|
||||
)
|
||||
yield chunk
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
stream: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
should_stream = stream if stream is not None else self.streaming
|
||||
if should_stream:
|
||||
stream_iter = self._astream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return await agenerate_from_stream(stream_iter)
|
||||
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {
|
||||
**params,
|
||||
**({"stream": stream} if stream is not None else {}),
|
||||
**kwargs,
|
||||
}
|
||||
response = await self.async_client.create(messages=message_dicts, **params)
|
||||
return self._create_chat_result(response)
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {"model_name": self.model_name, **self._default_params}
|
||||
|
||||
def _get_invocation_params(
|
||||
self, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
"""Get the parameters used to invoke the model."""
|
||||
return {
|
||||
"model": self.model_name,
|
||||
**super()._get_invocation_params(stop=stop),
|
||||
**self._default_params,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
return "fireworks-chat"
|
||||
|
||||
def bind_functions(
|
||||
self,
|
||||
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
||||
function_call: Optional[
|
||||
Union[_FunctionCall, str, Literal["auto", "none"]]
|
||||
] = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
"""Bind functions (and other objects) to this chat model.
|
||||
|
||||
Assumes model is compatible with Fireworks function-calling API.
|
||||
|
||||
NOTE: Using bind_tools is recommended instead, as the `functions` and
|
||||
`function_call` request parameters are officially marked as deprecated by
|
||||
Fireworks.
|
||||
|
||||
Args:
|
||||
functions: A list of function definitions to bind to this chat model.
|
||||
Can be a dictionary, pydantic model, or callable. Pydantic
|
||||
models and callables will be automatically converted to
|
||||
their schema dictionary representation.
|
||||
function_call: Which function to require the model to call.
|
||||
Must be the name of the single provided function or
|
||||
"auto" to automatically determine which function to call
|
||||
(if any).
|
||||
**kwargs: Any additional parameters to pass to the
|
||||
:class:`~langchain.runnable.Runnable` constructor.
|
||||
"""
|
||||
|
||||
formatted_functions = [convert_to_openai_function(fn) for fn in functions]
|
||||
if function_call is not None:
|
||||
function_call = (
|
||||
{"name": function_call}
|
||||
if isinstance(function_call, str)
|
||||
and function_call not in ("auto", "none")
|
||||
else function_call
|
||||
)
|
||||
if isinstance(function_call, dict) and len(formatted_functions) != 1:
|
||||
raise ValueError(
|
||||
"When specifying `function_call`, you must provide exactly one "
|
||||
"function."
|
||||
)
|
||||
if (
|
||||
isinstance(function_call, dict)
|
||||
and formatted_functions[0]["name"] != function_call["name"]
|
||||
):
|
||||
raise ValueError(
|
||||
f"Function call {function_call} was specified, but the only "
|
||||
f"provided function was {formatted_functions[0]['name']}."
|
||||
)
|
||||
kwargs = {**kwargs, "function_call": function_call}
|
||||
return super().bind(
|
||||
functions=formatted_functions,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
||||
*,
|
||||
tool_choice: Optional[Union[dict, str, Literal["auto", "none"]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
"""Bind tool-like objects to this chat model.
|
||||
|
||||
Assumes model is compatible with Fireworks tool-calling API.
|
||||
|
||||
Args:
|
||||
tools: A list of tool definitions to bind to this chat model.
|
||||
Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic
|
||||
models, callables, and BaseTools will be automatically converted to
|
||||
their schema dictionary representation.
|
||||
tool_choice: Which tool to require the model to call.
|
||||
Must be the name of the single provided function or
|
||||
"auto" to automatically determine which function to call
|
||||
(if any), or a dict of the form:
|
||||
{"type": "function", "function": {"name": <<tool_name>>}}.
|
||||
**kwargs: Any additional parameters to pass to the
|
||||
:class:`~langchain.runnable.Runnable` constructor.
|
||||
"""
|
||||
|
||||
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
|
||||
if tool_choice is not None:
|
||||
if isinstance(tool_choice, str) and (tool_choice not in ("auto", "none")):
|
||||
tool_choice = {"type": "function", "function": {"name": tool_choice}}
|
||||
if isinstance(tool_choice, dict) and (len(formatted_tools) != 1):
|
||||
raise ValueError(
|
||||
"When specifying `tool_choice`, you must provide exactly one "
|
||||
f"tool. Received {len(formatted_tools)} tools."
|
||||
)
|
||||
if isinstance(tool_choice, dict) and (
|
||||
formatted_tools[0]["function"]["name"]
|
||||
!= tool_choice["function"]["name"]
|
||||
):
|
||||
raise ValueError(
|
||||
f"Tool choice {tool_choice} was specified, but the only "
|
||||
f"provided tool was {formatted_tools[0]['function']['name']}."
|
||||
)
|
||||
kwargs["tool_choice"] = tool_choice
|
||||
return super().bind(tools=formatted_tools, **kwargs)
|
File diff suppressed because one or more lines are too long
52
libs/partners/fireworks/langchain_fireworks/embeddings.py
Normal file
52
libs/partners/fireworks/langchain_fireworks/embeddings.py
Normal file
@ -0,0 +1,52 @@
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
||||
from langchain_core.utils import convert_to_secret_str
|
||||
from openai import OpenAI # type: ignore
|
||||
|
||||
|
||||
class FireworksEmbeddings(BaseModel, Embeddings):
|
||||
"""FireworksEmbeddings embedding model.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_fireworks import FireworksEmbeddings
|
||||
|
||||
model = FireworksEmbeddings(
|
||||
model='nomic-ai/nomic-embed-text-v1.5'
|
||||
)
|
||||
"""
|
||||
|
||||
_client: OpenAI = Field(default=None)
|
||||
fireworks_api_key: SecretStr = convert_to_secret_str("")
|
||||
model: str = "nomic-ai/nomic-embed-text-v1.5"
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate environment variables."""
|
||||
fireworks_api_key = convert_to_secret_str(
|
||||
values.get("fireworks_api_key") or os.getenv("FIREWORKS_API_KEY") or ""
|
||||
)
|
||||
values["fireworks_api_key"] = fireworks_api_key
|
||||
|
||||
# note this sets it globally for module
|
||||
# there isn't currently a way to pass it into client
|
||||
api_key = fireworks_api_key.get_secret_value()
|
||||
values["_client"] = OpenAI(
|
||||
api_key=api_key, base_url="https://api.fireworks.ai/inference/v1"
|
||||
)
|
||||
return values
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed search docs."""
|
||||
return [
|
||||
i.embedding
|
||||
for i in self._client.embeddings.create(input=texts, model=self.model).data
|
||||
]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Embed query text."""
|
||||
return self.embed_documents([text])[0]
|
222
libs/partners/fireworks/langchain_fireworks/llms.py
Normal file
222
libs/partners/fireworks/langchain_fireworks/llms.py
Normal file
@ -0,0 +1,222 @@
|
||||
"""Wrapper around Fireworks AI's Completion API."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
from aiohttp import ClientSession
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.pydantic_v1 import Extra, Field, SecretStr, root_validator
|
||||
from langchain_core.utils import (
|
||||
convert_to_secret_str,
|
||||
get_from_dict_or_env,
|
||||
get_pydantic_field_names,
|
||||
)
|
||||
from langchain_core.utils.utils import build_extra_kwargs
|
||||
|
||||
from langchain_fireworks.version import __version__
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Fireworks(LLM):
|
||||
"""LLM models from `Fireworks`.
|
||||
|
||||
To use, you'll need an API key which you can find here:
|
||||
https://fireworks.ai This can be passed in as init param
|
||||
``fireworks_api_key`` or set as environment variable ``FIREWORKS_API_KEY``.
|
||||
|
||||
Fireworks AI API reference: https://readme.fireworks.ai/
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
response = fireworks.generate(["Tell me a joke."])
|
||||
"""
|
||||
|
||||
base_url: str = "https://api.fireworks.ai/inference/v1/completions"
|
||||
"""Base inference API URL."""
|
||||
fireworks_api_key: SecretStr = Field(default=None, alias="api_key")
|
||||
"""Fireworks AI API key. Get it here: https://fireworks.ai"""
|
||||
model: str
|
||||
"""Model name. Available models listed here:
|
||||
https://readme.fireworks.ai/
|
||||
"""
|
||||
temperature: Optional[float] = None
|
||||
"""Model temperature."""
|
||||
top_p: Optional[float] = None
|
||||
"""Used to dynamically adjust the number of choices for each predicted token based
|
||||
on the cumulative probabilities. A value of 1 will always yield the same
|
||||
output. A temperature less than 1 favors more correctness and is appropriate
|
||||
for question answering or summarization. A value greater than 1 introduces more
|
||||
randomness in the output.
|
||||
"""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
||||
top_k: Optional[int] = None
|
||||
"""Used to limit the number of choices for the next predicted word or token. It
|
||||
specifies the maximum number of tokens to consider at each step, based on their
|
||||
probability of occurrence. This technique helps to speed up the generation
|
||||
process and can improve the quality of the generated text by focusing on the
|
||||
most likely options.
|
||||
"""
|
||||
max_tokens: Optional[int] = None
|
||||
"""The maximum number of tokens to generate."""
|
||||
repetition_penalty: Optional[float] = None
|
||||
"""A number that controls the diversity of generated text by reducing the
|
||||
likelihood of repeated sequences. Higher values decrease repetition.
|
||||
"""
|
||||
logprobs: Optional[int] = None
|
||||
"""An integer that specifies how many top token log probabilities are included in
|
||||
the response for each token generation step.
|
||||
"""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
allow_population_by_field_name = True
|
||||
|
||||
@root_validator(pre=True)
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
extra = values.get("model_kwargs", {})
|
||||
values["model_kwargs"] = build_extra_kwargs(
|
||||
extra, values, all_required_field_names
|
||||
)
|
||||
return values
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key exists in environment."""
|
||||
values["fireworks_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "fireworks_api_key", "FIREWORKS_API_KEY")
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of model."""
|
||||
return "fireworks"
|
||||
|
||||
def _format_output(self, output: dict) -> str:
|
||||
return output["choices"][0]["text"]
|
||||
|
||||
@staticmethod
|
||||
def get_user_agent() -> str:
|
||||
return f"langchain-fireworks/{__version__}"
|
||||
|
||||
@property
|
||||
def default_params(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"model": self.model,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"top_k": self.top_k,
|
||||
"max_tokens": self.max_tokens,
|
||||
"repetition_penalty": self.repetition_penalty,
|
||||
}
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call out to Fireworks's text generation endpoint.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
|
||||
Returns:
|
||||
The string generated by the model..
|
||||
"""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.fireworks_api_key.get_secret_value()}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
stop_to_use = stop[0] if stop and len(stop) == 1 else stop
|
||||
payload: Dict[str, Any] = {
|
||||
**self.default_params,
|
||||
"prompt": prompt,
|
||||
"stop": stop_to_use,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
# filter None values to not pass them to the http payload
|
||||
payload = {k: v for k, v in payload.items() if v is not None}
|
||||
response = requests.post(url=self.base_url, json=payload, headers=headers)
|
||||
|
||||
if response.status_code >= 500:
|
||||
raise Exception(f"Fireworks Server: Error {response.status_code}")
|
||||
elif response.status_code >= 400:
|
||||
raise ValueError(f"Fireworks received an invalid payload: {response.text}")
|
||||
elif response.status_code != 200:
|
||||
raise Exception(
|
||||
f"Fireworks returned an unexpected response with status "
|
||||
f"{response.status_code}: {response.text}"
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
output = self._format_output(data)
|
||||
|
||||
return output
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call Fireworks model to get predictions based on the prompt.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
"""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.fireworks_api_key.get_secret_value()}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
stop_to_use = stop[0] if stop and len(stop) == 1 else stop
|
||||
payload: Dict[str, Any] = {
|
||||
**self.default_params,
|
||||
"prompt": prompt,
|
||||
"stop": stop_to_use,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
# filter None values to not pass them to the http payload
|
||||
payload = {k: v for k, v in payload.items() if v is not None}
|
||||
async with ClientSession() as session:
|
||||
async with session.post(
|
||||
self.base_url, json=payload, headers=headers
|
||||
) as response:
|
||||
if response.status >= 500:
|
||||
raise Exception(f"Fireworks Server: Error {response.status}")
|
||||
elif response.status >= 400:
|
||||
raise ValueError(
|
||||
f"Fireworks received an invalid payload: {response.text}"
|
||||
)
|
||||
elif response.status != 200:
|
||||
raise Exception(
|
||||
f"Fireworks returned an unexpected response with status "
|
||||
f"{response.status}: {response.text}"
|
||||
)
|
||||
|
||||
response_json = await response.json()
|
||||
|
||||
if response_json.get("status") != "finished":
|
||||
err_msg = response_json.get("error", "Undefined Error")
|
||||
raise Exception(err_msg)
|
||||
|
||||
output = self._format_output(response_json)
|
||||
return output
|
8
libs/partners/fireworks/langchain_fireworks/version.py
Normal file
8
libs/partners/fireworks/langchain_fireworks/version.py
Normal file
@ -0,0 +1,8 @@
|
||||
"""Main entrypoint into package."""
|
||||
from importlib import metadata
|
||||
|
||||
try:
|
||||
__version__ = metadata.version(__package__)
|
||||
except metadata.PackageNotFoundError:
|
||||
# Case where package metadata is not available.
|
||||
__version__ = ""
|
1541
libs/partners/fireworks/poetry.lock
generated
Normal file
1541
libs/partners/fireworks/poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
97
libs/partners/fireworks/pyproject.toml
Normal file
97
libs/partners/fireworks/pyproject.toml
Normal file
@ -0,0 +1,97 @@
|
||||
[tool.poetry]
|
||||
name = "langchain-fireworks"
|
||||
version = "0.0.1"
|
||||
description = "An integration package connecting Fireworks and LangChain"
|
||||
authors = []
|
||||
readme = "README.md"
|
||||
repository = "https://github.com/langchain-ai/langchain"
|
||||
license = "MIT"
|
||||
|
||||
[tool.poetry.urls]
|
||||
"Source Code" = "https://github.com/langchain-ai/langchain/tree/master/libs/partners/fireworks"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.8.1,<4.0"
|
||||
langchain-core = "^0.1.16"
|
||||
fireworks-ai = ">=0.12.0,<1"
|
||||
openai = "^1.10.0"
|
||||
requests = "^2"
|
||||
aiohttp = "^3.9.1"
|
||||
|
||||
[tool.poetry.group.test]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.test.dependencies]
|
||||
pytest = "^7.3.0"
|
||||
freezegun = "^1.2.2"
|
||||
pytest-mock = "^3.10.0"
|
||||
syrupy = "^4.0.2"
|
||||
pytest-watcher = "^0.3.4"
|
||||
pytest-asyncio = "^0.21.1"
|
||||
langchain-core = { path = "../../core", develop = true }
|
||||
|
||||
[tool.poetry.group.codespell]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.codespell.dependencies]
|
||||
codespell = "^2.2.0"
|
||||
|
||||
[tool.poetry.group.test_integration]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.test_integration.dependencies]
|
||||
|
||||
[tool.poetry.group.lint]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.lint.dependencies]
|
||||
ruff = "^0.2.2"
|
||||
|
||||
[tool.poetry.group.typing.dependencies]
|
||||
mypy = "^0.991"
|
||||
langchain-core = { path = "../../core", develop = true }
|
||||
types-requests = "^2"
|
||||
|
||||
[tool.poetry.group.dev]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
langchain-core = { path = "../../core", develop = true }
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
"E", # pycodestyle
|
||||
"F", # pyflakes
|
||||
"I", # isort
|
||||
"T201", # print
|
||||
]
|
||||
|
||||
[tool.mypy]
|
||||
disallow_untyped_defs = "True"
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = ["tests/*"]
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
# --strict-markers will raise errors on unknown marks.
|
||||
# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
|
||||
#
|
||||
# https://docs.pytest.org/en/7.1.x/reference/reference.html
|
||||
# --strict-config any warnings encountered while parsing the `pytest`
|
||||
# section of the configuration file raise errors.
|
||||
#
|
||||
# https://github.com/tophat/syrupy
|
||||
# --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite.
|
||||
addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5"
|
||||
# Registering custom markers.
|
||||
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
|
||||
markers = [
|
||||
"requires: mark tests as requiring a specific library",
|
||||
"asyncio: mark tests as requiring asyncio",
|
||||
"compile: mark placeholder test used to compile integration tests without running them",
|
||||
]
|
||||
asyncio_mode = "auto"
|
17
libs/partners/fireworks/scripts/check_imports.py
Normal file
17
libs/partners/fireworks/scripts/check_imports.py
Normal file
@ -0,0 +1,17 @@
|
||||
import sys
|
||||
import traceback
|
||||
from importlib.machinery import SourceFileLoader
|
||||
|
||||
if __name__ == "__main__":
|
||||
files = sys.argv[1:]
|
||||
has_failure = False
|
||||
for file in files:
|
||||
try:
|
||||
SourceFileLoader("x", file).load_module()
|
||||
except Exception:
|
||||
has_faillure = True
|
||||
print(file) # noqa: T201
|
||||
traceback.print_exc()
|
||||
print() # noqa: T201
|
||||
|
||||
sys.exit(1 if has_failure else 0)
|
27
libs/partners/fireworks/scripts/check_pydantic.sh
Executable file
27
libs/partners/fireworks/scripts/check_pydantic.sh
Executable file
@ -0,0 +1,27 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# This script searches for lines starting with "import pydantic" or "from pydantic"
|
||||
# in tracked files within a Git repository.
|
||||
#
|
||||
# Usage: ./scripts/check_pydantic.sh /path/to/repository
|
||||
|
||||
# Check if a path argument is provided
|
||||
if [ $# -ne 1 ]; then
|
||||
echo "Usage: $0 /path/to/repository"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
repository_path="$1"
|
||||
|
||||
# Search for lines matching the pattern within the specified repository
|
||||
result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic')
|
||||
|
||||
# Check if any matching lines were found
|
||||
if [ -n "$result" ]; then
|
||||
echo "ERROR: The following lines need to be updated:"
|
||||
echo "$result"
|
||||
echo "Please replace the code with an import from langchain_core.pydantic_v1."
|
||||
echo "For example, replace 'from pydantic import BaseModel'"
|
||||
echo "with 'from langchain_core.pydantic_v1 import BaseModel'"
|
||||
exit 1
|
||||
fi
|
17
libs/partners/fireworks/scripts/lint_imports.sh
Executable file
17
libs/partners/fireworks/scripts/lint_imports.sh
Executable file
@ -0,0 +1,17 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -eu
|
||||
|
||||
# Initialize a variable to keep track of errors
|
||||
errors=0
|
||||
|
||||
# make sure not importing from langchain or langchain_experimental
|
||||
git --no-pager grep '^from langchain\.' . && errors=$((errors+1))
|
||||
git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1))
|
||||
|
||||
# Decide on an exit status based on the errors
|
||||
if [ "$errors" -gt 0 ]; then
|
||||
exit 1
|
||||
else
|
||||
exit 0
|
||||
fi
|
0
libs/partners/fireworks/tests/__init__.py
Normal file
0
libs/partners/fireworks/tests/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.compile
|
||||
def test_placeholder() -> None:
|
||||
"""Used for compiling integration tests without running any real tests."""
|
||||
pass
|
@ -0,0 +1,20 @@
|
||||
"""Test Fireworks embeddings."""
|
||||
|
||||
from langchain_fireworks.embeddings import FireworksEmbeddings
|
||||
|
||||
|
||||
def test_langchain_fireworks_embedding_documents() -> None:
|
||||
"""Test Fireworks hosted embeddings."""
|
||||
documents = ["foo bar"]
|
||||
embedding = FireworksEmbeddings(model="nomic-ai/nomic-embed-text-v1.5")
|
||||
output = embedding.embed_documents(documents)
|
||||
assert len(output) == 1
|
||||
assert len(output[0]) > 0
|
||||
|
||||
|
||||
def test_langchain_fireworks_embedding_query() -> None:
|
||||
"""Test Fireworks hosted embeddings."""
|
||||
document = "foo bar"
|
||||
embedding = FireworksEmbeddings(model="nomic-ai/nomic-embed-text-v1.5")
|
||||
output = embedding.embed_query(document)
|
||||
assert len(output) > 0
|
41
libs/partners/fireworks/tests/integration_tests/test_llms.py
Normal file
41
libs/partners/fireworks/tests/integration_tests/test_llms.py
Normal file
@ -0,0 +1,41 @@
|
||||
"""Test Fireworks API wrapper.
|
||||
|
||||
In order to run this test, you need to have an Fireworks api key.
|
||||
You can get it by registering for free at https://api.fireworks.ai/.
|
||||
A test key can be found at https://api.fireworks.ai/settings/api-keys
|
||||
|
||||
You'll then need to set FIREWORKS_API_KEY environment variable to your api key.
|
||||
"""
|
||||
|
||||
import pytest as pytest
|
||||
|
||||
from langchain_fireworks import Fireworks
|
||||
|
||||
|
||||
def test_fireworks_call() -> None:
|
||||
"""Test simple call to fireworks."""
|
||||
llm = Fireworks(
|
||||
model="accounts/fireworks/models/mixtral-8x7b-instruct",
|
||||
temperature=0.2,
|
||||
max_tokens=250,
|
||||
)
|
||||
output = llm.invoke("Say foo:")
|
||||
|
||||
assert llm._llm_type == "fireworks"
|
||||
assert isinstance(output, str)
|
||||
assert len(output) > 0
|
||||
|
||||
|
||||
async def test_fireworks_acall() -> None:
|
||||
"""Test simple call to fireworks."""
|
||||
llm = Fireworks(
|
||||
model="accounts/fireworks/models/mixtral-8x7b-instruct",
|
||||
temperature=0.2,
|
||||
max_tokens=250,
|
||||
)
|
||||
output = await llm.agenerate(["Say foo:"], stop=["bar"])
|
||||
|
||||
assert llm._llm_type == "fireworks"
|
||||
output_text = output.generations[0][0].text
|
||||
assert isinstance(output_text, str)
|
||||
assert output_text.count("bar") <= 1
|
@ -0,0 +1,8 @@
|
||||
"""Test embedding model integration."""
|
||||
|
||||
from langchain_fireworks.embeddings import FireworksEmbeddings
|
||||
|
||||
|
||||
def test_initialization() -> None:
|
||||
"""Test embedding model initialization."""
|
||||
FireworksEmbeddings(model="nomic-ai/nomic-embed-text-v1.5")
|
12
libs/partners/fireworks/tests/unit_tests/test_imports.py
Normal file
12
libs/partners/fireworks/tests/unit_tests/test_imports.py
Normal file
@ -0,0 +1,12 @@
|
||||
from langchain_fireworks import __all__
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"__version__",
|
||||
"ChatFireworks",
|
||||
"Fireworks",
|
||||
"FireworksEmbeddings",
|
||||
]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert sorted(EXPECTED_ALL) == sorted(__all__)
|
62
libs/partners/fireworks/tests/unit_tests/test_llms.py
Normal file
62
libs/partners/fireworks/tests/unit_tests/test_llms.py
Normal file
@ -0,0 +1,62 @@
|
||||
"""Test Fireworks LLM"""
|
||||
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.pydantic_v1 import SecretStr
|
||||
from pytest import CaptureFixture, MonkeyPatch
|
||||
|
||||
from langchain_fireworks import Fireworks
|
||||
|
||||
|
||||
def test_fireworks_api_key_is_secret_string() -> None:
|
||||
"""Test that the API key is stored as a SecretStr."""
|
||||
llm = Fireworks(
|
||||
fireworks_api_key="secret-api-key",
|
||||
model="accounts/fireworks/models/mixtral-8x7b-instruct",
|
||||
temperature=0.2,
|
||||
max_tokens=250,
|
||||
)
|
||||
assert isinstance(llm.fireworks_api_key, SecretStr)
|
||||
|
||||
|
||||
def test_fireworks_api_key_masked_when_passed_from_env(
|
||||
monkeypatch: MonkeyPatch, capsys: CaptureFixture
|
||||
) -> None:
|
||||
"""Test that the API key is masked when passed from an environment variable."""
|
||||
monkeypatch.setenv("FIREWORKS_API_KEY", "secret-api-key")
|
||||
llm = Fireworks(
|
||||
model="accounts/fireworks/models/mixtral-8x7b-instruct",
|
||||
temperature=0.2,
|
||||
max_tokens=250,
|
||||
)
|
||||
print(llm.fireworks_api_key, end="") # noqa: T201
|
||||
captured = capsys.readouterr()
|
||||
|
||||
assert captured.out == "**********"
|
||||
|
||||
|
||||
def test_fireworks_api_key_masked_when_passed_via_constructor(
|
||||
capsys: CaptureFixture,
|
||||
) -> None:
|
||||
"""Test that the API key is masked when passed via the constructor."""
|
||||
llm = Fireworks(
|
||||
fireworks_api_key="secret-api-key",
|
||||
model="accounts/fireworks/models/mixtral-8x7b-instruct",
|
||||
temperature=0.2,
|
||||
max_tokens=250,
|
||||
)
|
||||
print(llm.fireworks_api_key, end="") # noqa: T201
|
||||
captured = capsys.readouterr()
|
||||
|
||||
assert captured.out == "**********"
|
||||
|
||||
|
||||
def test_fireworks_uses_actual_secret_value_from_secretstr() -> None:
|
||||
"""Test that the actual secret value is correctly retrieved."""
|
||||
llm = Fireworks(
|
||||
fireworks_api_key="secret-api-key",
|
||||
model="accounts/fireworks/models/mixtral-8x7b-instruct",
|
||||
temperature=0.2,
|
||||
max_tokens=250,
|
||||
)
|
||||
assert cast(SecretStr, llm.fireworks_api_key).get_secret_value() == "secret-api-key"
|
2583
poetry.lock
generated
2583
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -46,6 +46,7 @@ langchain-community = { path = "libs/community/", develop = true }
|
||||
langchain = { path = "libs/langchain/", develop = true }
|
||||
langchain-experimental = { path = "libs/experimental/", develop = true }
|
||||
langchain-openai = { path = "libs/partners/openai", develop = true }
|
||||
ipykernel = "^6.29.2"
|
||||
|
||||
[tool.poetry.group.test.dependencies]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user