Fine-Tuning FLAN-T5 with Reinforcement Learning (PPO) and PEFT to Generate Less-Toxic Summaries

In this project we will fine-tune a FLAN-T5 model to generate less toxic content with Meta AI’s hate speech reward model
natural-language-processing
deep-learning
aws
hugging-face
fine-tuning
Author

Pranath Fernando

Published

July 18, 2023

1 Introduction

In this project, we will fine-tune a FLAN-T5 model to generate less toxic content with Meta AI’s hate speech reward model. The reward model is a binary classifier that predicts either “not hate” or “hate” for the given text. We will use Proximal Policy Optimization (PPO) to fine-tune and reduce the model’s toxicity.

2 Set up Required Dependencies

from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM, GenerationConfig
from datasets import load_dataset
from peft import PeftModel, PeftConfig, LoraConfig, TaskType

# trl: Transformer Reinforcement Learning library
from trl import PPOTrainer, PPOConfig, AutoModelForSeq2SeqLMWithValueHead
from trl import create_reference_model
from trl.core import LengthSampler

import torch
import evaluate

import numpy as np
import pandas as pd

# tqdm library makes the loops show a smart progress meter.
from tqdm import tqdm
tqdm.pandas()

3 Load FLAN-T5 Model - Prepare Reward Model and Toxicity Evaluator

3.1 Load Data and FLAN-T5 Model Fine-Tuned with Summarization Instruction

We will use the Hugging Face dataset DialogSum and the pre-trained model FLAN-T5.

model_name="google/flan-t5-base"
huggingface_dataset_name = "knkarthick/dialogsum"

dataset_original = load_dataset(huggingface_dataset_name)

dataset_original
Downloading and preparing dataset csv/knkarthick--dialogsum to /root/.cache/huggingface/datasets/knkarthick___csv/knkarthick--dialogsum-391706c81424fc80/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1...
Dataset csv downloaded and prepared to /root/.cache/huggingface/datasets/knkarthick___csv/knkarthick--dialogsum-391706c81424fc80/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1. Subsequent calls will reuse this data.
DatasetDict({
    train: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 12460
    })
    test: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 1500
    })
    validation: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 500
    })
})

The next step will be to preprocess the dataset. We will take only a part of it, then filter the dialogues of a particular length (just to make those examples long enough and, at the same time, easy to read). Then wrap each dialogue with the instruction and tokenize the prompts. Save the token ids in the field input_ids and decoded version of the prompts in the field query.

We could do that all step by step in the cell below, but it is a good habit to organize that all in a function build_dataset:

def build_dataset(model_name,
                  dataset_name,
                  input_min_text_length, 
                  input_max_text_length):

    """
    Preprocess the dataset and split it into train and test parts.

    Parameters:
    - model_name (str): Tokenizer model name.
    - dataset_name (str): Name of the dataset to load.
    - input_min_text_length (int): Minimum length of the dialogues.
    - input_max_text_length (int): Maximum length of the dialogues.
        
    Returns:
    - dataset_splits (datasets.dataset_dict.DatasetDict): Preprocessed dataset containing train and test parts.
    """
    
    # load dataset (only "train" part will be enough for this lab).
    dataset = load_dataset(dataset_name, split="train")
    
    # Filter the dialogues of length between input_min_text_length and input_max_text_length characters.
    dataset = dataset.filter(lambda x: len(x["dialogue"]) > input_min_text_length and len(x["dialogue"]) <= input_max_text_length, batched=False)

    # Prepare tokenizer. Setting device_map="auto" allows to switch between GPU and CPU automatically.
    tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto")
    
    def tokenize(sample):
        
        # Wrap each dialogue with the instruction.
        prompt = f"""
Summarize the following conversation.

{sample["dialogue"]}

Summary:
"""
        sample["input_ids"] = tokenizer.encode(prompt)
        
        # This must be called "query", which is a requirement of our PPO library.
        sample["query"] = tokenizer.decode(sample["input_ids"])
        return sample

    # Tokenize each dialogue.
    dataset = dataset.map(tokenize, batched=False)
    dataset.set_format(type="torch")
    
    # Split the dataset into train and test parts.
    dataset_splits = dataset.train_test_split(test_size=0.2, shuffle=False, seed=42)

    return dataset_splits

dataset = build_dataset(model_name=model_name,
                        dataset_name=huggingface_dataset_name,
                        input_min_text_length=200, 
                        input_max_text_length=1000)

print(dataset)
Found cached dataset csv (/root/.cache/huggingface/datasets/knkarthick___csv/knkarthick--dialogsum-391706c81424fc80/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1)
DatasetDict({
    train: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic', 'input_ids', 'query'],
        num_rows: 8017
    })
    test: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic', 'input_ids', 'query'],
        num_rows: 2005
    })
})

In a previous project we fine-tuned the PEFT model with summarization instructions. The training in the notebook was done on a subset of data. Then we downloaded the checkpoint of the fully trained PEFT model from S3.

Let’s load the same model checkpoint here:

!aws s3 cp --recursive s3://dlai-generative-ai/models/peft-dialogue-summary-checkpoint/ ./peft-dialogue-summary-checkpoint-from-s3/ 
download: s3://dlai-generative-ai/models/peft-dialogue-summary-checkpoint/special_tokens_map.json to peft-dialogue-summary-checkpoint-from-s3/special_tokens_map.json
download: s3://dlai-generative-ai/models/peft-dialogue-summary-checkpoint/adapter_config.json to peft-dialogue-summary-checkpoint-from-s3/adapter_config.json
download: s3://dlai-generative-ai/models/peft-dialogue-summary-checkpoint/tokenizer_config.json to peft-dialogue-summary-checkpoint-from-s3/tokenizer_config.json
download: s3://dlai-generative-ai/models/peft-dialogue-summary-checkpoint/tokenizer.json to peft-dialogue-summary-checkpoint-from-s3/tokenizer.json
download: s3://dlai-generative-ai/models/peft-dialogue-summary-checkpoint/adapter_model.bin to peft-dialogue-summary-checkpoint-from-s3/adapter_model.bin

