diff --git a/docs/docs/integrations/chat/cohere.ipynb b/docs/docs/integrations/chat/cohere.ipynb index 6abebf77354..ee54225cfbb 100644 --- a/docs/docs/integrations/chat/cohere.ipynb +++ b/docs/docs/integrations/chat/cohere.ipynb @@ -40,18 +40,10 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 1, "id": "2108b517-1e8d-473d-92fa-4f930e8072a7", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "········\n" - ] - } - ], + "outputs": [], "source": [ "import getpass\n", "import os\n", @@ -90,7 +82,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 3, "id": "d4a7c55d-b235-4ca4-a579-c90cc9570da9", "metadata": { "tags": [] @@ -103,7 +95,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 4, "id": "70cf04e8-423a-4ff6-8b09-f11fb711c817", "metadata": { "tags": [] @@ -115,7 +107,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 5, "id": "8199ef8f-eb8b-4253-9ea0-6c24a013ca4c", "metadata": { "tags": [] @@ -124,22 +116,22 @@ { "data": { "text/plain": [ - "AIMessage(content=\"Who's there?\")" + "AIMessage(content=\"4! That's one, two, three, four. Keep adding and we'll reach new heights!\", response_metadata={'documents': None, 'citations': None, 'search_results': None, 'search_queries': None, 'token_count': {'prompt_tokens': 73, 'response_tokens': 21, 'total_tokens': 94, 'billed_tokens': 25}})" ] }, - "execution_count": 3, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "messages = [HumanMessage(content=\"knock knock\")]\n", + "messages = [HumanMessage(content=\"1\"), HumanMessage(content=\"2 3\")]\n", "chat.invoke(messages)" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 6, "id": "c5fac0e9-05a4-4fc1-a3b3-e5bbb24b971b", "metadata": { "tags": [] @@ -148,10 +140,10 @@ { "data": { "text/plain": [ - "AIMessage(content=\"Who's there?\")" + "AIMessage(content='4! According to the rules of addition, 1 + 2 equals 3, and 3 + 3 equals 6.', response_metadata={'documents': None, 'citations': None, 'search_results': None, 'search_queries': None, 'token_count': {'prompt_tokens': 73, 'response_tokens': 28, 'total_tokens': 101, 'billed_tokens': 32}})" ] }, - "execution_count": 4, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -162,7 +154,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 7, "id": "025be980-e50d-4a68-93dc-c9c7b500ce34", "metadata": { "tags": [] @@ -172,7 +164,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Who's there?" + "4! It's a pleasure to be of service in this mathematical game." ] } ], @@ -183,17 +175,17 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 8, "id": "064288e4-f184-4496-9427-bcf148fa055e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "[AIMessage(content=\"Who's there?\")]" + "[AIMessage(content='4! According to the rules of addition, 1 + 2 equals 3, and 3 + 3 equals 6.', response_metadata={'documents': None, 'citations': None, 'search_results': None, 'search_queries': None, 'token_count': {'prompt_tokens': 73, 'response_tokens': 28, 'total_tokens': 101, 'billed_tokens': 32}})]" ] }, - "execution_count": 6, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -214,7 +206,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 9, "id": "0851b103", "metadata": {}, "outputs": [], @@ -227,17 +219,17 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 10, "id": "ae950c0f-1691-47f1-b609-273033cae707", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "AIMessage(content=\"Why did the bear go to the chiropractor?\\n\\nBecause she was feeling a bit grizzly!\\n\\nHope you found that joke about bears to be a little bit amusing! If you'd like to hear another one, just let me know. In the meantime, if you have any other questions or need assistance with a different topic, feel free to let me know. \\n\\nJust remember, even if you have a sore back like the bear, it's always best to consult a licensed professional for injuries or pain you may be experiencing. \\n\\nWould you like me to tell you another joke?\")" + "AIMessage(content='What do you call a bear with no teeth? A gummy bear!', response_metadata={'documents': None, 'citations': None, 'search_results': None, 'search_queries': None, 'token_count': {'prompt_tokens': 72, 'response_tokens': 14, 'total_tokens': 86, 'billed_tokens': 20}})" ] }, - "execution_count": 8, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -263,7 +255,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.1" + "version": "3.11.7" } }, "nbformat": 4, diff --git a/docs/docs/integrations/retrievers/cohere.ipynb b/docs/docs/integrations/retrievers/cohere.ipynb index e3f8e084a61..4dad0589985 100644 --- a/docs/docs/integrations/retrievers/cohere.ipynb +++ b/docs/docs/integrations/retrievers/cohere.ipynb @@ -10,6 +10,19 @@ "This notebook covers how to get started with Cohere RAG retriever. This allows you to leverage the ability to search documents over various connectors or by supplying your own." ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "2c367be3", + "metadata": {}, + "outputs": [], + "source": [ + "import getpass\n", + "import os\n", + "\n", + "os.environ[\"COHERE_API_KEY\"] = getpass.getpass()" + ] + }, { "cell_type": "code", "execution_count": null, @@ -218,7 +231,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.11.7" } }, "nbformat": 4, diff --git a/docs/docs/integrations/text_embedding/cohere.ipynb b/docs/docs/integrations/text_embedding/cohere.ipynb index fb245bfb578..53defbcbb6f 100644 --- a/docs/docs/integrations/text_embedding/cohere.ipynb +++ b/docs/docs/integrations/text_embedding/cohere.ipynb @@ -10,6 +10,19 @@ "Let's load the Cohere Embedding class." ] }, + { + "cell_type": "code", + "execution_count": 1, + "id": "1bfad19b", + "metadata": {}, + "outputs": [], + "source": [ + "import getpass\n", + "import os\n", + "\n", + "os.environ[\"COHERE_API_KEY\"] = getpass.getpass()" + ] + }, { "cell_type": "code", "execution_count": 2, @@ -50,7 +63,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "[-0.072631836, 0.06921387, -0.02658081, 0.022705078, 0.027328491, 0.046905518, -0.01838684, -0.029525757, 0.0041046143, -0.028198242, 0.0496521, 0.026901245, 0.03274536, 0.01574707, -0.081726074, -0.022369385, 0.049591064, 0.06549072, -0.015083313, -0.053863525, 0.098083496, 0.034698486, -0.08557129, -0.0024662018, -0.07519531, 0.03265381, 0.006046295, -0.0060691833, 0.032196045, 0.07537842, 9.024143e-05, -0.00869751, 0.022735596, 0.06329346, 0.068481445, -0.006778717, -0.07885742, 0.049560547, -0.008811951, 0.025253296, 0.050750732, -0.05343628, 0.051361084, -0.02319336, 0.026382446, 0.088378906, 0.03567505, -0.0736084, 0.039215088, -0.020584106, -0.03112793, -0.071777344, 0.018218994, -0.01876831, 0.040863037, 0.080078125, 0.046020508, -0.030792236, -0.011779785, -0.024871826, -0.06652832, 0.04748535, -0.038116455, 0.08453369, 0.08746338, 0.059509277, -0.037628174, -0.045410156, -0.054626465, -0.0036334991, -0.035949707, -0.011070251, 0.054534912, 0.0803833, 0.052734375, 0.06689453, 0.0074310303, 0.018249512, -0.023773193, 0.03845215, -0.113220215, 0.014251709, 0.028289795, -0.03942871, 0.029525757, 0.03036499, 0.035095215, 0.031829834, -0.0015306473, 0.027252197, 0.005088806, -0.035858154, -0.113220215, 0.021606445, 0.012046814, -0.06137085, 0.0057640076, -0.06994629, 0.02532959, 0.016952515, -0.010398865, -0.0066184998, -0.020904541, -0.12030029, 0.0036029816, -0.061553955, 0.023956299, -0.07330322, 0.013053894, -0.009613037, -0.062683105, 0.00013184547, 0.12030029, 0.028167725, 0.048614502, -0.09301758, -0.020324707, 0.022369385, -0.14025879, -0.052764893, 0.07220459, 0.028198242, 0.01499939, -0.029449463, 0.004711151, -0.05947876, 0.1640625, -0.09240723, 0.019500732, -0.0031089783, 0.0032081604, -0.0049934387, -0.01676941, 0.002691269, 0.02848816, 0.013504028, -0.057800293, 0.049041748, -0.022384644, 0.05517578, -0.031982422, 0.055389404, 0.0859375, 0.019866943, -0.052978516, 0.030929565, -0.15979004, 0.068481445, -0.020080566, -0.033477783, 0.07922363, -0.020736694, -0.025680542, 0.054016113, -0.028839111, -0.016189575, 0.03564453, 0.0001078248, 0.06304932, -0.022781372, 0.06555176, 0.010093689, 0.03286743, 0.14111328, -0.008468628, -0.04849243, 0.04525757, 0.065979004, -0.012138367, -0.017044067, 0.059509277, 0.035339355, -0.017807007, -0.027267456, -0.0034656525, -0.02078247, -0.033477783, 0.05041504, -0.043518066, -0.064208984, 0.034942627, -0.009300232, -0.08148193, 0.007774353, -0.03540039, -0.008255005, -0.1060791, -0.0703125, 0.091308594, 0.10095215, -0.081970215, 0.02355957, -0.026382446, -0.0070610046, -0.051208496, -0.014961243, 0.07269287, -0.033721924, 0.017669678, -0.08972168, 0.035339355, 0.03579712, -0.07299805, -0.014144897, -0.008850098, 0.023742676, -0.05847168, -0.07873535, -0.015388489, -0.039642334, -0.028930664, 0.008926392, -0.040283203, -0.02897644, -0.013557434, -0.006088257, 0.024169922, -0.10217285, 0.014526367, 0.007381439, -0.0005607605, -0.058410645, -0.008399963, -0.08001709, 0.05065918, 0.01727295, 0.012191772, -0.016571045, 0.03717041, -0.02607727, 0.060760498, 0.057678223, -0.06585693, 0.059173584, 0.023117065, -0.034118652, -0.03189087, 0.010429382, 0.010368347, -0.011230469, -0.020980835, -0.04019165, 0.048187256, -0.019638062, -0.024414062, -0.0019989014, 0.04336548, 0.117248535, 0.00033903122, -0.0014419556, 0.013946533, -0.11541748, 0.030059814, -0.06500244, 0.05441284, 0.021759033, 0.030380249, 0.080566406, 0.02331543, -0.04586792, 0.037322998, 0.011390686, -0.01374054, 0.1459961, -0.050964355, 0.081970215, -0.061645508, 0.07067871, -0.036956787, 0.060455322, 0.051361084, -0.05831909, 0.05328369, -0.008628845, 0.054534912, -0.047332764, 0.030578613, -0.048828125, -0.018112183, 0.022979736, -0.07318115, -0.0423584, -0.094177246, -0.04071045, 0.054260254, 0.0423584, 0.075805664, -0.06365967, 0.009269714, -0.054779053, -0.007637024, -0.01876831, 0.08453369, 0.058898926, -0.07727051, 0.04360962, 0.010574341, -0.027694702, 0.024917603, -0.0463562, 0.040222168, -0.05496216, -0.048461914, 0.013710022, -0.1038208, 0.027954102, 0.031951904, -0.05618286, 0.0025730133, -0.06549072, -0.049957275, 0.01499939, -0.11090088, -0.009017944, 0.021835327, 0.03503418, 0.058746338, -0.12756348, -0.0345459, -0.04699707, -0.029830933, -0.06726074, 0.010612488, -0.024108887, 0.016464233, 0.013076782, -0.06298828, -0.0657959, -0.0025234222, -0.0625, 0.013420105, 0.05810547, -0.006362915, -0.028625488, 0.06085205, 0.12310791, 0.04751587, -0.027740479, -0.02029419, -0.02293396, 0.048858643, -0.006793976, -0.0061073303, 0.029067993, -0.0076942444, -0.00088596344, -0.007446289, 0.12756348, 0.082092285, -0.0037841797, 0.03866577, 0.040374756, 0.019104004, -0.0345459, 0.019042969, -0.038116455, 0.045410156, 0.062683105, -0.024963379, 0.085632324, 0.005897522, 0.008285522, 0.008811951, 0.026504517, 0.025558472, -0.005554199, -0.017822266, -0.112854004, -0.03768921, -0.00097227097, -0.061401367, 0.050567627, -0.010734558, 0.07220459, 0.03643799, 0.0007662773, -0.020980835, -0.04711914, -0.03488159, -0.09655762, 0.0048561096, 0.028030396, 0.04586792, -0.014915466]\n" + "[-0.09338379, 0.0871582, -0.03326416, 0.01953125, 0.07702637, 0.034729004, -0.058380127, -0.031021118, -0.030517578, -0.055999756, 0.050842285, -0.006752014, 0.038391113, -0.0014362335, -0.041137695, -0.008880615, 0.026000977, -0.023010254, 0.05456543, -0.03366089, 0.055633545, 0.028579712, -0.068603516, 0.03970337, -0.06677246, 0.06732178, -0.013053894, -0.0060920715, 0.038116455, 0.057800293, 0.048736572, 0.026855469, 0.009849548, 0.08312988, 0.073791504, 0.01663208, -0.0871582, 0.01802063, -0.0020828247, -0.0031356812, 0.039978027, -0.03164673, 0.009796143, 0.011375427, 0.0068855286, 0.092285156, 0.05218506, -0.060943604, 0.038269043, -0.018218994, -0.04510498, -0.0847168, 0.008300781, -0.060058594, 0.0012111664, 0.05102539, 0.05218506, -0.047210693, -0.051239014, -0.044158936, -0.058166504, 0.07849121, -0.019165039, 0.06451416, 0.024887085, 0.011405945, -0.03768921, -0.018814087, -0.06829834, -0.052825928, -0.019104004, -0.021194458, 0.043518066, 0.07525635, 0.082336426, 0.0037651062, -0.0060310364, -0.03265381, 0.011375427, -0.013847351, -0.07232666, 0.02986145, 0.03866577, -0.029083252, 0.008666992, 0.03845215, 0.045196533, 0.012756348, -0.018051147, 0.032440186, -0.030715942, -0.045440674, -0.11187744, 0.032073975, 0.021972656, -0.044921875, -0.030410767, -0.03668213, 0.12420654, 0.05029297, -0.032989502, -0.049438477, 0.001704216, -0.08074951, 0.00046396255, -0.04107666, 0.020599365, -0.089416504, 0.020477295, -0.038726807, -0.04437256, -0.019256592, 0.048583984, 0.046020508, 0.03741455, -0.037475586, -0.050720215, 0.052856445, -0.10229492, -0.00010281801, 0.058776855, 0.021453857, -0.031051636, 0.01676941, 0.024047852, -0.026306152, 0.15258789, -0.09979248, 0.04888916, 0.045166016, 0.008865356, -0.043914795, -0.032928467, 0.0052757263, 0.06072998, 0.036956787, -0.058013916, 0.053466797, -0.03225708, 0.018371582, -0.0042533875, 0.047943115, 0.06530762, 0.039855957, -0.025360107, 0.047332764, -0.15124512, 0.08325195, 0.016174316, -0.029724121, 0.111816406, -0.05230713, -0.06964111, 0.03060913, -0.04257202, -0.0284729, 0.007843018, -0.03866577, 0.07867432, -0.04446411, 0.028869629, -0.015823364, 0.02659607, 0.085754395, 0.03878784, -0.04232788, 0.017074585, 0.026779175, -0.04284668, -0.017105103, 0.10058594, 0.022323608, -0.007007599, -0.09661865, -0.01322937, -0.004627228, 0.057800293, 0.057159424, -0.033294678, -0.066101074, 0.010910034, 0.033569336, -0.062042236, -0.0072021484, -0.070373535, 0.034729004, -0.07434082, -0.06604004, 0.061401367, 0.09576416, -0.070739746, 0.066833496, -0.019042969, -0.0051994324, -0.07696533, -0.03564453, 0.048614502, -0.048919678, 0.036224365, -0.06652832, 0.03338623, 0.05847168, 0.009414673, -0.035095215, 0.011787415, -0.007675171, -0.057006836, -0.045074463, -0.027999878, -0.049102783, -0.025787354, -0.010101318, -0.000813961, -0.009963989, -0.013343811, 0.04046631, 0.02758789, -0.07086182, 0.09442139, -0.012275696, -0.018936157, -0.011940002, 0.10638428, -0.10913086, 0.05606079, 0.008895874, 0.017089844, 0.019958496, 0.03173828, -0.037322998, 0.019699097, 0.046722412, -0.08959961, 0.059448242, 0.018875122, -0.057495117, -0.039276123, 0.009063721, -0.0178833, 0.032073975, -0.08178711, -0.061431885, 0.05731201, 0.012886047, -0.025360107, 0.04498291, 0.027923584, 0.125, 0.013374329, -0.013069153, -0.031677246, -0.109558105, 0.05731201, -0.03765869, 0.04650879, -0.005706787, 0.021697998, -0.0008239746, 0.030090332, -0.048736572, 0.07940674, -0.017120361, 0.018737793, 0.12011719, -0.03564453, 0.07519531, -0.039611816, -0.014968872, -0.045288086, 0.07702637, 0.010681152, -0.04736328, 0.07623291, 0.008071899, 0.080078125, -0.060516357, 0.043426514, -0.026489258, -0.018188477, 0.049560547, -0.068847656, -0.03387451, -0.09661865, -0.03768921, 0.028549194, 0.036621094, 0.05307007, -0.053894043, 0.0019035339, -0.07788086, -0.010597229, -0.027420044, 0.10900879, 0.019302368, -0.06726074, 0.04937744, 0.05154419, -0.050598145, 0.07562256, -0.05569458, 0.073913574, -0.052337646, -0.0149383545, -0.00037050247, 0.037322998, 0.018478394, -0.03201294, -0.04788208, 0.03062439, -0.055786133, 0.0018081665, 0.029510498, -0.10864258, -0.027374268, 0.040405273, 0.01474762, -0.010726929, -0.086242676, -0.02658081, -0.057159424, -0.0095825195, -0.11804199, -0.014289856, -0.006881714, -0.028533936, 0.005382538, -0.053771973, -0.015853882, 0.0034332275, -0.08441162, -0.028182983, -0.00856781, -0.060394287, -0.036590576, 0.03062439, 0.112854004, -0.008041382, -0.03353882, 0.0181427, -0.03466797, 0.026565552, -0.033813477, 0.0074310303, -0.02017212, -0.047729492, 0.00010108948, -0.032073975, 0.08630371, 0.08557129, -0.0115737915, 0.044067383, 0.062042236, 0.00819397, -0.016082764, 0.01574707, 0.0154418945, 0.06726074, 0.056884766, 0.01210022, 0.048095703, -0.0017309189, 0.018295288, -0.00592041, 0.062286377, 0.040649414, -0.032928467, -0.05392456, -0.13891602, -0.033050537, 0.047973633, -0.07824707, 0.024627686, -0.02923584, 0.09118652, 0.0690918, 0.045837402, -0.06402588, -0.028747559, -0.06542969, -0.08496094, 0.06762695, 0.04220581, 0.059539795, 0.0023174286]\n" ] } ], @@ -103,7 +116,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.5" + "version": "3.11.7" }, "vscode": { "interpreter": { diff --git a/libs/community/langchain_community/chat_models/cohere.py b/libs/community/langchain_community/chat_models/cohere.py index 657bca68a38..e3f20ad9c43 100644 --- a/libs/community/langchain_community/chat_models/cohere.py +++ b/libs/community/langchain_community/chat_models/cohere.py @@ -80,7 +80,7 @@ def get_cohere_chat_request( "AUTO" if documents is not None or connectors is not None else None ) - return { + req = { "message": messages[-1].content, "chat_history": [ {"role": get_role(x), "message": x.content} for x in messages[:-1] @@ -91,6 +91,8 @@ def get_cohere_chat_request( **kwargs, } + return {k: v for k, v in req.items() if v is not None} + class ChatCohere(BaseChatModel, BaseCohere): """`Cohere` chat large language models. @@ -142,7 +144,11 @@ class ChatCohere(BaseChatModel, BaseCohere): **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: request = get_cohere_chat_request(messages, **self._default_params, **kwargs) - stream = self.client.chat(**request, stream=True) + + if hasattr(self.client, "chat_stream"): # detect and support sdk v5 + stream = self.client.chat_stream(**request) + else: + stream = self.client.chat(**request, stream=True) for data in stream: if data.event_type == "text-generation": @@ -160,7 +166,11 @@ class ChatCohere(BaseChatModel, BaseCohere): **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: request = get_cohere_chat_request(messages, **self._default_params, **kwargs) - stream = await self.async_client.chat(**request, stream=True) + + if hasattr(self.async_client, "chat_stream"): # detect and support sdk v5 + stream = self.async_client.chat_stream(**request) + else: + stream = self.async_client.chat(**request, stream=True) async for data in stream: if data.event_type == "text-generation": @@ -220,7 +230,7 @@ class ChatCohere(BaseChatModel, BaseCohere): return await agenerate_from_stream(stream_iter) request = get_cohere_chat_request(messages, **self._default_params, **kwargs) - response = self.client.chat(**request, stream=False) + response = self.client.chat(**request) message = AIMessage(content=response.text) generation_info = None diff --git a/libs/community/langchain_community/embeddings/cohere.py b/libs/community/langchain_community/embeddings/cohere.py index 2d4676d1254..dcc1c68c792 100644 --- a/libs/community/langchain_community/embeddings/cohere.py +++ b/libs/community/langchain_community/embeddings/cohere.py @@ -4,6 +4,8 @@ from langchain_core.embeddings import Embeddings from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator from langchain_core.utils import get_from_dict_or_env +from langchain_community.llms.cohere import _create_retry_decorator + class CohereEmbeddings(BaseModel, Embeddings): """Cohere embedding models. @@ -34,7 +36,7 @@ class CohereEmbeddings(BaseModel, Embeddings): cohere_api_key: Optional[str] = None - max_retries: Optional[int] = 3 + max_retries: int = 3 """Maximum number of retries to make when generating.""" request_timeout: Optional[float] = None """Timeout in seconds for the Cohere API request.""" @@ -52,7 +54,6 @@ class CohereEmbeddings(BaseModel, Embeddings): cohere_api_key = get_from_dict_or_env( values, "cohere_api_key", "COHERE_API_KEY" ) - max_retries = values.get("max_retries") request_timeout = values.get("request_timeout") try: @@ -61,13 +62,11 @@ class CohereEmbeddings(BaseModel, Embeddings): client_name = values["user_agent"] values["client"] = cohere.Client( cohere_api_key, - max_retries=max_retries, timeout=request_timeout, client_name=client_name, ) values["async_client"] = cohere.AsyncClient( cohere_api_key, - max_retries=max_retries, timeout=request_timeout, client_name=client_name, ) @@ -78,10 +77,30 @@ class CohereEmbeddings(BaseModel, Embeddings): ) return values + def embed_with_retry(self, **kwargs: Any) -> Any: + """Use tenacity to retry the embed call.""" + retry_decorator = _create_retry_decorator(self.max_retries) + + @retry_decorator + def _embed_with_retry(**kwargs: Any) -> Any: + return self.client.embed(**kwargs) + + return _embed_with_retry(**kwargs) + + def aembed_with_retry(self, **kwargs: Any) -> Any: + """Use tenacity to retry the embed call.""" + retry_decorator = _create_retry_decorator(self.max_retries) + + @retry_decorator + async def _embed_with_retry(**kwargs: Any) -> Any: + return await self.async_client.embed(**kwargs) + + return _embed_with_retry(**kwargs) + def embed( self, texts: List[str], *, input_type: Optional[str] = None ) -> List[List[float]]: - embeddings = self.client.embed( + embeddings = self.embed_with_retry( model=self.model, texts=texts, input_type=input_type, @@ -93,7 +112,7 @@ class CohereEmbeddings(BaseModel, Embeddings): self, texts: List[str], *, input_type: Optional[str] = None ) -> List[List[float]]: embeddings = ( - await self.async_client.embed( + await self.aembed_with_retry( model=self.model, texts=texts, input_type=input_type, diff --git a/libs/community/langchain_community/llms/cohere.py b/libs/community/langchain_community/llms/cohere.py index 91fc906ebd9..1f5aa75e8c7 100644 --- a/libs/community/langchain_community/llms/cohere.py +++ b/libs/community/langchain_community/llms/cohere.py @@ -24,25 +24,32 @@ from langchain_community.llms.utils import enforce_stop_tokens logger = logging.getLogger(__name__) -def _create_retry_decorator(llm: Cohere) -> Callable[[Any], Any]: +def _create_retry_decorator(max_retries: int) -> Callable[[Any], Any]: import cohere + # support v4 and v5 + retry_conditions = ( + retry_if_exception_type(cohere.error.CohereError) + if hasattr(cohere, "error") + else retry_if_exception_type(Exception) + ) + min_seconds = 4 max_seconds = 10 # Wait 2^x * 1 second between each retry starting with # 4 seconds, then up to 10 seconds, then 10 seconds afterwards return retry( reraise=True, - stop=stop_after_attempt(llm.max_retries), + stop=stop_after_attempt(max_retries), wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), - retry=(retry_if_exception_type(cohere.error.CohereError)), + retry=retry_conditions, before_sleep=before_sleep_log(logger, logging.WARNING), ) def completion_with_retry(llm: Cohere, **kwargs: Any) -> Any: """Use tenacity to retry the completion call.""" - retry_decorator = _create_retry_decorator(llm) + retry_decorator = _create_retry_decorator(llm.max_retries) @retry_decorator def _completion_with_retry(**kwargs: Any) -> Any: @@ -53,7 +60,7 @@ def completion_with_retry(llm: Cohere, **kwargs: Any) -> Any: def acompletion_with_retry(llm: Cohere, **kwargs: Any) -> Any: """Use tenacity to retry the completion call.""" - retry_decorator = _create_retry_decorator(llm) + retry_decorator = _create_retry_decorator(llm.max_retries) @retry_decorator async def _completion_with_retry(**kwargs: Any) -> Any: