Fine-tuning Gemma on local hardware#

In this notebook we are fine-tuning the google/gemma-2b-it model using huggingface infrastructure, locally on our computer using a consumer-grade graphics card with 16 GB of memory. Gemma is provided under and subject to the Gemma Terms of Use found at ai.google.dev/gemma/terms. This notebook was written, modifying code from this article about fine-tuning LLama 3, which is highly recommeded.

Read more#

Troubleshooting#

  • If you run this notebook on Windows and receive error messages mentioning that CUDA initialization failed, make sure you have bitsandbytes version 0.43.2 or larger installed.

  • If you run out of GPU memory, make sure you use the right hardware. This notebook was developed using an RTX 3080 mobile GPU with 16 GB of memory.

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging,
)
from peft import (
    LoraConfig,
    PeftModel,
    prepare_model_for_kbit_training,
    get_peft_model,
)
import os, torch
from datasets import load_dataset
from trl import SFTTrainer, setup_chat_format
from functools import partial

First, we define the model we want to fine-tune and the name under which we will store the new fine-tuned model.

base_model = "google/gemma-2b-it"
#"google/codegemma-1.1-7b-it"
#"google/gemma-2b"
#"google/codegemma-2b"
#"google/gemma-2b-it"
new_model = "gemma-2b-it-bia-proof-of-concept2"

Configuration#

torch_dtype = torch.float16
attn_implementation = "eager"

We will use the QLoRA fine-tuning scheme, simply to save memory.

# QLoRA config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch_dtype,
    bnb_4bit_use_double_quant=True,
)

Initialization of model and tokenizer#

Here we download and initialize the model and initialize the tokenizer.

# Load model
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    quantization_config=bnb_config,
    device_map="auto",
    attn_implementation=attn_implementation
)
`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model)
model, tokenizer = setup_chat_format(model, tokenizer)
tokenizer.padding_side = 'right'
# LoRA config
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=['up_proj', 'down_proj', 'gate_proj', 'k_proj', 'q_proj', 'v_proj', 'o_proj']
)
model = get_peft_model(model, peft_config)

Dataset preparation#

Next we load our dataset for fine-tuning, from its Huggingface Hub page.

dataset_name = "haesleinhuepf/bio-image-analysis-qa"
dataset_raw = load_dataset(dataset_name, split="all")

def format_chat_template(row, tokenizer):
    row_json = [{"role": "user", "content": row["question"]},
               {"role": "assistant", "content": row["answer"]}]
    row["text"] = tokenizer.apply_chat_template(row_json, tokenize=False)
    return row

format_chat_template_partial = partial(format_chat_template, tokenizer=tokenizer)

dataset_w_template = dataset_raw.map(
    format_chat_template_partial,
    num_proc=4,
)

print(dataset_w_template['text'][3])
<|im_start|>user
How can we use indices in Python to crop images, similar to cropping lists and tuples?<|im_end|>
<|im_start|>assistant

This code imports the necessary functions from the skimage.io module. It then reads an image called "blobs.tif" and assigns it to the variable 'image'. It crops the image, taking the first 128 rows, and assigns the result to 'cropped_image1'. The cropped image is then displayed using the 'imshow' function. Lastly, a list of numbers is created called 'mylist'.

```python
from skimage.io import imread, imshow, imshow

image = imread("../../data/blobs.tif")

cropped_image1 = image[0:128]

imshow(cropped_image1);

