mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-03 03:59:42 +00:00
Lint Python notebooks with ruff. (#12677)
The new ruff version fixed the blocking bugs, and I was able to fairly easily us to a passing state: ruff fixed some issues on its own, I fixed a handful by hand, and I added a list of narrowly-targeted exclusions for files that are currently failing ruff rules that we probably should look into eventually. I went pretty lenient on the docs / cookbooks rules, allowing dead code and such things. Perhaps in the future we may want to tighten the rules further, but this is already a good set of checks that found real issues and will prevent them going forward.
This commit is contained in:
@@ -63,11 +63,13 @@
|
||||
"\n",
|
||||
"# Load\n",
|
||||
"from langchain.document_loaders import PyPDFLoader\n",
|
||||
"\n",
|
||||
"loader = PyPDFLoader(path + \"cpi.pdf\")\n",
|
||||
"pdf_pages = loader.load()\n",
|
||||
"\n",
|
||||
"# Split\n",
|
||||
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
|
||||
"\n",
|
||||
"text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0)\n",
|
||||
"all_splits_pypdf = text_splitter.split_documents(pdf_pages)\n",
|
||||
"all_splits_pypdf_texts = [d.page_content for d in all_splits_pypdf]"
|
||||
@@ -132,10 +134,13 @@
|
||||
"source": [
|
||||
"from langchain.vectorstores import Chroma\n",
|
||||
"from langchain.embeddings import OpenAIEmbeddings\n",
|
||||
"baseline = Chroma.from_texts(texts=all_splits_pypdf_texts,\n",
|
||||
" collection_name=\"baseline\",\n",
|
||||
" embedding=OpenAIEmbeddings())\n",
|
||||
"retriever_baseline=baseline.as_retriever()"
|
||||
"\n",
|
||||
"baseline = Chroma.from_texts(\n",
|
||||
" texts=all_splits_pypdf_texts,\n",
|
||||
" collection_name=\"baseline\",\n",
|
||||
" embedding=OpenAIEmbeddings(),\n",
|
||||
")\n",
|
||||
"retriever_baseline = baseline.as_retriever()"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -160,7 +165,7 @@
|
||||
"from langchain.schema.output_parser import StrOutputParser\n",
|
||||
"\n",
|
||||
"# Prompt\n",
|
||||
"prompt_text = \"\"\"You are an assistant tasked with summarizing tables and text for retrieval. \\ \n",
|
||||
"prompt_text = \"\"\"You are an assistant tasked with summarizing tables and text for retrieval. \\\n",
|
||||
"These summaries will be embedded and used to retrieve the raw text or table elements. \\\n",
|
||||
"Give a concise summary of the table or text that is well optimized for retrieval. Table or text: {element} \"\"\"\n",
|
||||
"prompt = ChatPromptTemplate.from_template(prompt_text)\n",
|
||||
@@ -169,7 +174,7 @@
|
||||
"model = ChatOpenAI(temperature=0, model=\"gpt-4\")\n",
|
||||
"summarize_chain = {\"element\": lambda x: x} | prompt | model | StrOutputParser()\n",
|
||||
"\n",
|
||||
"# Apply to text \n",
|
||||
"# Apply to text\n",
|
||||
"text_summaries = summarize_chain.batch(texts, {\"max_concurrency\": 5})\n",
|
||||
"\n",
|
||||
"# Apply to tables\n",
|
||||
@@ -192,31 +197,32 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Image summary chain\n",
|
||||
"import os, base64, io\n",
|
||||
"import os\n",
|
||||
"import base64\n",
|
||||
"import io\n",
|
||||
"from io import BytesIO\n",
|
||||
"from PIL import Image\n",
|
||||
"from langchain.schema.messages import HumanMessage\n",
|
||||
"\n",
|
||||
"def encode_image(image_path):\n",
|
||||
" ''' Getting the base64 string '''\n",
|
||||
" with open(image_path, \"rb\") as image_file:\n",
|
||||
" return base64.b64encode(image_file.read()).decode('utf-8') \n",
|
||||
"\n",
|
||||
"def image_summarize(img_base64,prompt):\n",
|
||||
" ''' Image summary '''\n",
|
||||
" chat = ChatOpenAI(model=\"gpt-4-vision-preview\",\n",
|
||||
" max_tokens=1024)\n",
|
||||
" \n",
|
||||
"def encode_image(image_path):\n",
|
||||
" \"\"\"Getting the base64 string\"\"\"\n",
|
||||
" with open(image_path, \"rb\") as image_file:\n",
|
||||
" return base64.b64encode(image_file.read()).decode(\"utf-8\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def image_summarize(img_base64, prompt):\n",
|
||||
" \"\"\"Image summary\"\"\"\n",
|
||||
" chat = ChatOpenAI(model=\"gpt-4-vision-preview\", max_tokens=1024)\n",
|
||||
"\n",
|
||||
" msg = chat.invoke(\n",
|
||||
" [\n",
|
||||
" HumanMessage(\n",
|
||||
" content=[\n",
|
||||
" {\"type\": \"text\", \"text\":prompt},\n",
|
||||
" {\"type\": \"text\", \"text\": prompt},\n",
|
||||
" {\n",
|
||||
" \"type\": \"image_url\",\n",
|
||||
" \"image_url\": {\n",
|
||||
" \"url\": f\"data:image/jpeg;base64,{img_base64}\"\n",
|
||||
" },\n",
|
||||
" \"image_url\": {\"url\": f\"data:image/jpeg;base64,{img_base64}\"},\n",
|
||||
" },\n",
|
||||
" ]\n",
|
||||
" )\n",
|
||||
@@ -224,6 +230,7 @@
|
||||
" )\n",
|
||||
" return msg.content\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Store base64 encoded images\n",
|
||||
"img_base64_list = []\n",
|
||||
"\n",
|
||||
@@ -231,17 +238,17 @@
|
||||
"image_summaries = []\n",
|
||||
"\n",
|
||||
"# Prompt\n",
|
||||
"prompt = \"\"\"You are an assistant tasked with summarizing images for retrieval. \\ \n",
|
||||
"prompt = \"\"\"You are an assistant tasked with summarizing images for retrieval. \\\n",
|
||||
"These summaries will be embedded and used to retrieve the raw image. \\\n",
|
||||
"Give a concise summary of the image that is well optimized for retrieval.\"\"\"\n",
|
||||
"\n",
|
||||
"# Apply to images\n",
|
||||
"for img_file in sorted(os.listdir(path)):\n",
|
||||
" if img_file.endswith('.jpg'):\n",
|
||||
" if img_file.endswith(\".jpg\"):\n",
|
||||
" img_path = os.path.join(path, img_file)\n",
|
||||
" base64_image = encode_image(img_path)\n",
|
||||
" img_base64_list.append(base64_image)\n",
|
||||
" image_summaries.append(image_summarize(base64_image,prompt))"
|
||||
" image_summaries.append(image_summarize(base64_image, prompt))"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -267,14 +274,10 @@
|
||||
"from langchain.schema.document import Document\n",
|
||||
"from langchain.retrievers.multi_vector import MultiVectorRetriever\n",
|
||||
"\n",
|
||||
"def create_multi_vector_retriever(vectorstore, \n",
|
||||
" text_summaries, \n",
|
||||
" texts, \n",
|
||||
" table_summaries, \n",
|
||||
" tables, \n",
|
||||
" image_summaries, \n",
|
||||
" images):\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"def create_multi_vector_retriever(\n",
|
||||
" vectorstore, text_summaries, texts, table_summaries, tables, image_summaries, images\n",
|
||||
"):\n",
|
||||
" # Initialize the storage layer\n",
|
||||
" store = InMemoryStore()\n",
|
||||
" id_key = \"doc_id\"\n",
|
||||
@@ -309,18 +312,22 @@
|
||||
"\n",
|
||||
" return retriever\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# The vectorstore to use to index the summaries\n",
|
||||
"multi_vector_img = Chroma(collection_name=\"multi_vector_img\", \n",
|
||||
" embedding_function=OpenAIEmbeddings())\n",
|
||||
"multi_vector_img = Chroma(\n",
|
||||
" collection_name=\"multi_vector_img\", embedding_function=OpenAIEmbeddings()\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Create retriever\n",
|
||||
"retriever_multi_vector_img = create_multi_vector_retriever(multi_vector_img,\n",
|
||||
" text_summaries,\n",
|
||||
" texts,\n",
|
||||
" table_summaries, \n",
|
||||
" tables, \n",
|
||||
" image_summaries, \n",
|
||||
" img_base64_list)"
|
||||
"retriever_multi_vector_img = create_multi_vector_retriever(\n",
|
||||
" multi_vector_img,\n",
|
||||
" text_summaries,\n",
|
||||
" texts,\n",
|
||||
" table_summaries,\n",
|
||||
" tables,\n",
|
||||
" image_summaries,\n",
|
||||
" img_base64_list,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -330,10 +337,10 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Testing on retrieval \n",
|
||||
"query=\"What percentage of CPI is dedicated to Housing, and how does it compare to the combined percentage of Medical Care, Apparel, and Other Goods and Services?\"\n",
|
||||
"suffix_for_images=\" Include any pie charts, graphs, or tables.\"\n",
|
||||
"docs = retriever_multi_vector_img.get_relevant_documents(query+suffix_for_images)"
|
||||
"# Testing on retrieval\n",
|
||||
"query = \"What percentage of CPI is dedicated to Housing, and how does it compare to the combined percentage of Medical Care, Apparel, and Other Goods and Services?\"\n",
|
||||
"suffix_for_images = \" Include any pie charts, graphs, or tables.\"\n",
|
||||
"docs = retriever_multi_vector_img.get_relevant_documents(query + suffix_for_images)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -357,14 +364,16 @@
|
||||
],
|
||||
"source": [
|
||||
"from IPython.display import display, HTML\n",
|
||||
"def plt_img_base64(img_base64):\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def plt_img_base64(img_base64):\n",
|
||||
" # Create an HTML img tag with the base64 string as the source\n",
|
||||
" image_html = f'<img src=\"data:image/jpeg;base64,{img_base64}\" />'\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" # Display the image by rendering the HTML\n",
|
||||
" display(HTML(image_html))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"plt_img_base64(docs[1])"
|
||||
]
|
||||
},
|
||||
@@ -386,17 +395,20 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# The vectorstore to use to index the summaries\n",
|
||||
"multi_vector_text = Chroma(collection_name=\"multi_vector_text\", \n",
|
||||
" embedding_function=OpenAIEmbeddings())\n",
|
||||
"multi_vector_text = Chroma(\n",
|
||||
" collection_name=\"multi_vector_text\", embedding_function=OpenAIEmbeddings()\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Create retriever\n",
|
||||
"retriever_multi_vector_img_summary = create_multi_vector_retriever(multi_vector_text,\n",
|
||||
" text_summaries,\n",
|
||||
" texts,\n",
|
||||
" table_summaries, \n",
|
||||
" tables, \n",
|
||||
" image_summaries, \n",
|
||||
" image_summaries)"
|
||||
"retriever_multi_vector_img_summary = create_multi_vector_retriever(\n",
|
||||
" multi_vector_text,\n",
|
||||
" text_summaries,\n",
|
||||
" texts,\n",
|
||||
" table_summaries,\n",
|
||||
" tables,\n",
|
||||
" image_summaries,\n",
|
||||
" image_summaries,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -418,14 +430,17 @@
|
||||
"\n",
|
||||
"# Create chroma w/ multi-modal embeddings\n",
|
||||
"multimodal_embd = Chroma(\n",
|
||||
" collection_name=\"multimodal_embd\",\n",
|
||||
" embedding_function=OpenCLIPEmbeddings()\n",
|
||||
" collection_name=\"multimodal_embd\", embedding_function=OpenCLIPEmbeddings()\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Get image URIs\n",
|
||||
"image_uris = sorted([os.path.join(path, image_name) \n",
|
||||
" for image_name in os.listdir(path) \n",
|
||||
" if image_name.endswith('.jpg')])\n",
|
||||
"image_uris = sorted(\n",
|
||||
" [\n",
|
||||
" os.path.join(path, image_name)\n",
|
||||
" for image_name in os.listdir(path)\n",
|
||||
" if image_name.endswith(\".jpg\")\n",
|
||||
" ]\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Add images and documents\n",
|
||||
"if image_uris:\n",
|
||||
@@ -435,7 +450,7 @@
|
||||
"if tables:\n",
|
||||
" multimodal_embd.add_texts(texts=tables)\n",
|
||||
"\n",
|
||||
"# Make retriever \n",
|
||||
"# Make retriever\n",
|
||||
"retriever_multimodal_embd = multimodal_embd.as_retriever()"
|
||||
]
|
||||
},
|
||||
@@ -466,14 +481,14 @@
|
||||
"\"\"\"\n",
|
||||
"rag_prompt_text = ChatPromptTemplate.from_template(template)\n",
|
||||
"\n",
|
||||
"# Build \n",
|
||||
"\n",
|
||||
"# Build\n",
|
||||
"def text_rag_chain(retriever):\n",
|
||||
" \n",
|
||||
" ''' RAG chain '''\n",
|
||||
" \"\"\"RAG chain\"\"\"\n",
|
||||
"\n",
|
||||
" # LLM\n",
|
||||
" model = ChatOpenAI(temperature=0, model=\"gpt-4\")\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" # RAG pipeline\n",
|
||||
" chain = (\n",
|
||||
" {\"context\": retriever, \"question\": RunnablePassthrough()}\n",
|
||||
@@ -500,13 +515,15 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import re \n",
|
||||
"import re\n",
|
||||
"from langchain.schema import Document\n",
|
||||
"from langchain.schema.runnable import RunnableLambda\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def looks_like_base64(sb):\n",
|
||||
" \"\"\"Check if the string looks like base64.\"\"\"\n",
|
||||
" return re.match('^[A-Za-z0-9+/]+[=]{0,2}$', sb) is not None\n",
|
||||
" return re.match(\"^[A-Za-z0-9+/]+[=]{0,2}$\", sb) is not None\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def is_image_data(b64data):\n",
|
||||
" \"\"\"Check if the base64 data is an image by looking at the start of the data.\"\"\"\n",
|
||||
@@ -514,7 +531,7 @@
|
||||
" b\"\\xFF\\xD8\\xFF\": \"jpg\",\n",
|
||||
" b\"\\x89\\x50\\x4E\\x47\\x0D\\x0A\\x1A\\x0A\": \"png\",\n",
|
||||
" b\"\\x47\\x49\\x46\\x38\": \"gif\",\n",
|
||||
" b\"\\x52\\x49\\x46\\x46\": \"webp\"\n",
|
||||
" b\"\\x52\\x49\\x46\\x46\": \"webp\",\n",
|
||||
" }\n",
|
||||
" try:\n",
|
||||
" header = base64.b64decode(b64data)[:8] # Decode and get the first 8 bytes\n",
|
||||
@@ -525,6 +542,7 @@
|
||||
" except Exception:\n",
|
||||
" return False\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def split_image_text_types(docs):\n",
|
||||
" \"\"\"Split base64-encoded images and texts.\"\"\"\n",
|
||||
" b64_images = []\n",
|
||||
@@ -539,6 +557,7 @@
|
||||
" texts.append(doc)\n",
|
||||
" return {\"images\": b64_images, \"texts\": texts}\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def img_prompt_func(data_dict):\n",
|
||||
" # Joining the context texts into a single string\n",
|
||||
" formatted_texts = \"\\n\".join(data_dict[\"context\"][\"texts\"])\n",
|
||||
@@ -550,7 +569,7 @@
|
||||
" \"type\": \"image_url\",\n",
|
||||
" \"image_url\": {\n",
|
||||
" \"url\": f\"data:image/jpeg;base64,{data_dict['context']['images'][0]}\"\n",
|
||||
" }\n",
|
||||
" },\n",
|
||||
" }\n",
|
||||
" messages.append(image_message)\n",
|
||||
"\n",
|
||||
@@ -563,22 +582,24 @@
|
||||
" f\"User-provided question / keywords: {data_dict['question']}\\n\\n\"\n",
|
||||
" \"Text and / or tables:\\n\"\n",
|
||||
" f\"{formatted_texts}\"\n",
|
||||
" )\n",
|
||||
" ),\n",
|
||||
" }\n",
|
||||
" messages.append(text_message)\n",
|
||||
" return [HumanMessage(content=messages)]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def multi_modal_rag_chain(retriever):\n",
|
||||
" ''' Multi-modal RAG chain '''\n",
|
||||
" \"\"\"Multi-modal RAG chain\"\"\"\n",
|
||||
"\n",
|
||||
" # Multi-modal LLM\n",
|
||||
" model = ChatOpenAI(temperature=0, \n",
|
||||
" model=\"gpt-4-vision-preview\", \n",
|
||||
" max_tokens=1024)\n",
|
||||
" \n",
|
||||
" model = ChatOpenAI(temperature=0, model=\"gpt-4-vision-preview\", max_tokens=1024)\n",
|
||||
"\n",
|
||||
" # RAG pipeline\n",
|
||||
" chain = (\n",
|
||||
" {\"context\": retriever | RunnableLambda(split_image_text_types), \"question\": RunnablePassthrough()}\n",
|
||||
" {\n",
|
||||
" \"context\": retriever | RunnableLambda(split_image_text_types),\n",
|
||||
" \"question\": RunnablePassthrough(),\n",
|
||||
" }\n",
|
||||
" | RunnableLambda(img_prompt_func)\n",
|
||||
" | model\n",
|
||||
" | StrOutputParser()\n",
|
||||
@@ -603,12 +624,12 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# RAG chains\n",
|
||||
"chain_baseline=text_rag_chain(retriever_baseline)\n",
|
||||
"chain_mv_text=text_rag_chain(retriever_multi_vector_img_summary)\n",
|
||||
"chain_baseline = text_rag_chain(retriever_baseline)\n",
|
||||
"chain_mv_text = text_rag_chain(retriever_multi_vector_img_summary)\n",
|
||||
"\n",
|
||||
"# Multi-modal RAG chains\n",
|
||||
"chain_multimodal_mv_img=multi_modal_rag_chain(retriever_multi_vector_img)\n",
|
||||
"chain_multimodal_embd=multi_modal_rag_chain(retriever_multimodal_embd)"
|
||||
"chain_multimodal_mv_img = multi_modal_rag_chain(retriever_multi_vector_img)\n",
|
||||
"chain_multimodal_embd = multi_modal_rag_chain(retriever_multimodal_embd)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -694,7 +715,8 @@
|
||||
"source": [
|
||||
"# Read\n",
|
||||
"import pandas as pd\n",
|
||||
"eval_set = pd.read_csv(path+'cpi_eval.csv')\n",
|
||||
"\n",
|
||||
"eval_set = pd.read_csv(path + \"cpi_eval.csv\")\n",
|
||||
"eval_set.head(3)"
|
||||
]
|
||||
},
|
||||
@@ -715,12 +737,12 @@
|
||||
"# Populate dataset\n",
|
||||
"for _, row in eval_set.iterrows():\n",
|
||||
" # Get Q, A\n",
|
||||
" q = row['Question']\n",
|
||||
" a = row['Answer']\n",
|
||||
" q = row[\"Question\"]\n",
|
||||
" a = row[\"Answer\"]\n",
|
||||
" # Use the values in your function\n",
|
||||
" client.create_example(inputs={\"question\": q}, \n",
|
||||
" outputs={\"answer\": a}, \n",
|
||||
" dataset_id=dataset.id)"
|
||||
" client.create_example(\n",
|
||||
" inputs={\"question\": q}, outputs={\"answer\": a}, dataset_id=dataset.id\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -764,17 +786,22 @@
|
||||
" evaluators=[\"qa\"],\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"def run_eval(chain,run_name,dataset_name):\n",
|
||||
"\n",
|
||||
"def run_eval(chain, run_name, dataset_name):\n",
|
||||
" _ = client.run_on_dataset(\n",
|
||||
" dataset_name=dataset_name,\n",
|
||||
" llm_or_chain_factory=lambda: (lambda x: x[\"question\"]+suffix_for_images) | chain,\n",
|
||||
" llm_or_chain_factory=lambda: (lambda x: x[\"question\"] + suffix_for_images)\n",
|
||||
" | chain,\n",
|
||||
" evaluation=eval_config,\n",
|
||||
" project_name=run_name,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"for chain, run in zip([chain_baseline, chain_mv_text, chain_multimodal_mv_img, chain_multimodal_embd], \n",
|
||||
" [\"baseline\", \"mv_text\", \"mv_img\", \"mm_embd\"]):\n",
|
||||
" run_eval(chain, dataset_name+\"-\"+run, dataset_name)"
|
||||
"\n",
|
||||
"for chain, run in zip(\n",
|
||||
" [chain_baseline, chain_mv_text, chain_multimodal_mv_img, chain_multimodal_embd],\n",
|
||||
" [\"baseline\", \"mv_text\", \"mv_img\", \"mm_embd\"],\n",
|
||||
"):\n",
|
||||
" run_eval(chain, dataset_name + \"-\" + run, dataset_name)"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
Reference in New Issue
Block a user