diff --git a/flowtron.py b/flowtron.py index 06dea68..eac0a30 100644 --- a/flowtron.py +++ b/flowtron.py @@ -589,7 +589,7 @@ def forward(self, mel, speaker_vecs, text, in_lens, out_lens): [text, speaker_vecs.expand(text.size(0), -1, -1)], 2) log_s_list = [] attns_list = [] - mask = ~get_mask_from_lengths(in_lens)[..., None] + mask = ~get_mask_from_lengths(in_lens)[..., None].bool() for i, flow in enumerate(self.flows): mel, log_s, gate, attn = flow( mel, encoder_outputs, mask, out_lens)