diff --git a/docs/scripts/model_feat_table.py b/docs/scripts/model_feat_table.py index ecf37fbcd69..0f21a0ed6a4 100644 --- a/docs/scripts/model_feat_table.py +++ b/docs/scripts/model_feat_table.py @@ -15,12 +15,17 @@ LLM_FEAT_TABLE_CORRECTION = { "PromptLayerOpenAI": {"batch_generate": False, "batch_agenerate": False}, } CHAT_MODEL_IGNORE = ("FakeListChatModel", "HumanInputChatModel") + CHAT_MODEL_FEAT_TABLE_CORRECTION = { "ChatMLflowAIGateway": {"_agenerate": False}, "PromptLayerChatOpenAI": {"_stream": False, "_astream": False}, "ChatKonko": {"_astream": False, "_agenerate": False}, + "ChatOpenAI": {"tool_calling": True}, + "ChatAnthropic": {"tool_calling": True}, + "ChatMistralAI": {"tool_calling": True}, } + LLM_TEMPLATE = """\ --- sidebar_position: 1 @@ -101,6 +106,7 @@ def get_llm_table(): "_astream", "batch_generate", "batch_agenerate", + "tool_calling", ] title = [ "Model", @@ -110,6 +116,7 @@ def get_llm_table(): "Async stream", "Batch", "Async batch", + "Tool calling", ] rows = [title, [":-"] + [":-:"] * (len(title) - 1)] for llm, feats in sorted(final_feats.items()): @@ -117,7 +124,8 @@ def get_llm_table(): return "\n".join(["|".join(row) for row in rows]) -def get_chat_model_table(): +def get_chat_model_table() -> str: + """Get the table of chat models.""" feat_table = {} for cm in chat_models.__all__: feat_table[cm] = {} @@ -133,8 +141,15 @@ def get_chat_model_table(): for k, v in {**feat_table, **CHAT_MODEL_FEAT_TABLE_CORRECTION}.items() if k not in CHAT_MODEL_IGNORE } - header = ["model", "_agenerate", "_stream", "_astream"] - title = ["Model", "Invoke", "Async invoke", "Stream", "Async stream"] + header = ["model", "_agenerate", "_stream", "_astream", "tool_calling"] + title = [ + "Model", + "Invoke", + "Async invoke", + "Stream", + "Async stream", + "Tool calling", + ] rows = [title, [":-"] + [":-:"] * (len(title) - 1)] for llm, feats in sorted(final_feats.items()): rows += [[llm, "✅"] + ["✅" if feats.get(h) else "❌" for h in header[1:]]]