diff --git a/libs/experimental/langchain_experimental/llms/anthropic_functions.py b/libs/experimental/langchain_experimental/llms/anthropic_functions.py index 48e38690dea..0df09ae06d9 100644 --- a/libs/experimental/langchain_experimental/llms/anthropic_functions.py +++ b/libs/experimental/langchain_experimental/llms/anthropic_functions.py @@ -144,19 +144,30 @@ class AnthropicFunctions(BaseChatModel): forced = False function_call = "" if "functions" in kwargs: - content = prompt.format(tools=json.dumps(kwargs["functions"], indent=2)) - system = SystemMessage(content=content) - messages = [system] + messages + # get the function call method + if "function_call" in kwargs: + function_call = kwargs["function_call"] + del kwargs["function_call"] + else: + function_call = "auto" + + # should function calling be used + if function_call != "none": + content = prompt.format(tools=json.dumps(kwargs["functions"], indent=2)) + system = SystemMessage(content=content) + messages = [system] + messages + + # is the function call a dictionary (forced function calling) + if isinstance(function_call, dict): + forced = True + function_call_name = function_call["name"] + messages.append(AIMessage(content=f"{function_call_name}")) + del kwargs["functions"] if stop is None: stop = [""] else: stop.append("") - if "function_call" in kwargs: - forced = True - function_call = kwargs["function_call"]["name"] - AIMessage(content=f"{function_call}") - del kwargs["function_call"] else: if "function_call" in kwargs: raise ValueError( @@ -168,12 +179,19 @@ class AnthropicFunctions(BaseChatModel): completion = response.content if forced: tag_parser = TagParser() - tag_parser.feed(completion.strip() + "") - v1 = tag_parser.parse_data["tool_input"][0] + + if "" in completion: + tag_parser.feed(completion.strip() + "") + v1 = tag_parser.parse_data["tool_input"][0] + arguments = json.dumps(_destrip(v1)) + else: + v1 = completion + arguments = "" + kwargs = { "function_call": { - "name": function_call, - "arguments": json.dumps(_destrip(v1)), + "name": function_call_name, + "arguments": arguments, } } message = AIMessage(content="", additional_kwargs=kwargs) @@ -181,7 +199,7 @@ class AnthropicFunctions(BaseChatModel): elif "" in completion: tag_parser = TagParser() tag_parser.feed(completion.strip() + "") - msg = completion.split("")[0] + msg = completion.split("")[0].strip() v1 = tag_parser.parse_data["tool_input"][0] kwargs = { "function_call": { @@ -192,6 +210,7 @@ class AnthropicFunctions(BaseChatModel): message = AIMessage(content=msg, additional_kwargs=kwargs) return ChatResult(generations=[ChatGeneration(message=message)]) else: + response.content = response.content.strip() return ChatResult(generations=[ChatGeneration(message=response)]) @property