A novel extension of TabTransformer with gated fusion for residual-based model stacking
Quick Start • Architecture • Production • Changelog
This project implements TabTransformer++, an enhanced transformer architecture designed specifically for tabular data in a residual learning framework. Rather than predicting targets directly, the model learns to correct errors from simpler base models—a powerful technique for competition-winning ensembles.
+--------------------+ +------------------------+ +-----------------------+
| Base Model | | TabTransformer++ | | Final Prediction |
| (HistGBR, XGBoost) | --> | Predicts Residual | --> | Base + Residual |
| -> base_pred | | (error) | | |
+--------------------+ +------------------------+ +-----------------------+
Why residual learning?
- Base models capture linear/tree patterns efficiently
- Transformers excel at learning complex feature interactions
- Combined: each model focuses on what it does best
TabTransformer++ introduces six key innovations over the original TabTransformer:
Each feature is represented in two complementary ways:
| Type | Creation | Captures |
|---|---|---|
| Token Embedding | Quantile bin -> learned vector | Discrete patterns, ordinal relationships |
| Value Embedding | Raw scalar -> MLP projection | Precise numeric magnitude |
Why both? Binning loses precision (1.01 and 1.99 may share a bin), but raw scalars lack pattern-matching power.
Per-feature gates control the blend between token and scalar representations:
final_emb[i] = token_emb[i] + sigmoid(gate[i]) * value_emb[i]Safe Initialization: Gates are initialized to -2.0 (sigmoid ≈ 0.12), biasing the model to rely on stable token embeddings first. This prevents early divergence before the model learns when to trust scalar values.
- Gates are learned independently for each feature
- Model adapts to each column's characteristics automatically
- Low gate → token-dominant (categorical treatment)
- High gate → scalar-dominant (precise numeric treatment)
Each feature gets its own projection network instead of sharing:
Linear(1 -> 64) -> GELU -> Linear(64 -> 64) -> LayerNorm
Allows different transformations for different feature distributions.
During training, randomly zero out feature embeddings (p=0.12):
mask = (random > p) # per-sample, per-feature
mask[:, 0] = 1.0 # Never drop CLS token
x = x * mask / (1 - p) # Inverted scaling for magnitude consistencyPrevents over-reliance on any single feature. The inverted scaling maintains expected magnitude between train and test modes (like standard Dropout).
BERT-style [CLS] token prepended to the sequence:
[CLS, feat_1, feat_2, ..., feat_n, base_pred, dt_pred]
CLS attends to all features and produces the final representation.
Uses norm_first=True for more stable training without warmup:
Pre-LN: x = x + Attention(LayerNorm(x)) [Stable]
Post-LN: x = LayerNorm(x + Attention(x)) [Requires warmup]
+-------------------------------------+
| INPUT: T features + 2 meta |
| (tokens, raw_values) per feature |
+-------------------------------------+
|
+----------------------------+----------------------------+
| | |
v v v
+-------------+ +-------------+ +-------------+
| Feature 1 | | Feature 2 | ... | Feature T |
| token->embed| | token->embed| | token->embed|
| value->MLP | | value->MLP | | value->MLP |
| gate fusion | | gate fusion | | gate fusion |
+-------------+ +-------------+ +-------------+
| | |
+----------------------------+----------------------------+
|
v
+----------------------+
| Embedding Dropout |
| (p=0.05) |
+----------------------+
|
v
+----------------------+
| Prepend [CLS] |
| Token |
+----------------------+
|
v
+----------------------+
| TokenDrop |
| (p=0.12, train) |
+----------------------+
|
v
+-------------------------------------+
| TRANSFORMER ENCODER |
| +-------------------------------+ |
| | Layer 1: 4-head attention | |
| | + FFN(64->256->64) + PreLN | |
| +-------------------------------+ |
| +-------------------------------+ |
| | Layer 2: 4-head attention | |
| | + FFN(64->256->64) + PreLN | |
| +-------------------------------+ |
| +-------------------------------+ |
| | Layer 3: 4-head attention | |
| | + FFN(64->256->64) + PreLN | |
| +-------------------------------+ |
+-------------------------------------+
|
v
+----------------------+
| Extract [CLS] |
| Embedding |
+----------------------+
|
v
+-------------------------------------+
| PREDICTION HEAD |
| LayerNorm -> Linear(64->192) |
| -> GELU -> Dropout -> Linear(192->1)|
+-------------------------------------+
|
v
+---------------------+
| Predicted Residual |
| (robust-scaled) |
+---------------------+
TabTransformer++ includes built-in interpretability tools:
Extract and visualize learned gate values to understand feature treatment:
gate_values = extract_gate_values(model, feature_names)
visualize_gate_values(gate_values)- Low gate (near 0): Feature works better as categorical bins
- High gate (near 1): Feature requires precise scalar values
Visualize learned embeddings using t-SNE or PCA:
visualize_token_embeddings(model, tokenizer, feature_idx=0, method='pca')Shows how the model organizes quantile bins in embedding space, revealing learned semantic relationships.
Even when RMSE is comparable, TabTransformer++ offers unique advantages:
| Capability | XGBoost | TabTransformer++ |
|---|---|---|
| Dense Embeddings | ❌ No | ✅ Each row becomes a learned vector |
| Multi-Modal Fusion | ❌ Cannot combine with images/text | ✅ Embeddings fuse with vision/NLP models |
| Transfer Learning | ❌ Must retrain from scratch | ✅ Pre-train on large tables, fine-tune on small |
| Interpretable Gates | ❌ Feature importance only | ✅ Learn token vs scalar preference per feature |
| GPU Batch Inference | ✅ Native PyTorch batching |
The Real Value: TabTransformer++ generates dense embeddings suitable for downstream multi-modal tasks (e.g., combining tabular property data with house images).
The notebook implements a complete 5-fold cross-validation pipeline:
# HistGradientBoostingRegressor for base predictions (captures non-linearity)
model_base = HistGradientBoostingRegressor(max_iter=100, max_depth=5)
# RandomForest for additional signal
model_dt = RandomForestRegressor(n_estimators=20, max_depth=8)
# Out-of-fold predictions to prevent leakage
residual = target - base_predWhy HistGradientBoostingRegressor instead of Ridge?
- Captures non-linear patterns that linear models miss
- Leaves purer high-order feature interactions for the Transformer
- Faster than RandomForest due to histogram-based splits
- Quantile binning: 32 bins for features, 128 for base_pred, 64 for tree_pred
- Robust scaling:
(x - median) / IQR— resistant to outliers (replaces Z-score) - Fit on training fold only (leak-free)
Why Robust Scaling? Z-score (x - mean) / std is sensitive to outliers, which can cause gradient explosions in the scalar path. Robust scaling using median and IQR stabilizes training across all folds.
- EMA (Polyak averaging): Maintains exponential moving average of weights
- Huber loss: Robust to outliers
- AdamW optimizer: With weight decay regularization
Post-training calibration maps z-scored predictions to actual residuals:
iso = IsotonicRegression(out_of_bounds="clip")
iso.fit(preds_z, y_va_raw)
calibrated = iso.predict(preds_z)final_prediction = base_pred + calibrated_residualThis section outlines how TabTransformer++ fits into a production ML system.
┌─────────────────────────────────────────────────────────────────────────────┐
│ TRAINING PIPELINE │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ ┌──────────────┐ ┌───────────────────┐ ┌─────────────────────────┐ │
│ │ Raw Data │───▶│ TabularTokenizer │───▶│ Feature Store │ │
│ │ (Offline) │ │ .fit() on TRAIN │ │ (Serialize tokenizer) │ │
│ └──────────────┘ └───────────────────┘ └─────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────┐ │
│ │ TabTransformer++ │ │
│ │ PyTorch Training │ │
│ └─────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────────────┐ │
│ │ Model Export │ │
│ ├─────────────────────────────────────────────────────────────────────┤ │
│ │ • torch.jit.script() → TorchScript (.pt) │ │
│ │ • torch.onnx.export() → ONNX (.onnx) │ │
│ │ • TensorRT optimization for NVIDIA GPUs │ │
│ └─────────────────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────────────────────┐
│ INFERENCE PIPELINE │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ ┌──────────────┐ ┌───────────────────┐ ┌─────────────────────────┐ │
│ │ New Request │───▶│ Feature Store │───▶│ Tokenizer.transform() │ │
│ │ (Online) │ │ (Load tokenizer) │ │ (Consistent binning) │ │
│ └──────────────┘ └───────────────────┘ └─────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌────────────────────────────────────────────┐ │
│ │ Inference Runtime │ │
│ ├────────────────────────────────────────────┤ │
│ │ • ONNX Runtime (CPU/GPU) │ │
│ │ • TensorRT (NVIDIA, <1ms latency) │ │
│ │ • TorchServe / Triton Inference Server │ │
│ └────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────┐ │
│ │ Prediction + Post-Processing │ │
│ │ base_pred + calibrated_residual │ │
│ └─────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────────┘
The TabularTokenizer encapsulates learned quantile bins and scaling statistics. For online/offline consistency:
import pickle
# After training
with open("tokenizer.pkl", "wb") as f:
pickle.dump(tokenizer, f)
# Upload to Feature Store (e.g., Feast, Tecton, SageMaker Feature Store)
feature_store.register_artifact("tabtransformer_tokenizer", "tokenizer.pkl")Why Feature Store?
- Ensures identical preprocessing in training and serving
- Version control for tokenizer artifacts
- Supports A/B testing different tokenizer configurations
# Export to ONNX (cross-platform, optimized inference)
import torch.onnx
model.eval()
dummy_tok = torch.randint(0, 32, (1, num_features))
dummy_val = torch.randn(1, num_features)
torch.onnx.export(
model,
(dummy_tok, dummy_val),
"tabtransformer.onnx",
input_names=["tokens", "values"],
output_names=["prediction"],
dynamic_axes={"tokens": {0: "batch"}, "values": {0: "batch"}},
)
# For NVIDIA GPUs: Convert to TensorRT
# trtexec --onnx=tabtransformer.onnx --saveEngine=tabtransformer.trt --fp16Inference Latency Targets:
| Runtime | Hardware | Typical Latency |
|---|---|---|
| PyTorch | CPU | 5-20ms |
| ONNX Runtime | CPU | 2-8ms |
| ONNX Runtime | GPU | 0.5-2ms |
| TensorRT | NVIDIA GPU | <1ms |
Problem: Training uses batch statistics; serving sees single rows.
Solution: Store computed features, don't recompute at inference.
| Feature Type | Training | Serving |
|---|---|---|
| Raw features | Compute from source | Fetch from Feature Store |
| Base model predictions | OOF predictions | Pre-computed daily batch |
| Tokenized features | Batch transform | Single-row transform |
Preventing Train-Serve Skew:
- Tokenizer versioning: Hash tokenizer params, embed in model metadata
- Feature validation: Assert feature distributions at inference time
- Shadow mode: Run new model in parallel, compare outputs before deployment
Option A: Batch Prediction (Offline)
Airflow/Prefect → Load Data → Transform → Predict → Write to DB
- Use for: Daily scoring of large datasets
- Latency: Hours (acceptable)
- Cost: Low (spot instances)
Option B: Real-Time API (Online)
API Gateway → Load Balancer → Inference Pod (ONNX/TensorRT) → Response
- Use for: User-facing predictions
- Latency: <50ms p99
- Scaling: Horizontal pod autoscaling
Option C: Streaming (Near Real-Time)
Kafka → Feature Compute → Model Inference → Kafka → Downstream
- Use for: Event-driven predictions
- Latency: Seconds
- Throughput: High (parallelizable)
# Clone the repository
git clone https://github.com/LEDazzio01/Tab-Transformer-Plus-Plus.git
cd Tab-Transformer-Plus-Plus
# Install dependencies
pip install numpy pandas torch scikit-learn jupyter# Install the package
pip install -e .
# Train on built-in California Housing dataset
ttpp train --dataset cal_housing --epochs 10 --batch_size 1024
# Train on your own CSV data
ttpp train --train_data data/train.csv --target_col price --epochs 20
# Train with explicit train/test split
ttpp train --train_data train.csv --test_data test.csv --target_col target --n_folds 5jupyter notebook TabTransformer_Residual_Learning.ipynbThe notebook demonstrates the full pipeline using the California Housing dataset.
import pandas as pd
from tab_transformer_plus_plus import (
TabTransformerGated,
TabularTokenizer,
TTConfig,
ModelFactory,
Trainer,
TrainingConfig,
load_data,
compute_rmse,
)
# Load your data (or use built-in datasets)
train_df, test_df, target_col, features = load_data(seed=42)
# Fit tokenizer on TRAINING data only (prevents leakage)
tokenizer = TabularTokenizer(n_bins=32, features=features, target=target_col)
tokenizer.fit(train_df) # Never fit on full dataset!
# Transform data
X_train_tok, X_train_val = tokenizer.transform(train_df)
X_test_tok, X_test_val = tokenizer.transform(test_df)
# Create model with configuration
config = TTConfig(
n_features=len(features),
n_bins=32,
embed_dim=64,
n_heads=4,
n_layers=3,
)
model = TabTransformerGated(config)
# Train using the Trainer class
train_config = TrainingConfig(epochs=10, batch_size=1024, learning_rate=2e-3)
trainer = Trainer(model=model, config=train_config)
# ... or use train_tabular for the full residual learning pipeline# Save model checkpoint
ModelFactory.save_checkpoint(model, config, "model.pt")
# Load model checkpoint
model, config = ModelFactory.from_checkpoint("model.pt")
# Save/load tokenizer
tokenizer.save("tokenizer.pkl")
loaded_tokenizer = TabularTokenizer.load("tokenizer.pkl")Register your own base models using the factory pattern:
from tab_transformer_plus_plus import BaseModelFactory
from sklearn.linear_model import Ridge
# Register a custom model
BaseModelFactory.register("ridge", Ridge)
# Use in training
config = BaseModelConfig(model_type="ridge", hyperparams={"alpha": 1.0})from tab_transformer_plus_plus import (
EarlyStoppingCallback,
LRSchedulerCallback,
CheckpointCallback,
)
callbacks = [
EarlyStoppingCallback(patience=5, min_delta=0.001),
LRSchedulerCallback(scheduler_type="cosine"),
CheckpointCallback(save_dir="checkpoints/", save_best_only=True),
]
trainer = Trainer(model=model, config=train_config, callbacks=callbacks)All hyperparameters are centralized in the TTConfig and TrainingConfig classes:
| Category | Parameter | Default | Description |
|---|---|---|---|
| Tokenization | n_bins |
32 | Quantile bins for numeric features |
| Architecture | embed_dim |
64 | Embedding dimension (d_model) |
n_heads |
4 | Multi-head attention heads | |
n_layers |
3 | Transformer encoder layers | |
mlp_hidden |
192 | Prediction head hidden dim | |
| Regularization | dropout |
0.1 | Attention & FFN dropout |
emb_dropout |
0.05 | Post-embedding dropout | |
tokendrop_p |
0.12 | TokenDrop probability | |
| Training | epochs |
10 | Training epochs |
batch_size |
1024 | Batch size | |
learning_rate |
2e-3 | AdamW learning rate | |
weight_decay |
0.01 | AdamW weight decay |
Access default values from constants:
from tab_transformer_plus_plus import (
DEFAULT_EPOCHS,
DEFAULT_BATCH_SIZE,
DEFAULT_LEARNING_RATE,
DEFAULT_N_BINS,
GATE_INIT_VALUE,
)Tab-Transformer-Plus-Plus/
├── README.md # This documentation
├── CHANGELOG.md # Version history and changes
├── LICENSE # MIT License
├── pyproject.toml # Package configuration
├── requirements.txt # Dependencies
├── TabTransformer_Residual_Learning.ipynb # Interactive notebook demo
├── src/
│ └── tab_transformer_plus_plus/
│ ├── __init__.py # Package exports and public API
│ ├── base_models.py # BaseModelFactory and ensemble logic
│ ├── cli.py # Command-line interface
│ ├── configs.py # TTConfig, TrainingConfig, etc.
│ ├── constants.py # Default values and magic numbers
│ ├── data_loader.py # Data loading and splitting utilities
│ ├── exceptions.py # Custom exception hierarchy
│ ├── metrics.py # MetricRegistry and compute functions
│ ├── model.py # TabTransformerGated model (vectorized)
│ ├── protocols.py # Protocol classes for type safety
│ ├── tokenizer.py # TabularTokenizer (with serialization)
│ ├── train.py # High-level training pipeline
│ ├── trainer.py # Trainer class with callbacks
│ └── utils.py # Utility and visualization functions
└── tests/
├── test_cli.py # CLI argument parsing tests
├── test_integration.py # End-to-end integration tests
├── test_model.py # Model architecture tests
├── test_tokenizer.py # Tokenizer edge case tests
└── test_utils.py # Utility function tests
If you use this code, please cite the original TabTransformer paper:
@article{huang2020tabtransformer,
title={TabTransformer: Tabular Data Modeling Using Contextual Embeddings},
author={Huang, Xin and Khetan, Ashish and Cvitkovic, Milan and Karnin, Zohar},
journal={arXiv preprint arXiv:2012.06678},
year={2020}
}- TabTransformer Paper — Huang et al. (2020)
- tab-transformer-pytorch — Reference implementation by lucidrains
- California Housing Dataset — Demo dataset from scikit-learn
This project is licensed under the MIT License. See LICENSE for details.