Text Summarization

Text Summarization using Flan-T5 : A Simple Tutorial

Creating a tutorial for “Text Summarization using Flan-T5” involves several steps. Flan-T5 is a variant of the T5 (Text-to-Text Transfer Transformer) model, designed to perform a wide range of NLP tasks by treating every text processing task as a text-to-text problem.

Text Summarization
Text Summarization

Step 1: Understand Flan-T5

  • Flan-T5 is a model developed by Google, based on the T5 framework. It’s pre-trained on a large corpus and fine-tuned for various tasks, including summarization.

Step 2: Set Up Your Environment

  • You’ll need Python and TensorFlow installed.
  • Install necessary libraries:
Python
 !pip install transformers tensorflow

Step 3: Choose a Dataset

  • For this tutorial, let’s use the CNN/DailyMail dataset, a standard dataset for summarization tasks.
  • It contains news articles (CNN and Daily Mail) paired with multi-sentence summaries.

Step 4: Load the Dataset

  • Use Hugging Face’s datasets library to load the dataset.
Python

from datasets import load_dataset
dataset = load_dataset("cnn_dailymail", "3.0.0")

# Instead of training with full dataset wchich shall take longer time,
# you can load slice (%) of dataset to test the execution of code.
"""
dataset = DatasetDict({
            'train':load_dataset("cnn_dailymail", "3.0.0", split='train[:10%]'),
            'test': load_dataset("cnn_dailymail", "3.0.0", split='test[:20%]'),
            'validation': load_dataset("cnn_dailymail", "3.0.0", split='validation[:20%]')
})
"""

Step 5: Load Flan-T5

  • Import and load the Flan-T5 model.
Python
from transformers import TFAutoModelForSeq2SeqLM, AutoTokenizer, DataCollatorWithPadding

model_name = "google/flan-t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name)

Step 6: Preprocess the Data

  • Tokenize the texts. Adjust the token length according to your needs.
Python
def preprocess_function(examples):
    inputs = [doc for doc in examples["article"]]
    model_inputs = tokenizer(inputs, max_length=512,padding='max_length', truncation=True)

    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples["highlights"], max_length=128, padding='max_length', truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_datasets = dataset.map(preprocess_function, batched=True)

Transformers provides a DataCollatorForSeq2Seq collator that will dynamically pad the inputs and the labels for us. To instantiate this collator, we simply need to provide the tokenizer and model:

Python
from transformers import DataCollatorForSeq2Seq
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors="tf")

We just need to convert our datasets to tf.data.Datasets using the data collator we defined above, and then compile() and fit() the model. First, the datasets:

Python
tf_train_dataset = model.prepare_tf_dataset(
    tokenized_datasets["train"],
    batch_size=8,
    shuffle=True,
    collate_fn=data_collator,
)

tf_validation_dataset = model.prepare_tf_dataset(
    tokenized_datasets["validation"],
    batch_size=8,
    shuffle=False,
    collate_fn=data_collator,
)

Step 7: Fine-Tune the Model

  • Fine-tuning Flan-T5 on the specific dataset.
  • Set up the training arguments and train the model. This process can be resource-intensive.
Python
from tensorflow.keras.losses import SparseCategoricalCrossentropy

model.compile(
    optimizer="adam",    
    metrics=["accuracy"],
)
model.fit(
    tf_train_dataset,
    validation_data=tf_validation_dataset,
)

Step 8: Evaluate the Model

  • After training, evaluate the model’s performance on the test set. We can use ‘ROUGE’ metric

Step 9: Perform Summarization

  • Use the trained model to summarize new texts.
Python
def generate_summary(text):
    inputs = tokenizer.encode("summarize: " + text, return_tensors="tf", max_length=512)
    outputs = model.generate(inputs, max_length=150, min_length=40, length_penalty=2.0, num_beams=4, early_stopping=True)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)


text = "Your new text here."
print(generate_summary(text))

Step 10: Experiment and Iterate

  • Experiment with different model configurations, training longer, or using different datasets to improve results.

This tutorial provides a basic framework for text summarization using Flan-T5. For more advanced use, consider exploring additional parameters and methods in the Transformers library.