List the model item and check its size (it’s less than 15 Mb):

!ls -alh ./peft-dialogue-summary-checkpoint-from-s3/adapter_model.bin
-rw-r--r-- 1 root root 14M May 15 11:18 ./peft-dialogue-summary-checkpoint-from-s3/adapter_model.bin

Prepare a function to pull out the number of model parameters (it is the same as in the previous project):

def print_number_of_trainable_model_parameters(model):
    trainable_model_params = 0
    all_model_params = 0
    for _, param in model.named_parameters():
        all_model_params += param.numel()
        if param.requires_grad:
            trainable_model_params += param.numel()
    return f"\ntrainable model parameters: {trainable_model_params}\nall model parameters: {all_model_params}\npercentage of trainable model parameters: {100 * trainable_model_params / all_model_params:.2f}%"

Add the adapter to the original FLAN-T5 model. In the previous project we were adding the fully trained adapter only for inferences, so there was no need to pass LoRA configurations doing that. Now we need to pass them to the constructed PEFT model, also putting is_trainable=True.

lora_config = LoraConfig(
    r=32, # Rank
    lora_alpha=32,
    target_modules=["q", "v"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM # FLAN-T5
)

model = AutoModelForSeq2SeqLM.from_pretrained(model_name, 
                                              torch_dtype=torch.bfloat16)

peft_model = PeftModel.from_pretrained(model, 
                                       './peft-dialogue-summary-checkpoint-from-s3/', 
                                       lora_config=lora_config,
                                       torch_dtype=torch.bfloat16, 
                                       device_map="auto",                                       
                                       is_trainable=True)

print(f'PEFT model parameters to be updated:\n{print_number_of_trainable_model_parameters(peft_model)}\n')
PEFT model parameters to be updated:

trainable model parameters: 3538944
all model parameters: 251116800
percentage of trainable model parameters: 1.41%

In this project, we are preparing to fine-tune the LLM using Reinforcement Learning (RL). RL will be briefly discussed in the next section of this article, but at this stage, we just need to prepare the Proximal Policy Optimization (PPO) model passing the instruct-fine-tuned PEFT model to it. PPO will be used to optimize the RL policy against the reward model.

ppo_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(peft_model,                                                               
                                                               torch_dtype=torch.bfloat16,
                                                               is_trainable=True)

print(f'PPO model parameters to be updated (ValueHead + 769 params):\n{print_number_of_trainable_model_parameters(ppo_model)}\n')
print(ppo_model.v_head)
PPO model parameters to be updated (ValueHead + 769 params):

trainable model parameters: 3539713
all model parameters: 251117569
percentage of trainable model parameters: 1.41%

ValueHead(
  (dropout): Dropout(p=0.1, inplace=False)
  (summary): Linear(in_features=768, out_features=1, bias=True)
  (flatten): Flatten(start_dim=1, end_dim=-1)
)

During PPO, only a few parameters will be updated. Specifically, the parameters of the ValueHead. More information about this class of models can be found in the documentation. The number of trainable parameters can be computed as \((n+1)*m\), where \(n\) is the number of input units (here \(n=768\)) and \(m\) is the number of output units (you have \(m=1\)). The \(+1\) term in the equation takes into account the bias term.

Now create a frozen copy of the PPO which will not be fine-tuned - a reference model. The reference model will represent the LLM before detoxification. None of the parameters of the reference model will be updated during PPO training. This is on purpose.

ref_model = create_reference_model(ppo_model)

print(f'Reference model parameters to be updated:\n{print_number_of_trainable_model_parameters(ref_model)}\n')
Reference model parameters to be updated:

trainable model parameters: 0
all model parameters: 251117569
percentage of trainable model parameters: 0.00%

Everything is set. It is time to prepare the reward model!

3.2 Prepare Reward Model

Reinforcement Learning (RL) is one type of machine learning where agents take actions in an environment aimed at maximizing their cumulative rewards. The agent’s behavior is defined by the policy. And the goal of reinforcement learning is for the agent to learn an optimal, or nearly-optimal, policy that maximizes the reward function.

In the previous section the original policy is based on the instruct PEFT model - this is the LLM before detoxification. Then you could ask human labelers to give feedback on the outputs’ toxicity. However, it can be expensive to use them for the entire fine-tuning process. A practical way to avoid that is to use a reward model encouraging the agent to detoxify the dialogue summaries. The intuitive approach would be to do some form of sentiment analysis across two classes (nothate and hate) and give a higher reward if there is higher a chance of getting class nothate as an output.

For example, we can mention that having human labelers for the entire finetuning process can be expensive. A practical way to avoid that is to use a reward model.

We will use Meta AI’s RoBERTa-based hate speech model for the reward model. This model will output logits and then predict probabilities across two classes: nothate and hate. The logits of the output nothate will be taken as a positive reward. Then, the model will be fine-tuned with PPO using those reward values.

Create the instance of the required model class for the RoBERTa model. We also need to load a tokenizer to test the model. Notice that the model label 0 will correspond to the class nothate and label 1 to the class hate.

toxicity_model_name = "facebook/roberta-hate-speech-dynabench-r4-target"
toxicity_tokenizer = AutoTokenizer.from_pretrained(toxicity_model_name, device_map="auto")
toxicity_model = AutoModelForSequenceClassification.from_pretrained(toxicity_model_name, device_map="auto")
print(toxicity_model.config.id2label)
The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.
{0: 'nothate', 1: 'hate'}

Take some non-toxic text, tokenize it, and pass it to the model. Print the output logits, probabilities, and the corresponding reward that will be used for fine-tuning.

non_toxic_text = "#Person 1# tells Tommy that he didn't like the movie."

toxicity_input_ids = toxicity_tokenizer(non_toxic_text, return_tensors="pt").input_ids

logits = toxicity_model(input_ids=toxicity_input_ids).logits
print(f'logits [not hate, hate]: {logits.tolist()[0]}')

# Print the probabilities for [not hate, hate]
probabilities = logits.softmax(dim=-1).tolist()[0]
print(f'probabilities [not hate, hate]: {probabilities}')

# get the logits for "not hate" - this is the reward!
not_hate_index = 0
nothate_reward = (logits[:, not_hate_index]).tolist()
print(f'reward (high): {nothate_reward}')
logits [not hate, hate]: [3.114100694656372, -2.4896175861358643]
probabilities [not hate, hate]: [0.9963293671607971, 0.003670616541057825]
reward (high): [3.114100694656372]

Let’s show a toxic comment. This will have a low reward because it is more toxic.

toxic_text = "#Person 1# tells Tommy that the movie was terrible, dumb and stupid."

toxicity_input_ids = toxicity_tokenizer(toxic_text, return_tensors="pt").input_ids

logits = toxicity_model(toxicity_input_ids).logits
print(f'logits [not hate, hate]: {logits.tolist()[0]}')

# Print the probabilities for [not hate, hate]
probabilities = logits.softmax(dim=-1).tolist()[0]
print(f'probabilities [not hate, hate]: {probabilities}')

# Get the logits for "not hate" - this is the reward!
nothate_reward = (logits[:, not_hate_index]).tolist() 
print(f'reward (low): {nothate_reward}')
logits [not hate, hate]: [-0.6921188831329346, 0.3722729980945587]
probabilities [not hate, hate]: [0.25647106766700745, 0.7435289621353149]
reward (low): [-0.6921188831329346]

Setup Hugging Face inference pipeline to simplify the code for the toxicity reward model:

device = 0 if torch.cuda.is_available() else "cpu"

sentiment_pipe = pipeline("sentiment-analysis", 
                          model=toxicity_model_name, 
                          device=device)
reward_logits_kwargs = {
    "top_k": None, # Return all scores.
    "function_to_apply": "none", # Set to "none" to retrieve raw logits.
    "batch_size": 16
}

reward_probabilities_kwargs = {
    "top_k": None, # Return all scores.
    "function_to_apply": "softmax", # Set to "softmax" to apply softmax and retrieve probabilities.
    "batch_size": 16
}

print("Reward model output:")
print("For non-toxic text")
print(sentiment_pipe(non_toxic_text, **reward_logits_kwargs))
print(sentiment_pipe(non_toxic_text, **reward_probabilities_kwargs))
print("For toxic text")
print(sentiment_pipe(toxic_text, **reward_logits_kwargs))
print(sentiment_pipe(toxic_text, **reward_probabilities_kwargs))
Reward model output:
For non-toxic text
[{'label': 'nothate', 'score': 3.114100694656372}, {'label': 'hate', 'score': -2.4896175861358643}]
[{'label': 'nothate', 'score': 0.9963293671607971}, {'label': 'hate', 'score': 0.003670616541057825}]
For toxic text
[{'label': 'hate', 'score': 0.3722729980945587}, {'label': 'nothate', 'score': -0.6921188831329346}]
[{'label': 'hate', 'score': 0.7435289621353149}, {'label': 'nothate', 'score': 0.25647106766700745}]

The outputs are the logits for both nothate (positive) and hate (negative) classes. But PPO will be using logits only of the nothate class as the positive reward signal used to help detoxify the LLM outputs.

print(sentiment_pipe(non_toxic_text, **reward_logits_kwargs))
print(sentiment_pipe(non_toxic_text, **reward_probabilities_kwargs))
[{'label': 'nothate', 'score': 3.114100694656372}, {'label': 'hate', 'score': -2.4896175861358643}]
[{'label': 'nothate', 'score': 0.9963293671607971}, {'label': 'hate', 'score': 0.003670616541057825}]
print(sentiment_pipe(toxic_text, **reward_logits_kwargs))
print(sentiment_pipe(toxic_text, **reward_probabilities_kwargs))
[{'label': 'hate', 'score': 0.3722729980945587}, {'label': 'nothate', 'score': -0.6921188831329346}]
[{'label': 'hate', 'score': 0.7435289621353149}, {'label': 'nothate', 'score': 0.25647106766700745}]

3.3 Evaluate Toxicity

To evaluate the model before and after fine-tuning/detoxification we need to set up the toxicity evaluation metric. The toxicity score is a decimal value between 0 and 1 where 1 is the highest toxicity.

toxicity_evaluator = evaluate.load("toxicity", 
                                    toxicity_model_name,
                                    module_type="measurement",
                                    toxic_label="hate")

Lets try to calculate toxicity for the same sentences. It’s no surprise that the toxicity scores are the probabilities of hate class returned directly from the reward model.

toxicity_score = toxicity_evaluator.compute(predictions=[
    non_toxic_text
])

print("Toxicity score for non-toxic text:")
print(toxicity_score["toxicity"])

toxicity_score = toxicity_evaluator.compute(predictions=[
    toxic_text
])

print("\nToxicity score for toxic text:")
print(toxicity_score["toxicity"])
Toxicity score for non-toxic text:
[0.003670616541057825]

Toxicity score for toxic text:
[0.7435289621353149]

This evaluator can be used to compute the toxicity of the dialogues prepared previously. We will need to pass the test dataset (dataset["test"]), the same tokenizer which was used in that section, the frozen PEFT model prepared in section before, and the toxicity evaluator. It is convenient to wrap the required steps in the function evaluate_toxicity.

def evaluate_toxicity(model, 
                      toxicity_evaluator, 
                      tokenizer, 
                      dataset, 
                      num_samples):
    
    """
    Preprocess the dataset and split it into train and test parts.

    Parameters:
    - model (trl model): Model to be evaluated.
    - toxicity_evaluator (evaluate_modules toxicity metrics): Toxicity evaluator.
    - tokenizer (transformers tokenizer): Tokenizer to be used.
    - dataset (dataset): Input dataset for the evaluation.
    - num_samples (int): Maximum number of samples for the evaluation.
        
    Returns:
    tuple: A tuple containing two numpy.float64 values:
    - mean (numpy.float64): Mean of the samples toxicity.
    - std (numpy.float64): Standard deviation of the samples toxicity.
    """

    max_new_tokens=100

    toxicities = []
    input_texts = []
    for i, sample in tqdm(enumerate(dataset)):
        input_text = sample["query"]

        if i > num_samples:
            break
            
        input_ids = tokenizer(input_text, return_tensors="pt", padding=True).input_ids
        
        generation_config = GenerationConfig(max_new_tokens=max_new_tokens,
                                             top_k=0.0,
                                             top_p=1.0,
                                             do_sample=True)

        response_token_ids = model.generate(input_ids=input_ids,
                                            generation_config=generation_config)
        
        generated_text = tokenizer.decode(response_token_ids[0], skip_special_tokens=True)
        
        toxicity_score = toxicity_evaluator.compute(predictions=[(input_text + " " + generated_text)])

        toxicities.extend(toxicity_score["toxicity"])

    # Compute mean & std using np.
    mean = np.mean(toxicities)
    std = np.std(toxicities)
        
    return mean, std

And now perform the calculation of the model toxicity before fine-tuning/detoxification:

tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto")

mean_before_detoxification, std_before_detoxification = evaluate_toxicity(model=ref_model, 
                                                                          toxicity_evaluator=toxicity_evaluator, 
                                                                          tokenizer=tokenizer, 
                                                                          dataset=dataset["test"], 
                                                                          num_samples=10)

print(f'toxicity [mean, std] before detox: [{mean_before_detoxification}, {std_before_detoxification}]')
11it [00:25,  2.33s/it]
toxicity [mean, std] before detox: [0.02970629831014032, 0.03363027283000358]

4 Perform Fine-Tuning to Detoxify the Summaries

Optimize a RL policy against the reward model using Proximal Policy Optimization (PPO).

4.1 Initialize PPOTrainer

For the PPOTrainer initialization, we will need a collator. Here it will be a function transforming the dictionaries in a particular way. We can define and test it:

def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])

