Compare commits

...

4 Commits

Author SHA1 Message Date
Eugene Yurtsev
f8c20e18e7 x 2024-08-15 12:26:25 -04:00
Eugene Yurtsev
831708beb7 together[patch]: Update @root_validator for pydantic 2 compatibility (#25423)
This PR updates usage of @root_validator to be compatible with pydantic 2.
2024-08-15 11:27:42 -04:00
Eugene Yurtsev
a114255b82 ai21[patch]: Update @root_validators for pydantic2 migration (#25401)
Update @root_validators for pydantic 2 migration.
2024-08-15 11:26:44 -04:00
Eugene Yurtsev
6f68c8d6ab mistralai[patch]: Update root validator for compatibility with pydantic 2 (#25403) 2024-08-15 11:26:24 -04:00
16 changed files with 149 additions and 588 deletions

View File

@@ -55,12 +55,7 @@
"id": "433e8d2b-9519-4b49-b2c4-7ab65b046c94",
"metadata": {},
"outputs": [],
"source": [
"import getpass\n",
"import os\n",
"\n",
"os.environ[\"ANTHROPIC_API_KEY\"] = getpass.getpass(\"Enter your Anthropic API key: \")"
]
"source": ["import getpass\nimport os\n\nif \"ANTHROPIC_API_KEY\" not in os.environ:\n os.environ[\"ANTHROPIC_API_KEY\"] = getpass.getpass(\"Enter your Anthropic API key: \")"]
},
{
"cell_type": "markdown",
@@ -76,10 +71,7 @@
"id": "a15d341e-3e26-4ca3-830b-5aab30ed66de",
"metadata": {},
"outputs": [],
"source": [
"# os.environ[\"LANGSMITH_API_KEY\"] = getpass.getpass(\"Enter your LangSmith API key: \")\n",
"# os.environ[\"LANGSMITH_TRACING\"] = \"true\""
]
"source": ["# os.environ[\"LANGSMITH_API_KEY\"] = getpass.getpass(\"Enter your LangSmith API key: \")\n# os.environ[\"LANGSMITH_TRACING\"] = \"true\""]
},
{
"cell_type": "markdown",
@@ -97,9 +89,7 @@
"id": "652d6238-1f87-422a-b135-f5abbb8652fc",
"metadata": {},
"outputs": [],
"source": [
"%pip install -qU langchain-anthropic"
]
"source": ["%pip install -qU langchain-anthropic"]
},
{
"cell_type": "markdown",
@@ -117,18 +107,7 @@
"id": "cb09c344-1836-4e0c-acf8-11d13ac1dbae",
"metadata": {},
"outputs": [],
"source": [
"from langchain_anthropic import ChatAnthropic\n",
"\n",
"llm = ChatAnthropic(\n",
" model=\"claude-3-5-sonnet-20240620\",\n",
" temperature=0,\n",
" max_tokens=1024,\n",
" timeout=None,\n",
" max_retries=2,\n",
" # other params...\n",
")"
]
"source": ["from langchain_anthropic import ChatAnthropic\n\nllm = ChatAnthropic(\n model=\"claude-3-5-sonnet-20240620\",\n temperature=0,\n max_tokens=1024,\n timeout=None,\n max_retries=2,\n # other params...\n)"]
},
{
"cell_type": "markdown",
@@ -157,17 +136,7 @@
"output_type": "execute_result"
}
],
"source": [
"messages = [\n",
" (\n",
" \"system\",\n",
" \"You are a helpful assistant that translates English to French. Translate the user sentence.\",\n",
" ),\n",
" (\"human\", \"I love programming.\"),\n",
"]\n",
"ai_msg = llm.invoke(messages)\n",
"ai_msg"
]
"source": ["messages = [\n (\n \"system\",\n \"You are a helpful assistant that translates English to French. Translate the user sentence.\",\n ),\n (\"human\", \"I love programming.\"),\n]\nai_msg = llm.invoke(messages)\nai_msg"]
},
{
"cell_type": "code",
@@ -183,9 +152,7 @@
]
}
],
"source": [
"print(ai_msg.content)"
]
"source": ["print(ai_msg.content)"]
},
{
"cell_type": "markdown",
@@ -214,28 +181,7 @@
"output_type": "execute_result"
}
],
"source": [
"from langchain_core.prompts import ChatPromptTemplate\n",
"\n",
"prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\n",
" \"system\",\n",
" \"You are a helpful assistant that translates {input_language} to {output_language}.\",\n",
" ),\n",
" (\"human\", \"{input}\"),\n",
" ]\n",
")\n",
"\n",
"chain = prompt | llm\n",
"chain.invoke(\n",
" {\n",
" \"input_language\": \"English\",\n",
" \"output_language\": \"German\",\n",
" \"input\": \"I love programming.\",\n",
" }\n",
")"
]
"source": ["from langchain_core.prompts import ChatPromptTemplate\n\nprompt = ChatPromptTemplate.from_messages(\n [\n (\n \"system\",\n \"You are a helpful assistant that translates {input_language} to {output_language}.\",\n ),\n (\"human\", \"{input}\"),\n ]\n)\n\nchain = prompt | llm\nchain.invoke(\n {\n \"input_language\": \"English\",\n \"output_language\": \"German\",\n \"input\": \"I love programming.\",\n }\n)"]
},
{
"cell_type": "markdown",
@@ -273,20 +219,7 @@
"output_type": "execute_result"
}
],
"source": [
"from langchain_core.pydantic_v1 import BaseModel, Field\n",
"\n",
"\n",
"class GetWeather(BaseModel):\n",
" \"\"\"Get the current weather in a given location\"\"\"\n",
"\n",
" location: str = Field(..., description=\"The city and state, e.g. San Francisco, CA\")\n",
"\n",
"\n",
"llm_with_tools = llm.bind_tools([GetWeather])\n",
"ai_msg = llm_with_tools.invoke(\"Which city is hotter today: LA or NY?\")\n",
"ai_msg.content"
]
"source": ["from langchain_core.pydantic_v1 import BaseModel, Field\n\n\nclass GetWeather(BaseModel):\n \"\"\"Get the current weather in a given location\"\"\"\n\n location: str = Field(..., description=\"The city and state, e.g. San Francisco, CA\")\n\n\nllm_with_tools = llm.bind_tools([GetWeather])\nai_msg = llm_with_tools.invoke(\"Which city is hotter today: LA or NY?\")\nai_msg.content"]
},
{
"cell_type": "code",
@@ -310,9 +243,7 @@
"output_type": "execute_result"
}
],
"source": [
"ai_msg.tool_calls"
]
"source": ["ai_msg.tool_calls"]
},
{
"cell_type": "markdown",

View File

@@ -58,7 +58,10 @@
"import getpass\n",
"import os\n",
"\n",
"os.environ[\"AZURE_OPENAI_API_KEY\"] = getpass.getpass(\"Enter your AzureOpenAI API key: \")\n",
"if \"AZURE_OPENAI_API_KEY\" not in os.environ:\n",
" os.environ[\"AZURE_OPENAI_API_KEY\"] = getpass.getpass(\n",
" \"Enter your AzureOpenAI API key: \"\n",
" )\n",
"os.environ[\"AZURE_OPENAI_ENDPOINT\"] = \"https://YOUR-ENDPOINT.openai.azure.com/\""
]
},

