Skip to content

Commit 1c16e54

Browse files
committed
deploy: caf2b61
1 parent 46cba34 commit 1c16e54

File tree

3 files changed

+43
-11
lines changed

3 files changed

+43
-11
lines changed

etc/compute_embeddings.py

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
import argparse
22
import json
3+
from timeit import default_timer as timer
4+
from datetime import date
35

46
import numpy as np
57
import torch
8+
import torch.nn.functional as F
69
import sklearn.manifold
710
import transformers
811

@@ -13,33 +16,62 @@ def parse_arguments():
1316
parser.add_argument("json", default=False, help="the path the json containing all papers.")
1417
parser.add_argument("outpath", default=False, help="the target path of the visualizations papers.")
1518
parser.add_argument("--seed", default=0, help="The seed for TSNE.", type=int)
19+
parser.add_argument("--model", default='sentence-transformers/all-MiniLM-L6-v2', help="The name of the HF model")
20+
parser.add_argument("--save_emb", action='store_true', help="Save embeddings in CSV for Tensorboard Projector")
21+
1622
return parser.parse_args()
1723

24+
def mean_pooling(token_embeddings, attention_mask):
25+
""" Mean Pooling, takes attention mask into account for correct averaging"""
26+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
27+
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
1828

19-
if __name__ == "__main__":
20-
args = parse_arguments()
21-
tokenizer = transformers.AutoTokenizer.from_pretrained("deepset/sentence_bert")
22-
model = transformers.AutoModel.from_pretrained("deepset/sentence_bert")
29+
def main(args):
30+
tokenizer = transformers.AutoTokenizer.from_pretrained(args.model)
31+
model = transformers.AutoModel.from_pretrained(args.model)
2332
model.eval()
2433

2534
with open(args.json) as f:
2635
data = json.load(f)
2736

2837
print(f"Num papers: {len(data)}")
2938

30-
all_embeddings = []
39+
corpus = []
3140
for paper_info in data:
41+
corpus.append(tokenizer.sep_token.join([paper_info['title'], paper_info['abstract']]))
42+
43+
batch_size = 4
44+
all_embeddings=[]
45+
start = timer()
46+
for i in range(0, len(corpus), batch_size):
47+
encoded_batch = tokenizer(corpus[i:min(i+batch_size, len(corpus))], padding=True, truncation=True, return_tensors='pt')
3248
with torch.no_grad():
33-
token_ids = torch.tensor([tokenizer.encode(paper_info["abstract"])][:512])
34-
hidden_states, _ = model(token_ids)[-2:]
35-
all_embeddings.append(hidden_states.mean(0).mean(0).numpy())
49+
hidden_state = model(**encoded_batch).last_hidden_state
50+
all_embeddings.append(mean_pooling(hidden_state, encoded_batch['attention_mask']))
51+
52+
all_embeddings = torch.cat(all_embeddings, dim=0)
53+
all_embeddings = F.normalize(all_embeddings, p=2, dim=1)
54+
print(f"elapsed {timer()-start:.1f}s")
55+
56+
if args.save_emb:
57+
filename = f"{args.model.replace('/', '_')}-{date.today().strftime('%d.%m.%y')}"
58+
np.savetxt(f"{filename}-emb.tsv", all_embeddings, delimiter="\t")
59+
import csv
60+
with open(f"{filename}-meta.tsv", 'w', newline='') as csvfile:
61+
w = csv.writer(csvfile, delimiter='\t', quoting=csv.QUOTE_MINIMAL)
62+
w.writerow(["year", "key", "title"])
63+
for paper in data:
64+
w.writerow([paper["year"], paper["key"], paper["title"]])
3665

3766
np.random.seed(args.seed)
38-
all_embeddings = np.array(all_embeddings)
3967
out = sklearn.manifold.TSNE(n_components=2, metric="cosine").fit_transform(all_embeddings)
4068

4169
for i, paper_info in enumerate(data):
4270
paper_info['tsne_embedding'] = out[i].tolist()
4371

4472
with open(args.outpath, 'w') as f:
4573
json.dump(data, f)
74+
75+
if __name__ == "__main__":
76+
args = parse_arguments()
77+
main(args)

topics.json

Lines changed: 1 addition & 1 deletion
Large diffs are not rendered by default.

tsne.json

Lines changed: 1 addition & 1 deletion
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)