test_data = [{"key1": "value1", "key2": "value2", "key3": "value3"}]
print(f'Collator input: {test_data}')
print(f'Collator output: {collator(test_data)}')
Collator input: [{'key1': 'value1', 'key2': 'value2', 'key3': 'value3'}]
Collator output: {'key1': ['value1'], 'key2': ['value2'], 'key3': ['value3']}

Set up the configuration parameters. Load the ppo_model and the tokenizer. We will also load a frozen version of the model ref_model. The first model is optimized while the second model serves as a reference to calculate the KL-divergence from the starting point. This works as an additional reward signal in the PPO training to make sure the optimized model does not deviate too much from the original LLM.

learning_rate=1.41e-5
max_ppo_epochs=1
mini_batch_size=4
batch_size=16

config = PPOConfig(
    model_name=model_name,    
    learning_rate=learning_rate,
    ppo_epochs=max_ppo_epochs,
    mini_batch_size=mini_batch_size,
    batch_size=batch_size
)

ppo_trainer = PPOTrainer(config=config, 
                         model=ppo_model, 
                         ref_model=ref_model, 
                         tokenizer=tokenizer, 
                         dataset=dataset["train"], 
                         data_collator=collator)

4.2 Fine-Tune the Model

The fine-tuning loop consists of the following main steps:

  1. Get the query responses from the policy LLM (PEFT model).
  2. Get sentiments for query/responses from hate speech RoBERTa model.
  3. Optimize policy with PPO using the (query, response, reward) triplet.

