dice_loss_for_NLP  by ShannonAI

Code for a research paper on Dice Loss

created 5 years ago
274 stars

Top 95.2% on sourcepulse

GitHubView on GitHub
Project Summary

This repository provides the implementation for "Dice Loss for Data-imbalanced NLP Tasks," a method to improve model performance on imbalanced datasets in Natural Language Processing. It is targeted at NLP researchers and practitioners dealing with class imbalance issues. The primary benefit is enhanced accuracy and robustness in tasks like machine reading comprehension, paraphrase identification, named entity recognition, and text classification.

How It Works

The project implements Dice Loss, a metric derived from the Dice coefficient, which is known for its effectiveness in segmentation tasks. By adapting Dice Loss for NLP, the approach aims to directly optimize for overlap and reduce sensitivity to class imbalance compared to traditional losses like cross-entropy. This is achieved by integrating the Dice Loss function into BERT-based models for various NLP tasks.

Quick Start & Requirements

  • Install: pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html followed by pip install -r requirements.txt.
  • Prerequisites: Python 3.6.9+, PyTorch 1.7.1, CUDA 10.1 (for the specified PyTorch installation). Requires downloading BERT checkpoints and converting them to PyTorch format using scripts/prepare_ckpt.sh. Datasets for specific tasks (SQuAD 1.1, MRPC, NER datasets, TNews) need to be downloaded separately.
  • Setup: Requires environment setup (virtualenv), dependency installation, BERT checkpoint conversion, and dataset preparation.

Highlighted Details

  • Implements Dice Loss for four NLP tasks: Machine Reading Comprehension (SQuAD 1.1), Paraphrase Identification (MRPC), Named Entity Recognition (NER), and Text Classification (TNews).
  • Provides scripts to reproduce experimental results for fine-tuning BERT with Dice Loss, Focal Loss, and Binary Cross-Entropy.
  • Reports specific performance gains, e.g., 93.21 Span-F1 for NER on English CoNLL03 with Dice Loss compared to 93.08 with Focal Loss.

Maintenance & Community

  • The code is associated with an ACL2020 paper.
  • Contact information (email addresses) is provided for discussions and questions.

Licensing & Compatibility

  • The README does not explicitly state a license.
  • Compatibility for commercial use or closed-source linking is not specified.

Limitations & Caveats

The setup requires specific versions of PyTorch and CUDA, and manual conversion of BERT checkpoints. The project focuses on BERT as the backbone, limiting its direct applicability to other architectures without modification.

Health Check
Last commit

2 years ago

Responsiveness

Inactive

Pull Requests (30d)
0
Issues (30d)
0
Star History
0 stars in the last 90 days

Explore Similar Projects

Feedback? Help us improve.