Skip to content

Conversation

@nurtas-m
Copy link
Collaborator

Add new stop criteria. Previously number of the allowable number of epochs without improvement was fixed. Training stopped if no improvement was seen during this window. New stop criteria include possibility of increasing of that window depending on the number of epochs needed for previous improvements during training.

Copy link
Contributor

@nshmyrev nshmyrev left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs work

README.md Outdated
# for the model with size 512:
--max_steps 150000
```

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should work by default, not with options.

num_iter_cover_train = int(sum(train_bucket_sizes) /
self.params.batch_size /
self.params.steps_per_checkpoint)
current_step, iter_inx, num_epochs_last_impr, max_num_epochs,\
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"inx" is unclear abbreviation.

self.params.batch_size /
self.params.steps_per_checkpoint)
current_step, iter_inx, num_epochs_last_impr, max_num_epochs,\
num_up_trends, num_down_trends = 0, 0, 0, 2, 0, 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"trend" is not a proper word here.

current_step, iter_inx, num_epochs_last_impr, max_num_epochs,\
num_up_trends, num_down_trends = 0, 0, 0, 2, 0, 0
prev_train_losses, prev_valid_losses, prev_epoch_valid_losses = [], [], []
num_iter_cover_train = max(1, int(sum(train_bucket_sizes) /
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is called "epoch", no? "cover" is not a good word in this context.

if len(prev_epoch_valid_losses) > 0:
print('Previous min epoch eval loss: %f, current epoch eval loss: %f' %
(min(prev_epoch_valid_losses), epoch_eval_loss))
# Check if there was improvement during last epoch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

an improvement

# Check if there was improvement during last epoch
if (epoch_eval_loss < min(prev_epoch_valid_losses)):
if num_epochs_last_impr > max_num_epochs/1.5:
max_num_epochs = int(1.5 * num_epochs_last_impr)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1.5 must be separated into something standalone and properly named, not used in multiple places in the code without name.



def __calc_epoch_loss(self, epoch_losses):
"""Calculate average loss during the epoch.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment is wrong, this is not really an average.

prev_train_losses.append(train_loss)
prev_valid_losses.append(eval_loss)
step_time, train_loss = 0.0, 0.0
iter_idx += 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use current step instead of iter_idx

num_epochs_last_impr = 0
else:
print('No improvement during last epoch.')
num_epochs_last_impr += 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

epochs_without_improvement

num_iter_total += num_iter_cover_valid
for batch_id in xrange(num_iter_cover_valid):
iter_total += iter_per_valid
for batch_id in xrange(iter_per_valid):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use xrange with batch_size step.

self.params.batch_size))
num_iter_total += num_iter_cover_valid
for batch_id in xrange(num_iter_cover_valid):
iter_total += iter_per_valid
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Count iter_total in inner loop

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Development

Successfully merging this pull request may close these issues.

3 participants