-
Notifications
You must be signed in to change notification settings - Fork 10
Open
Description
Hello,
When I run the suggested command python train.py --dump_path /some_path_on_your_computer/ --exp_name my_first_experiment --exp_id 1 --operation "gcd" I get the following error:
INFO - 01/11/25 20:38:27 - 0:16:02 - ============ End of epoch 3 ============
INFO - 01/11/25 20:38:27 - 0:16:02 - Creating valid iterator for arithmetic ...
INFO - 01/11/25 20:38:27 - 0:16:02 - (128/10000) Found 110/128 valid top-1 predictions. Generating solutions ...
INFO - 01/11/25 20:38:27 - 0:16:02 - Found 110/128 solutions in beam hypotheses.
INFO - 01/11/25 20:38:27 - 0:16:02 - (256/10000) Found 110/128 valid top-1 predictions. Generating solutions ...
INFO - 01/11/25 20:38:27 - 0:16:02 - Found 110/128 solutions in beam hypotheses.
INFO - 01/11/25 20:38:27 - 0:16:02 - (384/10000) Found 113/128 valid top-1 predictions. Generating solutions ...
INFO - 01/11/25 20:38:27 - 0:16:02 - Found 113/128 solutions in beam hypotheses.
INFO - 01/11/25 20:38:27 - 0:16:02 - (512/10000) Found 107/128 valid top-1 predictions. Generating solutions ...
INFO - 01/11/25 20:38:27 - 0:16:02 - Found 107/128 solutions in beam hypotheses.
INFO - 01/11/25 20:38:27 - 0:16:02 - (640/10000) Found 105/128 valid top-1 predictions. Generating solutions ...
INFO - 01/11/25 20:38:27 - 0:16:02 - Found 105/128 solutions in beam hypotheses.
INFO - 01/11/25 20:38:27 - 0:16:02 - (768/10000) Found 103/128 valid top-1 predictions. Generating solutions ...
INFO - 01/11/25 20:38:27 - 0:16:02 - Found 103/128 solutions in beam hypotheses.
INFO - 01/11/25 20:38:27 - 0:16:02 - (896/10000) Found 106/128 valid top-1 predictions. Generating solutions ...
Traceback (most recent call last):
File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/aj659-on20241119/code/Users/aj659/Int2Int/train.py", line 345, in <module>
main(params)
File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/aj659-on20241119/code/Users/aj659/Int2Int/train.py", line 296, in main
scores = evaluator.run_all_evals()
File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/aj659-on20241119/code/Users/aj659/Int2Int/src/evaluator.py", line 91, in run_all_evals
self.enc_dec_step(data_type, task, scores)
File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/aj659-on20241119/code/Users/aj659/Int2Int/src/evaluator.py", line 283, in enc_dec_step
generated, _ = decoder.generate(
File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/aj659-on20241119/code/Users/aj659/Int2Int/src/model/transformer.py", line 714, in generate
generated[-1].masked_fill_(unfinished_sents.byte(), self.eos_index)
RuntimeError: masked_fill only supports boolean masks, but got dtype Byte
When I replace the .byte() with .bool() the error seems to be resolved.
Metadata
Metadata
Assignees
Labels
No labels