Lint Python notebooks with ruff. (#12677)

The new ruff version fixed the blocking bugs, and I was able to fairly
easily us to a passing state: ruff fixed some issues on its own, I fixed
a handful by hand, and I added a list of narrowly-targeted exclusions
for files that are currently failing ruff rules that we probably should
look into eventually.

I went pretty lenient on the docs / cookbooks rules, allowing dead code
and such things. Perhaps in the future we may want to tighten the rules
further, but this is already a good set of checks that found real issues
and will prevent them going forward.
This commit is contained in:
Predrag Gruevski
2023-11-14 15:58:22 -05:00
committed by GitHub
parent 344cab0739
commit 2ebd167dba
189 changed files with 2249 additions and 2362 deletions

View File

@@ -63,11 +63,13 @@
"\n",
"# Load\n",
"from langchain.document_loaders import PyPDFLoader\n",
"\n",
"loader = PyPDFLoader(path + \"cpi.pdf\")\n",
"pdf_pages = loader.load()\n",
"\n",
"# Split\n",
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
"\n",
"text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0)\n",
"all_splits_pypdf = text_splitter.split_documents(pdf_pages)\n",
"all_splits_pypdf_texts = [d.page_content for d in all_splits_pypdf]"
@@ -132,10 +134,13 @@
"source": [
"from langchain.vectorstores import Chroma\n",
"from langchain.embeddings import OpenAIEmbeddings\n",
"baseline = Chroma.from_texts(texts=all_splits_pypdf_texts,\n",
" collection_name=\"baseline\",\n",
" embedding=OpenAIEmbeddings())\n",
"retriever_baseline=baseline.as_retriever()"
"\n",
"baseline = Chroma.from_texts(\n",
" texts=all_splits_pypdf_texts,\n",
" collection_name=\"baseline\",\n",
" embedding=OpenAIEmbeddings(),\n",
")\n",
"retriever_baseline = baseline.as_retriever()"
]
},
{
@@ -160,7 +165,7 @@
"from langchain.schema.output_parser import StrOutputParser\n",
"\n",
"# Prompt\n",
"prompt_text = \"\"\"You are an assistant tasked with summarizing tables and text for retrieval. \\ \n",
"prompt_text = \"\"\"You are an assistant tasked with summarizing tables and text for retrieval. \\\n",
"These summaries will be embedded and used to retrieve the raw text or table elements. \\\n",
"Give a concise summary of the table or text that is well optimized for retrieval. Table or text: {element} \"\"\"\n",
"prompt = ChatPromptTemplate.from_template(prompt_text)\n",
@@ -169,7 +174,7 @@
"model = ChatOpenAI(temperature=0, model=\"gpt-4\")\n",
"summarize_chain = {\"element\": lambda x: x} | prompt | model | StrOutputParser()\n",
"\n",
"# Apply to text \n",
"# Apply to text\n",
"text_summaries = summarize_chain.batch(texts, {\"max_concurrency\": 5})\n",
"\n",
"# Apply to tables\n",
@@ -192,31 +197,32 @@
"outputs": [],
"source": [
"# Image summary chain\n",
"import os, base64, io\n",
"import os\n",
"import base64\n",
"import io\n",
"from io import BytesIO\n",
"from PIL import Image\n",
"from langchain.schema.messages import HumanMessage\n",
"\n",
"def encode_image(image_path):\n",
" ''' Getting the base64 string '''\n",
" with open(image_path, \"rb\") as image_file:\n",
" return base64.b64encode(image_file.read()).decode('utf-8') \n",
"\n",
"def image_summarize(img_base64,prompt):\n",
" ''' Image summary '''\n",
" chat = ChatOpenAI(model=\"gpt-4-vision-preview\",\n",
" max_tokens=1024)\n",
" \n",
"def encode_image(image_path):\n",
" \"\"\"Getting the base64 string\"\"\"\n",
" with open(image_path, \"rb\") as image_file:\n",
" return base64.b64encode(image_file.read()).decode(\"utf-8\")\n",
"\n",
"\n",
"def image_summarize(img_base64, prompt):\n",
" \"\"\"Image summary\"\"\"\n",
" chat = ChatOpenAI(model=\"gpt-4-vision-preview\", max_tokens=1024)\n",
"\n",
" msg = chat.invoke(\n",
" [\n",
" HumanMessage(\n",
" content=[\n",
" {\"type\": \"text\", \"text\":prompt},\n",
" {\"type\": \"text\", \"text\": prompt},\n",
" {\n",
" \"type\": \"image_url\",\n",
" \"image_url\": {\n",
" \"url\": f\"data:image/jpeg;base64,{img_base64}\"\n",
" },\n",
" \"image_url\": {\"url\": f\"data:image/jpeg;base64,{img_base64}\"},\n",
" },\n",
" ]\n",
" )\n",
@@ -224,6 +230,7 @@
" )\n",
" return msg.content\n",
"\n",
"\n",
"# Store base64 encoded images\n",
"img_base64_list = []\n",
"\n",
@@ -231,17 +238,17 @@
"image_summaries = []\n",
"\n",
"# Prompt\n",
"prompt = \"\"\"You are an assistant tasked with summarizing images for retrieval. \\ \n",
"prompt = \"\"\"You are an assistant tasked with summarizing images for retrieval. \\\n",
"These summaries will be embedded and used to retrieve the raw image. \\\n",
"Give a concise summary of the image that is well optimized for retrieval.\"\"\"\n",
"\n",
"# Apply to images\n",
"for img_file in sorted(os.listdir(path)):\n",
" if img_file.endswith('.jpg'):\n",
" if img_file.endswith(\".jpg\"):\n",
" img_path = os.path.join(path, img_file)\n",
" base64_image = encode_image(img_path)\n",
" img_base64_list.append(base64_image)\n",
" image_summaries.append(image_summarize(base64_image,prompt))"
" image_summaries.append(image_summarize(base64_image, prompt))"
]
},
{
@@ -267,14 +274,10 @@
"from langchain.schema.document import Document\n",
"from langchain.retrievers.multi_vector import MultiVectorRetriever\n",
"\n",
"def create_multi_vector_retriever(vectorstore, \n",
" text_summaries, \n",
" texts, \n",
" table_summaries, \n",
" tables, \n",
" image_summaries, \n",
" images):\n",
" \n",
"\n",
"def create_multi_vector_retriever(\n",
" vectorstore, text_summaries, texts, table_summaries, tables, image_summaries, images\n",
"):\n",
" # Initialize the storage layer\n",
" store = InMemoryStore()\n",
" id_key = \"doc_id\"\n",
@@ -309,18 +312,22 @@
"\n",
" return retriever\n",
"\n",
"\n",
"# The vectorstore to use to index the summaries\n",
"multi_vector_img = Chroma(collection_name=\"multi_vector_img\", \n",
" embedding_function=OpenAIEmbeddings())\n",
"multi_vector_img = Chroma(\n",
" collection_name=\"multi_vector_img\", embedding_function=OpenAIEmbeddings()\n",
")\n",
"\n",
"# Create retriever\n",
"retriever_multi_vector_img = create_multi_vector_retriever(multi_vector_img,\n",
" text_summaries,\n",
" texts,\n",
" table_summaries, \n",
" tables, \n",
" image_summaries, \n",
" img_base64_list)"
"retriever_multi_vector_img = create_multi_vector_retriever(\n",
" multi_vector_img,\n",
" text_summaries,\n",
" texts,\n",
" table_summaries,\n",
" tables,\n",
" image_summaries,\n",
" img_base64_list,\n",
")"
]
},
{
@@ -330,10 +337,10 @@
"metadata": {},
"outputs": [],
"source": [
"# Testing on retrieval \n",
"query=\"What percentage of CPI is dedicated to Housing, and how does it compare to the combined percentage of Medical Care, Apparel, and Other Goods and Services?\"\n",
"suffix_for_images=\" Include any pie charts, graphs, or tables.\"\n",
"docs = retriever_multi_vector_img.get_relevant_documents(query+suffix_for_images)"
"# Testing on retrieval\n",
"query = \"What percentage of CPI is dedicated to Housing, and how does it compare to the combined percentage of Medical Care, Apparel, and Other Goods and Services?\"\n",
"suffix_for_images = \" Include any pie charts, graphs, or tables.\"\n",
"docs = retriever_multi_vector_img.get_relevant_documents(query + suffix_for_images)"
]
},
{
@@ -357,14 +364,16 @@
],
"source": [
"from IPython.display import display, HTML\n",
"def plt_img_base64(img_base64):\n",
"\n",
"\n",
"def plt_img_base64(img_base64):\n",
" # Create an HTML img tag with the base64 string as the source\n",
" image_html = f'<img src=\"data:image/jpeg;base64,{img_base64}\" />'\n",
" \n",
"\n",
" # Display the image by rendering the HTML\n",
" display(HTML(image_html))\n",
"\n",
"\n",
"plt_img_base64(docs[1])"
]
},
@@ -386,17 +395,20 @@
"outputs": [],
"source": [
"# The vectorstore to use to index the summaries\n",
"multi_vector_text = Chroma(collection_name=\"multi_vector_text\", \n",
" embedding_function=OpenAIEmbeddings())\n",
"multi_vector_text = Chroma(\n",
" collection_name=\"multi_vector_text\", embedding_function=OpenAIEmbeddings()\n",
")\n",
"\n",
"# Create retriever\n",
"retriever_multi_vector_img_summary = create_multi_vector_retriever(multi_vector_text,\n",
" text_summaries,\n",
" texts,\n",
" table_summaries, \n",
" tables, \n",
" image_summaries, \n",
" image_summaries)"
"retriever_multi_vector_img_summary = create_multi_vector_retriever(\n",
" multi_vector_text,\n",
" text_summaries,\n",
" texts,\n",
" table_summaries,\n",
" tables,\n",
" image_summaries,\n",
" image_summaries,\n",
")"
]
},
{
@@ -418,14 +430,17 @@
"\n",
"# Create chroma w/ multi-modal embeddings\n",
"multimodal_embd = Chroma(\n",
" collection_name=\"multimodal_embd\",\n",
" embedding_function=OpenCLIPEmbeddings()\n",
" collection_name=\"multimodal_embd\", embedding_function=OpenCLIPEmbeddings()\n",
")\n",
"\n",
"# Get image URIs\n",
"image_uris = sorted([os.path.join(path, image_name) \n",
" for image_name in os.listdir(path) \n",
" if image_name.endswith('.jpg')])\n",
"image_uris = sorted(\n",
" [\n",
" os.path.join(path, image_name)\n",
" for image_name in os.listdir(path)\n",
" if image_name.endswith(\".jpg\")\n",
" ]\n",
")\n",
"\n",
"# Add images and documents\n",
"if image_uris:\n",
@@ -435,7 +450,7 @@
"if tables:\n",
" multimodal_embd.add_texts(texts=tables)\n",
"\n",
"# Make retriever \n",
"# Make retriever\n",
"retriever_multimodal_embd = multimodal_embd.as_retriever()"
]
},
@@ -466,14 +481,14 @@
"\"\"\"\n",
"rag_prompt_text = ChatPromptTemplate.from_template(template)\n",
"\n",
"# Build \n",
"\n",
"# Build\n",
"def text_rag_chain(retriever):\n",
" \n",
" ''' RAG chain '''\n",
" \"\"\"RAG chain\"\"\"\n",
"\n",
" # LLM\n",
" model = ChatOpenAI(temperature=0, model=\"gpt-4\")\n",
" \n",
"\n",
" # RAG pipeline\n",
" chain = (\n",
" {\"context\": retriever, \"question\": RunnablePassthrough()}\n",
@@ -500,13 +515,15 @@
"metadata": {},
"outputs": [],
"source": [
"import re \n",
"import re\n",
"from langchain.schema import Document\n",
"from langchain.schema.runnable import RunnableLambda\n",
"\n",
"\n",
"def looks_like_base64(sb):\n",
" \"\"\"Check if the string looks like base64.\"\"\"\n",
" return re.match('^[A-Za-z0-9+/]+[=]{0,2}$', sb) is not None\n",
" return re.match(\"^[A-Za-z0-9+/]+[=]{0,2}$\", sb) is not None\n",
"\n",
"\n",
"def is_image_data(b64data):\n",
" \"\"\"Check if the base64 data is an image by looking at the start of the data.\"\"\"\n",
@@ -514,7 +531,7 @@
" b\"\\xFF\\xD8\\xFF\": \"jpg\",\n",
" b\"\\x89\\x50\\x4E\\x47\\x0D\\x0A\\x1A\\x0A\": \"png\",\n",
" b\"\\x47\\x49\\x46\\x38\": \"gif\",\n",
" b\"\\x52\\x49\\x46\\x46\": \"webp\"\n",
" b\"\\x52\\x49\\x46\\x46\": \"webp\",\n",
" }\n",
" try:\n",
" header = base64.b64decode(b64data)[:8] # Decode and get the first 8 bytes\n",
@@ -525,6 +542,7 @@
" except Exception:\n",
" return False\n",
"\n",
"\n",
"def split_image_text_types(docs):\n",
" \"\"\"Split base64-encoded images and texts.\"\"\"\n",
" b64_images = []\n",
@@ -539,6 +557,7 @@
" texts.append(doc)\n",
" return {\"images\": b64_images, \"texts\": texts}\n",
"\n",
"\n",
"def img_prompt_func(data_dict):\n",
" # Joining the context texts into a single string\n",
" formatted_texts = \"\\n\".join(data_dict[\"context\"][\"texts\"])\n",
@@ -550,7 +569,7 @@
" \"type\": \"image_url\",\n",
" \"image_url\": {\n",
" \"url\": f\"data:image/jpeg;base64,{data_dict['context']['images'][0]}\"\n",
" }\n",
" },\n",
" }\n",
" messages.append(image_message)\n",
"\n",
@@ -563,22 +582,24 @@
" f\"User-provided question / keywords: {data_dict['question']}\\n\\n\"\n",
" \"Text and / or tables:\\n\"\n",
" f\"{formatted_texts}\"\n",
" )\n",
" ),\n",
" }\n",
" messages.append(text_message)\n",
" return [HumanMessage(content=messages)]\n",
"\n",
"\n",
"def multi_modal_rag_chain(retriever):\n",
" ''' Multi-modal RAG chain '''\n",
" \"\"\"Multi-modal RAG chain\"\"\"\n",
"\n",
" # Multi-modal LLM\n",
" model = ChatOpenAI(temperature=0, \n",
" model=\"gpt-4-vision-preview\", \n",
" max_tokens=1024)\n",
" \n",
" model = ChatOpenAI(temperature=0, model=\"gpt-4-vision-preview\", max_tokens=1024)\n",
"\n",
" # RAG pipeline\n",
" chain = (\n",
" {\"context\": retriever | RunnableLambda(split_image_text_types), \"question\": RunnablePassthrough()}\n",
" {\n",
" \"context\": retriever | RunnableLambda(split_image_text_types),\n",
" \"question\": RunnablePassthrough(),\n",
" }\n",
" | RunnableLambda(img_prompt_func)\n",
" | model\n",
" | StrOutputParser()\n",
@@ -603,12 +624,12 @@
"outputs": [],
"source": [
"# RAG chains\n",
"chain_baseline=text_rag_chain(retriever_baseline)\n",
"chain_mv_text=text_rag_chain(retriever_multi_vector_img_summary)\n",
"chain_baseline = text_rag_chain(retriever_baseline)\n",
"chain_mv_text = text_rag_chain(retriever_multi_vector_img_summary)\n",
"\n",
"# Multi-modal RAG chains\n",
"chain_multimodal_mv_img=multi_modal_rag_chain(retriever_multi_vector_img)\n",
"chain_multimodal_embd=multi_modal_rag_chain(retriever_multimodal_embd)"
"chain_multimodal_mv_img = multi_modal_rag_chain(retriever_multi_vector_img)\n",
"chain_multimodal_embd = multi_modal_rag_chain(retriever_multimodal_embd)"
]
},
{
@@ -694,7 +715,8 @@
"source": [
"# Read\n",
"import pandas as pd\n",
"eval_set = pd.read_csv(path+'cpi_eval.csv')\n",
"\n",
"eval_set = pd.read_csv(path + \"cpi_eval.csv\")\n",
"eval_set.head(3)"
]
},
@@ -715,12 +737,12 @@
"# Populate dataset\n",
"for _, row in eval_set.iterrows():\n",
" # Get Q, A\n",
" q = row['Question']\n",
" a = row['Answer']\n",
" q = row[\"Question\"]\n",
" a = row[\"Answer\"]\n",
" # Use the values in your function\n",
" client.create_example(inputs={\"question\": q}, \n",
" outputs={\"answer\": a}, \n",
" dataset_id=dataset.id)"
" client.create_example(\n",
" inputs={\"question\": q}, outputs={\"answer\": a}, dataset_id=dataset.id\n",
" )"
]
},
{
@@ -764,17 +786,22 @@
" evaluators=[\"qa\"],\n",
")\n",
"\n",
"def run_eval(chain,run_name,dataset_name):\n",
"\n",
"def run_eval(chain, run_name, dataset_name):\n",
" _ = client.run_on_dataset(\n",
" dataset_name=dataset_name,\n",
" llm_or_chain_factory=lambda: (lambda x: x[\"question\"]+suffix_for_images) | chain,\n",
" llm_or_chain_factory=lambda: (lambda x: x[\"question\"] + suffix_for_images)\n",
" | chain,\n",
" evaluation=eval_config,\n",
" project_name=run_name,\n",
" )\n",
"\n",
"for chain, run in zip([chain_baseline, chain_mv_text, chain_multimodal_mv_img, chain_multimodal_embd], \n",
" [\"baseline\", \"mv_text\", \"mv_img\", \"mm_embd\"]):\n",
" run_eval(chain, dataset_name+\"-\"+run, dataset_name)"
"\n",
"for chain, run in zip(\n",
" [chain_baseline, chain_mv_text, chain_multimodal_mv_img, chain_multimodal_embd],\n",
" [\"baseline\", \"mv_text\", \"mv_img\", \"mm_embd\"],\n",
"):\n",
" run_eval(chain, dataset_name + \"-\" + run, dataset_name)"
]
}
],