View File

@@ -90,7 +90,10 @@
"import os\n",
"\n",
"os.environ[\"DATABRICKS_HOST\"] = \"https://your-workspace.cloud.databricks.com\"\n",
"os.environ[\"DATABRICKS_TOKEN\"] = getpass.getpass(\"Enter your Databricks access token: \")"
"if \"DATABRICKS_TOKEN\" not in os.environ:\n",
" os.environ[\"DATABRICKS_TOKEN\"] = getpass.getpass(\n",
" \"Enter your Databricks access token: \"\n",
" )"
]
},
{

View File

@@ -48,12 +48,7 @@
"id": "433e8d2b-9519-4b49-b2c4-7ab65b046c94",
"metadata": {},
"outputs": [],
"source": [
"import getpass\n",
"import os\n",
"\n",
"os.environ[\"FIREWORKS_API_KEY\"] = getpass.getpass(\"Enter your Fireworks API key: \")"
]
"source": ["import getpass\nimport os\n\nif \"FIREWORKS_API_KEY\" not in os.environ:\n os.environ[\"FIREWORKS_API_KEY\"] = getpass.getpass(\"Enter your Fireworks API key: \")"]
},
{
"cell_type": "markdown",
@@ -69,10 +64,7 @@
"id": "a15d341e-3e26-4ca3-830b-5aab30ed66de",
"metadata": {},
"outputs": [],
"source": [
"# os.environ[\"LANGSMITH_API_KEY\"] = getpass.getpass(\"Enter your LangSmith API key: \")\n",
"# os.environ[\"LANGSMITH_TRACING\"] = \"true\""
]
"source": ["# os.environ[\"LANGSMITH_API_KEY\"] = getpass.getpass(\"Enter your LangSmith API key: \")\n# os.environ[\"LANGSMITH_TRACING\"] = \"true\""]
},
{
"cell_type": "markdown",
@@ -90,9 +82,7 @@
"id": "652d6238-1f87-422a-b135-f5abbb8652fc",
"metadata": {},
"outputs": [],
"source": [
"%pip install -qU langchain-fireworks"
]
"source": ["%pip install -qU langchain-fireworks"]
},
{
"cell_type": "markdown",
@@ -112,18 +102,7 @@
"id": "cb09c344-1836-4e0c-acf8-11d13ac1dbae",
"metadata": {},
"outputs": [],
"source": [
"from langchain_fireworks import ChatFireworks\n",
"\n",
"llm = ChatFireworks(\n",
" model=\"accounts/fireworks/models/llama-v3-70b-instruct\",\n",
" temperature=0,\n",
" max_tokens=None,\n",
" timeout=None,\n",
" max_retries=2,\n",
" # other params...\n",
")"
]
"source": ["from langchain_fireworks import ChatFireworks\n\nllm = ChatFireworks(\n model=\"accounts/fireworks/models/llama-v3-70b-instruct\",\n temperature=0,\n max_tokens=None,\n timeout=None,\n max_retries=2,\n # other params...\n)"]
},
{
"cell_type": "markdown",
@@ -152,17 +131,7 @@
"output_type": "execute_result"
}
],
"source": [
"messages = [\n",
" (\n",
" \"system\",\n",
" \"You are a helpful assistant that translates English to French. Translate the user sentence.\",\n",
" ),\n",
" (\"human\", \"I love programming.\"),\n",
"]\n",
"ai_msg = llm.invoke(messages)\n",
"ai_msg"
]
"source": ["messages = [\n (\n \"system\",\n \"You are a helpful assistant that translates English to French. Translate the user sentence.\",\n ),\n (\"human\", \"I love programming.\"),\n]\nai_msg = llm.invoke(messages)\nai_msg"]
},
{
"cell_type": "code",
@@ -178,9 +147,7 @@
]
}
],
"source": [
"print(ai_msg.content)"
]
"source": ["print(ai_msg.content)"]
},
{
"cell_type": "markdown",
@@ -209,28 +176,7 @@
"output_type": "execute_result"
}
],
"source": [
"from langchain_core.prompts import ChatPromptTemplate\n",
"\n",
"prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\n",
" \"system\",\n",
" \"You are a helpful assistant that translates {input_language} to {output_language}.\",\n",
" ),\n",
" (\"human\", \"{input}\"),\n",
" ]\n",
")\n",
"\n",
"chain = prompt | llm\n",
"chain.invoke(\n",
" {\n",
" \"input_language\": \"English\",\n",
" \"output_language\": \"German\",\n",
" \"input\": \"I love programming.\",\n",
" }\n",
")"
]
"source": ["from langchain_core.prompts import ChatPromptTemplate\n\nprompt = ChatPromptTemplate.from_messages(\n [\n (\n \"system\",\n \"You are a helpful assistant that translates {input_language} to {output_language}.\",\n ),\n (\"human\", \"{input}\"),\n ]\n)\n\nchain = prompt | llm\nchain.invoke(\n {\n \"input_language\": \"English\",\n \"output_language\": \"German\",\n \"input\": \"I love programming.\",\n }\n)"]
},
{
"cell_type": "markdown",

View File

@@ -40,12 +40,7 @@
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import getpass\n",
"import os\n",
"\n",
"os.environ[\"FRIENDLI_TOKEN\"] = getpass.getpass(\"Friendi Personal Access Token: \")"
]
"source": ["import getpass\nimport os\n\nif \"FRIENDLI_TOKEN\" not in os.environ:\n os.environ[\"FRIENDLI_TOKEN\"] = getpass.getpass(\"Friendi Personal Access Token: \")"]
},
{
"cell_type": "markdown",
@@ -59,11 +54,7 @@
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.chat_models.friendli import ChatFriendli\n",
"\n",
"chat = ChatFriendli(model=\"llama-2-13b-chat\", max_tokens=100, temperature=0)"
]
"source": ["from langchain_community.chat_models.friendli import ChatFriendli\n\nchat = ChatFriendli(model=\"llama-2-13b-chat\", max_tokens=100, temperature=0)"]
},
{
"cell_type": "markdown",
@@ -97,16 +88,7 @@
"output_type": "execute_result"
}
],
"source": [
"from langchain_core.messages.human import HumanMessage\n",
"from langchain_core.messages.system import SystemMessage\n",
"\n",
"system_message = SystemMessage(content=\"Answer questions as short as you can.\")\n",
"human_message = HumanMessage(content=\"Tell me a joke.\")\n",
"messages = [system_message, human_message]\n",
"\n",
"chat.invoke(messages)"
]
"source": ["from langchain_core.messages.human import HumanMessage\nfrom langchain_core.messages.system import SystemMessage\n\nsystem_message = SystemMessage(content=\"Answer questions as short as you can.\")\nhuman_message = HumanMessage(content=\"Tell me a joke.\")\nmessages = [system_message, human_message]\n\nchat.invoke(messages)"]
},
{
"cell_type": "code",
@@ -125,9 +107,7 @@
"output_type": "execute_result"
}
],
"source": [
"chat.batch([messages, messages])"
]
"source": ["chat.batch([messages, messages])"]
},
{
"cell_type": "code",
@@ -145,9 +125,7 @@
"output_type": "execute_result"
}
],
"source": [
"chat.generate([messages, messages])"
]
"source": ["chat.generate([messages, messages])"]
},
{
"cell_type": "code",
@@ -166,10 +144,7 @@
]
}
],
"source": [
"for chunk in chat.stream(messages):\n",
" print(chunk.content, end=\"\", flush=True)"
]
"source": ["for chunk in chat.stream(messages):\n print(chunk.content, end=\"\", flush=True)"]
},
{
"cell_type": "markdown",
@@ -194,9 +169,7 @@
"output_type": "execute_result"
}
],
"source": [
"await chat.ainvoke(messages)"
]
"source": ["await chat.ainvoke(messages)"]
},
{
"cell_type": "code",
@@ -215,9 +188,7 @@
"output_type": "execute_result"
}
],
"source": [
"await chat.abatch([messages, messages])"
]
"source": ["await chat.abatch([messages, messages])"]
},
{
"cell_type": "code",
@@ -235,9 +206,7 @@
"output_type": "execute_result"
}
],
"source": [
"await chat.agenerate([messages, messages])"
]
"source": ["await chat.agenerate([messages, messages])"]
},
{
"cell_type": "code",
@@ -256,10 +225,7 @@
]
}
],
"source": [
"async for chunk in chat.astream(messages):\n",
" print(chunk.content, end=\"\", flush=True)"
]
"source": ["async for chunk in chat.astream(messages):\n print(chunk.content, end=\"\", flush=True)"]
}
],
"metadata": {

View File

@@ -56,12 +56,7 @@
"id": "433e8d2b-9519-4b49-b2c4-7ab65b046c94",
"metadata": {},
"outputs": [],
"source": [
"import getpass\n",
"import os\n",
"\n",
"os.environ[\"GOOGLE_API_KEY\"] = getpass.getpass(\"Enter your Google AI API key: \")"
]
"source": ["import getpass\nimport os\n\nif \"GOOGLE_API_KEY\" not in os.environ:\n os.environ[\"GOOGLE_API_KEY\"] = getpass.getpass(\"Enter your Google AI API key: \")"]
},
{
"cell_type": "markdown",
@@ -77,10 +72,7 @@
"id": "a15d341e-3e26-4ca3-830b-5aab30ed66de",
"metadata": {},
"outputs": [],
"source": [
"# os.environ[\"LANGSMITH_API_KEY\"] = getpass.getpass(\"Enter your LangSmith API key: \")\n",
"# os.environ[\"LANGSMITH_TRACING\"] = \"true\""
]
"source": ["# os.environ[\"LANGSMITH_API_KEY\"] = getpass.getpass(\"Enter your LangSmith API key: \")\n# os.environ[\"LANGSMITH_TRACING\"] = \"true\""]
},
{
"cell_type": "markdown",
@@ -98,9 +90,7 @@
"id": "652d6238-1f87-422a-b135-f5abbb8652fc",
"metadata": {},
"outputs": [],
"source": [
"%pip install -qU langchain-google-genai"
]
"source": ["%pip install -qU langchain-google-genai"]
},
{
"cell_type": "markdown",
@@ -118,18 +108,7 @@
"id": "cb09c344-1836-4e0c-acf8-11d13ac1dbae",
"metadata": {},
"outputs": [],
"source": [
"from langchain_google_genai import ChatGoogleGenerativeAI\n",
"\n",
"llm = ChatGoogleGenerativeAI(\n",
" model=\"gemini-1.5-pro\",\n",
" temperature=0,\n",
" max_tokens=None,\n",
" timeout=None,\n",
" max_retries=2,\n",
" # other params...\n",
")"
]
"source": ["from langchain_google_genai import ChatGoogleGenerativeAI\n\nllm = ChatGoogleGenerativeAI(\n model=\"gemini-1.5-pro\",\n temperature=0,\n max_tokens=None,\n timeout=None,\n max_retries=2,\n # other params...\n)"]
},
{
"cell_type": "markdown",
@@ -158,17 +137,7 @@
"output_type": "execute_result"
}
],
"source": [
"messages = [\n",
" (\n",
" \"system\",\n",
" \"You are a helpful assistant that translates English to French. Translate the user sentence.\",\n",
" ),\n",
" (\"human\", \"I love programming.\"),\n",
"]\n",
"ai_msg = llm.invoke(messages)\n",
"ai_msg"
]
"source": ["messages = [\n (\n \"system\",\n \"You are a helpful assistant that translates English to French. Translate the user sentence.\",\n ),\n (\"human\", \"I love programming.\"),\n]\nai_msg = llm.invoke(messages)\nai_msg"]
},
{
"cell_type": "code",
@@ -185,9 +154,7 @@
]
}
],
"source": [
"print(ai_msg.content)"
]
"source": ["print(ai_msg.content)"]
},
{
"cell_type": "markdown",
@@ -216,28 +183,7 @@
"output_type": "execute_result"
}
],
"source": [
"from langchain_core.prompts import ChatPromptTemplate\n",
"\n",
"prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\n",
" \"system\",\n",
" \"You are a helpful assistant that translates {input_language} to {output_language}.\",\n",
" ),\n",
" (\"human\", \"{input}\"),\n",
" ]\n",
")\n",
"\n",
"chain = prompt | llm\n",
"chain.invoke(\n",
" {\n",
" \"input_language\": \"English\",\n",
" \"output_language\": \"German\",\n",
" \"input\": \"I love programming.\",\n",
" }\n",
")"
]
"source": ["from langchain_core.prompts import ChatPromptTemplate\n\nprompt = ChatPromptTemplate.from_messages(\n [\n (\n \"system\",\n \"You are a helpful assistant that translates {input_language} to {output_language}.\",\n ),\n (\"human\", \"{input}\"),\n ]\n)\n\nchain = prompt | llm\nchain.invoke(\n {\n \"input_language\": \"English\",\n \"output_language\": \"German\",\n \"input\": \"I love programming.\",\n }\n)"]
},
{
"cell_type": "markdown",
@@ -255,20 +201,7 @@
"id": "238b2f96-e573-4fac-bbf2-7e52ad926833",
"metadata": {},
"outputs": [],
"source": [
"from langchain_google_genai import (\n",
" ChatGoogleGenerativeAI,\n",
" HarmBlockThreshold,\n",
" HarmCategory,\n",
")\n",
"\n",
"llm = ChatGoogleGenerativeAI(\n",
" model=\"gemini-1.5-pro\",\n",
" safety_settings={\n",
" HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,\n",
" },\n",
")"
]
"source": ["from langchain_google_genai import (\n ChatGoogleGenerativeAI,\n HarmBlockThreshold,\n HarmCategory,\n)\n\nllm = ChatGoogleGenerativeAI(\n model=\"gemini-1.5-pro\",\n safety_settings={\n HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,\n },\n)"]
},
{
"cell_type": "markdown",

View File

@@ -46,12 +46,7 @@
"id": "433e8d2b-9519-4b49-b2c4-7ab65b046c94",
"metadata": {},
"outputs": [],
"source": [
"import getpass\n",
"import os\n",
"\n",
"os.environ[\"GROQ_API_KEY\"] = getpass.getpass(\"Enter your Groq API key: \")"
]
"source": ["import getpass\nimport os\n\nif \"GROQ_API_KEY\" not in os.environ:\n os.environ[\"GROQ_API_KEY\"] = getpass.getpass(\"Enter your Groq API key: \")"]
},
{
"cell_type": "markdown",
@@ -67,10 +62,7 @@
"id": "a15d341e-3e26-4ca3-830b-5aab30ed66de",
"metadata": {},
"outputs": [],
"source": [
"# os.environ[\"LANGSMITH_API_KEY\"] = getpass.getpass(\"Enter your LangSmith API key: \")\n",
"# os.environ[\"LANGSMITH_TRACING\"] = \"true\""
]
"source": ["# os.environ[\"LANGSMITH_API_KEY\"] = getpass.getpass(\"Enter your LangSmith API key: \")\n# os.environ[\"LANGSMITH_TRACING\"] = \"true\""]
},
{
"cell_type": "markdown",
@@ -99,9 +91,7 @@
]
}
],
"source": [
"%pip install -qU langchain-groq"
]
"source": ["%pip install -qU langchain-groq"]
},
{
"cell_type": "markdown",
@@ -119,18 +109,7 @@
"id": "cb09c344-1836-4e0c-acf8-11d13ac1dbae",
"metadata": {},
"outputs": [],
"source": [
"from langchain_groq import ChatGroq\n",
"\n",
"llm = ChatGroq(\n",
" model=\"mixtral-8x7b-32768\",\n",
" temperature=0,\n",
" max_tokens=None,\n",
" timeout=None,\n",
" max_retries=2,\n",
" # other params...\n",
")"
]
"source": ["from langchain_groq import ChatGroq\n\nllm = ChatGroq(\n model=\"mixtral-8x7b-32768\",\n temperature=0,\n max_tokens=None,\n timeout=None,\n max_retries=2,\n # other params...\n)"]
},
{
"cell_type": "markdown",
@@ -159,17 +138,7 @@
"output_type": "execute_result"
}
],
"source": [
"messages = [\n",
" (\n",
" \"system\",\n",
" \"You are a helpful assistant that translates English to French. Translate the user sentence.\",\n",
" ),\n",
" (\"human\", \"I love programming.\"),\n",
"]\n",
"ai_msg = llm.invoke(messages)\n",
"ai_msg"
]
"source": ["messages = [\n (\n \"system\",\n \"You are a helpful assistant that translates English to French. Translate the user sentence.\",\n ),\n (\"human\", \"I love programming.\"),\n]\nai_msg = llm.invoke(messages)\nai_msg"]
},
{
"cell_type": "code",
@@ -187,9 +156,7 @@
]
}
],
"source": [
"print(ai_msg.content)"
]
"source": ["print(ai_msg.content)"]
},
{
"cell_type": "markdown",
@@ -218,28 +185,7 @@
"output_type": "execute_result"
}
],
"source": [
"from langchain_core.prompts import ChatPromptTemplate\n",
"\n",
"prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\n",
" \"system\",\n",
" \"You are a helpful assistant that translates {input_language} to {output_language}.\",\n",
" ),\n",
" (\"human\", \"{input}\"),\n",
" ]\n",
")\n",
"\n",
"chain = prompt | llm\n",
"chain.invoke(\n",
" {\n",
" \"input_language\": \"English\",\n",
" \"output_language\": \"German\",\n",
" \"input\": \"I love programming.\",\n",
" }\n",
")"
]
"source": ["from langchain_core.prompts import ChatPromptTemplate\n\nprompt = ChatPromptTemplate.from_messages(\n [\n (\n \"system\",\n \"You are a helpful assistant that translates {input_language} to {output_language}.\",\n ),\n (\"human\", \"{input}\"),\n ]\n)\n\nchain = prompt | llm\nchain.invoke(\n {\n \"input_language\": \"English\",\n \"output_language\": \"German\",\n \"input\": \"I love programming.\",\n }\n)"]
},
{
"cell_type": "markdown",

View File

@@ -36,13 +36,7 @@
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import getpass\n",
"import os\n",
"\n",
"if not os.getenv(\"HUGGINGFACEHUB_API_TOKEN\"):\n",
" os.environ[\"HUGGINGFACEHUB_API_TOKEN\"] = getpass.getpass(\"Enter your token: \")"
]
"source": ["import getpass\nimport os\n\nif not os.getenv(\"HUGGINGFACEHUB_API_TOKEN\"):\n os.environ[\"HUGGINGFACEHUB_API_TOKEN\"] = getpass.getpass(\"Enter your token: \")"]
},
{
"cell_type": "markdown",
@@ -73,14 +67,7 @@
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import getpass\n",
"import os\n",
"\n",
"os.environ[\"HUGGINGFACEHUB_API_TOKEN\"] = getpass.getpass(\n",
" \"Enter your Hugging Face API key: \"\n",
")"
]
"source": ["import getpass\nimport os\n\nif \"HUGGINGFACEHUB_API_TOKEN\" not in os.environ:\n os.environ[\"HUGGINGFACEHUB_API_TOKEN\"] = getpass.getpass(\n \"Enter your Hugging Face API key: \"\n )"]
},
{
"cell_type": "code",
@@ -98,9 +85,7 @@
]
}
],
"source": [
"%pip install --upgrade --quiet langchain-huggingface text-generation transformers google-search-results numexpr langchainhub sentencepiece jinja2 bitsandbytes accelerate"
]
"source": ["%pip install --upgrade --quiet langchain-huggingface text-generation transformers google-search-results numexpr langchainhub sentencepiece jinja2 bitsandbytes accelerate"]
},
{
"cell_type": "markdown",
@@ -134,19 +119,7 @@
]
}
],
"source": [
"from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint\n",
"\n",
"llm = HuggingFaceEndpoint(\n",
" repo_id=\"HuggingFaceH4/zephyr-7b-beta\",\n",
" task=\"text-generation\",\n",
" max_new_tokens=512,\n",
" do_sample=False,\n",
" repetition_penalty=1.03,\n",
")\n",
"\n",
"chat_model = ChatHuggingFace(llm=llm)"
]
"source": ["from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint\n\nllm = HuggingFaceEndpoint(\n repo_id=\"HuggingFaceH4/zephyr-7b-beta\",\n task=\"text-generation\",\n max_new_tokens=512,\n do_sample=False,\n repetition_penalty=1.03,\n)\n\nchat_model = ChatHuggingFace(llm=llm)"]
},
{
"cell_type": "markdown",
@@ -343,21 +316,7 @@
"output_type": "display_data"
}
],
"source": [
"from langchain_huggingface import ChatHuggingFace, HuggingFacePipeline\n",
"\n",
"llm = HuggingFacePipeline.from_model_id(\n",
" model_id=\"HuggingFaceH4/zephyr-7b-beta\",\n",
" task=\"text-generation\",\n",
" pipeline_kwargs=dict(\n",
" max_new_tokens=512,\n",
" do_sample=False,\n",
" repetition_penalty=1.03,\n",
" ),\n",
")\n",
"\n",
"chat_model = ChatHuggingFace(llm=llm)"
]
"source": ["from langchain_huggingface import ChatHuggingFace, HuggingFacePipeline\n\nllm = HuggingFacePipeline.from_model_id(\n model_id=\"HuggingFaceH4/zephyr-7b-beta\",\n task=\"text-generation\",\n pipeline_kwargs=dict(\n max_new_tokens=512,\n do_sample=False,\n repetition_penalty=1.03,\n ),\n)\n\nchat_model = ChatHuggingFace(llm=llm)"]
},
{
"cell_type": "markdown",
@@ -373,16 +332,7 @@
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"from transformers import BitsAndBytesConfig\n",
"\n",
"quantization_config = BitsAndBytesConfig(\n",
" load_in_4bit=True,\n",
" bnb_4bit_quant_type=\"nf4\",\n",
" bnb_4bit_compute_dtype=\"float16\",\n",
" bnb_4bit_use_double_quant=True,\n",
")"
]
"source": ["from transformers import BitsAndBytesConfig\n\nquantization_config = BitsAndBytesConfig(\n load_in_4bit=True,\n bnb_4bit_quant_type=\"nf4\",\n bnb_4bit_compute_dtype=\"float16\",\n bnb_4bit_use_double_quant=True,\n)"]
},
{
"cell_type": "markdown",
@@ -396,20 +346,7 @@
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"llm = HuggingFacePipeline.from_model_id(\n",
" model_id=\"HuggingFaceH4/zephyr-7b-beta\",\n",
" task=\"text-generation\",\n",
" pipeline_kwargs=dict(\n",
" max_new_tokens=512,\n",
" do_sample=False,\n",
" repetition_penalty=1.03,\n",
" ),\n",
" model_kwargs={\"quantization_config\": quantization_config},\n",
")\n",
"\n",
"chat_model = ChatHuggingFace(llm=llm)"
]
"source": ["llm = HuggingFacePipeline.from_model_id(\n model_id=\"HuggingFaceH4/zephyr-7b-beta\",\n task=\"text-generation\",\n pipeline_kwargs=dict(\n max_new_tokens=512,\n do_sample=False,\n repetition_penalty=1.03,\n ),\n model_kwargs={\"quantization_config\": quantization_config},\n)\n\nchat_model = ChatHuggingFace(llm=llm)"]
},
{
"cell_type": "markdown",
@@ -423,21 +360,7 @@
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"from langchain_core.messages import (\n",
" HumanMessage,\n",
" SystemMessage,\n",
")\n",
"\n",
"messages = [\n",
" SystemMessage(content=\"You're a helpful assistant\"),\n",
" HumanMessage(\n",
" content=\"What happens when an unstoppable force meets an immovable object?\"\n",
" ),\n",
"]\n",
"\n",
"ai_msg = chat_model.invoke(messages)"
]
"source": ["from langchain_core.messages import (\n HumanMessage,\n SystemMessage,\n)\n\nmessages = [\n SystemMessage(content=\"You're a helpful assistant\"),\n HumanMessage(\n content=\"What happens when an unstoppable force meets an immovable object?\"\n ),\n]\n\nai_msg = chat_model.invoke(messages)"]
},
{
"cell_type": "code",
@@ -454,9 +377,7 @@
]
}
],
"source": [
"print(ai_msg.content)"
]
"source": ["print(ai_msg.content)"]
},
{
"cell_type": "markdown",

View File

@@ -48,12 +48,7 @@
"id": "2461605e",
"metadata": {},
"outputs": [],
"source": [
"import getpass\n",
"import os\n",
"\n",
"os.environ[\"MISTRAL_API_KEY\"] = getpass.getpass(\"Enter your Mistral API key: \")"
]
"source": ["import getpass\nimport os\n\nif \"MISTRAL_API_KEY\" not in os.environ:\n os.environ[\"MISTRAL_API_KEY\"] = getpass.getpass(\"Enter your Mistral API key: \")"]
},
{
"cell_type": "markdown",
@@ -69,10 +64,7 @@
"id": "007209d5",
"metadata": {},
"outputs": [],
"source": [
"# os.environ[\"LANGSMITH_API_KEY\"] = getpass.getpass(\"Enter your LangSmith API key: \")\n",
"# os.environ[\"LANGSMITH_TRACING\"] = \"true\""
]
"source": ["# os.environ[\"LANGSMITH_API_KEY\"] = getpass.getpass(\"Enter your LangSmith API key: \")\n# os.environ[\"LANGSMITH_TRACING\"] = \"true\""]
},
{
"cell_type": "markdown",
@@ -90,9 +82,7 @@
"id": "1ab11a65",
"metadata": {},
"outputs": [],
"source": [
"%pip install -qU langchain_mistralai"
]
"source": ["%pip install -qU langchain_mistralai"]
},
{
"cell_type": "markdown",
@@ -110,16 +100,7 @@
"id": "e6c38580",
"metadata": {},
"outputs": [],
"source": [
"from langchain_mistralai import ChatMistralAI\n",
"\n",
"llm = ChatMistralAI(\n",
" model=\"mistral-large-latest\",\n",
" temperature=0,\n",
" max_retries=2,\n",
" # other params...\n",
")"
]
"source": ["from langchain_mistralai import ChatMistralAI\n\nllm = ChatMistralAI(\n model=\"mistral-large-latest\",\n temperature=0,\n max_retries=2,\n # other params...\n)"]
},
{
"cell_type": "markdown",
@@ -146,17 +127,7 @@
"output_type": "execute_result"
}
],
"source": [
"messages = [\n",
" (\n",
" \"system\",\n",
" \"You are a helpful assistant that translates English to French. Translate the user sentence.\",\n",
" ),\n",
" (\"human\", \"I love programming.\"),\n",
"]\n",
"ai_msg = llm.invoke(messages)\n",
"ai_msg"
]
"source": ["messages = [\n (\n \"system\",\n \"You are a helpful assistant that translates English to French. Translate the user sentence.\",\n ),\n (\"human\", \"I love programming.\"),\n]\nai_msg = llm.invoke(messages)\nai_msg"]
},
{
"cell_type": "code",
@@ -172,9 +143,7 @@
]
}
],
"source": [
"print(ai_msg.content)"
]
"source": ["print(ai_msg.content)"]
},
{
"cell_type": "markdown",
@@ -203,28 +172,7 @@
"output_type": "execute_result"
}
],
"source": [
"from langchain_core.prompts import ChatPromptTemplate\n",
"\n",
"prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\n",
" \"system\",\n",
" \"You are a helpful assistant that translates {input_language} to {output_language}.\",\n",
" ),\n",
" (\"human\", \"{input}\"),\n",
" ]\n",
")\n",
"\n",
"chain = prompt | llm\n",
"chain.invoke(\n",
" {\n",
" \"input_language\": \"English\",\n",
" \"output_language\": \"German\",\n",
" \"input\": \"I love programming.\",\n",
" }\n",
")"
]
"source": ["from langchain_core.prompts import ChatPromptTemplate\n\nprompt = ChatPromptTemplate.from_messages(\n [\n (\n \"system\",\n \"You are a helpful assistant that translates {input_language} to {output_language}.\",\n ),\n (\"human\", \"{input}\"),\n ]\n)\n\nchain = prompt | llm\nchain.invoke(\n {\n \"input_language\": \"English\",\n \"output_language\": \"German\",\n \"input\": \"I love programming.\",\n }\n)"]
},
{
"cell_type": "markdown",

View File

@@ -37,12 +37,7 @@
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import getpass\n",
"import os\n",
"\n",
"os.environ[\"YI_API_KEY\"] = getpass.getpass(\"Enter your Yi API key: \")"
]
"source": ["import getpass\nimport os\n\nif \"YI_API_KEY\" not in os.environ:\n os.environ[\"YI_API_KEY\"] = getpass.getpass(\"Enter your Yi API key: \")"]
},
{
"cell_type": "markdown",
@@ -56,10 +51,7 @@
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# os.environ[\"LANGSMITH_API_KEY\"] = getpass.getpass(\"Enter your LangSmith API key: \")\n",
"# os.environ[\"LANGSMITH_TRACING\"] = \"true\""
]
"source": ["# os.environ[\"LANGSMITH_API_KEY\"] = getpass.getpass(\"Enter your LangSmith API key: \")\n# os.environ[\"LANGSMITH_TRACING\"] = \"true\""]
},
{
"cell_type": "markdown",
@@ -75,9 +67,7 @@
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%pip install -qU langchain_community"
]
"source": ["%pip install -qU langchain_community"]
},
{
"cell_type": "markdown",
@@ -95,17 +85,7 @@
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.chat_models.yi import ChatYi\n",
"\n",
"llm = ChatYi(\n",
" model=\"yi-large\",\n",
" temperature=0,\n",
" timeout=60,\n",
" yi_api_base=\"https://api.01.ai/v1/chat/completions\",\n",
" # other params...\n",
")"
]
"source": ["from langchain_community.chat_models.yi import ChatYi\n\nllm = ChatYi(\n model=\"yi-large\",\n temperature=0,\n timeout=60,\n yi_api_base=\"https://api.01.ai/v1/chat/completions\",\n # other params...\n)"]
},
{
"cell_type": "markdown",
@@ -130,19 +110,7 @@
"output_type": "execute_result"
}
],
"source": [
"from langchain_core.messages import HumanMessage, SystemMessage\n",
"\n",
"messages = [\n",
" SystemMessage(content=\"You are an AI assistant specializing in technology trends.\"),\n",
" HumanMessage(\n",
" content=\"What are the potential applications of large language models in healthcare?\"\n",
" ),\n",
"]\n",
"\n",
"ai_msg = llm.invoke(messages)\n",
"ai_msg"
]
"source": ["from langchain_core.messages import HumanMessage, SystemMessage\n\nmessages = [\n SystemMessage(content=\"You are an AI assistant specializing in technology trends.\"),\n HumanMessage(\n content=\"What are the potential applications of large language models in healthcare?\"\n ),\n]\n\nai_msg = llm.invoke(messages)\nai_msg"]
},
{
"cell_type": "markdown",
@@ -169,28 +137,7 @@
"output_type": "execute_result"
}
],
"source": [
"from langchain_core.prompts import ChatPromptTemplate\n",
"\n",
"prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\n",
" \"system\",\n",
" \"You are a helpful assistant that translates {input_language} to {output_language}.\",\n",
" ),\n",
" (\"human\", \"{input}\"),\n",
" ]\n",
")\n",
"\n",
"chain = prompt | llm\n",
"chain.invoke(\n",
" {\n",
" \"input_language\": \"English\",\n",
" \"output_language\": \"German\",\n",
" \"input\": \"I love programming.\",\n",
" }\n",
")"
]
"source": ["from langchain_core.prompts import ChatPromptTemplate\n\nprompt = ChatPromptTemplate.from_messages(\n [\n (\n \"system\",\n \"You are a helpful assistant that translates {input_language} to {output_language}.\",\n ),\n (\"human\", \"{input}\"),\n ]\n)\n\nchain = prompt | llm\nchain.invoke(\n {\n \"input_language\": \"English\",\n \"output_language\": \"German\",\n \"input\": \"I love programming.\",\n }\n)"]
},
{
"cell_type": "markdown",

View File

@@ -28,7 +28,7 @@ class AI21Base(BaseModel):
num_retries: Optional[int] = None
"""Maximum number of retries for API requests before giving up."""
@root_validator()
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
api_key = convert_to_secret_str(
values.get("api_key") or os.getenv("AI21_API_KEY") or ""
@@ -46,7 +46,13 @@ class AI21Base(BaseModel):
os.getenv("AI21_TIMEOUT_SEC", _DEFAULT_TIMEOUT_SEC)
)
values["timeout_sec"] = timeout_sec
return values
@root_validator(pre=False, skip_on_failure=True)
def post_init(cls, values: Dict) -> Dict:
api_key = values["api_key"]
api_host = values["api_host"]
timeout_sec = values["timeout_sec"]
if values.get("client") is None:
values["client"] = AI21Client(
api_key=api_key.get_secret_value(),

View File

@@ -73,7 +73,7 @@ from langchain_core.pydantic_v1 import (
)
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_core.utils import secret_from_env
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.pydantic import is_basemodel_subclass
@@ -360,7 +360,10 @@ class ChatMistralAI(BaseChatModel):
client: httpx.Client = Field(default=None) #: :meta private:
async_client: httpx.AsyncClient = Field(default=None) #: :meta private:
mistral_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
mistral_api_key: Optional[SecretStr] = Field(
alias="api_key",
default_factory=secret_from_env("MISTRAL_API_KEY", default=None),
)
endpoint: str = "https://api.mistral.ai/v1"
max_retries: int = 5
timeout: int = 120
@@ -465,15 +468,9 @@ class ChatMistralAI(BaseChatModel):
combined = {"token_usage": overall_token_usage, "model_name": self.model}
return combined
@root_validator()
@root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate api key, python package exists, temperature, and top_p."""
values["mistral_api_key"] = convert_to_secret_str(
get_from_dict_or_env(
values, "mistral_api_key", "MISTRAL_API_KEY", default=""
)
)
api_key_str = values["mistral_api_key"].get_secret_value()
# todo: handle retries
if not values.get("client"):

View File

@@ -1,6 +1,5 @@
"""Wrapper around Together AI's Chat Completions API."""
import os
from typing import (
Any,
Dict,
@@ -12,8 +11,8 @@ import openai
from langchain_core.language_models.chat_models import LangSmithParams
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import (
convert_to_secret_str,
get_from_dict_or_env,
from_env,
secret_from_env,
)
from langchain_openai.chat_models.base import BaseChatOpenAI
@@ -311,13 +310,27 @@ class ChatTogether(BaseChatOpenAI):
model_name: str = Field(default="meta-llama/Llama-3-8b-chat-hf", alias="model")
"""Model name to use."""
together_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
"""Automatically inferred from env are `TOGETHER_API_KEY` if not provided."""
together_api_base: Optional[str] = Field(
default="https://api.together.ai/v1/", alias="base_url"
together_api_key: Optional[SecretStr] = Field(
alias="api_key",
default_factory=secret_from_env("TOGETHER_API_KEY", default=None),
)
"""Together AI API key.
Automatically read from env variable `TOGETHER_API_KEY` if not provided.
"""
together_api_base: str = Field(
default_factory=from_env(
"TOGETHER_API_BASE", default="https://api.together.ai/v1/"
),
alias="base_url",
)
@root_validator()
class Config:
"""Pydantic config."""
allow_population_by_field_name = True
@root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
if values["n"] < 1:
@@ -325,13 +338,6 @@ class ChatTogether(BaseChatOpenAI):
if values["n"] > 1 and values["streaming"]:
raise ValueError("n must be 1 when streaming.")
values["together_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "together_api_key", "TOGETHER_API_KEY")
)
values["together_api_base"] = values["together_api_base"] or os.getenv(
"TOGETHER_API_BASE"
)
client_params = {
"api_key": (
values["together_api_key"].get_secret_value()

View File

@@ -1,7 +1,6 @@
"""Wrapper around Together AI's Embeddings API."""
import logging
import os
import warnings
from typing import (
Any,
@@ -25,9 +24,9 @@ from langchain_core.pydantic_v1 import (
root_validator,
)
from langchain_core.utils import (
convert_to_secret_str,
get_from_dict_or_env,
from_env,
get_pydantic_field_names,
secret_from_env,
)
logger = logging.getLogger(__name__)
@@ -115,10 +114,19 @@ class TogetherEmbeddings(BaseModel, Embeddings):
Not yet supported.
"""
together_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
"""API Key for Solar API."""
together_api_key: Optional[SecretStr] = Field(
alias="api_key",
default_factory=secret_from_env("TOGETHER_API_KEY", default=None),
)
"""Together AI API key.
Automatically read from env variable `TOGETHER_API_KEY` if not provided.
"""
together_api_base: str = Field(
default="https://api.together.ai/v1/", alias="base_url"
default_factory=from_env(
"TOGETHER_API_BASE", default="https://api.together.ai/v1/"
),
alias="base_url",
)
"""Endpoint URL to use."""
embedding_ctx_length: int = 4096
@@ -198,18 +206,9 @@ class TogetherEmbeddings(BaseModel, Embeddings):
values["model_kwargs"] = extra
return values
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
together_api_key = get_from_dict_or_env(
values, "together_api_key", "TOGETHER_API_KEY"
)
values["together_api_key"] = (
convert_to_secret_str(together_api_key) if together_api_key else None
)
values["together_api_base"] = values["together_api_base"] or os.getenv(
"TOGETHER_API_BASE"
)
@root_validator(pre=False, skip_on_failure=True)
def post_init(cls, values: Dict) -> Dict:
"""Logic that will post Pydantic initialization."""
client_params = {
"api_key": (
values["together_api_key"].get_secret_value()

View File

@@ -11,8 +11,10 @@ from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import LLM
from langchain_core.pydantic_v1 import SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import (
secret_from_env,
)
logger = logging.getLogger(__name__)
@@ -36,8 +38,14 @@ class Together(LLM):
base_url: str = "https://api.together.ai/v1/completions"
"""Base completions API URL."""
together_api_key: SecretStr
"""Together AI API key. Get it here: https://api.together.ai/settings/api-keys"""
together_api_key: SecretStr = Field(
alias="api_key",
default_factory=secret_from_env("TOGETHER_API_KEY"),
)
"""Together AI API key.
Automatically read from env variable `TOGETHER_API_KEY` if not provided.
"""
model: str
"""Model name. Available models listed here:
Base Models: https://docs.together.ai/docs/inference-models#language-models
@@ -74,21 +82,11 @@ class Together(LLM):
"""Configuration for this pydantic object."""
extra = "forbid"
allow_population_by_field_name = True
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key exists in environment."""
values["together_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "together_api_key", "TOGETHER_API_KEY")
)
return values
@root_validator()
def validate_max_tokens(cls, values: Dict) -> Dict:
"""The v1 completions endpoint, has max_tokens as required parameter.
Set a default value and warn if the parameter is missing.
"""
if values.get("max_tokens") is None:
warnings.warn(
"The completions endpoint, has 'max_tokens' as required argument. "

View File

@@ -9,7 +9,7 @@ from langchain_together import Together
def test_together_api_key_is_secret_string() -> None:
"""Test that the API key is stored as a SecretStr."""
llm = Together(
together_api_key="secret-api-key", # type: ignore[arg-type]
together_api_key="secret-api-key", # type: ignore[call-arg]
model="togethercomputer/RedPajama-INCITE-7B-Base",
temperature=0.2,
max_tokens=250,
@@ -38,7 +38,7 @@ def test_together_api_key_masked_when_passed_via_constructor(
) -> None:
"""Test that the API key is masked when passed via the constructor."""
llm = Together(
together_api_key="secret-api-key", # type: ignore[arg-type]
together_api_key="secret-api-key", # type: ignore[call-arg]
model="togethercomputer/RedPajama-INCITE-7B-Base",
temperature=0.2,
max_tokens=250,
@@ -52,7 +52,18 @@ def test_together_api_key_masked_when_passed_via_constructor(
def test_together_uses_actual_secret_value_from_secretstr() -> None:
"""Test that the actual secret value is correctly retrieved."""
llm = Together(
together_api_key="secret-api-key", # type: ignore[arg-type]
together_api_key="secret-api-key", # type: ignore[call-arg]
model="togethercomputer/RedPajama-INCITE-7B-Base",
temperature=0.2,
max_tokens=250,
)
assert cast(SecretStr, llm.together_api_key).get_secret_value() == "secret-api-key"
def test_together_uses_actual_secret_value_from_secretstr_api_key() -> None:
"""Test that the actual secret value is correctly retrieved."""
llm = Together(
api_key="secret-api-key", # type: ignore[arg-type]
model="togethercomputer/RedPajama-INCITE-7B-Base",
temperature=0.2,
max_tokens=250,