{
"cells": [
{
"cell_type": "markdown",
"id": "e6f397e4-5acd-4b64-aaef-496b6b604480",
"metadata": {},
"source": [
"# Optimizing image generation prompting using CLIP scores\n",
"In this notebook we will compare different prompts and measure a text-to-image similarity metric: The [CLIP score](https://lightning.ai/docs/torchmetrics/stable/multimodal/clip_score.html)."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "450a4c57-6a80-4b47-8b18-ee62ea9e5705",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import numpy as np\n",
"from skimage.io import imread\n",
"import stackview\n",
"from diffusers import DiffusionPipeline, AutoencoderTiny\n",
"import pandas as pd\n",
"import seaborn as sns\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "markdown",
"id": "d969af2a-3870-4148-8e8b-94536a7efdf2",
"metadata": {},
"source": [
"This is how the metric is initialized"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "668f368c-75b8-4ff4-bfe7-51d1a6ac2c38",
"metadata": {},
"outputs": [],
"source": [
"from torchmetrics.multimodal.clip_score import CLIPScore\n",
"metric = CLIPScore(model_name_or_path=\"openai/clip-vit-base-patch16\")"
]
},
{
"cell_type": "markdown",
"id": "4815fcf7-7ac7-4321-961b-e1a0ba4bb8c7",
"metadata": {},
"source": [
"We start with this example image."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "e76fb935-d7e4-4faa-bd8a-40d2670e6259",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"\n",
" \n",
" | \n",
"\n",
"\n",
"\n",
"shape | (512, 512, 3) | \n",
"dtype | uint8 | \n",
"size | 768.0 kB | \n",
"min | 0 | max | 255 | \n",
" \n",
" \n",
" | \n",
"
\n",
"
"
],
"text/plain": [
"StackViewNDArray([[[176, 178, 179],\n",
" [175, 178, 178],\n",
" [177, 177, 180],\n",
" ...,\n",
" [182, 186, 188],\n",
" [185, 188, 191],\n",
" [191, 194, 197]],\n",
"\n",
" [[178, 180, 181],\n",
" [178, 179, 181],\n",
" [178, 180, 181],\n",
" ...,\n",
" [185, 189, 192],\n",
" [187, 191, 192],\n",
" [191, 195, 198]],\n",
"\n",
" [[181, 183, 185],\n",
" [180, 182, 183],\n",
" [180, 181, 183],\n",
" ...,\n",
" [190, 193, 196],\n",
" [189, 193, 196],\n",
" [192, 195, 198]],\n",
"\n",
" ...,\n",
"\n",
" [[125, 91, 66],\n",
" [124, 90, 65],\n",
" [123, 89, 65],\n",
" ...,\n",
" [137, 92, 64],\n",
" [136, 91, 62],\n",
" [135, 89, 61]],\n",
"\n",
" [[122, 88, 64],\n",
" [121, 87, 63],\n",
" [121, 87, 63],\n",
" ...,\n",
" [142, 96, 68],\n",
" [142, 96, 68],\n",
" [139, 94, 65]],\n",
"\n",
" [[120, 86, 62],\n",
" [120, 86, 60],\n",
" [119, 85, 61],\n",
" ...,\n",
" [144, 99, 70],\n",
" [144, 99, 70],\n",
" [142, 97, 68]]], dtype=uint8)"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"image = imread(\"data/real_cat.png\")\n",
"stackview.insight(image)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "79111f68-bd94-4536-bed5-b70d1374d91d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(25.3473)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"score = metric(torch.as_tensor(image), \"cat\")\n",
"score.detach()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "2f37d6eb-9bbd-4bc1-8535-a1a6206cc859",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"30.786287307739258"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"score = metric(torch.as_tensor(image), \"microscope\")\n",
"float(score.detach())"
]
},
{
"cell_type": "markdown",
"id": "1a0f745b-aae0-4bf8-9081-c20599a44a6d",
"metadata": {},
"source": [
"## Recap: Generating images\n",
"We will now use a prompt to generate an image and will measure if it shows a cat."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "57dcb8e3-764d-41e1-9ea2-811d9c1f655a",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "dfb9c2beef8a4a1eb426d20f4bb7f955",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading pipeline components...: 0%| | 0/6 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"pipe = DiffusionPipeline.from_pretrained(\n",
" \"stabilityai/stable-diffusion-2-1-base\", torch_dtype=torch.float16\n",
")\n",
"pipe = pipe.to(\"cuda\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "861a15dd-9041-4283-84ed-4ac57d165b08",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d4b0634c8c37433a9d0cb38e78b3e129",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/50 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"prompt = \"Draw a realistic photo of a cat.\"\n",
"\n",
"cat = pipe(prompt).images[0]\n",
"cat"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c51c842a-31a4-4db7-985e-d7e1f0fb07ed",
"metadata": {},
"outputs": [],
"source": [
"score = metric(torch.as_tensor(np.array(cat)), \"cat\")\n",
"np.asarray(score.detach())"
]
},
{
"cell_type": "markdown",
"id": "8041cdcd-48a6-4f33-990e-9f8c61a14638",
"metadata": {},
"source": [
"## Benchmarking prompts\n",
"To compare different prompts quantitatively, we run image generation in a loop and measure their quality. As a control, we also generate images that show dogs and no cats.\n",
"\n",
"Using this strategy, we can do prompt engineering / prompt optimization in an informed way."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0787aeb9-cbe6-4c09-bc55-77261b3a4e4d",
"metadata": {},
"outputs": [],
"source": [
"num_attempts = 2\n",
"prompts = [\"Draw a realistic photo of a cat.\", \n",
" \"Draw a cat\",\n",
" \"cat\", \n",
" \"Draw a realistic photo of a dog.\"]\n",
"\n",
"data = {\"prompt\":[],\n",
" \"score\":[]}\n",
"for prompt in prompts:\n",
" for i in range(num_attempts):\n",
" image = pipe(prompt).images[0]\n",
"\n",
" score = metric(torch.as_tensor(np.array(image)), \"cat\")\n",
" data[\"score\"].append(float(score.detach()))\n",
" data[\"prompt\"].append(prompt)\n",
"\n",
" print(f\"{prompt}: {score}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2eaa61f9-979b-471f-93ce-7047ddb9db5d",
"metadata": {},
"outputs": [],
"source": [
"data = pd.DataFrame(data)\n",
"data"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "596379f8-eba5-4d68-8e1e-5a347329193f",
"metadata": {},
"outputs": [],
"source": [
"# Plotting the boxplot\n",
"plt.figure(figsize=(10, 6))\n",
"sns.boxplot(x='prompt', y='score', data=data)\n",
"plt.title('Boxplot of Scores by Prompt')\n",
"plt.xlabel('Prompt')\n",
"plt.ylabel('Score')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "1a14a0c5-21a8-4665-93c7-23363bb6facc",
"metadata": {},
"source": [
"## Exercise\n",
"Append more prompts to the list above. Attempt to improve the score for the image showing a cat."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e382e4da-7c08-4d63-8ed9-5f84aceed554",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"id": "c004aeb1-c666-4c61-8134-5ad4b8cf57e9",
"metadata": {},
"source": [
"Measure which of the prompts produces the most photorealistic images. "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3d5144e7-7cc0-4bf3-9e93-62e56d45720b",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}