Meta Networks is a meta-learning algorithm that learns to generate task-specific parameters (called "fast weights") through a meta-learner network. Unlike MAML which learns a good initialization for gradient-based adaptation, Meta Networks directly produce classifier parameters from support set examples.
Paper: Meta Networks - Munkhdalai & Yu, ICML 2017
This implementation is a variant of the original Meta Networks algorithm, commonly known as Embedding-based Meta Networks.
| Aspect | This Implementation | Original Meta Networks |
|---|---|---|
| Name | Embedding-based Meta Networks | Meta Networks |
| Category | 🎯 Metric-based Meta Learning | 🏗️ Model-based Meta Learning |
| Approach | Generates task-specific embeddings for metric learning | Generates weights for the entire base network |
| Fast Weights | Used for computing similarity metrics | Used as actual network parameters |
| Similarity | Closer to Matching Networks / Prototypical Networks | More general weight generation framework |
- ✅ This variant uses the meta-learner to generate task-specific embeddings that are used in a metric-based classification approach (similar to prototypical networks)
- 🔜 Original approach (coming soon): The meta-learner predicts the actual weights of the base network, making it a true model-based meta-learning algorithm
The original Meta Networks implementation (where the meta-learner predicts base network weights) will be added to this repository next. This is the more commonly used and cited approach from the original paper.
-
Embedding Network (
EmbeddingNetwork)- 4 convolutional layers with batch normalization
- Extracts fixed-dimensional feature embeddings from images
- Input: 105×105 grayscale images
- Output: 64-dimensional embeddings
-
Meta-Learner (
MetaLearner)- Learns three key parameters:
- U Matrix (hidden_dim × embedding_dim): Projects support embeddings
- V Matrix (hidden_dim × embedding_dim): Projects query embeddings
- e Vector (hidden_dim): Base embedding capturing task structure
- Generates task-specific classifier weights from support set
- Learns three key parameters:
-
Meta Network (
MetaNetwork)- Combines EmbeddingNetwork and MetaLearner
- End-to-end trainable system
For each batch of tasks:
- Extract embeddings from support and query sets using EmbeddingNetwork
- Generate fast weights from support embeddings using MetaLearner
- For each support example (x_i, y_i):
- Compute embedding: h_i = EmbeddingNetwork(x_i)
- Project: r_i = U @ h_i
- Add base: w_i = r_i + e
- Average per class: W_c = mean(w_i for all i where y_i = c)
- For each support example (x_i, y_i):
- Classify queries using generated weights
- For each query x:
- Compute embedding: h = EmbeddingNetwork(x)
- Project: query_proj = V @ h
- Compute logits: logits_c = query_proj^T @ W_c
- For each query x:
- Backpropagate loss to update U, V, e and EmbeddingNetwork
- Single forward pass - no gradient-based adaptation needed!
- Very fast compared to MAML's inner loop optimization
- Meta-learner directly generates optimal classifier
- ~ Accuracy: 75.5%
- Competitive with MAML
- Much faster inference (no adaptation loop)
| Aspect | MAML | Meta Networks |
|---|---|---|
| Approach | Learn good initialization | Learn to generate parameters |
| Adaptation | Gradient-based (inner loop) | Direct generation (meta-learner) |
| Inference Speed | Slower (requires gradient steps) | Faster (single forward pass) |
| Parameters | Model parameters | U, V, e matrices + embedding network |
| Memory | Higher (computation graph) | Lower (no inner loop gradients) |
from algorithms.eb_meta_network import MetaNetwork, train_meta_network
from utils.load_omniglot import OmniglotDataset, OmniglotTaskDataset
from torch.utils.data import DataLoader
# Load data
dataset = OmniglotDataset("omniglot/images_background")
task_dataset = OmniglotTaskDataset(dataset, n_way=5, k_shot=1, k_query=15, num_tasks=2000)
dataloader = DataLoader(task_dataset, batch_size=4, shuffle=True)
# Create and train model
model = MetaNetwork(embedding_dim=64, hidden_dim=128, num_classes=5)
model, optimizer, losses = train_meta_network(
model=model,
task_dataloader=dataloader,
learning_rate=0.001
)from algorithms.eb_meta_network import evaluate_meta_network
from evaluation.eval_visualization import plot_evaluation_results
# Evaluate on test tasks
eval_results = evaluate_meta_network(
model=model,
eval_dataloader=test_dataloader,
num_classes=5
)
# Visualize results
plot_evaluation_results(eval_results)-
EB_Meta_Network.py: Complete Meta Networks implementationEmbeddingNetwork: Feature extractorMetaLearner: Fast weight generatorMetaNetwork: Combined systemtrain_meta_network(): Training functionevaluate_meta_network(): Evaluation function
-
examples/embedding_based_meta_network.ipynb: Complete tutorial notebook -
Shared utilities (used by both MAML and Meta Networks):
utils/load_omniglot.py: Dataset loadersutils/evaluate.py: Visualization functions
- Embedding dimension: 64
- Hidden dimension: 128 (for U, V, e)
- Learning rate: 0.001
- Optimizer: Adam
- Gradient clipping: max_norm=1.0
- Batch size: 4 tasks
- N-way: 5 (5 characters per task)
- K-shot: 1 (1 support example per class)
- K-query: 15 (15 query examples per class)
- Training tasks: 2000
- Test tasks: 200
The U, V, and e parameters learn to:
- U: Extract task-relevant features from support embeddings
- V: Transform query embeddings for classification
- e: Capture base task structure (what's common across tasks)
Together, they form a powerful mechanism for rapid classifier generation!
-
Meta Networks Paper: Munkhdalai & Yu, "Meta Networks", ICML 2017
-
Omniglot Dataset: Lake et al., "Human-level concept learning through probabilistic program induction"
-
Related: MAML (Model-Agnostic Meta-Learning)
- See
docs/MAML_vs_FOMAML.mdfor comparison
- See
See examples/embedding_based_meta_network.ipynb for a complete step-by-step tutorial with:
- Data loading and visualization
- Model architecture explanation
- Training from scratch
- Evaluation and visualization
- Comparison with MAML
Happy Meta-Learning! 🤖💡