NLP Augmentation Hands-On Part-2
Part-2 TextAttack library for Data Augmentation
In the part-1 we had implemented data augmentation using wordnet synonyms.
In part-2(this article) We will be using TextAttack for data augmentation.
These are the learning objectives.
- What is TextAttack and how to use it for Data Augmentation?
- Use built-in and custom augmentation recipes for augmentation.
- Compare the various TextAttack methods on model performance.
What is TextAttack and how to use it for Data Augmentation?
TextAttack is a Python framework for adversarial attacks, data augmentation, and model training in NLP.
Reasons to use TextAttack:
- Understand NLP models better by running different adversarial attacks on them and examining the output.
- Augment your dataset to increase model generalization and robustness downstream.
- Train NLP models using just a single command.
TextAttack has option to built Custome data augmenter. It also implements recent research papers published for text data augmentation.
Example of adversarial attack on text. Sentiment changes by changing one or few words.
Use built-in and custom augmentation recipes for augmentation.
Installation
pip install textattack
Once we have the textattack installed lets see how we can use prebuilt recipes for augmentation.
We will try these recipes available in textattack and also evaluate their boost on modeling performance.
- CheckListAugmenter : It combines Name Replacement, Location Replacement, Number Alteration, and Contraction/Extension. The original paper can be found here: “Beyond Accuracy: Behavioral Testing of NLP models with CheckList” (Ribeiro et al., 2020).
from textattack.augmentation import CheckListAugmenteraugmenter = CheckListAugmenter(pct_words_to_swap=0.2, transformations_per_example=5)s = 'What I cannot create, I do not understand.'print(augmenter.augment(s))#Output
['What I cannot create, I do not understand.', "What I cannot create, I don't understand."]
- CLAREAugmenter: CLARE builds on a pre-trained masked language model and modifies the inputs in a contextaware manner. Contextualized Perturbation for Textual Adversarial Attack” (Li et al., 2020).
from textattack.augmentation import CLAREAugmenteraugmenter = CLAREAugmenter(pct_words_to_swap=0.2, transformations_per_example=3)s = 'What I cannot create, I do not understand.'print(augmenter.augment(s))#Output
['What I cannot create, I do n not understand.', "What I cannot create, âĢ¦ I do not understand.", "What I cannot âĢĶ create, I do not understand."]
- EasyDataAugmenter: It combines random replace, remove, order-swaps and insert using WordNet. EDA: Easy Data Augmentation Techniques for Boosting Performance on Text Classification Tasks” (Wei and Zou, 2019) https://arxiv.org/abs/1901.11196.
from textattack.augmentation import EasyDataAugmenteraugmenter = EasyDataAugmenter(pct_words_to_swap=0.2, transformations_per_example=3)s = 'What I cannot create, I do not understand.'print(augmenter.augment(s))#Output
['What realize I cannot create, I do not understand.', 'What not cannot create, I do I understand.', 'What I cannot create, unity do not understand.']
CLAREAugmenter augmentation is slow compared to CheckListAugmenter.
Compare the various TextAttack methods on model performance.
Lets use the same dataset that we used in Part-1
- Classification Accuracy using EasyDataAugmenter
from textattack.augmentation import EasyDataAugmentereda_augmenter = EasyDataAugmenter(pct_words_to_swap=0.1, transformations_per_example=1)augmented_data_eda, augmented_label_eda = augment_data_textattack(data, target, eda_augmenter)from sklearn.feature_extraction.text import CountVectorizer
count_vect = CountVectorizer()
X_train_counts = count_vect.fit_transform(augmented_data_eda)from sklearn.naive_bayes import MultinomialNB
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import cross_val_scoremnb = MultinomialNB()
print("Mean Accuracy: {:.2}".format(cross_val_score(mnb, X_train_counts, augmented_label_eda, cv=5).mean()))#Output
Mean Accuracy: 0.79
- Classification Accuracy using CheckListAugmenter Modeling
from textattack.augmentation import CheckListAugmenterchecklist_augmenter = CheckListAugmenter(pct_words_to_swap=0.1, transformations_per_example=1)augmented_data_checklist, augmented_label_checklist = augment_data_textattack(data, target, checklist_augmenter)
from sklearn.feature_extraction.text import CountVectorizer
count_vect = CountVectorizer()
X_train_counts = count_vect.fit_transform(augmented_data_checklist)from sklearn.naive_bayes import MultinomialNB
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import cross_val_scoremnb = MultinomialNB()
print("Mean Accuracy: {:.2}".format(cross_val_score(mnb, X_train_counts, augmented_label_checklist, cv=5).mean()))#Output
Mean Accuracy: 0.82
Conclusion
- CLAREAugmenter is really slow. It will take few hours for small dataset.
- Speed comparison
CLAREAugmenter << EasyDataAugmenter < CheckListAugmenter
- Accuracy Improvement
Without Augmentation — 81%
With EasyDataAugmenter — 79%
With CheckListAugmenter — 82%
Here are the links to github and google colab
akgeni/textattack_aug (github.com)
Thank You for reading it. Please upvote if you like it. It will encourage me to write more content.