The operation is running if you see the following metrics appearing:

  • objective/kl: minimize kl divergence,
  • ppo/returns/mean: maximize mean returns,
  • ppo/policy/advantages_mean: maximize advantages.
# May take 20-30 mins to run this cell
output_min_length = 100
output_max_length = 400
output_length_sampler = LengthSampler(output_min_length, output_max_length)

generation_kwargs = {
    "min_length": 5,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True
}

reward_kwargs = {
    "top_k": None, # Return all scores.
    "function_to_apply": "none", # You want the raw logits without softmax.
    "batch_size": 16
}

max_ppo_steps = 10

for step, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    # Break when you reach max_steps.
    if step >= max_ppo_steps:
        break   

    prompt_tensors = batch["input_ids"]

    # Get response from FLAN-T5/PEFT LLM.
    summary_tensors = []

    for prompt_tensor in prompt_tensors:
        max_new_tokens = output_length_sampler()        
            
        generation_kwargs["max_new_tokens"] = max_new_tokens
        summary = ppo_trainer.generate(prompt_tensor, **generation_kwargs)
        
        summary_tensors.append(summary.squeeze()[-max_new_tokens:])
        
    # This needs to be called "response".
    batch["response"] = [tokenizer.decode(r.squeeze()) for r in summary_tensors]

    # Compute reward outputs.
    query_response_pairs = [q + r for q, r in zip(batch["query"], batch["response"])]    
    rewards = sentiment_pipe(query_response_pairs, **reward_kwargs)

    # You use the `nothate` item because this is the score for the positive `nothate` class.
    reward_tensors = [torch.tensor(reward[not_hate_index]["score"]) for reward in rewards]    

    # Run PPO step.
    stats = ppo_trainer.step(prompt_tensors, summary_tensors, reward_tensors)
    ppo_trainer.log_stats(stats, batch, reward_tensors)
    
    print(f'objective/kl: {stats["objective/kl"]}')
    print(f'ppo/returns/mean: {stats["ppo/returns/mean"]}')
    print(f'ppo/policy/advantages_mean: {stats["ppo/policy/advantages_mean"]}')
    print('-'.join('' for x in range(100)))
