docs: use init_chat_model (#29623)

This commit is contained in:
Erick Friis
2025-02-07 12:39:27 -08:00
committed by GitHub
parent bff25b552c
commit eb9eddae0c
16 changed files with 95 additions and 186 deletions

View File

@@ -91,29 +91,7 @@ export const CustomDropdown = ({ selectedOption, options, onSelect, modelType })
/**
* @typedef {Object} ChatModelTabsProps - Component props.
* @property {string} [openaiParams] - Parameters for OpenAI chat model. Defaults to `model="gpt-3.5-turbo-0125"`
* @property {string} [anthropicParams] - Parameters for Anthropic chat model. Defaults to `model="claude-3-sonnet-20240229"`
* @property {string} [cohereParams] - Parameters for Cohere chat model. Defaults to `model="command-r-plus"`
* @property {string} [fireworksParams] - Parameters for Fireworks chat model. Defaults to `model="accounts/fireworks/models/mixtral-8x7b-instruct"`
* @property {string} [groqParams] - Parameters for Groq chat model. Defaults to `model="llama3-8b-8192"`
* @property {string} [mistralParams] - Parameters for Mistral chat model. Defaults to `model="mistral-large-latest"`
* @property {string} [googleParams] - Parameters for Google chat model. Defaults to `model="gemini-pro"`
* @property {string} [togetherParams] - Parameters for Together chat model. Defaults to `model="mistralai/Mixtral-8x7B-Instruct-v0.1"`
* @property {string} [nvidiaParams] - Parameters for Nvidia NIM model. Defaults to `model="meta/llama3-70b-instruct"`
* @property {string} [databricksParams] - Parameters for Databricks model. Defaults to `endpoint="databricks-meta-llama-3-1-70b-instruct"`
* @property {string} [awsBedrockParams] - Parameters for AWS Bedrock chat model.
* @property {boolean} [hideOpenai] - Whether or not to hide OpenAI chat model.
* @property {boolean} [hideAnthropic] - Whether or not to hide Anthropic chat model.
* @property {boolean} [hideCohere] - Whether or not to hide Cohere chat model.
* @property {boolean} [hideFireworks] - Whether or not to hide Fireworks chat model.
* @property {boolean} [hideGroq] - Whether or not to hide Groq chat model.
* @property {boolean} [hideMistral] - Whether or not to hide Mistral chat model.
* @property {boolean} [hideGoogle] - Whether or not to hide Google VertexAI chat model.
* @property {boolean} [hideTogether] - Whether or not to hide Together chat model.
* @property {boolean} [hideAzure] - Whether or not to hide Microsoft Azure OpenAI chat model.
* @property {boolean} [hideNvidia] - Whether or not to hide NVIDIA NIM model.
* @property {boolean} [hideAWS] - Whether or not to hide AWS models.
* @property {boolean} [hideDatabricks] - Whether or not to hide Databricks models.
* @property {Object} [overrideParams] - An object for overriding the default parameters for each chat model, e.g. `{ openai: { model: "gpt-4o-mini" } }`
* @property {string} [customVarName] - Custom variable name for the model. Defaults to `model`.
*/
@@ -121,198 +99,151 @@ export const CustomDropdown = ({ selectedOption, options, onSelect, modelType })
* @param {ChatModelTabsProps} props - Component props.
*/
export default function ChatModelTabs(props) {
const [selectedModel, setSelectedModel] = useState("Groq");
const [selectedModel, setSelectedModel] = useState("groq");
const {
openaiParams,
anthropicParams,
cohereParams,
fireworksParams,
groqParams,
mistralParams,
googleParams,
togetherParams,
azureParams,
nvidiaParams,
awsBedrockParams,
databricksParams,
hideOpenai,
hideAnthropic,
hideCohere,
hideFireworks,
hideGroq,
hideMistral,
hideGoogle,
hideTogether,
hideAzure,
hideNvidia,
hideAWS,
hideDatabricks,
overrideParams,
customVarName,
} = props;
const openAIParamsOrDefault = openaiParams ?? `model="gpt-4o-mini"`;
const anthropicParamsOrDefault =
anthropicParams ?? `model="claude-3-5-sonnet-20240620"`;
const cohereParamsOrDefault = cohereParams ?? `model="command-r-plus"`;
const fireworksParamsOrDefault =
fireworksParams ??
`model="accounts/fireworks/models/llama-v3p1-70b-instruct"`;
const groqParamsOrDefault = groqParams ?? `model="llama3-8b-8192"`;
const mistralParamsOrDefault =
mistralParams ?? `model="mistral-large-latest"`;
const googleParamsOrDefault = googleParams ?? `model="gemini-1.5-flash"`;
const togetherParamsOrDefault =
togetherParams ??
`\n base_url="https://api.together.xyz/v1",\n api_key=os.environ["TOGETHER_API_KEY"],\n model="mistralai/Mixtral-8x7B-Instruct-v0.1",\n`;
const azureParamsOrDefault =
azureParams ??
`\n azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"],\n azure_deployment=os.environ["AZURE_OPENAI_DEPLOYMENT_NAME"],\n openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"],\n`;
const nvidiaParamsOrDefault = nvidiaParams ?? `model="meta/llama3-70b-instruct"`
const awsBedrockParamsOrDefault = awsBedrockParams ?? `model="anthropic.claude-3-5-sonnet-20240620-v1:0",\n beta_use_converse_api=True`;
const databricksParamsOrDefault = databricksParams ?? `endpoint="databricks-meta-llama-3-1-70b-instruct"`
const llmVarName = customVarName ?? "model";
const tabItems = [
{
value: "Groq",
value: "groq",
label: "Groq",
text: `from langchain_groq import ChatGroq\n\n${llmVarName} = ChatGroq(${groqParamsOrDefault})`,
model: "llama3-8b-8192",
apiKeyName: "GROQ_API_KEY",
packageName: "langchain-groq",
shouldHide: hideGroq,
},
{
value: "OpenAI",
value: "openai",
label: "OpenAI",
text: `from langchain_openai import ChatOpenAI\n\n${llmVarName} = ChatOpenAI(${openAIParamsOrDefault})`,
model: "gpt-4o-mini",
apiKeyName: "OPENAI_API_KEY",
packageName: "langchain-openai",
shouldHide: hideOpenai,
},
{
value: "Anthropic",
value: "anthropic",
label: "Anthropic",
text: `from langchain_anthropic import ChatAnthropic\n\n${llmVarName} = ChatAnthropic(${anthropicParamsOrDefault})`,
model: "claude-3-5-sonnet-latest",
apiKeyName: "ANTHROPIC_API_KEY",
packageName: "langchain-anthropic",
shouldHide: hideAnthropic,
},
{
value: "Azure",
value: "azure",
label: "Azure",
text: `from langchain_openai import AzureChatOpenAI\n\n${llmVarName} = AzureChatOpenAI(${azureParamsOrDefault})`,
text: `from langchain_openai import AzureChatOpenAI
${llmVarName} = AzureChatOpenAI(
azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"],
azure_deployment=os.environ["AZURE_OPENAI_DEPLOYMENT_NAME"],
openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"],
)`,
apiKeyName: "AZURE_OPENAI_API_KEY",
packageName: "langchain-openai",
shouldHide: hideAzure,
},
{
value: "Google",
label: "Google",
text: `from langchain_google_vertexai import ChatVertexAI\n\n${llmVarName} = ChatVertexAI(${googleParamsOrDefault})`,
value: "google_vertexai",
label: "Google Vertex",
model: "gemini-2.0-flash",
apiKeyText: "# Ensure your VertexAI credentials are configured",
packageName: "langchain-google-vertexai",
shouldHide: hideGoogle,
},
{
value: "AWS",
value: "aws",
label: "AWS",
text: `from langchain_aws import ChatBedrock\n\n${llmVarName} = ChatBedrock(${awsBedrockParamsOrDefault})`,
model: "anthropic.claude-3-5-sonnet-20240620-v1:0",
kwargs: "beta_use_converse_api=True",
apiKeyText: "# Ensure your AWS credentials are configured",
packageName: "langchain-aws",
shouldHide: hideAWS,
},
{
value: "Cohere",
value: "cohere",
label: "Cohere",
text: `from langchain_cohere import ChatCohere\n\n${llmVarName} = ChatCohere(${cohereParamsOrDefault})`,
model: "command-r-plus",
apiKeyName: "COHERE_API_KEY",
packageName: "langchain-cohere",
shouldHide: hideCohere,
},
{
value: "NVIDIA",
value: "nvidia",
label: "NVIDIA",
text: `from langchain_nvidia_ai_endpoints import ChatNVIDIA\n\n${llmVarName} = ChatNVIDIA(${nvidiaParamsOrDefault})`,
model: "meta/llama3-70b-instruct",
apiKeyName: "NVIDIA_API_KEY",
packageName: "langchain-nvidia-ai-endpoints",
shouldHide: hideNvidia,
},
{
value: "FireworksAI",
value: "fireworks",
label: "Fireworks AI",
text: `from langchain_fireworks import ChatFireworks\n\n${llmVarName} = ChatFireworks(${fireworksParamsOrDefault})`,
model: "accounts/fireworks/models/llama-v3p1-70b-instruct",
apiKeyName: "FIREWORKS_API_KEY",
packageName: "langchain-fireworks",
shouldHide: hideFireworks,
},
{
value: "MistralAI",
value: "mistralai",
label: "Mistral AI",
text: `from langchain_mistralai import ChatMistralAI\n\n${llmVarName} = ChatMistralAI(${mistralParamsOrDefault})`,
model: "mistral-large-latest",
apiKeyName: "MISTRAL_API_KEY",
packageName: "langchain-mistralai",
shouldHide: hideMistral,
},
{
value: "TogetherAI",
value: "together",
label: "Together AI",
text: `from langchain_openai import ChatOpenAI\n\n${llmVarName} = ChatOpenAI(${togetherParamsOrDefault})`,
model: "mistralai/Mixtral-8x7B-Instruct-v0.1",
apiKeyName: "TOGETHER_API_KEY",
packageName: "langchain-openai",
shouldHide: hideTogether,
packageName: "langchain-together",
},
{
value: "Databricks",
value: "databricks",
label: "Databricks",
text: `from databricks_langchain import ChatDatabricks\n\nos.environ["DATABRICKS_HOST"] = "https://example.staging.cloud.databricks.com/serving-endpoints"\n\n${llmVarName} = ChatDatabricks(${databricksParamsOrDefault})`,
text: `from databricks_langchain import ChatDatabricks\n\nos.environ["DATABRICKS_HOST"] = "https://example.staging.cloud.databricks.com/serving-endpoints"\n\n${llmVarName} = ChatDatabricks(endpoint="databricks-meta-llama-3-1-70b-instruct")`,
apiKeyName: "DATABRICKS_TOKEN",
packageName: "databricks-langchain",
shouldHide: hideDatabricks,
},
];
].map((item) => ({
...item,
...overrideParams?.[item.value],
}));
const modelOptions = tabItems
.filter((item) => !item.shouldHide)
.map((item) => ({
value: item.value,
label: item.label,
text: item.text,
apiKeyName: item.apiKeyName,
apiKeyText: item.apiKeyText,
packageName: item.packageName,
}));
const selectedOption = modelOptions.find(
(option) => option.value === selectedModel
);
const selectedTabItem = tabItems.find(
(option) => option.value === selectedModel
);
let apiKeyText = "";
if (selectedOption.apiKeyName) {
if (selectedTabItem.apiKeyName) {
apiKeyText = `import getpass
import os
if not os.environ.get("${selectedOption.apiKeyName}"):
os.environ["${selectedOption.apiKeyName}"] = getpass.getpass("Enter API key for ${selectedOption.label}: ")`;
} else if (selectedOption.apiKeyText) {
apiKeyText = selectedOption.apiKeyText;
if not os.environ.get("${selectedTabItem.apiKeyName}"):
os.environ["${selectedTabItem.apiKeyName}"] = getpass.getpass("Enter API key for ${selectedTabItem.label}: ")`;
} else if (selectedTabItem.apiKeyText) {
apiKeyText = selectedTabItem.apiKeyText;
}
return (
<div>
<CustomDropdown
selectedOption={selectedOption}
options={modelOptions}
onSelect={setSelectedModel}
modelType="chat"
/>
const initModelText = selectedTabItem?.text || `from langchain.chat_models import init_chat_model
<CodeBlock language="bash">
{`pip install -qU ${selectedOption.packageName}`}
</CodeBlock>
<CodeBlock language="python">
{apiKeyText ? apiKeyText + "\n\n" + selectedOption.text : selectedOption.text}
</CodeBlock>
</div>
);
${llmVarName} = init_chat_model("${selectedTabItem.model}", *, model_provider="${selectedTabItem.value}"${selectedTabItem?.kwargs ? `, ${selectedTabItem.kwargs}` : ""})`;
return (
<div>
<CustomDropdown
selectedOption={selectedTabItem}
options={modelOptions}
onSelect={setSelectedModel}
modelType="chat"
/>
<CodeBlock language="bash">
{`pip install -qU langchain ${selectedTabItem.packageName}`}
</CodeBlock>
<CodeBlock language="python">
{apiKeyText ? apiKeyText + "\n\n" + initModelText : initModelText}
</CodeBlock>
</div>
);
}