11import argparse
22import json
3+ from timeit import default_timer as timer
4+ from datetime import date
35
46import numpy as np
57import torch
8+ import torch .nn .functional as F
69import sklearn .manifold
710import transformers
811
@@ -13,33 +16,62 @@ def parse_arguments():
1316 parser .add_argument ("json" , default = False , help = "the path the json containing all papers." )
1417 parser .add_argument ("outpath" , default = False , help = "the target path of the visualizations papers." )
1518 parser .add_argument ("--seed" , default = 0 , help = "The seed for TSNE." , type = int )
19+ parser .add_argument ("--model" , default = 'sentence-transformers/all-MiniLM-L6-v2' , help = "The name of the HF model" )
20+ parser .add_argument ("--save_emb" , action = 'store_true' , help = "Save embeddings in CSV for Tensorboard Projector" )
21+
1622 return parser .parse_args ()
1723
24+ def mean_pooling (token_embeddings , attention_mask ):
25+ """ Mean Pooling, takes attention mask into account for correct averaging"""
26+ input_mask_expanded = attention_mask .unsqueeze (- 1 ).expand (token_embeddings .size ()).float ()
27+ return torch .sum (token_embeddings * input_mask_expanded , 1 ) / torch .clamp (input_mask_expanded .sum (1 ), min = 1e-9 )
1828
19- if __name__ == "__main__" :
20- args = parse_arguments ()
21- tokenizer = transformers .AutoTokenizer .from_pretrained ("deepset/sentence_bert" )
22- model = transformers .AutoModel .from_pretrained ("deepset/sentence_bert" )
29+ def main (args ):
30+ tokenizer = transformers .AutoTokenizer .from_pretrained (args .model )
31+ model = transformers .AutoModel .from_pretrained (args .model )
2332 model .eval ()
2433
2534 with open (args .json ) as f :
2635 data = json .load (f )
2736
2837 print (f"Num papers: { len (data )} " )
2938
30- all_embeddings = []
39+ corpus = []
3140 for paper_info in data :
41+ corpus .append (tokenizer .sep_token .join ([paper_info ['title' ], paper_info ['abstract' ]]))
42+
43+ batch_size = 4
44+ all_embeddings = []
45+ start = timer ()
46+ for i in range (0 , len (corpus ), batch_size ):
47+ encoded_batch = tokenizer (corpus [i :min (i + batch_size , len (corpus ))], padding = True , truncation = True , return_tensors = 'pt' )
3248 with torch .no_grad ():
33- token_ids = torch .tensor ([tokenizer .encode (paper_info ["abstract" ])][:512 ])
34- hidden_states , _ = model (token_ids )[- 2 :]
35- all_embeddings .append (hidden_states .mean (0 ).mean (0 ).numpy ())
49+ hidden_state = model (** encoded_batch ).last_hidden_state
50+ all_embeddings .append (mean_pooling (hidden_state , encoded_batch ['attention_mask' ]))
51+
52+ all_embeddings = torch .cat (all_embeddings , dim = 0 )
53+ all_embeddings = F .normalize (all_embeddings , p = 2 , dim = 1 )
54+ print (f"elapsed { timer ()- start :.1f} s" )
55+
56+ if args .save_emb :
57+ filename = f"{ args .model .replace ('/' , '_' )} -{ date .today ().strftime ('%d.%m.%y' )} "
58+ np .savetxt (f"{ filename } -emb.tsv" , all_embeddings , delimiter = "\t " )
59+ import csv
60+ with open (f"{ filename } -meta.tsv" , 'w' , newline = '' ) as csvfile :
61+ w = csv .writer (csvfile , delimiter = '\t ' , quoting = csv .QUOTE_MINIMAL )
62+ w .writerow (["year" , "key" , "title" ])
63+ for paper in data :
64+ w .writerow ([paper ["year" ], paper ["key" ], paper ["title" ]])
3665
3766 np .random .seed (args .seed )
38- all_embeddings = np .array (all_embeddings )
3967 out = sklearn .manifold .TSNE (n_components = 2 , metric = "cosine" ).fit_transform (all_embeddings )
4068
4169 for i , paper_info in enumerate (data ):
4270 paper_info ['tsne_embedding' ] = out [i ].tolist ()
4371
4472 with open (args .outpath , 'w' ) as f :
4573 json .dump (data , f )
74+
75+ if __name__ == "__main__" :
76+ args = parse_arguments ()
77+ main (args )
0 commit comments