0it [00:00, ?it/s]You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
1it [01:43, 103.28s/it]2it [03:26, 103.07s/it]3it [05:01, 99.70s/it] 4it [06:24, 93.12s/it]5it [07:51, 90.59s/it]6it [09:43, 97.93s/it]7it [11:08, 93.86s/it]8it [12:31, 90.48s/it]9it [14:03, 90.88s/it]10it [15:30, 93.00s/it]
objective/kl: 29.314075469970703
ppo/returns/mean: -0.6003289222717285
ppo/policy/advantages_mean: -5.777030853693077e-09
---------------------------------------------------------------------------------------------------
objective/kl: 37.534847259521484
ppo/returns/mean: -1.0223705768585205
ppo/policy/advantages_mean: -1.4861395669640842e-08
---------------------------------------------------------------------------------------------------
objective/kl: 34.685516357421875
ppo/returns/mean: -0.8575112819671631
ppo/policy/advantages_mean: 1.101110846946085e-08
---------------------------------------------------------------------------------------------------
objective/kl: 23.096426010131836
ppo/returns/mean: -0.35878801345825195
ppo/policy/advantages_mean: 2.7160158566630344e-09
---------------------------------------------------------------------------------------------------
objective/kl: 26.646108627319336
ppo/returns/mean: -0.3976229429244995
ppo/policy/advantages_mean: 1.1457839121931102e-08
---------------------------------------------------------------------------------------------------
objective/kl: 33.84138870239258
ppo/returns/mean: -0.649031400680542
ppo/policy/advantages_mean: 2.6997275526241538e-09
---------------------------------------------------------------------------------------------------
objective/kl: 25.600406646728516
ppo/returns/mean: -0.49898040294647217
ppo/policy/advantages_mean: 3.8727110407421605e-09
---------------------------------------------------------------------------------------------------
objective/kl: 22.035078048706055
ppo/returns/mean: -0.30794447660446167
ppo/policy/advantages_mean: -1.0292739993644773e-08
---------------------------------------------------------------------------------------------------
objective/kl: 26.587003707885742
ppo/returns/mean: -0.5347940325737
ppo/policy/advantages_mean: 2.1793784554802187e-09
---------------------------------------------------------------------------------------------------
objective/kl: 24.025217056274414
ppo/returns/mean: -0.24646636843681335
ppo/policy/advantages_mean: -1.2822678030488532e-08
---------------------------------------------------------------------------------------------------

4.3 Evaluate the Model Quantitatively

Load the PPO/PEFT model back in from disk and use the test dataset split to evaluate the toxicity score of the RL-fine-tuned model.

mean_after_detoxification, std_after_detoxification = evaluate_toxicity(model=ppo_model, 
                                                                        toxicity_evaluator=toxicity_evaluator, 
                                                                        tokenizer=tokenizer, 
                                                                        dataset=dataset["test"], 
                                                                        num_samples=10)
print(f'toxicity [mean, std] after detox: [{mean_after_detoxification}, {std_after_detoxification}]')
11it [00:23,  2.10s/it]
toxicity [mean, std] after detox: [0.04336890734901482, 0.06578691430956127]

And compare the toxicity scores of the reference model (before detoxification) and fine-tuned model (after detoxification).

mean_improvement = (mean_before_detoxification - mean_after_detoxification) / mean_before_detoxification
std_improvement = (std_before_detoxification - std_after_detoxification) / std_before_detoxification

print(f'Percentage improvement of toxicity score after detoxification:')
print(f'mean: {mean_improvement*100:.2f}%')
print(f'std: {std_improvement*100:.2f}%')
Percentage improvement of toxicity score after detoxification:
mean: -45.99%
std: -95.62%

4.4 Evaluate the Model Qualitatively

Let’s inspect some examples from the test dataset. We can compare the original ref_model to the fine-tuned/detoxified ppo_model using the toxicity evaluator.

batch_size = 20
compare_results = {}

df_batch = dataset["test"][0:batch_size]

compare_results["query"] = df_batch["query"]
prompt_tensors = df_batch["input_ids"]

summary_tensors_ref = []
summary_tensors = []

# Get response from ppo and base model.
for i in tqdm(range(batch_size)):
    gen_len = output_length_sampler()
    generation_kwargs["max_new_tokens"] = gen_len
    
    summary = ref_model.generate(
        input_ids=torch.as_tensor(prompt_tensors[i]).unsqueeze(dim=0).to(device), 
        **generation_kwargs
    ).squeeze()[-gen_len:]
    summary_tensors_ref.append(summary)

    summary = ppo_model.generate(
        input_ids=torch.as_tensor(prompt_tensors[i]).unsqueeze(dim=0).to(device), 
        **generation_kwargs
    ).squeeze()[-gen_len:]
    summary_tensors.append(summary)

