Skip to content

Bug found in the code for BERT-K and BERT-LS  #2

@twadada

Description

@twadada

Hi, I've found a bug in the following method that is used for BERT-K and BERT-LS encoding sentences.

**** In "methods/bert.py/_logits_for_dropout_target" ****
with torch.no_grad():
embeddings = self.bert.embeddings(self.list_to_tensor(context_target_enc))

(Omitting Dropout procedure)

logits = self.bert_mlm(inputs_embeds=embeddings)[0][0, target_tok_off]


Here, you use "self.bert.embeddings" to generate input embeddings (i.e. "embeddings"), but this class method returns "token embeddings + token_type_embeddings + positional embedings". However, what self.bert_mlm takes as "inputs_embeds" is only the token embeddings. So I think your code adds "token_type_embeddings" and "positional embeddings" twice to the token embeddings. To fix this, I think the first line needs to be changed to "self.bert.embeddings.word_embeddings(self.list_to_tensor(context_target_enc))".

Best,

Takashi

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions