From c4cd6b97ad173a0331a15f41f8e88d2e5a73db0d Mon Sep 17 00:00:00 2001 From: daylight-00 Date: Wed, 29 Oct 2025 23:48:25 +0900 Subject: [PATCH 1/3] fix: improve boundary handling in one_hot_buckets function - Clamp bucket indices to valid range [0, n-1] - Fix off-by-one error in bucket assignment - Ensure values below 'low' go to first bucket - Ensure values above 'high' go to last bucket --- rf_diffusion/conditions/v2.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/rf_diffusion/conditions/v2.py b/rf_diffusion/conditions/v2.py index 6b8b2fb..5022cef 100644 --- a/rf_diffusion/conditions/v2.py +++ b/rf_diffusion/conditions/v2.py @@ -48,10 +48,11 @@ def one_hot_buckets(a, low, high, n, eps=1e-6): First category absorbs anything below low Last category absorbs anything above high ''' - step = (high-low) / n - bins = torch.linspace(low+step, high-step, n-1) - cat = torch.bucketize(a, bins).long() - return F.one_hot(cat, num_classes=n) + a = a.float() + buckets = torch.linspace(low, high+eps, n+1) + bucket_idx = torch.searchsorted(buckets, a) - 1 + bucket_idx = torch.clamp(bucket_idx, 0, n-1) + return F.one_hot(bucket_idx, n) def init_radius_of_gyration(indep, feature_conf, feature_inference_conf, **kwargs): """ From 05442a31ab3b1d393408a6643e16451985fe6ed6 Mon Sep 17 00:00:00 2001 From: daylight-00 Date: Wed, 29 Oct 2025 23:50:19 +0900 Subject: [PATCH 2/3] feat: add metadata support for atom-wise feature computation Enable ligand atom information to be accessible during feature computation: - Store metadata (including ligand_atom_names) in Sampler.sample_init() - Pass metadata to get_extra_tXd_inference() calls via kwargs - Enables features like atom-wise RASA to map atom names to values This infrastructure is required for per-atom feature specification, where different atoms in a ligand can have different conditioning values. --- rf_diffusion/inference/model_runners.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/rf_diffusion/inference/model_runners.py b/rf_diffusion/inference/model_runners.py index 57f3a3b..a9c4283 100644 --- a/rf_diffusion/inference/model_runners.py +++ b/rf_diffusion/inference/model_runners.py @@ -192,6 +192,7 @@ def sample_init(self, i_des=0): """ indep_uncond, self.indep_orig, self.indep_cond, metadata, self.is_diffused, self.atomizer, contig_map, t_step_input, self.conditions_dict = self.dataset[i_des % len(self.dataset)] indep = self.indep_cond.clone() + self.metadata = metadata return indep, contig_map, self.atomizer, t_step_input def symmetrise_prev_pred(self, px0, seq_in, alpha): @@ -236,7 +237,7 @@ def sample_step(self, t, indep, rfo, extra, features_cache): extra_tXd_names = getattr(self._conf, 'extra_tXd', []) t_cont = t/self._conf.diffuser.T - indep.extra_t1d, indep.extra_t2d = features.get_extra_tXd_inference(indep, extra_tXd_names, self._conf.extra_tXd_params, self._conf.inference.conditions, t_cont=t_cont, features_cache=features_cache, **self.conditions_dict) + indep.extra_t1d, indep.extra_t2d = features.get_extra_tXd_inference(indep, extra_tXd_names, self._conf.extra_tXd_params, self._conf.inference.conditions, t_cont=t_cont, features_cache=features_cache, metadata=getattr(self, 'metadata', {}), **self.conditions_dict) rfi = self.model_adaptor.prepro(indep, t, self.is_diffused) rf2aa.tensor_util.to_device(rfi, self.device) @@ -323,7 +324,7 @@ class FlowMatching(Sampler): def run_model(self, t, indep, rfo, is_diffused, features_cache): extra_tXd_names = getattr(self._conf, 'extra_tXd', []) t_cont = t/self._conf.diffuser.T - indep.extra_t1d, indep.extra_t2d = features.get_extra_tXd_inference(indep, extra_tXd_names, self._conf.extra_tXd_params, self._conf.inference.conditions, t_cont=t_cont, features_cache=features_cache, **self.conditions_dict) + indep.extra_t1d, indep.extra_t2d = features.get_extra_tXd_inference(indep, extra_tXd_names, self._conf.extra_tXd_params, self._conf.inference.conditions, t_cont=t_cont, features_cache=features_cache, metadata=getattr(self, 'metadata', {}), **self.conditions_dict) rfi = self.model_adaptor.prepro(indep, t, is_diffused) rf2aa.tensor_util.to_device(rfi, self.device) @@ -524,12 +525,14 @@ class FlowMatching_make_conditional_diffuse_all(FlowMatching_make_conditional): def sample_init(self, i_des=0): indep_uncond, self.indep_orig, self.indep_cond, metadata, self.is_diffused, atomizer, contig_map, t_step_input, self.conditions_dict = self.dataset[i_des % len(self.dataset)] + self.metadata = metadata return indep_uncond, contig_map, atomizer, t_step_input class FlowMatching_make_conditional_diffuse_all_xt_unfrozen(FlowMatching): def sample_init(self, i_des=0): indep_uncond, self.indep_orig, self.indep_cond, metadata, self.is_diffused, atomizer, contig_map, t_step_input, self.conditions_dict = self.dataset[i_des % len(self.dataset)] + self.metadata = metadata return indep_uncond, contig_map, atomizer, t_step_input def sample_step(self, t, indep, rfo, extra, features_cache): @@ -548,6 +551,7 @@ class ClassifierFreeGuidance(FlowMatching): # WIP def sample_init(self, i_des=0): indep_uncond, self.indep_orig, self.indep_cond, metadata, self.is_diffused, atomizer, contig_map, t_step_input, self.conditions_dict = self.dataset[i_des % len(self.dataset)] + self.metadata = metadata return indep_uncond, contig_map, atomizer, t_step_input def get_grads(self, t, indep_in, indep_t, rfo, is_diffused, features_cache): From 3dab2315cc2ad6bf12734d6b56017b079bbc1a96 Mon Sep 17 00:00:00 2001 From: daylight-00 Date: Thu, 30 Oct 2025 19:07:34 +0900 Subject: [PATCH 3/3] feat: implement atom-wise RASA specification for ligands Add parse_atomwise_rasa_config() to support per-atom RASA values: - Parse configuration strings like '0.0,O7:0.8,C8:1.0,C9:1.0' - Map ligand atom names to specific RASA values - Fall back to global RASA when atom-specific values not provided - Add validation and logging for atom matching Update get_relative_sasa_inference() to: - Use new parsing function - Accept metadata via kwargs - Print summary statistics of applied RASA values --- rf_diffusion/conditions/v2.py | 74 +++++++++++++++++++++++++++++++++-- 1 file changed, 71 insertions(+), 3 deletions(-) diff --git a/rf_diffusion/conditions/v2.py b/rf_diffusion/conditions/v2.py index 5022cef..83ebce2 100644 --- a/rf_diffusion/conditions/v2.py +++ b/rf_diffusion/conditions/v2.py @@ -119,23 +119,91 @@ def get_radius_of_gyration_inference(indep, feature_conf, feature_inference_conf ic(out[0:2, :], out[-3:-1, :]) return out +def parse_atomwise_rasa_config(rasa_config, indep, metadata): + """ + Parse the atomwise RASA configuration string and create a per-atom RASA tensor. + + Args: + rasa_config (str or float): Either a single float for global RASA, or + a string like "0.0,O7:0.8,C8:1.0,C9:1.0" + indep (Indep): The indep object containing is_sm mask + metadata (dict): Metadata containing ligand_atom_names + + Returns: + torch.Tensor: Per-atom RASA values for the entire indep + """ + rasa = torch.full((indep.length(),), 0.0) + + # If it's just a number, apply globally to small molecules + if isinstance(rasa_config, (float, int)): + rasa[indep.is_sm] = float(rasa_config) + return rasa + + # Parse the string format: "global_value,atom1:value1,atom2:value2,..." + config_str = str(rasa_config) + parts = [p.strip() for p in config_str.split(',')] + global_value = float(parts[0]) + rasa[indep.is_sm] = global_value + + if not metadata or 'ligand_atom_names' not in metadata: + print("[RASA WARNING] No metadata or ligand_atom_names found, using global RASA") + return rasa + + # Build atom name to specific RASA mapping + atom_rasa_map = {} + for part in parts[1:]: + if ':' not in part: + continue + atom_name, value_str = part.split(':', 1) + atom_rasa_map[atom_name.strip()] = float(value_str.strip()) + if not atom_rasa_map: + return rasa + + # Apply atom-specific values + ligand_atom_names = metadata['ligand_atom_names'] + sm_indices = torch.where(indep.is_sm)[0] + n_sm_atoms = len(sm_indices) + + # Validate indices + if n_sm_atoms > len(ligand_atom_names): + print(f"[RASA ERROR] More SM atoms ({n_sm_atoms}) than ligand names ({len(ligand_atom_names)})") + return rasa + + # Ligand atom names are stored at the end of the array + ligand_names_start = len(ligand_atom_names) - n_sm_atoms + matched_atoms = [] + for i, sm_idx in enumerate(sm_indices): + ligand_name_idx = ligand_names_start + i + if ligand_name_idx < len(ligand_atom_names): + atom_name_in_metadata = ligand_atom_names[ligand_name_idx].strip() + if atom_name_in_metadata in atom_rasa_map: + rasa[sm_idx] = atom_rasa_map[atom_name_in_metadata] + matched_atoms.append(f"{atom_name_in_metadata}={atom_rasa_map[atom_name_in_metadata]}") + + if matched_atoms: + print(f"[RASA] Set atom-specific RASA for {len(matched_atoms)} atoms: {', '.join(matched_atoms)}") + + return rasa def get_relative_sasa_inference(indep, feature_conf, feature_inference_conf, cache, **kwargs): """ - Calculates the radius of gyration fature + Calculates the relative SASA feature with support for atom-wise specification Args: indep (Indep): The holy indep. feature_conf (omegaconf): The feature config. feature_inference_conf (omegaconf): The feature inference config. cache (dict): data cache + **kwargs: Additional keyword arguments including metadata Returns: - sasa feature + dict: Dictionary with 't1d' key containing the SASA feature tensor """ if not feature_inference_conf.active: return {'t1d':torch.zeros((indep.length(), feature_conf.n_bins + 1))} - rasa = torch.full((indep.length(),), feature_inference_conf.rasa) + + metadata = kwargs.get('metadata', {}) + rasa = parse_atomwise_rasa_config(feature_inference_conf.rasa, indep, metadata) one_hot = one_hot_buckets(rasa, feature_conf.low, feature_conf.high, feature_conf.n_bins) is_feature_applicable = indep.is_sm one_hot[~is_feature_applicable] = 0