Skip to content

Conversation

@lehendo
Copy link
Contributor

@lehendo lehendo commented Jan 25, 2026

This PR adds the k-means clustering conformal prediction.

  • Implemented ClusterLabel for cluster-specific calibration thresholds using K-means on patient embeddings
  • Added example script tuev_kmeans_conformal.py demonstrating usage on TUEV EEG dataset
  • Includes test suite with coverage for initialization, calibration, and prediction

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds K-means clustering-based conformal prediction (ClusterLabel) for multiclass classification in EEG analysis. The method groups patients into clusters using K-means on embeddings and computes cluster-specific calibration thresholds to improve prediction set efficiency compared to global thresholds.

Changes:

  • Implemented ClusterLabel class that performs K-means clustering on patient embeddings and applies cluster-specific calibration thresholds
  • Added comprehensive test suite covering initialization, calibration, and prediction
  • Included example script demonstrating usage on TUEV EEG dataset with ContraWR model

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
pyhealth/calib/predictionset/cluster/cluster_label.py Core implementation of cluster-based conformal prediction with K-means clustering
pyhealth/calib/predictionset/cluster/init.py Module initialization exporting ClusterLabel class
pyhealth/calib/predictionset/init.py Updated to export ClusterLabel from cluster submodule
tests/core/test_cluster_label.py Test suite covering initialization, calibration, and forward pass with various configurations
examples/conformal_eeg/tuev_kmeans_conformal.py Example script demonstrating ClusterLabel usage on TUEV EEG events dataset

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

alpha = np.asarray(alpha)
self.alpha = alpha

# Store clustering parameters
Copy link

Copilot AI Jan 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing input validation for n_clusters parameter. The parameter should be validated to ensure it's a positive integer (n_clusters > 0). Invalid values like 0 or negative numbers would cause sklearn's KMeans to raise an error during calibration, but it's better to fail fast with a clear error message during initialization.

Suggested change
# Store clustering parameters
# Store clustering parameters
if not isinstance(n_clusters, int) or n_clusters <= 0:
raise ValueError(
f"n_clusters must be a positive integer, got {n_clusters!r}"
)

Copilot uses AI. Check for mistakes.
Comment on lines +293 to +310
cluster_id = self.kmeans_model.predict(sample_embedding)[0]

# Get cluster-specific threshold
cluster_threshold = self.cluster_thresholds[cluster_id]

# Convert to tensor if needed
if isinstance(cluster_threshold, np.ndarray):
cluster_threshold = torch.tensor(
cluster_threshold, device=self.device, dtype=pred["y_prob"].dtype
)
else:
cluster_threshold = torch.tensor(
cluster_threshold, device=self.device, dtype=pred["y_prob"].dtype
)

# Construct prediction set using cluster-specific threshold
pred["y_predset"] = pred["y_prob"] >= cluster_threshold

Copy link

Copilot AI Jan 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical batch handling bug: The forward method only uses the first sample's cluster assignment for the entire batch. At line 293, cluster_id = self.kmeans_model.predict(sample_embedding)[0] extracts only the first element from the cluster predictions array. When processing a batch with multiple samples, this causes all samples in the batch to use the same threshold (from the first sample's cluster), leading to incorrect prediction sets for all but the first sample.

The method needs to handle batches properly by:

  1. Predicting cluster assignments for all samples in the batch
  2. Retrieving the appropriate threshold for each sample based on its cluster
  3. Applying per-sample thresholds to construct prediction sets

This will require either processing samples individually in a loop or implementing vectorized batch operations with per-sample threshold application.

Suggested change
cluster_id = self.kmeans_model.predict(sample_embedding)[0]
# Get cluster-specific threshold
cluster_threshold = self.cluster_thresholds[cluster_id]
# Convert to tensor if needed
if isinstance(cluster_threshold, np.ndarray):
cluster_threshold = torch.tensor(
cluster_threshold, device=self.device, dtype=pred["y_prob"].dtype
)
else:
cluster_threshold = torch.tensor(
cluster_threshold, device=self.device, dtype=pred["y_prob"].dtype
)
# Construct prediction set using cluster-specific threshold
pred["y_predset"] = pred["y_prob"] >= cluster_threshold
# Predict cluster assignments for all samples in the batch
cluster_ids = self.kmeans_model.predict(sample_embedding)
# Get cluster-specific thresholds for each sample
cluster_thresholds = np.array(
[self.cluster_thresholds[cid] for cid in cluster_ids]
)
# Convert thresholds to tensor on the correct device/dtype
cluster_thresholds = torch.as_tensor(
cluster_thresholds, device=self.device, dtype=pred["y_prob"].dtype
)
# Reshape thresholds to broadcast over any non-batch dimensions of y_prob
if pred["y_prob"].ndim > 1:
# (batch_size,) -> (batch_size, 1, ..., 1)
view_shape = (cluster_thresholds.shape[0],) + (1,) * (pred["y_prob"].ndim - 1)
cluster_thresholds = cluster_thresholds.view(view_shape)
# Construct prediction set using per-sample, cluster-specific thresholds
pred["y_predset"] = pred["y_prob"] >= cluster_thresholds

Copilot uses AI. Check for mistakes.
Comment on lines +275 to +286
# Get base model prediction
pred = self.model(**kwargs)

# Extract embedding for this sample to assign to cluster
embed_kwargs = {**kwargs, "embed": True}
embed_output = self.model(**embed_kwargs)
if "embed" not in embed_output:
raise ValueError(
f"Model {type(self.model).__name__} does not return "
"embeddings. Make sure the model supports the "
"embed=True flag in its forward() method."
)
Copy link