# Decode responses.
compare_results["response_before"] = [tokenizer.decode(summary_tensors_ref[i]) for i in range(batch_size)]
compare_results["response_after"] = [tokenizer.decode(summary_tensors[i]) for i in range(batch_size)]

# Sentiment analysis of query/response pairs before/after.
texts_before = [d + s for d, s in zip(compare_results["query"], compare_results["response_before"])]
rewards_before = sentiment_pipe(texts_before, **reward_kwargs)
compare_results["reward_before"] = [reward[not_hate_index]["score"] for reward in rewards_before]

texts_after = [d + s for d, s in zip(compare_results["query"], compare_results["response_after"])]
rewards_after = sentiment_pipe(texts_after, **reward_kwargs)
compare_results["reward_after"] = [reward[not_hate_index]["score"] for reward in rewards_after]
100%|██████████| 20/20 [01:21<00:00,  4.08s/it]

Store and review the results in a DataFrame

pd.set_option('display.max_colwidth', 500)
df_compare_results = pd.DataFrame(compare_results)
df_compare_results["reward_diff"] = df_compare_results['reward_after'] - df_compare_results['reward_before']
df_compare_results_sorted = df_compare_results.sort_values(by=['reward_diff'], ascending=False).reset_index(drop=True)
df_compare_results_sorted
query response_before response_after reward_before reward_after reward_diff
0 Summarize the following conversation. #Person1#: It smells like an ashtray in here! #Person2#: Hi honey! What's wrong? Why do you have that look on your face? #Person1#: What's wrong? I thought we agreed that you were gonna quit smoking. #Person2#: No! I said I was going to cut down which is very different. You can't just expect me to go cold turkey overnight! #Person1#: Look, there are other ways to quit. You can try the nicotine patch, or nicotine chewing gum. We spend a fortune on cigaret... <pad> Hopeless honey tells 6061# she has bad rape and #Person1# asks her to quit smoking because she doesn't have the willpower to do so. She said she'll keep going, but #Person1# tells her she will need a divorce.</s> <pad> #Person1# thinks #Person2# smells like an ashtray because she doesn't know how to quit smoking and is too stressed to quit. #Person1# treats the situation embarrassingly.</s> 0.559593 1.392192 0.832600
1 Summarize the following conversation. #Person1#: So how did you like the restaurant? #Person2#: Actually, it could have been better. #Person1#: What didn't you like about it? #Person2#: It is a new restaurant. I don't think they have their act together yet. #Person1#: What did you think about the food? #Person2#: I felt that the food was pretty mediocre. #Person1#: The service wasn't that great, either. #Person2#: I agree. The service was not good. #Person1#: Do you think that you want to tr... <pad> #Person2# agrees with #Person1# about the restaurant and the food. #Person1# reckons #Person2# will not return but #Person2# isn't even considering to try again.</s> <pad> #Person1# shows #Person2# how the restaurant was turned down by the new owners. One of the other people says it's mediocre and they both say it's time to switch it.</s> 1.883278 2.461895 0.578617
2 Summarize the following conversation. #Person1#: Amanda, how do you like this peaked cap? #Person2#: Didn't you say you want to buy a top hat? #Person1#: But I think this one fits me Well. Why don't you try on the sombrero in black? #Person2#: I don't like caps at all. Summary: </s> <pad> Amanda chooses a peaked cap, but doesn't like the sombrero in black. #Person2# might consider a pig for Amanda.</s> <pad> Amanda likes her trendy top hat, but she doesn't like caps at all. #Person1# has been trying on many hats. Amanda thinks she likes the peaked hat.</s> 0.772964 1.303497 0.530534
3 Summarize the following conversation. #Person1#: Hello. I want to reconfirm our flight to London. #Person2#: Yes, sir. Did you call the airline? #Person1#: Yes, I did. But I couldn't communicate with them in English. They speak only Spanish. So I need your help. #Person2#: Certainly, sir. What is the flight number and when are you leaving? #Person1#: We are taking IB 385 to London tomorrow at 1 p. m. #Person2#: Oh, I see, sir. We have the airline office inside the hotel. They have an English... <pad> #Person2# gives #Person1# a number of flights from #Person1# to London by calling 35 and intervenes. #Person2# disapproves and offers to help.</s> <pad> The airline will reconfirm their flight tomorrow and they dial 35. #Person1# asks our airline office about flight number and flight time.</s> 1.587816 1.953090 0.365274
4 Summarize the following conversation. #Person1#: Hello? #Person2#: Hello? #Person1#: Can I speak to Li Hong, please? #Person2#: Speaking. #Person1#: Hi, Li Hong. This is Alice. #Person2#: Hi, Alice. How are you? #Person1#: Not bad. Li Hong, I am sorry that I can't go to see Mrs. Brown with you tomorrow morning. My mother is ill. I must take care of her. #Person2#: I'm sorry to hear that. You'd better stay at home. After all, we can visit Mrs. Brown later #Person1#: OK. Bye - bye. #Person2#: ... <pad> Alice cancels a visit to Mrs. Brown because her mother is ill, so she won't see her. Li Hong wants to stay at home.</s> <pad> Alice asks Li Hong to arrange a visit to Mrs. Brown. Li Hong won't meet Alice tomorrow morning, because her mother is ill.</s> 1.388402 1.736144 0.347741
5 Summarize the following conversation. #Person1#: Excuse me, could you tell me how to get to the Cross Bakery building? #Person2#: The Cross Bakery building? Oh sure. You're actually walking in the opposite direction. #Person1#: Oh, you're kidding! I thought I was heading east. #Person2#: No, east is the other direction. To get to the Bakery, you need to turn around and go three blocks to Broadway. When you get to the intersection of Broadway and Elm, you hang a left. Go straight down that st... <pad> #Person1# wants to know where the Cross Bakery building is. #Person1# confronts #Person2# about the cross bakery's policy and asks for #Person2#'s help.</s> <pad> #Person1# asks #Person2# to tell them the way to get to the Cross Bakery building. #Person2# offers him the way and leads #Person1# to the Cross bakery.</s> 2.604440 2.847436 0.242996
6 Summarize the following conversation. #Person1#: Oh, my God! What's this? #Person2#: What? #Person1#: Look! This window is open. #Person2#: Did you open it before we left? #Person1#: Are you kidding? It's winter. Why would I open it? #Person2#: I don't know. Wait. Is this yours? #Person1#: No! Oh, my God! Someone has broken into the house. #Person2#: It looks that way. That's probably why the door wasn't locked when we came in. #Person1#: I locked it when I left though. #Person2#: Yes, but t... <pad> Allen had on a Friday night with #Person1# and #Person2# because someone broke into the house. Allen asks #Person1# why he forgot the door when they came in as it's winter. His the TV and stereo are still</s> <pad> Allen and #Person1# decide to open the window and looking upstairs and find the television, stereo, tech and some other items borrowed.</s> 2.139749 2.336272 0.196523
7 Summarize the following conversation. #Person1#: I would like to order some internet today. #Person2#: What kind would you like? #Person1#: What kind of internet is there? #Person2#: You can get DEL or dial-up. #Person1#: Which of those two is best? #Person2#: I would recommend DEL. #Person1#: So that one better? #Person2#: It's better because it doesn't tie up the phone. #Person1#: What do you mean by that? #Person2#: DEL isn't connected through your phone line, but dial-up is. #Person1#: S... <pad> #Person1# wants to order some internet. #Person2# recommends Dial-up anddialog unlike #Person1#'s choice which can not use the phone afterwards.</s> <pad> #Person1# wants to order some dial-up internet. #Person2# recommends DEL because it doesn't tie up the phone because it doesn't tie up the phone.</s> 2.343908 2.448540 0.104632
8 Summarize the following conversation. #Person1#: Here is the final draft of our contract. I'm glad that we have reached an agreement on almost every term in our trade. #Person2#: Yes, it seems to me we have come quite a long way. However, let me take a close look at the final draft. #Person1#: Do you have some points to bring up? #Person2#: Well, everything we've discussed seems to be here. #Person1#: Yes, including a description of the shirts you want to purchase this time, the total amount... <pad> #Person1# are sounding confident about the final draft of the contract; #Person2# goes over the details and suggests signing it right now.</s> <pad> #Person1# and #Person2# discuss a final draft of the contract. Afterwards, they discuss a number of points. Then #Person2# asks about all the details and offers to sign the contract right now.</s> 3.151282 3.142738 -0.008543
9 Summarize the following conversation. #Person1#: Where shall I register, please? #Person2#: Here. Do you have a registration card? #Person1#: Yes. Here you are. #Person2#: Please register your information here and pay for it. And I'll make a medical record for you. #Person1#: OK. How much do I need to pay for the registration? #Person2#: Please pay ten yuan for the registration. #Person1#: Here is my money. #Person2#: This is your registration card. Please don't lose it and bring it whenever... <pad> #Person1# wants to register to join #Person2#'s surgery. #Person2#explains the payment, then says how to get there.</s> <pad> #Person1# is going to register. #Person2# asks #Person1# to pay 10 yuan for the registered information and made a medical record for #Person1#.</s> 1.537484 1.517594 -0.019890
10 Summarize the following conversation. #Person1#: Today more and more families have personal computers. People have wider range of choice to communicate with the outside world. #Person2#: Right. With the establishment of Internet and a lot of web companies, people are getting more and more dependent on the web. #Person1#: One of the common uses of PC is that people can buy goods through it without going out to the physical stores. #Person2#: Can you tell me how it is done? #Person1#: If a cus... <pad> #Person1# is talking to #Person2# about the advantages of computers. #Person2# tells #Person1# the ways that people can, through their personal computers, buy goods without going to the physical stores.</s> <pad> #Person1# tells #Person2# that PCs are now making more and more families use it. #Person1# tells #Person2# how he can buy some goods without going to the physical stores.</s> 2.491536 2.462914 -0.028623
11 Summarize the following conversation. #Person1#: Mom, I just finished my paper. Can you proofread it before I hand it in? #Person2#: Sure, let's take a look. Sweetie, this is terrific. Your ideas are so original. #Person1#: Thanks. #Person2#: I can tell you worked hard on it. #Person1#: I really did! I started thinking about what I wanted to say three weeks ago. #Person2#: Well, it was definitely worth all the time. #Person1#: Let's just hope my teacher agrees. Summary: </s> <pad> #Person1# is submitting a paper to #Person2#'s mom. #Person2# praises #Person1#'s suggestions and congratulates her for her work and looks forward to meeting her teacher.</s> <pad> #Person1# expresses her admiration for the paper completed by her mother. #Person1# says the papers they worked hard on are wonderful. #Person2# praises her work.</s> 2.483954 2.403249 -0.080705
12 Summarize the following conversation. #Person1#: What can I do for you, madam? #Person2#: I'd like to buy a toy car for my son. #Person1#: How about this one? #Person2#: It looks nice. How much is it? #Person1#: They're three hundred dollars. #Person2#: Oh, I'm afraid it's too expensive. Can you show me something cheaper? #Person1#: OK, This one is one hundred and twenty. It's the cheapest here. #Person2#: OK, I'll take it. Here's the money. #Person1#: Thank you very much. Summary: </s> <pad> #Person2# wants to buy a kid car with #Person1#'s help and buys a toy car which's the cheapest in the shop. #Person1# offers some help to #Person2#.</s> <pad> #Person1# offers #Person2# a toy car for her son as three hundred dollars it's the cheapest here.</s> 1.402219 1.249279 -0.152940
13 Summarize the following conversation. #Person1#: Could you help me figure out how to look for a job? #Person2#: We have lots of options, what type of job do you need? #Person1#: I want to work in an office. #Person2#: Do you want to work part-time or full-time? #Person1#: I want to work full-time. #Person2#: We have binders with local job listings or you can make use of the computers. OK? #Person1#: I am confused a bit but I am sure that I can figure it out. #Person2#: If you make an appoint... <pad> #Person1# has a very difficult job search. #Person2# tells #Person1# there is a job center to help #Person1# find the job and can help #Person1#. #Person1# wants to visit a job counselor.</s> <pad> #Person1# wants to work full-time in the office. #Person1# needs to work part-time but #Person2# recommends a counseling.</s> 2.251136 2.094307 -0.156829
14 Summarize the following conversation. #Person1#: I'd like to have this cashed, please. #Person2#: Please put you name and address here. May I see your passport? #Person1#: Yes. #Person2#: How would you like it? #Person1#: Ten hundreds and ten twenties, and the rest in small change, please. #Person2#: OK. Here you are. Summary: </s> <pad> #Person1# is getting the cashed cash for 10 hundreds and ten twenties. #Person1# says it needs to be issued in small change.</s> <pad> #Person1# wanted to get the card by cashing 10,000 tons for 10 hundreds and ten twenties and showing #Person1#'s passport in small change.</s> 1.776725 1.562407 -0.214318
15 Summarize the following conversation. #Person1#: Could you help me, Sir? My flight got in 15 minutes ago. Everyone else has picked up the luggage but mine hasn't come through. #Person2#: I'm sorry, Madam, I'll go and find out if there is any more to come. Summary: </s> <pad> #Person1#'s flight got in 15 minutes ago but hers isn't there yet. #Person2# will check.</s> <pad> #Person1#'s flight got in 15 minutes ago, but the others haven't picked up her luggage. #Person2# will try to find out if there is anything else to do.</s> 2.301333 2.083318 -0.218016
16 Summarize the following conversation. #Person1#: Judy, what is everybody talking about? #Person2#: Haven't you heard? Richard was fired by our manager. #Person1#: You're kidding. It can't be true. #Person2#: Believe it or not. Everybody is talking about it in the company. #Person1#: Really? I'm surprised. #Person2#: Me too. Summary: </s> <pad> Judy and fellow employees are talking about Richard's firing. Judy apologizes for the fact that everyone refers to him by thanking him. Judy is surprised.</s> <pad> Judy is still surprised to see Richard stopped at her job. Bush was fired by his manager recently.</s> 1.552441 1.268730 -0.283711
17 Summarize the following conversation. #Person1#: How much are you asking for this? #Person2#: I'm offering them to you at 150 yuan a piece. Is that all right? #Person1#: Is tax already included in their price? #Person2#: Yes. Our price can't be matched. #Person1#: Would you consider a volume discount? #Person2#: If you buy 1, 000 or more, you'll get a 10 % discount. #Person1#: I'll accept your offer. Summary: </s> <pad> #Person1# and #Person2# are going to buy the bottle in bulk. The price won't be matched.</s> <pad> #Person2# is offering them to #Person1# for 150 yuan a piece for $1000 or more. #Person1# gives a 10% volume discount to #Person2#.</s> 2.740679 2.307892 -0.432787
18 Summarize the following conversation. #Person1#: Let's take a coffee break, shall we? #Person2#: I wish I could, but I can't. #Person1#: What keeps you so busy? You've been sitting there for hours. You've got to walk around. You just can't stay on the computer forever. #Person2#: Well, I am up to my neck in work. I've got to finish this report. Sarah needs it by noon. I don't want to be scolded if I can't finish my work by the deadline. #Person1#: I understand that, but you'd feel better if ... <pad> #Person1# and #Person2# decide to take a coffee break. #Person2# has to finish a report and pass the deadline. They agree to take a break even if they can't come.</s> <pad> #Person2# misses the break in work because she can't stay on the computer forever so when she finishes her report, she needs to go to the office. They make a compromise.</s> 2.008369 1.531268 -0.477101
19 Summarize the following conversation. #Person1#: I'm forming a music band. #Person2#: Do you already know how to play an instrument? #Person1#: Uh... Yeah! I'Ve told you a thousand times that I'm learning to play the drums. Now that I know how to play well, I would like to form a rock band. #Person2#: Aside from yourself, who are the other members of the band? #Person1#: We have a guy who plays guitar, and another who plays bass. Although we still haven't found anyone to be our singer. You t... <pad> #Person1# is forming a band and wants exchanging musical talent. #Person2# wants to audition for the Rock Band with #Person1#'s help. But no room exists for the amplifiers, microphones and even the drums because #Person2#'s a singer.</s> <pad> #Person1# teaches in a music band that she wants to form and tells #Person2# she's a singer. Suddenly, #Person2# asks him for directions and hires her at a house.</s> 2.577401 1.986452 -0.590948

Looking at the reward mean/median of the generated sequences we can observe a significant difference!

5 Acknowledgements

I’d like to express my thanks to the wonderful Generative AI with Large Language Models Course by DeepLearning.ai and AWS - which i completed, and acknowledge the use of some images and other materials from the course in this article.

Subscribe