Meta Dropout has been successfully integrated into both Meta Networks implementations to ensure consistent dropout masks across support and query sets within the same task. This is crucial for Meta Networks because the meta-learner needs to process consistent embeddings to generate effective task-specific parameters.
This document covers Meta Dropout integration in:
- Embedding-based Meta Networks (Metric-based Meta Learning)
- Original Meta Networks (Model-based Meta Learning)
Both implementations share the same EmbeddingNetwork with Meta Dropout, ensuring consistent regularization across different meta-learning paradigms.
- Category: Metric-based Meta Learning
- Approach: Generates task-specific embeddings for similarity-based classification
- Meta-learner Output: Task-specific embeddings for query examples
- Classification: Similarity between query embeddings and class prototypes
- Category: Model-based Meta Learning
- Approach: Generates actual FC layer weights and biases
- Meta-learner Output: Weight matrix W [embedding_dim × num_classes] and bias vector b [num_classes]
- Classification: Direct multiplication:
logits = query_embeddings @ W + b
Both implementations use the same EmbeddingNetwork (embedding_network.py) with Meta Dropout:
- Consistent CNN architecture (4 conv layers with batch norm)
- Meta Dropout at 3 strategic locations (after conv1, conv2, conv3)
- Shared code promotes consistency and easy comparison between algorithms
Unlike standard dropout which generates independent random masks for each sample, Meta Dropout:
- ✅ Shares the same spatial dropout mask across all samples in support and query sets
- ✅ Resets masks per task to ensure different tasks get different regularization
- ✅ Maintains consistency throughout the task processing pipeline
# ❌ Standard Dropout (nn.Dropout)
support_emb = embedding_network(support_data) # Gets mask A, B, C, D, E
query_emb = embedding_network(query_data) # Gets mask F, G, H, I, J
# Problem: Meta-learner sees inconsistent embeddings!
# Support embeddings used by meta-learner have different dropout
# than query embeddings used for final predictions# ✅ Meta Dropout
embedding_network.reset_dropout_masks(support_data.shape, device)
support_emb = embedding_network(support_data) # Gets mask X (shared)
query_emb = embedding_network(query_data) # Gets mask X (same!)
# Solution: Meta-learner processes consistent embeddings!
# Both support and query use the same spatial dropout pattern# From embedding_network.py
class EmbeddingNetwork(nn.Module):
def __init__(self, embedding_dim: int = 64, dropout_rates: list = None):
super(EmbeddingNetwork, self).__init__()
# Default dropout configuration
if dropout_rates is None:
dropout_rates = [0.05, 0.10, 0.15]
# Convolutional layers
self.conv1 = nn.Conv2d(1, 64, 3, padding=1)
self.bn1 = nn.BatchNorm2d(64)
# ... conv2, conv3, conv4 ...
# Meta Dropout layers at strategic locations
self.dropout1 = MetaDropout(p=dropout_rates[0]) # After conv1
self.dropout2 = MetaDropout(p=dropout_rates[1]) # After conv2
self.dropout3 = MetaDropout(p=dropout_rates[2]) # After conv3
# Fully connected layer (no classification head)
self.fc = nn.Linear(64 * 6 * 6, embedding_dim)
def reset_dropout_masks(self, input_shape, device):
"""Reset masks for new task - shapes calculated for BEFORE pooling"""
self.dropout1.reset_mask((1, 64, 105, 105), device) # After conv1
self.dropout2.reset_mask((1, 64, 52, 52), device) # After conv2
self.dropout3.reset_mask((1, 64, 26, 26), device) # After conv3
self._masks_initialized = True
def forward(self, x):
# Layer 1
x = F.relu(self.bn1(self.conv1(x)))
x = self.dropout1(x) if self.training and not self.force_eval else x
x = self.pool(x) # 52x52
# Layer 2
x = F.relu(self.bn2(self.conv2(x)))
x = self.dropout2(x) if self.training and not self.force_eval else x
x = self.pool(x) # 26x26
# Layer 3
x = F.relu(self.bn3(self.conv3(x)))
x = self.dropout3(x) if self.training and not self.force_eval else x
x = self.pool(x) # 13x13
# Layer 4 (no dropout)
x = F.relu(self.bn4(self.conv4(x)))
x = self.pool(x) # 6x6
# Flatten and project to embedding space
x = self.flatten(x)
x = self.fc(x)
return x# From eb_meta_network.py
class MetaNetwork(nn.Module):
def __init__(self, embedding_dim=64, hidden_dim=128, num_classes=5):
super(MetaNetwork, self).__init__()
# Use shared embedding network
self.embedding_network = EmbeddingNetwork(embedding_dim)
self.meta_learner = MetaLearner(embedding_dim, hidden_dim, num_classes)
def forward(self, support_data, support_labels, query_data):
# Reset dropout masks for this task
self.embedding_network.reset_dropout_masks(support_data.shape, support_data.device)
# Extract embeddings with consistent dropout
support_embeddings = self.embedding_network(support_data)
query_embeddings = self.embedding_network(query_data)
# Generate fast weights and classify using similarity
logits = self.meta_learner(support_embeddings, support_labels, query_embeddings)
return logits# From original_meta_network.py
class OriginalMetaNetwork(nn.Module):
def __init__(self, embedding_dim=64, hidden_dim=128, num_classes=5):
super(OriginalMetaNetwork, self).__init__()
# Use same shared embedding network
self.embedding_network = EmbeddingNetwork(embedding_dim)
self.meta_learner = MetaLearner(embedding_dim, hidden_dim, num_classes)
def forward(self, support_data, support_labels, query_data):
# Note: Automatic mask reset happens in embedding_network
# Extract embeddings with consistent dropout
support_embeddings = self.embedding_network(support_data)
query_embeddings = self.embedding_network(query_data)
# Meta-learner generates W and b, then classifies
logits = self.meta_learner(support_embeddings, support_labels, query_embeddings)
return logitsKey Point: Both implementations call the same EmbeddingNetwork, which automatically manages Meta Dropout masks to ensure consistency within each task.
All tests pass (test_meta_network_dropout.py):
Support and query embeddings: IDENTICAL
Max difference: 0.000000
✅ PASS: Support and query use the SAME dropout masks!
Task 1 vs Task 2 difference: 0.930160
Task 2 vs Task 3 difference: 1.359920
✅ PASS: Different tasks get DIFFERENT dropout masks!
Output shape: torch.Size([15, 5])
✅ PASS: Forward pass successful with correct output shape!
Difference between runs: 0.0000000000
✅ PASS: Dropout correctly disabled in eval mode!
We trained two Embedding-based Meta Network models with identical hyperparameters (2000 tasks, batch size 8, learning rate 0.001) and evaluated them on 200 test tasks:
| Configuration | Accuracy | Std Dev | Loss | High-Performing Tasks (>80%) |
|---|---|---|---|---|
| Without Dropout | 75.8% | ±10.4% | N/A | N/A |
| Meta Dropout [0.05, 0.10, 0.15] | 77.3% | ±11.9% | N/A | N/A |
| Improvement | +1.5% | +1.5% (14.4% increase) | N/A | N/A |
Key Findings:
- ✅ +1.5% accuracy improvement through consistent regularization
⚠️ +14.4% variance increase (acceptable trade-off for better accuracy)- 💡 Insight: Meta Dropout helps the meta-learner learn better embedding generation strategies
We trained two Original Meta Network models with identical hyperparameters (2000 tasks, batch size 16, learning rate 0.001) and evaluated them on 200 test tasks:
| Configuration | Accuracy | Std Dev | Loss | High-Performing Tasks (>80%) |
|---|---|---|---|---|
| Without Dropout | 84.15% | ±10.27% | 0.4480 | 149/200 (74.5%) |
| Meta Dropout [0.05, 0.10, 0.15, 0.05] | 86.31% | ±9.07% | 0.3836 | 159/200 (79.5%) |
| Improvement | +2.16% | -1.2% (11.7% decrease) | -14.4% | +5.0% |
Key Findings:
- ✅ +2.16% accuracy improvement - larger gain than Embedding-based variant!
- ✅ -11.7% variance reduction - more consistent performance across tasks
- ✅ -14.4% loss reduction - better confidence in predictions
- ✅ +5.0% more tasks with >80% accuracy - improved reliability
- 💡 Insight: Meta Dropout helps weight/bias generators predict more effective and consistent classifier parameters
Task Distribution with Meta Dropout:
- 100% of tasks achieved >50% accuracy (perfect reliability)
- 79.5% of tasks achieved >80% accuracy (up from 74.5%)
- 43.5% of tasks achieved >90% accuracy (up from 32.5%)
| Metric | Embedding-based MN | Original MN |
|---|---|---|
| Accuracy Gain | +1.5% | +2.16% 🏆 |
| Variance Change | +14.4% worse | -11.7% better 🏆 |
| Loss Reduction | N/A | -14.4% 🏆 |
| Reliability Gain | N/A | +5.0% tasks >80% 🏆 |
Winner: Original Meta Networks - Meta Dropout provides significantly better benefits for Original Meta Networks compared to the Embedding-based variant!
| Algorithm | Accuracy Change | Variance Change | Why? |
|---|---|---|---|
| MAML | +1.2% | -8.9% ✅ | Multiple gradient steps with consistent masks → stable adaptation |
| Embedding-based MN | +1.5% | +14.4% |
Direct generation with consistent embeddings → better but more varied |
| Original MN | +2.16% 🏆 | -11.7% ✅ | Weight prediction with consistent embeddings → best of both worlds! |
-
All Algorithms Benefit from Meta Dropout
- Consistent regularization improves accuracy across all paradigms
- Gradient-based (MAML): +1.2% accuracy
- Metric-based (Embedding MN): +1.5% accuracy
- Model-based (Original MN): +2.16% accuracy - largest improvement!
-
Variance Effects Differ by Approach
- MAML: Reduced variance (-8.9%) - consistent gradient signals across adaptation
- Embedding-based MN: Increased variance (+14.4%) - more diverse similarity patterns
- Original MN: Reduced variance (-11.7%)** ✅ - consistent parameter generation with better convergence
-
Why Original Meta Networks Benefits Most 🎯
- Direct parameter generation: Weight/bias generators see consistent embeddings
- No gradient accumulation noise: Unlike MAML, predictions are deterministic given embeddings
- Simpler optimization: Consistent inputs lead to more stable weight predictions
- Better generalization: Regularization during training transfers to more robust generated parameters
-
Mechanism Differences
- MAML: Dropout masks consistent across multiple gradient steps
- Embedding-based MN: Dropout masks consistent for embedding generation
- Original MN: Dropout masks consistent for parameter generation → most direct benefit
By Accuracy Improvement:
- 🥇 Original MN: +2.16%
- 🥈 Embedding-based MN: +1.5%
- 🥉 MAML: +1.2%
By Variance Reduction (Lower is Better):
- 🥇 Original MN: -11.7% (improved consistency)
- 🥈 MAML: -8.9% (improved consistency)
- 🥉 Embedding-based MN: +14.4% (worse consistency)
Overall Best Performance: 🏆 Original Meta Networks - Achieves both the highest accuracy gain AND reduced variance, making it the clear winner for Meta Dropout integration!
Original Meta Networks Architecture:
Support Set → EmbeddingNetwork (with Meta Dropout) → Consistent Embeddings
↓
MetaLearner (U, V, e) processes
↓
Weight Generator → W [64×5]
Bias Generator → b [5]
↓
Query Set → EmbeddingNetwork (same masks!) → Consistent Embeddings
↓
Classify: logits = query @ W + b
-
Consistency Throughout Pipeline ✅
- Support embeddings used to generate W and b are consistent
- Query embeddings classified by W and b use the same dropout masks
- No mismatch between "training" (support) and "testing" (query) representations
-
Generator Stability 🎯
- Weight/bias generators learn from consistent support embeddings
- Reduces noise in the generator's input distribution
- More reliable parameter predictions → lower variance
-
Regularization Transfer 🔄
- Dropout during training teaches generators to be robust
- Generated parameters work well with masked features
- Better generalization to test tasks
-
No Gradient Accumulation ⚡
- Unlike MAML, no inner loop gradient steps
- Predictions are pure functions of embeddings
- Consistency directly translates to better outputs
Why Embedding-based MN has increased variance:
- Metric-based classification is sensitive to embedding variations
- Similarity computations amplify small differences
- More diverse dropout patterns → more varied similarity scores
Why MAML has reduced variance but lower accuracy gain:
- Multiple gradient steps can smooth out some inconsistencies
- But adaptation process adds its own noise
- Smaller accuracy gain due to gradient-based optimization complexity
Why Original MN achieves the best of both worlds:
- Direct parameter prediction without gradient steps
- Full pipeline consistency (support → W,b → query classification)
- Simpler, more direct optimization → larger gains