Creating a tutorial for “Text Summarization using Flan-T5” involves several steps. Flan-T5 is a variant of the T5 (Text-to-Text Transfer Transformer) model, designed to perform a wide range of NLP tasks by treating every text processing task as a text-to-text problem.
Step 1: Understand Flan-T5
- Flan-T5 is a model developed by Google, based on the T5 framework. It’s pre-trained on a large corpus and fine-tuned for various tasks, including summarization.
Step 2: Set Up Your Environment
- You’ll need Python and TensorFlow installed.
- Install necessary libraries:
!pip install transformers tensorflow
Step 3: Choose a Dataset
- For this tutorial, let’s use the CNN/DailyMail dataset, a standard dataset for summarization tasks.
- It contains news articles (CNN and Daily Mail) paired with multi-sentence summaries.
Step 4: Load the Dataset
- Use Hugging Face’s
datasets
library to load the dataset.
from datasets import load_dataset
dataset = load_dataset("cnn_dailymail", "3.0.0")
# Instead of training with full dataset wchich shall take longer time,
# you can load slice (%) of dataset to test the execution of code.
"""
dataset = DatasetDict({
'train':load_dataset("cnn_dailymail", "3.0.0", split='train[:10%]'),
'test': load_dataset("cnn_dailymail", "3.0.0", split='test[:20%]'),
'validation': load_dataset("cnn_dailymail", "3.0.0", split='validation[:20%]')
})
"""
Step 5: Load Flan-T5
- Import and load the Flan-T5 model.
from transformers import TFAutoModelForSeq2SeqLM, AutoTokenizer, DataCollatorWithPadding
model_name = "google/flan-t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name)
Step 6: Preprocess the Data
- Tokenize the texts. Adjust the token length according to your needs.
def preprocess_function(examples):
inputs = [doc for doc in examples["article"]]
model_inputs = tokenizer(inputs, max_length=512,padding='max_length', truncation=True)
with tokenizer.as_target_tokenizer():
labels = tokenizer(examples["highlights"], max_length=128, padding='max_length', truncation=True)
model_inputs["labels"] = labels["input_ids"]
return model_inputs
tokenized_datasets = dataset.map(preprocess_function, batched=True)
Transformers provides a DataCollatorForSeq2Seq
collator that will dynamically pad the inputs and the labels for us. To instantiate this collator, we simply need to provide the tokenizer
and model
:
from transformers import DataCollatorForSeq2Seq
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors="tf")
We just need to convert our datasets to tf.data.Dataset
s using the data collator we defined above, and then compile()
and fit()
the model. First, the datasets:
tf_train_dataset = model.prepare_tf_dataset(
tokenized_datasets["train"],
batch_size=8,
shuffle=True,
collate_fn=data_collator,
)
tf_validation_dataset = model.prepare_tf_dataset(
tokenized_datasets["validation"],
batch_size=8,
shuffle=False,
collate_fn=data_collator,
)
Step 7: Fine-Tune the Model
- Fine-tuning Flan-T5 on the specific dataset.
- Set up the training arguments and train the model. This process can be resource-intensive.
from tensorflow.keras.losses import SparseCategoricalCrossentropy
model.compile(
optimizer="adam",
metrics=["accuracy"],
)
model.fit(
tf_train_dataset,
validation_data=tf_validation_dataset,
)
Step 8: Evaluate the Model
- After training, evaluate the model’s performance on the test set. We can use ‘ROUGE’ metric
Step 9: Perform Summarization
- Use the trained model to summarize new texts.
def generate_summary(text):
inputs = tokenizer.encode("summarize: " + text, return_tensors="tf", max_length=512)
outputs = model.generate(inputs, max_length=150, min_length=40, length_penalty=2.0, num_beams=4, early_stopping=True)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
text = "Your new text here."
print(generate_summary(text))
Step 10: Experiment and Iterate
- Experiment with different model configurations, training longer, or using different datasets to improve results.
This tutorial provides a basic framework for text summarization using Flan-T5. For more advanced use, consider exploring additional parameters and methods in the Transformers library.