diff --git a/rmgpy/data/kinetics/family.py b/rmgpy/data/kinetics/family.py index 31b7695632..dd6793f1b7 100644 --- a/rmgpy/data/kinetics/family.py +++ b/rmgpy/data/kinetics/family.py @@ -1144,11 +1144,24 @@ def get_rate_rule(self, template): raise ValueError('No entry for template {0}.'.format(template)) return entry - def add_rules_from_training(self, thermo_database=None, train_indices=None): + def add_rules_from_training(self, thermo_database=None, train_indices=None, constraints={}): """ For each reaction involving real reactants and products in the training - set, add a rate rule for that reaction. + set, add a rate rule for that reaction if it does not violate any provided `constraints` + + `constraints` (dict): + "metal" : list of allowed metals (e.g ['Pt']) or None (all allowed) + "facet" : list of allowed facets (e.g ['111']) or None (all allowed) + "elements" : list of allowed elements (e.g ['C','H','O']) or None (all allowed) + "forward_only" : True/False, if True, only reactions in the forward direction are allowed """ + + # Parse the constraints for selection of training reactions + allowed_metals = constraints.get('metal') + allowed_facets = constraints.get('facet') + allowed_elements = constraints.get('elements') + forward_only = constraints.get('forward_only', False) + try: depository = self.get_training_depository() except: @@ -1173,12 +1186,33 @@ def add_rules_from_training(self, thermo_database=None, train_indices=None): reverse_entries = [] for entry in entries: + # skip entry if it has an element that is not allowed + if allowed_elements is not None: + violates_element_constraint = False + for reactant in entry.item.reactants: + for element in reactant.molecule[0].get_element_count().keys(): + if element not in allowed_elements: + violates_element_constraint = True + break + if violates_element_constraint: + continue + if entry.item.is_surface_reaction(): + # skip entry if the metal is not allowed + if allowed_metals: + if entry.metal not in allowed_metals: + continue + # skip entry if the facet is not allowed + if allowed_facets: + if entry.facet not in allowed_facets: + continue + try: template = self.get_reaction_template(entry.item) except UndeterminableKineticsError: # Some entries might be stored in the reverse direction for - # this family; save them so we can try this - reverse_entries.append(entry) + # this family; save them so we can try this in reverse if `forward only` is False + if not forward_only: + reverse_entries.append(entry) continue tentries[entry.index].item.is_forward = True diff --git a/rmgpy/rmg/input.py b/rmgpy/rmg/input.py index 4017004b6e..3ec62ccb78 100644 --- a/rmgpy/rmg/input.py +++ b/rmgpy/rmg/input.py @@ -827,6 +827,26 @@ def generated_species_constraints(**kwargs): rmg.species_constraints[key] = value +def training_reactions_constraints(**kwargs): + valid_constraints = [ + 'metal', + 'facet', + 'elements', + 'forward_only' + ] + + for key, value in kwargs.items(): + if key not in valid_constraints: + raise InputError('Invalid generated species constraint {0!r}.'.format(key)) + if key == 'forward_only': + if not isinstance(value, bool): + raise InputError('Invalid value for `forward_only` constraint {0!r}. Value must be a bool (True or False)'.format(value)) + rmg.training_reactions_constraints[key] = value + continue + if not isinstance(value, list): + value = [value] + rmg.training_reactions_constraints[key] = [str(v) for v in value] + def thermo_central_database(host, port, @@ -968,6 +988,7 @@ def read_input_file(path, rmg0): 'pressureDependence': pressure_dependence, 'options': options, 'generatedSpeciesConstraints': generated_species_constraints, + 'trainingReactionsConstraints': training_reactions_constraints, 'thermoCentralDatabase': thermo_central_database, 'uncertainty': uncertainty, 'restartFromSeed': restart_from_seed, @@ -1217,6 +1238,14 @@ def save_input_file(path, rmg): f.write(' {0} = {1},\n'.format(constraint, value)) f.write(')\n\n') + # Training Reactions Constraints + if rmg.training_reactions_constraints: + f.write('trainingReactionsConstraints(\n') + for constraint, value in sorted(list(rmg.training_reactions_constraints.items()), key=lambda constraint: constraint[0]): + if value is not None: + f.write(' {0} = {1},\n'.format(constraint, value)) + f.write(')\n\n') + # Options f.write('options(\n') f.write(' units = "{0}",\n'.format(rmg.units)) diff --git a/rmgpy/rmg/main.py b/rmgpy/rmg/main.py index b986f915c1..d2933e8700 100644 --- a/rmgpy/rmg/main.py +++ b/rmgpy/rmg/main.py @@ -221,6 +221,7 @@ def clear(self): self.ml_estimator = None self.ml_settings = None self.species_constraints = {} + self.training_reactions_constraints = {} self.walltime = '00:00:00:00' self.save_seed_modulus = -1 self.max_iterations = None @@ -417,7 +418,7 @@ def load_database(self): self.species_constraints = {} for family in self.database.kinetics.families.values(): if not family.auto_generated: - family.add_rules_from_training(thermo_database=self.database.thermo) + family.add_rules_from_training(thermo_database=self.database.thermo, constraints=self.training_reactions_constraints) # If requested by the user, write a text file for each kinetics family detailing the source of each entry if self.kinetics_datastore: