tabnet  by dreamquark-ai

PyTorch implementation of TabNet for tabular data

created 5 years ago
2,816 stars

Top 17.3% on sourcepulse

GitHubView on GitHub
Project Summary

This repository provides a PyTorch implementation of the TabNet architecture, designed for interpretable and attentive tabular learning. It addresses classification and regression tasks, including multi-task scenarios, and is suitable for researchers and practitioners seeking high-performance, transparent models for tabular data.

How It Works

TabNet employs a sequential attention mechanism to learn feature importance at each decision step, allowing for efficient processing of high-dimensional tabular data. It utilizes gated linear units (GLUs) and feature re-usage through a gamma parameter to control sparsity and model capacity. The implementation also supports embedding-aware attention and grouped features for improved handling of categorical and correlated inputs.

Quick Start & Requirements

  • Installation: pip install pytorch-tabnet or conda install -c conda-forge pytorch-tabnet
  • Prerequisites: PyTorch. GPU acceleration is recommended for performance.
  • Source: For local development, clone the repository and use make start-gpu (or make start for CPU) followed by poetry install and make notebook.
  • Documentation: https://github.com/dreamquark-ai/tabnet

Highlighted Details

  • Scikit-learn compatible API for easy integration.
  • Supports semi-supervised pre-training for improved performance on limited labeled data.
  • Includes on-the-fly data augmentation capabilities.
  • Easy model saving and loading for production deployment.

Maintenance & Community

The project is actively maintained by dreamquark-ai. Community interaction is encouraged via Slack.

Licensing & Compatibility

The repository is licensed under the Apache 2.0 license, permitting commercial use and integration with closed-source projects.

Limitations & Caveats

The README notes that some implementation choices may differ from the original TabNet paper. Reconstruction during pre-training can be challenging with Batch Normalization and large batch sizes.

Health Check
Last commit

9 months ago

Responsiveness

Inactive

Pull Requests (30d)
0
Issues (30d)
3
Star History
64 stars in the last 90 days

Explore Similar Projects

Feedback? Help us improve.