Copilot AI Jan 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Performance issue: The forward method calls the model twice for each prediction - once at line 276 to get probabilities and once at line 280 to get embeddings. This doubles the computational cost and memory usage during inference.

The model should be called only once with embed=True to get both predictions and embeddings in a single forward pass. The returned dictionary should contain both 'y_prob' (or similar prediction keys) and 'embed'. This would require verifying that models support returning both outputs simultaneously when embed=True is set.

Copilot uses AI. Check for mistakes.
Comment on lines +174 to +259
print("Extracting embeddings from calibration set...")
cal_embeddings = extract_embeddings(
self.model, cal_dataset, batch_size=32, device=self.device
)
else:
cal_embeddings = np.asarray(cal_embeddings)

if train_embeddings is None:
raise ValueError(
"train_embeddings must be provided. "
"Extract embeddings from training set using extract_embeddings()."
)
else:
train_embeddings = np.asarray(train_embeddings)

# Combine embeddings for clustering
print(f"Combining embeddings: train={train_embeddings.shape}, cal={cal_embeddings.shape}")
all_embeddings = np.concatenate([train_embeddings, cal_embeddings], axis=0)
print(f"Total embeddings for clustering: {all_embeddings.shape}")

# Fit K-means on combined embeddings
print(f"Fitting K-means with {self.n_clusters} clusters...")
self.kmeans_model = KMeans(
n_clusters=self.n_clusters,
random_state=self.random_state,
n_init=10,
)
self.kmeans_model.fit(all_embeddings)

# Assign calibration samples to clusters
# Note: cal_embeddings start at index len(train_embeddings) in all_embeddings
cal_start_idx = len(train_embeddings)
cal_cluster_labels = self.kmeans_model.labels_[cal_start_idx:]

print(f"Cluster assignments: {np.bincount(cal_cluster_labels)}")

# Compute conformity scores (probabilities of true class)
conformity_scores = y_prob[np.arange(N), y_true]

# Compute cluster-specific thresholds
self.cluster_thresholds = {}
for cluster_id in range(self.n_clusters):
cluster_mask = cal_cluster_labels == cluster_id
cluster_scores = conformity_scores[cluster_mask]

if len(cluster_scores) == 0:
print(
f"Warning: No calibration samples in cluster {cluster_id}, "
"using -inf threshold (include all classes)"
)
if isinstance(self.alpha, float):
self.cluster_thresholds[cluster_id] = -np.inf
else:
self.cluster_thresholds[cluster_id] = np.array(
[-np.inf] * K
)
else:
if isinstance(self.alpha, float):
# Marginal coverage: single threshold per cluster
t = _query_quantile(cluster_scores, self.alpha)
self.cluster_thresholds[cluster_id] = t
else:
# Class-conditional coverage: one threshold per class per cluster
if len(self.alpha) != K:
raise ValueError(
f"alpha must have length {K} for class-conditional "
f"coverage, got {len(self.alpha)}"
)
t = []
for k in range(K):
class_mask = (y_true[cluster_mask] == k)
if np.sum(class_mask) > 0:
class_scores = cluster_scores[class_mask]
t_k = _query_quantile(class_scores, self.alpha[k])
else:
# If no calibration examples for this class in this cluster
print(
f"Warning: No calibration examples for class {k} "
f"in cluster {cluster_id}, using -inf threshold"
)
t_k = -np.inf
t.append(t_k)
self.cluster_thresholds[cluster_id] = np.array(t)

if self.debug:
print(f"Cluster thresholds: {self.cluster_thresholds}")
Copy link

Copilot AI Jan 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The print statements in the calibrate method should use a logging framework or be controlled by the debug flag. Lines 174, 190, 192, 195, 208, 220-222, and 250-252 unconditionally print to stdout, which can clutter output in production usage. Consider using Python's logging module or only printing when self.debug is True, consistent with line 259 which gates debug output.

Copilot uses AI. Check for mistakes.
Comment on lines +290 to +291
if sample_embedding.ndim == 1:
sample_embedding = sample_embedding.reshape(1, -1)
Copy link

Copilot AI Jan 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The condition at line 290 checking if embeddings are 1D appears to be attempting to handle single-sample inputs, but this conflicts with the batch-based usage throughout the codebase (as evidenced by the tests using batch_size=2 and the example script using batch_size=32). This suggests unclear design intent - the method should be designed to handle batches consistently. If single-sample support is needed, it should be handled as a batch of size 1, not as a special case with different dimensionality.

Suggested change
if sample_embedding.ndim == 1:
sample_embedding = sample_embedding.reshape(1, -1)
# Ensure embeddings are always treated as a batch (even for single samples)
sample_embedding = np.atleast_2d(sample_embedding)

Copilot uses AI. Check for mistakes.
Comment on lines +175 to +176
cal_embeddings = extract_embeddings(
self.model, cal_dataset, batch_size=32, device=self.device
Copy link

Copilot AI Jan 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The batch_size parameter is hardcoded to 32 at line 176. This should use a configurable parameter or match the batch size used elsewhere in the method. Consider adding a batch_size parameter to the calibrate method or using a class attribute to control this, especially since users may want to adjust it based on available memory.

Suggested change
cal_embeddings = extract_embeddings(
self.model, cal_dataset, batch_size=32, device=self.device
batch_size = getattr(self, "batch_size", 32)
cal_embeddings = extract_embeddings(
self.model, cal_dataset, batch_size=batch_size, device=self.device

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant