Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 38 additions & 4 deletions rmgpy/data/kinetics/family.py
Original file line number Diff line number Diff line change
Expand Up @@ -1144,11 +1144,24 @@ def get_rate_rule(self, template):
raise ValueError('No entry for template {0}.'.format(template))
return entry

def add_rules_from_training(self, thermo_database=None, train_indices=None):
def add_rules_from_training(self, thermo_database=None, train_indices=None, constraints={}):
"""
For each reaction involving real reactants and products in the training
set, add a rate rule for that reaction.
set, add a rate rule for that reaction if it does not violate any provided `constraints`

`constraints` (dict):
"metal" : list of allowed metals (e.g ['Pt']) or None (all allowed)
"facet" : list of allowed facets (e.g ['111']) or None (all allowed)
"elements" : list of allowed elements (e.g ['C','H','O']) or None (all allowed)
"forward_only" : True/False, if True, only reactions in the forward direction are allowed
"""

# Parse the constraints for selection of training reactions
allowed_metals = constraints.get('metal')
allowed_facets = constraints.get('facet')
allowed_elements = constraints.get('elements')
forward_only = constraints.get('forward_only', False)

try:
depository = self.get_training_depository()
except:
Expand All @@ -1173,12 +1186,33 @@ def add_rules_from_training(self, thermo_database=None, train_indices=None):

reverse_entries = []
for entry in entries:
# skip entry if it has an element that is not allowed
if allowed_elements is not None:
violates_element_constraint = False
for reactant in entry.item.reactants:
for element in reactant.molecule[0].get_element_count().keys():
if element not in allowed_elements:
violates_element_constraint = True
break
if violates_element_constraint:
continue
if entry.item.is_surface_reaction():
# skip entry if the metal is not allowed
if allowed_metals:
if entry.metal not in allowed_metals:
continue
# skip entry if the facet is not allowed
if allowed_facets:
if entry.facet not in allowed_facets:
continue

try:
template = self.get_reaction_template(entry.item)
except UndeterminableKineticsError:
# Some entries might be stored in the reverse direction for
# this family; save them so we can try this
reverse_entries.append(entry)
# this family; save them so we can try this in reverse if `forward only` is False
if not forward_only:
reverse_entries.append(entry)
continue

tentries[entry.index].item.is_forward = True
Expand Down
29 changes: 29 additions & 0 deletions rmgpy/rmg/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,26 @@ def generated_species_constraints(**kwargs):

rmg.species_constraints[key] = value

def training_reactions_constraints(**kwargs):
valid_constraints = [
'metal',
'facet',
'elements',
'forward_only'
]

for key, value in kwargs.items():
if key not in valid_constraints:
raise InputError('Invalid generated species constraint {0!r}.'.format(key))
if key == 'forward_only':
if not isinstance(value, bool):
raise InputError('Invalid value for `forward_only` constraint {0!r}. Value must be a bool (True or False)'.format(value))
rmg.training_reactions_constraints[key] = value
continue
if not isinstance(value, list):
value = [value]
rmg.training_reactions_constraints[key] = [str(v) for v in value]


def thermo_central_database(host,
port,
Expand Down Expand Up @@ -968,6 +988,7 @@ def read_input_file(path, rmg0):
'pressureDependence': pressure_dependence,
'options': options,
'generatedSpeciesConstraints': generated_species_constraints,
'trainingReactionsConstraints': training_reactions_constraints,
'thermoCentralDatabase': thermo_central_database,
'uncertainty': uncertainty,
'restartFromSeed': restart_from_seed,
Expand Down Expand Up @@ -1217,6 +1238,14 @@ def save_input_file(path, rmg):
f.write(' {0} = {1},\n'.format(constraint, value))
f.write(')\n\n')

# Training Reactions Constraints
if rmg.training_reactions_constraints:
f.write('trainingReactionsConstraints(\n')
for constraint, value in sorted(list(rmg.training_reactions_constraints.items()), key=lambda constraint: constraint[0]):
if value is not None:
f.write(' {0} = {1},\n'.format(constraint, value))
f.write(')\n\n')

# Options
f.write('options(\n')
f.write(' units = "{0}",\n'.format(rmg.units))
Expand Down
3 changes: 2 additions & 1 deletion rmgpy/rmg/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def clear(self):
self.ml_estimator = None
self.ml_settings = None
self.species_constraints = {}
self.training_reactions_constraints = {}
self.walltime = '00:00:00:00'
self.save_seed_modulus = -1
self.max_iterations = None
Expand Down Expand Up @@ -417,7 +418,7 @@ def load_database(self):
self.species_constraints = {}
for family in self.database.kinetics.families.values():
if not family.auto_generated:
family.add_rules_from_training(thermo_database=self.database.thermo)
family.add_rules_from_training(thermo_database=self.database.thermo, constraints=self.training_reactions_constraints)

# If requested by the user, write a text file for each kinetics family detailing the source of each entry
if self.kinetics_datastore:
Expand Down