{ "cells": [ { "cell_type": "markdown", "id": "e6f397e4-5acd-4b64-aaef-496b6b604480", "metadata": {}, "source": [ "# Optimizing image generation prompting using CLIP scores\n", "In this notebook we will compare different prompts and measure a text-to-image similarity metric: The [CLIP score](https://lightning.ai/docs/torchmetrics/stable/multimodal/clip_score.html)." ] }, { "cell_type": "code", "execution_count": 1, "id": "450a4c57-6a80-4b47-8b18-ee62ea9e5705", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import numpy as np\n", "from skimage.io import imread\n", "import stackview\n", "from diffusers import DiffusionPipeline, AutoencoderTiny\n", "import pandas as pd\n", "import seaborn as sns\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", "id": "d969af2a-3870-4148-8e8b-94536a7efdf2", "metadata": {}, "source": [ "This is how the metric is initialized" ] }, { "cell_type": "code", "execution_count": 2, "id": "668f368c-75b8-4ff4-bfe7-51d1a6ac2c38", "metadata": {}, "outputs": [], "source": [ "from torchmetrics.multimodal.clip_score import CLIPScore\n", "metric = CLIPScore(model_name_or_path=\"openai/clip-vit-base-patch16\")" ] }, { "cell_type": "markdown", "id": "4815fcf7-7ac7-4321-961b-e1a0ba4bb8c7", "metadata": {}, "source": [ "We start with this example image." ] }, { "cell_type": "code", "execution_count": 3, "id": "e76fb935-d7e4-4faa-bd8a-40d2670e6259", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n", "\n", "\n", "\n", "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
shape(512, 512, 3)
dtypeuint8
size768.0 kB
min0
max255
\n", "\n", "
" ], "text/plain": [ "StackViewNDArray([[[176, 178, 179],\n", " [175, 178, 178],\n", " [177, 177, 180],\n", " ...,\n", " [182, 186, 188],\n", " [185, 188, 191],\n", " [191, 194, 197]],\n", "\n", " [[178, 180, 181],\n", " [178, 179, 181],\n", " [178, 180, 181],\n", " ...,\n", " [185, 189, 192],\n", " [187, 191, 192],\n", " [191, 195, 198]],\n", "\n", " [[181, 183, 185],\n", " [180, 182, 183],\n", " [180, 181, 183],\n", " ...,\n", " [190, 193, 196],\n", " [189, 193, 196],\n", " [192, 195, 198]],\n", "\n", " ...,\n", "\n", " [[125, 91, 66],\n", " [124, 90, 65],\n", " [123, 89, 65],\n", " ...,\n", " [137, 92, 64],\n", " [136, 91, 62],\n", " [135, 89, 61]],\n", "\n", " [[122, 88, 64],\n", " [121, 87, 63],\n", " [121, 87, 63],\n", " ...,\n", " [142, 96, 68],\n", " [142, 96, 68],\n", " [139, 94, 65]],\n", "\n", " [[120, 86, 62],\n", " [120, 86, 60],\n", " [119, 85, 61],\n", " ...,\n", " [144, 99, 70],\n", " [144, 99, 70],\n", " [142, 97, 68]]], dtype=uint8)" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "image = imread(\"data/real_cat.png\")\n", "stackview.insight(image)" ] }, { "cell_type": "code", "execution_count": 4, "id": "79111f68-bd94-4536-bed5-b70d1374d91d", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(25.3473)" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "score = metric(torch.as_tensor(image), \"cat\")\n", "score.detach()" ] }, { "cell_type": "code", "execution_count": 5, "id": "2f37d6eb-9bbd-4bc1-8535-a1a6206cc859", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "30.786287307739258" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "score = metric(torch.as_tensor(image), \"microscope\")\n", "float(score.detach())" ] }, { "cell_type": "markdown", "id": "1a0f745b-aae0-4bf8-9081-c20599a44a6d", "metadata": {}, "source": [ "## Recap: Generating images\n", "We will now use a prompt to generate an image and will measure if it shows a cat." ] }, { "cell_type": "code", "execution_count": 6, "id": "57dcb8e3-764d-41e1-9ea2-811d9c1f655a", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "dfb9c2beef8a4a1eb426d20f4bb7f955", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading pipeline components...: 0%| | 0/6 [00:00