mylist = [1,2,2,3,4,5,78]
```
<|im_end|>

We then split the data into two sets: for training and for testing.

dataset_train_test = dataset_w_template.train_test_split(test_size=0.3)
dataset_train_test
DatasetDict({
    train: Dataset({
        features: ['question', 'answer', 'text'],
        num_rows: 91
    })
    test: Dataset({
        features: ['question', 'answer', 'text'],
        num_rows: 39
    })
})
training_arguments = TrainingArguments(
    output_dir=new_model,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=2,
    optim="paged_adamw_32bit",
    num_train_epochs=10,
    eval_strategy="steps",
    eval_steps=0.2,
    logging_steps=1,
    warmup_steps=10,
    logging_strategy="steps",
    learning_rate=2e-4,
    fp16=False,
    bf16=False,
    group_by_length=True,
    report_to="none"
)
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset_train_test["train"],
    eval_dataset=dataset_train_test["test"],
    peft_config=peft_config,
    max_seq_length=512,
    dataset_text_field="text",
    tokenizer=tokenizer,
    args=training_arguments,
    packing= False,
)
C:\Users\haase\miniconda3\envs\genai-gpu\Lib\site-packages\huggingface_hub\utils\_deprecation.py:100: FutureWarning: Deprecated argument(s) used in '__init__': max_seq_length, dataset_text_field. Will not be supported from version '1.0.0'.

Deprecated positional argument(s) used in SFTTrainer, please use the SFTConfig to set these arguments instead.
  warnings.warn(message, FutureWarning)
C:\Users\haase\miniconda3\envs\genai-gpu\Lib\site-packages\trl\trainer\sft_trainer.py:280: UserWarning: You passed a `max_seq_length` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`.
  warnings.warn(
C:\Users\haase\miniconda3\envs\genai-gpu\Lib\site-packages\trl\trainer\sft_trainer.py:318: UserWarning: You passed a `dataset_text_field` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`.
  warnings.warn(
trainer.train()
[450/450 03:26, Epoch 9/10]
Step Training Loss Validation Loss
90 0.702000 1.277717
180 0.208500 1.599397
270 0.158900 1.870157
360 0.060700 2.060952
450 0.043300 2.185464

C:\Users\haase\miniconda3\envs\genai-gpu\Lib\site-packages\peft\utils\save_and_load.py:232: UserWarning: Setting `save_embedding_layers` to `True` as the embedding layer has been resized during finetuning.
  warnings.warn(
TrainOutput(global_step=450, training_loss=0.5049639929375714, metrics={'train_runtime': 208.7607, 'train_samples_per_second': 4.359, 'train_steps_per_second': 2.156, 'total_flos': 2017628122275840.0, 'train_loss': 0.5049639929375714, 'epoch': 9.89010989010989})
trainer.save_model(new_model + "_temp")
C:\Users\haase\miniconda3\envs\genai-gpu\Lib\site-packages\peft\utils\save_and_load.py:232: UserWarning: Setting `save_embedding_layers` to `True` as the embedding layer has been resized during finetuning.
  warnings.warn(
trainer.model.save_pretrained(new_model)
trainer.model.push_to_hub(new_model, use_temp_dir=False)
CommitInfo(commit_url='https://huggingface.co/haesleinhuepf/gemma-2b-it-bia-proof-of-concept2/commit/0261693ce0b873fffb094bd96ac633b58f8b85bd', commit_message='Upload model', commit_description='', oid='0261693ce0b873fffb094bd96ac633b58f8b85bd', pr_url=None, pr_revision=None, pr_num=None)

Testing the model#

After the model is trained, we can do first tests with it.

messages = [{"role": "user", "content": """
Write Python code to load the image ../11a_prompt_engineering/data/blobs.tif,
segment the nuclei in it and
show the result
"""}]

prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    torch_dtype=torch.float16,
    device_map="auto",
)

outputs = pipe(prompt, max_new_tokens=120, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
print(outputs[0]["generated_text"])
The model 'PeftModelForCausalLM' is not supported for text-generation. Supported models are ['BartForCausalLM', 'BertLMHeadModel', 'BertGenerationDecoder', 'BigBirdForCausalLM', 'BigBirdPegasusForCausalLM', 'BioGptForCausalLM', 'BlenderbotForCausalLM', 'BlenderbotSmallForCausalLM', 'BloomForCausalLM', 'CamembertForCausalLM', 'LlamaForCausalLM', 'CodeGenForCausalLM', 'CohereForCausalLM', 'CpmAntForCausalLM', 'CTRLLMHeadModel', 'Data2VecTextForCausalLM', 'DbrxForCausalLM', 'ElectraForCausalLM', 'ErnieForCausalLM', 'FalconForCausalLM', 'FuyuForCausalLM', 'GemmaForCausalLM', 'Gemma2ForCausalLM', 'GitForCausalLM', 'GPT2LMHeadModel', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTNeoForCausalLM', 'GPTNeoXForCausalLM', 'GPTNeoXJapaneseForCausalLM', 'GPTJForCausalLM', 'JambaForCausalLM', 'JetMoeForCausalLM', 'LlamaForCausalLM', 'MambaForCausalLM', 'MarianForCausalLM', 'MBartForCausalLM', 'MegaForCausalLM', 'MegatronBertForCausalLM', 'MistralForCausalLM', 'MixtralForCausalLM', 'MptForCausalLM', 'MusicgenForCausalLM', 'MusicgenMelodyForCausalLM', 'MvpForCausalLM', 'OlmoForCausalLM', 'OpenLlamaForCausalLM', 'OpenAIGPTLMHeadModel', 'OPTForCausalLM', 'PegasusForCausalLM', 'PersimmonForCausalLM', 'PhiForCausalLM', 'Phi3ForCausalLM', 'PLBartForCausalLM', 'ProphetNetForCausalLM', 'QDQBertLMHeadModel', 'Qwen2ForCausalLM', 'Qwen2MoeForCausalLM', 'RecurrentGemmaForCausalLM', 'ReformerModelWithLMHead', 'RemBertForCausalLM', 'RobertaForCausalLM', 'RobertaPreLayerNormForCausalLM', 'RoCBertForCausalLM', 'RoFormerForCausalLM', 'RwkvForCausalLM', 'Speech2Text2ForCausalLM', 'StableLmForCausalLM', 'Starcoder2ForCausalLM', 'TransfoXLLMHeadModel', 'TrOCRForCausalLM', 'WhisperForCausalLM', 'XGLMForCausalLM', 'XLMWithLMHeadModel', 'XLMProphetNetForCausalLM', 'XLMRobertaForCausalLM', 'XLMRobertaXLForCausalLM', 'XLNetLMHeadModel', 'XmodForCausalLM'].
<|im_start|>user

Write Python code to load the image ../11a_prompt_engineering/data/blobs.tif,
segment the nuclei in it and
show the result
<|im_end|>
<|im_start|>assistant
The code is importing the `imread` and `imshow` functions from the `skimage.io` module, and the `pyclesperanto_prototype` module is imported and assigned to the variable `cle`.

```python
from skimage.io import imread, imshow
import pyclesperanto_prototype as cle

cle.load_image('../11a_prompt_engineering/data/blobs.tif')
```
The loaded image is then assigned to the variable `image`.

```python
image = cle.load_image('../11a_