Skip to content

Latest commit

 

History

History
371 lines (281 loc) · 15.2 KB

File metadata and controls

371 lines (281 loc) · 15.2 KB

Meta Dropout in Meta Networks

🎯 Overview

Meta Dropout has been successfully integrated into both Meta Networks implementations to ensure consistent dropout masks across support and query sets within the same task. This is crucial for Meta Networks because the meta-learner needs to process consistent embeddings to generate effective task-specific parameters.

This document covers Meta Dropout integration in:

  1. Embedding-based Meta Networks (Metric-based Meta Learning)
  2. Original Meta Networks (Model-based Meta Learning)

Both implementations share the same EmbeddingNetwork with Meta Dropout, ensuring consistent regularization across different meta-learning paradigms.


📚 Two Meta Networks Implementations

🎯 Embedding-based Meta Networks (eb_meta_network.py)

  • Category: Metric-based Meta Learning
  • Approach: Generates task-specific embeddings for similarity-based classification
  • Meta-learner Output: Task-specific embeddings for query examples
  • Classification: Similarity between query embeddings and class prototypes

🏗️ Original Meta Networks (original_meta_network.py)

  • Category: Model-based Meta Learning
  • Approach: Generates actual FC layer weights and biases
  • Meta-learner Output: Weight matrix W [embedding_dim × num_classes] and bias vector b [num_classes]
  • Classification: Direct multiplication: logits = query_embeddings @ W + b

🔗 Shared Component: EmbeddingNetwork

Both implementations use the same EmbeddingNetwork (embedding_network.py) with Meta Dropout:

  • Consistent CNN architecture (4 conv layers with batch norm)
  • Meta Dropout at 3 strategic locations (after conv1, conv2, conv3)
  • Shared code promotes consistency and easy comparison between algorithms

🔑 Why Meta Dropout for Meta Networks?

Unlike standard dropout which generates independent random masks for each sample, Meta Dropout:

  1. Shares the same spatial dropout mask across all samples in support and query sets
  2. Resets masks per task to ensure different tasks get different regularization
  3. Maintains consistency throughout the task processing pipeline

Standard Dropout Problem

# ❌ Standard Dropout (nn.Dropout)
support_emb = embedding_network(support_data)  # Gets mask A, B, C, D, E
query_emb = embedding_network(query_data)      # Gets mask F, G, H, I, J

# Problem: Meta-learner sees inconsistent embeddings!
# Support embeddings used by meta-learner have different dropout
# than query embeddings used for final predictions

Meta Dropout Solution

# ✅ Meta Dropout
embedding_network.reset_dropout_masks(support_data.shape, device)
support_emb = embedding_network(support_data)  # Gets mask X (shared)
query_emb = embedding_network(query_data)      # Gets mask X (same!)

# Solution: Meta-learner processes consistent embeddings!
# Both support and query use the same spatial dropout pattern

🏗️ Implementation Details

1. Shared EmbeddingNetwork with Meta Dropout

