-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
3084 lines (2603 loc) · 93.3 KB
/
main.py
File metadata and controls
3084 lines (2603 loc) · 93.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import print_function
# Standard imports
#-----------------------------------------------------------------------------#
import os, sys, time, random
from shutil import rmtree
# enforce auto-flushing of stdout on every write
sys.stdout = os.fdopen(sys.stdout.fileno(),'w',0)
import numpy as np
import imageio
from keras.layers import Embedding
import cPickle
from gensim.models.doc2vec import TaggedDocument
# Local imports
#-----------------------------------------------------------------------------#
from WikiParse.main import download_wikidump, parse_wikidump, text_corpus, item_corpus
from WikiLearn.code.vectorize import doc2vec, make_seconds_pretty
from WikiLearn.code.classify import vector_classifier_keras
from pathfinder import get_queries, astar_path
import matplotlib.pyplot as plt
import matplotlib.patches as patches
# Main function
#-----------------------------------------------------------------------------#
classification_dict = None # used by get_classified_sequences, cross-call variable
stop_words_dict = None # used by get_classified_sequences, cross-call variable
most_common_dict = None #
class_dict = None
class_pretty = None
classes = None
seen_article_dict = None # articles pulled on the last iteration, prevent repeats
# Returns the first sequences_per_class good, mid, and poor article bodies, converted to word vectors
# single sequence of word vectors returned for each article, if a start_at value is specified, will
# traverse into the text.tsv file that far (line-wise) before starting to add documents to the data
def get_classified_sequences(
encoder, # word2vec/doc2vec model wrapper
seq_per_class, # sequences per class to return
min_words_per_seq, # trim sequences w/ less than this many words
max_words_per_seq, # maximum words to return per sequence
class_names, # names to use for classes
class_map, # dict mapping from real classes to class_name indices
start_at=0, # where to pick up reading source file from
print_avg_length=False, # if true, print average length/class
remove_stop_words=True, # if true, remove stop words sequences
trim_vocab_to=-1, # trim any articles with less than this value in word_list.tsv (col 3)
replace_removed=False, # if true, any words removed will be replaced w/ zero-vector
swap_with_word_idx=False, # if true, wordvectors are swapped out
max_class_mappings=-1, # if non -1, limit size of classification mapping
classifications="quality.tsv", # where to pull article classifications from
tailored_to_content=False, # diff. exec for content-classified documents
total_seq=None, # if seq_per_class is -1, this is used to specify the number of total seqs
multi_class_file=False # if true, class mapping file may contain multiple classes per article
):
global classification_dict # holds mappings from article_id to classification
global stop_words_dict # holds mappings for all stop words
global most_common_dict # holds mappings for highest frequency words
global class_dict # holds mappings from index in class_names to class name
global class_pretty # holds mappings from class names to formatted class names
global classes # holds mappings from class (found in file) to index in class_names list
global seen_article_dict # holds the last article read for each class on last iteration
y = [] # sequence classifications
x = [] # sequences
if classes==None:
classes = {}
'''
for key,val in class_map.items():
print(key,val)
classes[key] = np.array([class_map[key]],dtype=int)
'''
#for i in range(len(class_names)):
# classes[key]=
pass
longest_class_len = 0
counts = {}
for n in class_names:
if len(n)>longest_class_len:
longest_class_len = len(n)
counts[n] = 0
counts["unknown/not_in_model"]=0
if seen_article_dict==None:
seen_article_dict = {}
for i in range(len(class_names)):
seen_article_dict[class_names[i]]=0
if class_dict==None:
class_dict = {}
for i in range(len(class_names)):
class_dict[str(i)] = class_names[i]
if class_pretty==None:
class_pretty = {}
for i in range(len(class_names)):
class_name_pretty = class_names[i]
while len(class_name_pretty)<=longest_class_len:
class_name_pretty+=" "
class_pretty[class_names[i]]=class_name_pretty
class_lengths = {}
for i in range(len(class_names)):
class_lengths[i]=0
zero_vector = [0.00]*300 # put in place of non-modeled words
print_sentences = 0 # approx number of sentences to print during sentence encoding
text = "text.tsv" # where to get article contents from
stop_words_filename = "WikiLearn/data/stopwords/stopwords_long.txt" # source of stopwords
word_list_filename = "WikiLearn/data/models/dictionary/text/word_list.tsv" # source of word list
# create stop words dictionary if not yet loaded
if remove_stop_words and stop_words_dict==None:
stop_words_dict = {}
stop_words = open(stop_words_filename,"r").read().replace("\'","").split("\n")
for s in stop_words:
if s not in [""," "] and len(s)>0:
stop_words_dict[s] = True
# Create the classifications dictionary if not yet loaded
if classification_dict==None:
num_lines = len(open(classifications,"r").read().split("\n"))
qf = open(classifications,"r")
i=0
kept=0
num_overmapped=0
classification_dict = {}
for line in qf:
i+=1
if max_class_mappings!=-1 and i>max_class_mappings: break
sys.stdout.write("\rCreating class dict (%d/%d) | Kept:%d | Overmapped:%d"%(i,num_lines,kept,num_overmapped))
try:
if multi_class_file:
items=line.strip().split("\t")
article_id=items[0]
article_subclasses=items[1].split(" ")
mapped_class=None
overmapped=False
for a in article_subclasses:
try:
cur_mapped_class = class_map[a]
if mapped_class==None:
mapped_class = class_map[a]
else:
if cur_mapped_class!=mapped_class:
overmapped=True
break
except:
continue
if overmapped==False and mapped_class!=None:
classification_dict[article_id]=mapped_class
class_lengths[mapped_class]+=1
kept+=1
else:
num_overmapped+=1
else:
article_id,article_subclass = line.strip().split("\t")
mapped_class = class_map[article_subclass]
classification_dict[article_id] = mapped_class
class_lengths[mapped_class]+=1
kept+=1
except:
i+=-1
continue
sys.stdout.write("\n")
qf.close()
print_class_counts=False
if print_class_counts:
for key,val in class_lengths.items():
print("%s\t%d\n"%(key,val))
class_lengths = {}
for c in class_names:
class_lengths[c]=0
# if trimming vocabulary
if ((trim_vocab_to!=-1 or swap_with_word_idx) and most_common_dict==None):
most_common_dict = {}
word_list = open(word_list_filename,"r").read().split("\n")
w_idx=0
for w in word_list:
items = w.split("\t")
if len(items)==3 and int(items[2])>trim_vocab_to:
try:
in_model = encoder.model[items[1].lower()]
most_common_dict[items[1]]=w_idx
w_idx+=1
except:
continue
# number of total items we will have to load (sentences or docs)
if seq_per_class!=-1:
approx_total = len(class_names)*seq_per_class
else:
approx_total = total_seq
f = open(text,"r")
i=0
enc_start_time = time.time()
num_loaded=0
eof=True
earliest_class_full_at=-1 # i value when first class was filled
# iterate over each line (document) in text.tsv
for line in f:
i+=1
# skip
if i<start_at: continue
percent_done = "%0.1f%%" % (100.0*float(num_loaded+1)/float(approx_total))
perc_loaded = "%0.1f%%" % (100.0 *float(i)/13119700.0)
sys.stdout.write("\rEncoding: %s done (%d/%d) | %s total | %s "% (percent_done,num_loaded+1,approx_total,perc_loaded,make_seconds_pretty(time.time()-enc_start_time)))
sys.stdout.flush()
# load in the current line
try:
#article_id, article_contents = line.decode('utf-8', errors='replace').strip().split('\t')
article_id, article_contents = line.strip().split('\t')
except:
continue
# see if this article has a class mapping
try:
art_effclass_idx = classification_dict[article_id]
art_effclass = class_names[art_effclass_idx]
art_eff_y = art_effclass_idx
except:
# skip article if no listed quality
counts['unknown/not_in_model'] +=1
continue
# skip this article if we got farther than this on the last iteration for this class
if seen_article_dict[art_effclass]>i: continue
# if we have already loaded the maximum number of articles in this category
if seq_per_class!=-1 and counts[art_effclass]>=seq_per_class: continue
# split article into sentences
article_sentences = article_contents.split(". ")
full_doc = []
full_doc_str = []
doc_arr = np.zeros(shape=(max_words_per_seq,300)).astype(float)
# iterate over each sentence in the article
for a in article_sentences:
if len(full_doc)>=max_words_per_seq: break
cleaned_a = a.replace(","," ").replace("(","").replace(")","")
cleaned_a = cleaned_a.replace(" ","").replace(" "," ")
cleaned_a = cleaned_a.replace(" "," ").lower()
sentence_words = cleaned_a.split(" ")
if remove_stop_words:
s_idx=0
while True:
try:
stop_words_dict[sentence_words[s_idx]]
del sentence_words[s_idx]
except:
s_idx+=1
if s_idx>=len(sentence_words): break
for w in sentence_words:
if trim_vocab_to!=-1 or swap_with_word_idx:
try:
idx = most_common_dict[w.lower()]
except:
if replace_removed:
if swap_with_word_idx: full_doc.append(0)
else: full_doc.append(zero_vector)
continue
try:
word_vec = encoder.model[w.lower()]
if swap_with_word_idx: full_doc.append(most_common_dict[w.lower()])
else: full_doc.append(word_vec)
full_doc_str.append(w)
except:
if replace_removed:
if swap_with_word_idx: full_doc.append(0)
else: full_doc.append(zero_vector)
full_doc_str.append(w)
continue
# add the document
if len(full_doc)<min_words_per_seq: continue
# print out doc maybe
if random.randint(0,approx_total)<print_sentences:
sys.stdout.write("\rExample %s sentence (decoded): %s\n"%(class_pretty[art_effclass],''.join(e+" " for e in full_doc_str[:10])+"..."))
# populate numpy array with full_doc contents
for q in range(len(full_doc)):
if q==max_words_per_seq: break
doc_arr[q,:] = full_doc[q]
# add document/quality to total found so far
x.append(doc_arr)
#y.append(art_eff_y)
y.append(art_effclass_idx)
counts[art_effclass]+=1
class_lengths[art_effclass]+=len(full_doc)
num_loaded+=1
seen_article_dict[art_effclass]=i+1
if seq_per_class!=-1 and counts[art_effclass]>=seq_per_class:
if earliest_class_full_at==-1: earliest_class_full_at=i
# check if every class has been filled
if min([val for _,val in counts.items()])>=seq_per_class:
eof=False
break
if seq_per_class==-1:
if len(x)>=total_seq:
eof=False
break
sys.stdout.write("\n")
f.close()
if print_avg_length:
for key,value in class_lengths.iteritems():
print("%s avg length (words): %0.1f"%(key,float(value/seq_per_class)))
sys.stdout.write("\n")
y=np.array(y)
#y = np.ravel(np.array(y)) # flatten classifications
X = np.array(x) #
if seq_per_class==-1: earliest_class_full_at=i
if eof: earliest_class_full_at=-1 # denote eof
return X,y,earliest_class_full_at
# Resets all state-saving items used in get_classified_sequences
def reset_corpus():
global seen_article_dict
seen_article_dict=None # start classes back at start
def make_gif(parent_folder,frame_duration=0.3):
items = os.listdir(parent_folder)
png_filenames = []
for elem in items:
if elem.find(".png")!=-1 and elem.find("heatmap")!=-1:
png_filenames.append(elem)
sorted_png = []
while True:
lowest = 10000000
lowest_idx = -1
for p in png_filenames:
old_save_format=False
if old_save_format:
iter_val = int(p.split("-")[2].split(":")[1])
epoch_val = int(p.split("-")[3].split(":")[1].split(".")[0])
val = float(iter_val)+0.1*epoch_val
else:
iter_val = int(p.split("-")[3].split(":")[1].split(".")[0])
epoch_val = int(p.split("-")[2].split(":")[1])
val = float(epoch_val)+0.1*iter_val
if lowest_idx==-1 or val<lowest:
lowest = val
lowest_idx = png_filenames.index(p)
sorted_png.append(png_filenames[lowest_idx])
del png_filenames[lowest_idx]
if len(png_filenames)==0: break
png_filenames = sorted_png
with imageio.get_writer(parent_folder+"/prediction-heatmap.gif", mode='I',duration=frame_duration) as writer:
for filename in png_filenames:
image = imageio.imread(parent_folder+"/"+filename)
writer.append_data(image)
seen_article_dict = None
def get_classified_docs(encoder,mapping_file,class_names,class_map,doc_idx,per_class):
global seen_article_dict
class_cts = {}
for c in class_names:
class_cts[class_map[c]]=0
if seen_article_dict==None:
seen_article_dict={}
for c in class_names:
seen_article_dict[class_map[c]]=0
docvecs=[]
ys=[]
eof=True
t0=time.time()
f=open(mapping_file,"r")
i=0
dropped=0
kept=0
for line in f:
i+=1
perc_loaded = "%0.1f%%" % (100.0 *float(i)/5162289.0)
sys.stdout.write("\rEncoding: %d %s total | Kept:%d | Dropped:%d"%(i,perc_loaded,kept,dropped))
if i<=doc_idx: continue
items=line.strip().split("\t")
if len(items)==2:
try:
item_class = class_map[items[1]]
if seen_article_dict[item_class]>=i: continue
if class_cts[item_class]>=per_class: continue
article_id = items[0]
article_vec = encoder.model.docvecs[article_id]
docvecs.append(article_vec)
ys.append(item_class)
class_cts[item_class]+=1
seen_article_dict[item_class]=i
if min([val for _,val in class_cts.items()])>=per_class:
eof=False
break
kept+=1
except:
dropped+=1
f.close()
sys.stdout.write(" | %s\n"%(make_seconds_pretty(time.time()-t0)))
X = np.array(docvecs)
y = np.ravel(np.array(ys))
if eof: stopped_at=-1
else: stopped_at=i
return [X,y,stopped_at]
def classify_importance_docs(encoder,directory,gif=True,model_type="lstm"):
if not os.path.exists(directory): os.makedirs(directory)
print("\nClassifying importance on document level...\n")
cur_time = int(time.time())
directory = os.path.join(directory,str(cur_time))
# all output quality classes
class_names = ["top","high","mid","low"]
# tagged class : index of name in class_names (to treat it as)
class_map = {"top":0,"high":1,"mid":2,"low":3}
class_str = ""
for c in class_names:
class_str+=c
if class_names.index(c)!=len(class_names)-1: class_str+=" | "
sys.stdout.write("Classes:\t\t%s\n"% (class_str))
per_class = 15500
epochs=500
sys.stdout.write(" \t\tDocs/Iter: %d\n"%per_class)
sys.stdout.write(" \t\tEpochs: %d\n\n"%epochs)
classifier = vector_classifier_keras(class_names,directory)
one_fetch=True
if one_fetch:
i=1
X,y,_ = get_classified_docs(encoder,"importance.tsv",class_names,class_map,0,5000000)
for j in range(epochs):
loss = classifier.train_doc_iter(X,y,iteration=i,epoch=j,plot=True)
else:
for j in range(epochs):
reset_corpus()
doc_idx=0
i=0
while True:
i+=1
print("\nEpoch:%d | Iteration: %d | doc_idx: %d"%(j,i,doc_idx))
X,y,doc_idx = get_classified_docs(encoder,"importance.tsv",class_names,class_map,doc_idx,per_class)
loss = classifier.train_doc_iter(X,y,iteration=i,epoch=j,plot=True,save_all=False)
if doc_idx==-1: break
classifier.model.save(os.path.join(classifier.directory,"%s-classifier-last.h5"%(model_type)))
if gif: make_gif(classifier.pic_dir)
sys.stdout.write("\nDone\n")
def classify_importance(encoder,directory,gif=True,model_type="lstm"):
if not os.path.exists(directory): os.makedirs(directory)
print("\nClassifying importance by word-vector sequences...\n")
cur_time = int(time.time())
directory = os.path.join(directory,str(cur_time))
# all output quality classes
class_names = ["top","high","mid","low"]
# tagged class : index of name in class_names (to treat it as)
class_map = {"top":0,"high":1,"mid":2,"low":3}
class_str = ""
for c in class_names:
class_str+=c
if class_names.index(c)!=len(class_names)-1: class_str+=" | "
sys.stdout.write("Classes:\t\t%s\n"% (class_str))
#### SETTINGS
#dpipc = 960
dpipc=-1
dpi=20000
min_words = 2
max_words = 1000 # maximum number of words to maintain in each document
remove_stop_words = False # if True, removes stop words before calculating sentence lengths
max_class_mappings = 100000 # max items to load
limit_vocab_size = 1000 # if !=-1, trim vocab to 'limit_vocab_size' words
batch_size = None # if None, defaults to whats set in classify.py, requires vram
replace_removed = True # replace words not found in model with zero vector
swap_with_word_idx = False
epochs=50
####
classifications = "importance.tsv"
# if using CNN, this must be non -1
if model_type=="cnn":
if limit_vocab_size==-1:
print("WARNING: limit_vocab_size must be non -1 for CNN")
sys.exit(0)
remove_stop_words=True
swap_with_word_idx=False
sys.stdout.write("Model Type: \t%s\n"%model_type)
sys.stdout.write("Max Words/Doc: \t%d\n"%max_words)
sys.stdout.write("Min Words/Doc: \t%d\n"%min_words)
sys.stdout.write("Stopwords: \t%s\n"%("<leave>" if not remove_stop_words else "<remove>"))
sys.stdout.write("Limit Vocab: \t%s\n"%("<none>" if limit_vocab_size==-1 else str(limit_vocab_size)))
sys.stdout.write("Replace Non-Model: \t%s\n"%("True" if replace_removed else "False"))
sys.stdout.write("Swap w/ Index: \t%s\n"%("True" if swap_with_word_idx else "False"))
sys.stdout.write("Doc/Class/Iter: \t%d\n\n"%dpipc)
sys.stdout.write("Batch Size: %s\n"%("<default>" if batch_size is None else str(batch_size)))
sys.stdout.write("Epochs: %s\n\n"%("<default>" if epochs is None else str(epochs)))
classifier = vector_classifier_keras(class_names=class_names,directory=directory,model_type=model_type,vocab_size=limit_vocab_size)
last_loss=None
for j in range(epochs):
reset_corpus()
doc_idx=0
i=0
while True:
i+=1
print("\nEpoch:%d | Iteration: %d | doc_idx: %d"%(j,i,doc_idx))
X,y,doc_idx = get_classified_sequences( encoder, dpipc, min_words, max_words,
class_names=class_names,
class_map=class_map,
start_at=doc_idx,
remove_stop_words=remove_stop_words,
trim_vocab_to=limit_vocab_size,
replace_removed=replace_removed,
swap_with_word_idx=swap_with_word_idx,
classifications=classifications,
max_class_mappings=max_class_mappings,
dpi=dpi )
loss=classifier.train_seq_iter(X,y,i,j,plot=True)
if doc_idx==-1 or X.shape(0)<=dpi-1: break
# write out ordered gif of all items in picture directory (heatmaps)
if gif: make_gif(classifier.pic_dir)
sys.stdout.write("\nReached end of text.tsv")
def classify_us_state(encoder,directory,gif=True,model_type="lstm"):
if not os.path.exists(directory): os.makedirs(directory)
print("Classifying U.S. states by word-vector sequences...\n")
cur_time = int(time.time())
directory = os.path.join(directory,str(cur_time))
class_names_string=[]
class_names_id=[]
class_sizes=[]
class_map={}
cat_tree=False
if cat_tree:
limit=1000
f=open("category_children-string.tsv")
all_classes=False
if all_classes:
limit=100
# load in content category strings, ids, and counts
f=open("largest_categories-meta.txt","r")
lines = f.read().split("\n")
i=0
for l in lines:
i+=1
if i>limit: break
items=l.strip().split(" | ")
if len(items)==3 and items[0]!="String":
class_map[items[1]]=len(class_names_string)
class_names_string.append(items[0])
class_names_id.append(items[1])
class_sizes.append(int(items[2]))
f.close()
i=0
for c in class_names_string:
sys.stdout.write("%s\t\t%s\n"%("Classes:" if i==0 else " ",c))
i+=1
custom_classes=True
if custom_classes:
us_states=[]
s_f=open("us_states.txt","r")
for line in s_f:
line = line.strip().lower().replace(" ","_")
if len(line)>2:
us_states.append(line)
s_f.close()
class_names_string=us_states
class_tags=[]
for c in class_names_string:
class_tags.append([c])
class_map={} # dict from cat id to class_name_string index to use as class
contained_classes={}
for c in class_names_string:
contained_classes[c]=[]
class_sizes={} # number of articles in each class
for c in class_names_string:
class_sizes[c]=0
class_names_id=[]
t0=time.time()
#f=open("largest_categories-meta.txt","r")
f=open("sorted_categories.tsv","r")
lines=f.read().split("\n")
i=0
for l in lines:
i+=1
sys.stdout.write("\rMapping categories (%d/%d)"%(i,len(lines)))
items=l.strip().split("\t")
if len(items)==3 and items[0]!="String":
cur_cat_str = items[0].lower()[9:]
cur_cat_id = items[1]
cur_cat_ct = int(items[2])
if cur_cat_ct==0: continue # skip this cat if not articles in it
found_class=False
# mapping this cat to a class
c_index=0
for c_name,c_tags in zip(class_names_string,class_tags):
# iterate over each tag for a match
for t in c_tags:
tag_loc=cur_cat_str.find(t)
if tag_loc!=-1:
if tag_loc!=0: # if maybe just the end of another word (dont include)
if cur_cat_str[tag_loc-1]!="_":
continue
if tag_loc+len(t)!=len(cur_cat_str): # if not at the end of the category string
if cur_cat_str[tag_loc+len(t)] not in ["_","s"]: # if the beginning of another word
continue
if c_name=="virginia" and cur_cat_str.find("west_virginia")!=-1: continue
found_class=True
class_map[cur_cat_id]=c_index
contained_classes[c_name].append(cur_cat_str)
class_sizes[c_name]+=cur_cat_ct
if found_class: break
c_index+=1
f.close()
sys.stdout.write(" | %s\n"%(make_seconds_pretty(time.time()-t0)))
i=0
for c in class_names_string:
contained_str=""
q=0
for cont in contained_classes[c]:
q+=1
contained_str+=cont
if q!=len(contained_classes[c]): contained_str+=","
sys.stdout.write("%s\t\t%d - %s\n"%("Classes:" if i==0 else " ",class_sizes[c],c))
print_full_classes=False
if print_full_classes:
sys.stdout.write(" \t%s\n"%(contained_str))
i+=1
#### SETTINGS
dpipc=160
min_words = 2
max_words = 120 # maximum number of words to maintain in each document
remove_stop_words = False # if True, removes stop words before calculating sentence lengths
limit_vocab_size = 3000 # if !=-1, trim vocab to 'limit_vocab_size' words
batch_size = None # if None, defaults to whats set in classify.py, requires vram
replace_removed = True # replace words not found in model with zero vector
swap_with_word_idx = False
epochs=100
####
# if using CNN, this must be non -1
if model_type=="cnn":
if limit_vocab_size==-1:
print("WARNING: limit_vocab_size must be non -1 for CNN")
sys.exit(0)
remove_stop_words=True
swap_with_word_idx=False
sys.stdout.write("Model Type: \t%s\n"%model_type)
sys.stdout.write("Max Words/Doc: \t%d\n"%max_words)
sys.stdout.write("Min Words/Doc: \t%d\n"%min_words)
sys.stdout.write("Stopwords: \t%s\n"%("<leave>" if not remove_stop_words else "<remove>"))
sys.stdout.write("Limit Vocab: \t%s\n"%("<none>" if limit_vocab_size==-1 else str(limit_vocab_size)))
sys.stdout.write("Replace Non-Model: \t%s\n"%("True" if replace_removed else "False"))
sys.stdout.write("Swap w/ Index: \t%s\n"%("True" if swap_with_word_idx else "False"))
sys.stdout.write("Doc/Class/Iter: \t%d\n\n"%dpipc)
sys.stdout.write("Batch Size: %s\n"%("<default>" if batch_size is None else str(batch_size)))
sys.stdout.write("Epochs: %s\n\n"%("<default>" if epochs is None else str(epochs)))
classifier = vector_classifier_keras(class_names=class_names_string,directory=directory,model_type=model_type,vocab_size=limit_vocab_size)
last_loss=None
# iterate over the full corpus on each iteration
for j in range(epochs):
reset_corpus() # tell get_classified_sequences to reset all state variables
doc_idx = 0
i=0
while True:
i+=1
print("\nEpoch:%d | Iteration: %d | doc_idx: %d"%(j,i,doc_idx))
X,y,doc_idx = get_classified_sequences( encoder, dpipc, min_words, max_words,
class_names=class_names_string,
class_map=class_map,
start_at=doc_idx,
remove_stop_words=remove_stop_words,
trim_vocab_to=limit_vocab_size,
replace_removed=replace_removed,
swap_with_word_idx=swap_with_word_idx,
classifications="categories.tsv",
multi_class_file=True )
num_worse=0
plot=True
loss = classifier.train_seq_iter(X,y,i,j,plot=plot,save_all=False)
if doc_idx==-1: break # if at the end of the corpus
classifier.model.save(os.path.join(classifier.directory,"%s-classifier-epoch_%d.h5"%(model_type,j)))
if last_loss==None:
last_loss = loss
elif loss>last_loss:
print("\nLoss is increasing, ending training.\n")
break
# write out ordered gif of all items in picture directory (heatmaps)
if gif: make_gif(classifier.pic_dir)
sys.stdout.write("\nReached end of text.tsv")
def classify_content(encoder,directory,gif=True,model_type="lstm"):
if not os.path.exists(directory): os.makedirs(directory)
print("Classifying contents by word-vector sequences...\n")
cur_time = int(time.time())
directory = os.path.join(directory,str(cur_time))
class_names_string=[]
class_names_id=[]
class_sizes=[]
class_map={} # from id to index in class_names_id
cat_tree=False
if cat_tree:
limit=1000
f=open("category_children-string.tsv")
all_classes=False
if all_classes:
limit=100
# load in content category strings, ids, and counts
#f=open("largest_categories-meta.txt","r")
f=open("sorted_categories.tsv","r")
lines = f.read().split("\n")
i=0
for l in lines:
i+=1
if i>limit: break
#items=l.strip().split(" | ")
items=l.strip().split("\t")
if len(items)==3 and items[0]!="String":
class_map[items[1]]=len(class_names_string)
class_names_string.append(items[0])
class_names_id.append(items[1])
class_sizes.append(int(items[2]))
f.close()
i=0
for c in class_names_string:
sys.stdout.write("%s\t\t%s\n"%("Classes:" if i==0 else " ",c))
i+=1
# read from titles.tsv and search actual titles rather than categories here
custom_from_titles=True
if custom_from_titles:
load=True
if os.path.isfile("mapped-std_categories-ids.txt") and load:
print("Loading mapped articles...")
f=open("std_categories.txt","r")
class_names_string=[]
class_map={}
class_tags=[]
contained_classes={}
cur_class_tags=[]
last_line=None
comment=False
for line in f:
line=line.strip()
if line=="/*":
comment=True
continue
if line=="*/":
comment=False
continue
if comment:
continue
if last_line==None:
#class_names_string.append(line)
last_line=line
continue
if line=="=":
class_names_string.append(last_line)
contained_classes[last_line]=[]
continue
if line=="]":
class_tags.append(cur_class_tags)
cur_class_tags=[]
continue
last_line=last_line.lower().replace(" ","_").replace("&","&")
for line_tag in last_line.split("/"):
for line_tag2 in line_tag.split("&"):
line_tag2=line_tag2.strip("_")
#class_map[line_tag2]=len(class_names_string)-1
contained_classes[class_names_string[-1]].append(line_tag2)
cur_class_tags.append(line_tag2)
last_line=line
f.close()
i=0
for c in class_names_string:
class_map[c]=i
i+=1
else:
f_dest=open("mapped-std_categories.txt","w")
f_dest_ids=open("mapped-std_categories-ids.txt","w")
f=open("std_categories.txt","r")
class_names_string=[]
class_map={}
class_tags=[]
contained_classes={}
cur_class_tags=[]
last_line=None
comment=False
for line in f:
line=line.strip()
if line=="/*":
comment=True
continue
if line=="*/":
comment=False
continue
if comment:
continue
if last_line==None:
#class_names_string.append(line)
last_line=line
continue
if line=="=":
class_names_string.append(last_line)
contained_classes[last_line]=[]
continue
if line=="]":
class_tags.append(cur_class_tags)
cur_class_tags=[]
continue
last_line=last_line.lower().replace(" ","_")
for line_tag in last_line.split("/"):
for line_tag2 in line_tag.split("&"):
line_tag2=line_tag2.strip("_")
#class_map[line_tag2]=len(class_names_string)-1
contained_classes[class_names_string[-1]].append(line_tag2)
cur_class_tags.append(line_tag2)
last_line=line
f.close()
i=0
for c in class_names_string:
class_map[c]=i
i+=1
class_sizes={} # number of articles in each class
for c in class_names_string:
class_sizes[c]=0
class_names_id=[]
t0=time.time()
#f=open("largest_categories-meta.txt","r")
f=open("titles.tsv","r")
#f=open("sorted_categories.tsv","r")
#lines=f.read().split("\n")
lines=13000000
i=0
print_counts=True
num_mapped=0
for l in f:
i+=1
#if i>100000: break
if print_counts:
sys.stdout.write("\rMapping titles (%d/%d) | Mapped:%d | Counts: {"%(i,lines,num_mapped))
for c in class_names_string:
sys.stdout.write("%s:%d%s"%(c,class_sizes[c],", " if class_names_string.index(c)!=len(class_names_string)-1 else "}"))
else:
sys.stdout.write("\rMapping titles (%d/%d) | Mapped:%d"%(i,lines,num_mapped))
items=l.strip().split("\t")
if len(items)==2:
cur_title_str = items[1].lower() # title of current article
cur_title_id = items[0] # id of current article
found_class=False
# mapping this cat to a class
c_index=-1
# iterate over each class and its tags
for c_name,c_tags in zip(class_names_string,class_tags):
c_index+=1
# iterate over each tag for a match
for t in c_tags:
tag_loc=cur_title_str.find(t)
if tag_loc!=-1:
if tag_loc!=0: # if maybe just the end of another word (dont include)
if cur_title_str[tag_loc-1]!="_":
continue
if tag_loc+len(t)!=len(cur_title_str): # if not at the end of the category string
if cur_title_str[tag_loc+len(t)] not in ["_","s"]: # if the beginning of another word
continue
if found_class:
continue
found_class=True
#class_map[cur_title_id]=c_index
f_dest.write("%s\t%s\n"%(cur_title_str,c_name))
f_dest_ids.write("%s\t%s\n"%(cur_title_id,c_name))
contained_classes[c_name].append(cur_title_str)
class_sizes[c_name]+=1
if found_class:
num_mapped+=1