Bert-Multi-Label-Text-Classification  by lonePatient

PyTorch code for multi-label text classification

created 6 years ago
920 stars

Top 40.4% on sourcepulse

GitHubView on GitHub
Project Summary

This repository provides a PyTorch implementation for multi-label text classification using pre-trained BERT and XLNET models. It is designed for researchers and practitioners working with text classification tasks requiring the assignment of multiple labels to a single text document. The primary benefit is leveraging powerful transformer architectures for nuanced classification.

How It Works

The project fine-tunes pre-trained BERT and XLNET models for multi-label classification. It processes raw text through WordPiece tokenization before feeding it into the transformer models. The architecture includes modular components for data handling, configuration, training, and output, facilitating customization and experimentation.

Quick Start & Requirements

  • Install: pip install pytorch-transformers (from GitHub)
  • Dependencies: PyTorch 1.1+, csv, tqdm, numpy, pickle, scikit-learn, matplotlib, pandas, transformers=2.5.1.
  • Pretrained Models: Requires downloading bert-base-uncased and xlnet-base-cased models and placing them in the /pybert/pretrain/bert/base-uncased directory.
  • Data: Download Kaggle data and place it in pybert/dataset.
  • Configuration: Modify paths and parameters in pybert/configs/basic_config.py.
  • Preprocessing: python run_bert.py --do_data
  • Training: python run_bert.py --do_train --save_best --do_lower_case
  • Prediction: python run_bert.py --do_test --do_lower_case
  • Documentation: GitHub Repository

Highlighted Details

  • Supports both BERT and XLNET architectures.
  • Reports per-label AUC scores for training and validation.
  • Fine-tuning all layers is recommended for better performance over feature-based approaches.
  • Input sentence length is limited to 512 tokens; raw data length should be ~128-256.

Maintenance & Community

No specific information on maintainers, community channels, or roadmap is provided in the README.

Licensing & Compatibility

The README does not explicitly state a license. It mentions using pytorch-transformers from GitHub, which is typically Apache 2.0 licensed. Compatibility for commercial use or closed-source linking is not specified.

Limitations & Caveats

The project requires specific versions of dependencies (transformers=2.5.1) and PyTorch (1.1+), which may be outdated. Non-tensor calculations (e.g., accuracy, F1) are not supported with DataParallel when using multiple GPUs. The README notes that converting TensorFlow checkpoints requires specific handling to avoid loading corrupted models.

Health Check
Last commit

2 years ago

Responsiveness

Inactive

Pull Requests (30d)
0
Issues (30d)
0
Star History
13 stars in the last 90 days

Explore Similar Projects

Feedback? Help us improve.