Compare commits

...

2 Commits

Author SHA1 Message Date
Bagatur
61644408c2 docs 2024-01-19 14:16:47 -08:00
Bagatur
43700be083 wip: classification chain 2024-01-19 14:10:46 -08:00
3 changed files with 335 additions and 3 deletions

View File

@@ -0,0 +1,151 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "5fbab4e7-2c5e-4682-8c97-bfada89d206f",
"metadata": {},
"source": [
"# Classification"
]
},
{
"cell_type": "markdown",
"id": "0e135e18-f9b5-41d9-971d-f1c57a76add4",
"metadata": {},
"source": [
"## Direct prompting"
]
},
{
"cell_type": "markdown",
"id": "cf9d74f6-f18c-4ceb-a21b-ef101a758816",
"metadata": {},
"source": [
"## Function-calling"
]
},
{
"cell_type": "markdown",
"id": "81b3dd6a-c648-4930-98f3-733f0f0e5af7",
"metadata": {},
"source": [
"## Logprobs"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "edd66c98-1146-4f34-8828-433553f55ab0",
"metadata": {},
"outputs": [],
"source": [
"from langchain_openai import ChatOpenAI"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "2e9f633e-5d2c-41dd-a1d1-f4b0230facca",
"metadata": {},
"outputs": [],
"source": [
"from langchain.chains.classification import create_openai_logprobs_classification_chain"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "107020bc-fadd-49ad-aa30-c5bee19d8f63",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'classification': 'D', 'confidence': 0.9996887772698445}"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"llm = ChatOpenAI(model=\"gpt-3.5-turbo\", temperature=0)\n",
"classes = {\"C\": \"The input is about cats\", \"D\": \"The input is about dogs\"}\n",
"chain = create_openai_logprobs_classification_chain(llm, classes)\n",
"chain.invoke({\"input\": \"I really love my golden retriever\"})"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "c829152c-e113-4ae4-a628-10fc235c5bba",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'classification': 'C', 'confidence': 0.9997948118239739}"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain.invoke({\"input\": \"Aren't siamese just the best\"})"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "88bd9984-a050-4a1b-9707-b8f0994bd5b7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'classification': 'C', 'confidence': 0.677221622476509}"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain.invoke({\"input\": \"They scratched up everything\"})"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "246e36e0-3ca7-47e7-a46d-de9f1ea33c87",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "poetry-venv",
"language": "python",
"name": "poetry-venv"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,182 @@
import math
from bisect import bisect
from operator import itemgetter
from typing import Any, Dict, List, Literal, Optional, Sequence, Union
from langchain_core.language_models import BaseChatModel, LanguageModelLike
from langchain_core.outputs import LLMResult
from langchain_core.prompts import (
BasePromptTemplate,
ChatPromptTemplate,
)
from langchain_core.runnables import Runnable, RunnableLambda
from langchain.output_parsers import JsonOutputKeyToolsParser
_DEFAULT_OPENAI_TOOLS_PROMPT = ChatPromptTemplate.from_messages(
[
("system", "Classify the user input.{class_descriptions}"),
("human", "{input}"),
]
)
_DEFAULT_OPENAI_LOGPROBS_PROMPT = ChatPromptTemplate.from_messages(
[
(
"system",
"Classify the user input.{class_descriptions} MAKE SURE your output is one of the classes and NOTHING else.", # noqa: E501
),
("human", "{input}"),
]
)
def create_classification_chain(
llm: LanguageModelLike,
classes: Union[Sequence[str], Dict[str, str]],
/,
*,
type: Literal["openai-tools", "openai-logprobs"],
prompt: Optional[BasePromptTemplate] = None,
**kwargs: Any,
) -> Runnable:
""""""
if not classes:
raise ValueError("classes cannot be empty.")
if type == "openai-tools":
return create_openai_tool_classification_chain(
llm, classes, prompt=prompt, **kwargs
)
elif type == "openai-logprobs":
return create_openai_logprobs_classification_chain(
llm, classes, prompt=prompt, **kwargs
)
# TODO: Add JSON and XML chains.
else:
raise ValueError(
f"Unknown type {type}. Expected one of 'openai-tools', 'openai-logprobs'."
)
def create_openai_tool_classification_chain(
llm: LanguageModelLike,
classes: Union[Sequence[str], Dict[str, str]],
/,
*,
prompt: Optional[BasePromptTemplate] = None,
) -> Runnable[Dict, str]:
""""""
prompt = prompt or _DEFAULT_OPENAI_TOOLS_PROMPT
if isinstance(classes, Dict):
descriptions = "\n".join(f"{k}: {v}" for k, v in classes.items())
class_descriptions = f"\n\nThe classes are:\n\n{descriptions}"
else:
class_descriptions = ""
if "class_descriptions" in prompt.input_variables:
prompt = prompt.partial(class_descriptions=class_descriptions)
class_names = ", ".join(f'"{c}"' for c in classes)
tool = {
"type": "function",
"function": {
"name": "classify",
"description": "Classify the input as one of the given classes.",
"parameters": {
"type": "object",
"properties": {
"classification": {
"description": (
"The classification of the input. Must be one of "
f"{class_names}."
),
"enum": list(classes),
"type": "string",
}
},
"required": ["classification"],
},
},
}
llm_with_tool = llm.bind(
tools=[tool],
tool_choice={"type": "function", "function": {"name": "classify"}},
)
return (
prompt
| llm_with_tool
| JsonOutputKeyToolsParser(key="classify", return_single=True)
| itemgetter("classification")
)
def _parse_logprobs(
result: LLMResult, classes: List[str], top_k: int
) -> Union[Dict, List]:
original_classes = classes.copy()
classes = [c.lower() for c in classes]
top_classes = [c for c in classes if c in result.generations[0][0].text.lower()]
logprobs = result.generations[0][0].generation_info["logprobs"]["content"]
all_logprobs = [lp for token in logprobs for lp in token["top_logprobs"]]
present_token_classes = [
lp for lp in all_logprobs if lp["token"].strip().lower() in classes
]
if not top_classes and not present_token_classes:
res = {"classification": None, "confidence": None}
return res if top_k == 1 else [res]
# If any individual token matches a class.
cumulative = {}
for lp in present_token_classes:
normalized = lp["token"].strip().lower()
if normalized in cumulative:
cumulative[normalized] += math.exp(lp["logprob"])
else:
cumulative[normalized] = math.exp(lp["logprob"])
# If there are present classes that span more than a token.
present_multi_token_classes = set(top_classes).difference(cumulative)
spans = [len(logprobs[0]["token"])]
for lp in logprobs[1:]:
spans.append(len(lp["token"]))
for top_class in present_multi_token_classes:
start = result.generations[0][0].text.find(top_class)
start_token_idx = bisect.bisect(spans, start)
end = start + len(top_class)
end_token_idx = bisect.bisect_left(spans, end)
cumulative[top_class] = math.exp(
sum(lp["logprob"] for lp in logprobs[start_token_idx : end_token_idx + 1])
)
res = sorted(
[
{"classification": original_classes[classes.index(k)], "confidence": v}
for k, v in cumulative.items()
],
key=(lambda x: x["confidence"]),
reverse=True,
)
return res[0] if top_k == 1 else res[:top_k]
def create_openai_logprobs_classification_chain(
llm: BaseChatModel,
classes: Union[Sequence[str], Dict[str, str]],
/,
*,
prompt: Optional[BasePromptTemplate] = None,
top_k: int = 1,
) -> Runnable[Dict, Dict]:
""""""
prompt = prompt or _DEFAULT_OPENAI_TOOLS_PROMPT
if isinstance(classes, Dict):
descriptions = "\n".join(f"{k}: {v}" for k, v in classes.items())
class_descriptions = f"\n\nThe classes are:\n\n{descriptions}\n\n"
else:
names = ", ".join(classes)
class_descriptions = f"The classes are: {names}."
prompt = prompt.partial(class_descriptions=class_descriptions)
generate = RunnableLambda(llm.generate_prompt, afunc=llm.agenerate_prompt).bind(
logprobs=True, top_logprobs=top_k
)
parse = RunnableLambda(_parse_logprobs).bind(classes=list(classes), top_k=top_k)
return prompt | (lambda x: [x]) | generate | parse

View File

@@ -4,7 +4,7 @@ from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import Runnable
from langchain_core.utils.function_calling import convert_pydantic_to_openai_function
from langchain_core.utils.function_calling import convert_pydantic_to_openai_tool
from langchain.output_parsers import PydanticToolsParser
@@ -34,8 +34,7 @@ def create_extraction_chain_pydantic(
prompt = ChatPromptTemplate.from_messages(
[("system", system_message), ("user", "{input}")]
)
functions = [convert_pydantic_to_openai_function(p) for p in pydantic_schemas]
tools = [{"type": "function", "function": d} for d in functions]
tools = [convert_pydantic_to_openai_tool(p) for p in pydantic_schemas]
model = llm.bind(tools=tools)
chain = prompt | model | PydanticToolsParser(tools=pydantic_schemas)
return chain