# From embedding_network.py
class EmbeddingNetwork(nn.Module):
    def __init__(self, embedding_dim: int = 64, dropout_rates: list = None):
        super(EmbeddingNetwork, self).__init__()
        
        # Default dropout configuration
        if dropout_rates is None:
            dropout_rates = [0.05, 0.10, 0.15]
        
        # Convolutional layers
        self.conv1 = nn.Conv2d(1, 64, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        # ... conv2, conv3, conv4 ...
        
        # Meta Dropout layers at strategic locations
        self.dropout1 = MetaDropout(p=dropout_rates[0])  # After conv1
        self.dropout2 = MetaDropout(p=dropout_rates[1])  # After conv2
        self.dropout3 = MetaDropout(p=dropout_rates[2])  # After conv3
        
        # Fully connected layer (no classification head)
        self.fc = nn.Linear(64 * 6 * 6, embedding_dim)
    
    def reset_dropout_masks(self, input_shape, device):
        """Reset masks for new task - shapes calculated for BEFORE pooling"""
        self.dropout1.reset_mask((1, 64, 105, 105), device)  # After conv1
        self.dropout2.reset_mask((1, 64, 52, 52), device)    # After conv2  
        self.dropout3.reset_mask((1, 64, 26, 26), device)    # After conv3
        self._masks_initialized = True
    
    def forward(self, x):
        # Layer 1
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.dropout1(x) if self.training and not self.force_eval else x
        x = self.pool(x)  # 52x52
        
        # Layer 2
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.dropout2(x) if self.training and not self.force_eval else x
        x = self.pool(x)  # 26x26
        
        # Layer 3
        x = F.relu(self.bn3(self.conv3(x)))
        x = self.dropout3(x) if self.training and not self.force_eval else x
        x = self.pool(x)  # 13x13
        
        # Layer 4 (no dropout)
        x = F.relu(self.bn4(self.conv4(x)))
        x = self.pool(x)  # 6x6
        
        # Flatten and project to embedding space
        x = self.flatten(x)
        x = self.fc(x)
        return x

2. Embedding-based Meta Networks Usage

# From eb_meta_network.py
class MetaNetwork(nn.Module):
    def __init__(self, embedding_dim=64, hidden_dim=128, num_classes=5):
        super(MetaNetwork, self).__init__()
        # Use shared embedding network
        self.embedding_network = EmbeddingNetwork(embedding_dim)
        self.meta_learner = MetaLearner(embedding_dim, hidden_dim, num_classes)
    
    def forward(self, support_data, support_labels, query_data):
        # Reset dropout masks for this task
        self.embedding_network.reset_dropout_masks(support_data.shape, support_data.device)
        
        # Extract embeddings with consistent dropout
        support_embeddings = self.embedding_network(support_data)
        query_embeddings = self.embedding_network(query_data)
        
        # Generate fast weights and classify using similarity
        logits = self.meta_learner(support_embeddings, support_labels, query_embeddings)
        return logits

3. Original Meta Networks Usage

# From original_meta_network.py
class OriginalMetaNetwork(nn.Module):
    def __init__(self, embedding_dim=64, hidden_dim=128, num_classes=5):
        super(OriginalMetaNetwork, self).__init__()
        # Use same shared embedding network
        self.embedding_network = EmbeddingNetwork(embedding_dim)
        self.meta_learner = MetaLearner(embedding_dim, hidden_dim, num_classes)
    
    def forward(self, support_data, support_labels, query_data):
        # Note: Automatic mask reset happens in embedding_network
        # Extract embeddings with consistent dropout
        support_embeddings = self.embedding_network(support_data)
        query_embeddings = self.embedding_network(query_data)
        
        # Meta-learner generates W and b, then classifies
        logits = self.meta_learner(support_embeddings, support_labels, query_embeddings)
        return logits

Key Point: Both implementations call the same EmbeddingNetwork, which automatically manages Meta Dropout masks to ensure consistency within each task.


✅ Validation Tests

All tests pass (test_meta_network_dropout.py):

Test 1: Mask Consistency ✅

Support and query embeddings: IDENTICAL
Max difference: 0.000000
✅ PASS: Support and query use the SAME dropout masks!

Test 2: Different Tasks ✅

Task 1 vs Task 2 difference: 0.930160
Task 2 vs Task 3 difference: 1.359920
✅ PASS: Different tasks get DIFFERENT dropout masks!

Test 3: Forward Pass ✅

Output shape: torch.Size([15, 5])
✅ PASS: Forward pass successful with correct output shape!

Test 4: Eval Mode ✅

Difference between runs: 0.0000000000
✅ PASS: Dropout correctly disabled in eval mode!

📊 Experimental Results

Embedding-based Meta Networks (5-Way 1-Shot Omniglot)

We trained two Embedding-based Meta Network models with identical hyperparameters (2000 tasks, batch size 8, learning rate 0.001) and evaluated them on 200 test tasks:

Configuration Accuracy Std Dev Loss High-Performing Tasks (>80%)
Without Dropout 75.8% ±10.4% N/A N/A
Meta Dropout [0.05, 0.10, 0.15] 77.3% ±11.9% N/A N/A
Improvement +1.5% +1.5% (14.4% increase) N/A N/A

Key Findings:

  • +1.5% accuracy improvement through consistent regularization
  • ⚠️ +14.4% variance increase (acceptable trade-off for better accuracy)
  • 💡 Insight: Meta Dropout helps the meta-learner learn better embedding generation strategies

Original Meta Networks (5-Way 1-Shot Omniglot)

We trained two Original Meta Network models with identical hyperparameters (2000 tasks, batch size 16, learning rate 0.001) and evaluated them on 200 test tasks:

Configuration Accuracy Std Dev Loss High-Performing Tasks (>80%)
Without Dropout 84.15% ±10.27% 0.4480 149/200 (74.5%)
Meta Dropout [0.05, 0.10, 0.15, 0.05] 86.31% ±9.07% 0.3836 159/200 (79.5%)
Improvement +2.16% -1.2% (11.7% decrease) -14.4% +5.0%

Key Findings:

  • +2.16% accuracy improvement - larger gain than Embedding-based variant!
  • -11.7% variance reduction - more consistent performance across tasks
  • -14.4% loss reduction - better confidence in predictions
  • +5.0% more tasks with >80% accuracy - improved reliability
  • 💡 Insight: Meta Dropout helps weight/bias generators predict more effective and consistent classifier parameters

Task Distribution with Meta Dropout:

  • 100% of tasks achieved >50% accuracy (perfect reliability)
  • 79.5% of tasks achieved >80% accuracy (up from 74.5%)
  • 43.5% of tasks achieved >90% accuracy (up from 32.5%)

📊 Side-by-Side Comparison

Metric Embedding-based MN Original MN
Accuracy Gain +1.5% +2.16% 🏆
Variance Change +14.4% worse -11.7% better 🏆
Loss Reduction N/A -14.4% 🏆
Reliability Gain N/A +5.0% tasks >80% 🏆

Winner: Original Meta Networks - Meta Dropout provides significantly better benefits for Original Meta Networks compared to the Embedding-based variant!


🔬 Comparison: Meta Dropout Across Algorithms

Complete Algorithm Comparison

Algorithm Accuracy Change Variance Change Why?
MAML +1.2% -8.9% Multiple gradient steps with consistent masks → stable adaptation
Embedding-based MN +1.5% +14.4% ⚠️ Direct generation with consistent embeddings → better but more varied
Original MN +2.16% 🏆 -11.7% Weight prediction with consistent embeddings → best of both worlds!

Key Insights:

  1. All Algorithms Benefit from Meta Dropout

    • Consistent regularization improves accuracy across all paradigms
    • Gradient-based (MAML): +1.2% accuracy
    • Metric-based (Embedding MN): +1.5% accuracy
    • Model-based (Original MN): +2.16% accuracy - largest improvement!
  2. Variance Effects Differ by Approach

    • MAML: Reduced variance (-8.9%) - consistent gradient signals across adaptation
    • Embedding-based MN: Increased variance (+14.4%) - more diverse similarity patterns
    • Original MN: Reduced variance (-11.7%)** ✅ - consistent parameter generation with better convergence
  3. Why Original Meta Networks Benefits Most 🎯

    • Direct parameter generation: Weight/bias generators see consistent embeddings
    • No gradient accumulation noise: Unlike MAML, predictions are deterministic given embeddings
    • Simpler optimization: Consistent inputs lead to more stable weight predictions
    • Better generalization: Regularization during training transfers to more robust generated parameters
  4. Mechanism Differences

    • MAML: Dropout masks consistent across multiple gradient steps
    • Embedding-based MN: Dropout masks consistent for embedding generation
    • Original MN: Dropout masks consistent for parameter generation → most direct benefit

🏆 Performance Rankings

By Accuracy Improvement:

  1. 🥇 Original MN: +2.16%
  2. 🥈 Embedding-based MN: +1.5%
  3. 🥉 MAML: +1.2%

By Variance Reduction (Lower is Better):

  1. 🥇 Original MN: -11.7% (improved consistency)
  2. 🥈 MAML: -8.9% (improved consistency)
  3. 🥉 Embedding-based MN: +14.4% (worse consistency)

Overall Best Performance: 🏆 Original Meta Networks - Achieves both the highest accuracy gain AND reduced variance, making it the clear winner for Meta Dropout integration!


🔍 Deep Dive: Why Original Meta Networks Excels with Meta Dropout

The Parameter Generation Advantage

Original Meta Networks Architecture:

Support Set → EmbeddingNetwork (with Meta Dropout) → Consistent Embeddings
                                                              ↓
                                     MetaLearner (U, V, e) processes
                                                              ↓
                              Weight Generator → W [64×5]
                              Bias Generator → b [5]
                                                              ↓
Query Set → EmbeddingNetwork (same masks!) → Consistent Embeddings
                                                              ↓
                                    Classify: logits = query @ W + b

Why This Works So Well:

  1. Consistency Throughout Pipeline

    • Support embeddings used to generate W and b are consistent
    • Query embeddings classified by W and b use the same dropout masks
    • No mismatch between "training" (support) and "testing" (query) representations
  2. Generator Stability 🎯

    • Weight/bias generators learn from consistent support embeddings
    • Reduces noise in the generator's input distribution
    • More reliable parameter predictions → lower variance
  3. Regularization Transfer 🔄

    • Dropout during training teaches generators to be robust
    • Generated parameters work well with masked features
    • Better generalization to test tasks
  4. No Gradient Accumulation

    • Unlike MAML, no inner loop gradient steps
    • Predictions are pure functions of embeddings
    • Consistency directly translates to better outputs

Comparison with Other Algorithms:

Why Embedding-based MN has increased variance:

  • Metric-based classification is sensitive to embedding variations
  • Similarity computations amplify small differences
  • More diverse dropout patterns → more varied similarity scores

Why MAML has reduced variance but lower accuracy gain:

  • Multiple gradient steps can smooth out some inconsistencies
  • But adaptation process adds its own noise
  • Smaller accuracy gain due to gradient-based optimization complexity

Why Original MN achieves the best of both worlds:

  • Direct parameter prediction without gradient steps
  • Full pipeline consistency (support → W,b → query classification)
  • Simpler, more direct optimization → larger gains