-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcluster.py
More file actions
152 lines (118 loc) · 5.45 KB
/
cluster.py
File metadata and controls
152 lines (118 loc) · 5.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from sentence_transformers import SentenceTransformer, util
import pandas as pd
class Cluster(object):
""" Class to cluster similar documents
:param pre_trained_name: BERT pre-trained name
:param min_threshold: minimum threshold to consider sentence pairs with a cosine-similarity
larger than threshold as similar
:param min_community_size: minimum number of documents to appear in a cluster
:param show_progress_bar: boolean flag to show the progress during embeddings
:param convert_to_numpy: boolean flag to convert embeddings to numpy format
"""
def __init__(self, pre_trained_name='distilbert-base-nli-stsb-quora-ranking', min_threshold=0.75,
min_community_size=10, show_progress_bar=True, convert_to_numpy=True, largest=True):
self.largest = largest
self.min_threshold = min_threshold
self.convert_to_numpy = convert_to_numpy
self.show_progress_bar = show_progress_bar
self.min_community_size = min_community_size
self.pre_trained_name = pre_trained_name
def cluster(self, docs):
""" Generate clusters from list of documents
:param docs: list of documents
:return: Python dictionary
"""
corpus = self.generate_corpus(docs)
embeddings = self.generate_embeddings(corpus)
communities = self.community_detection(embeddings)
clusters = self.get_cluster_docs(communities, corpus)
return clusters
def generate_embeddings(self, docs):
""" Generate sentence embeddings from list of documents
:param docs: list of documents
:return: sentence-embeddings
"""
model = SentenceTransformer(self.pre_trained_name)
embeddings = model.encode(docs, show_progress_bar=self.show_progress_bar, convert_to_numpy=self.convert_to_numpy)
return embeddings
def generate_corpus(self, docs):
""" Generate corpus from list of unique documents
:param docs: list of documents
:return: unique documents
"""
docs = list(set(docs))
return docs
def get_cluster_docs(self, clusters, corpus):
""" Group similar documents into clusters from
input corpus documents
:param clusters: communities
:param corpus: corpus documents
:return: Python dictionary where
key : cluster no
value: list of documents
"""
cluster_docs = dict()
for cluster_no, cluster in enumerate(clusters):
cluster_no += 1
docs = []
for doc_id in cluster:
docs.append(corpus[doc_id])
cluster_docs[cluster_no] = docs
return cluster_docs
def get_distribution(self, clusters):
""" Get the distribution of cluster documents
:param cluster: Python dictionary where
key : cluster no
value: list of documents
:return: pandas DataFrame
"""
df = pd.DataFrame(list(clusters.items()), columns=['Cluster', 'Docs'])
df['Num_Docs'] = df['Docs'].apply(lambda x: len(x))
df.drop('Docs', axis=1, inplace=True)
return df
def community_detection(self, embeddings):
""" Extract groups of documents that are highly similar from embeddings
using Fast Community Detection
:param embeddings: sentence embeddings
:return: communities that are larger than min_community_size
"""
# Compute cosine similarity scores
cos_scores = util.pytorch_cos_sim(embeddings, embeddings)
# Minimum size for a community
top_k_values, _ = cos_scores.topk(k=self.min_community_size, largest=self.largest)
# Filter for rows >= min_threshold
extracted_communities = []
for i in range(len(top_k_values)):
if top_k_values[i][-1] >= self.min_threshold:
new_cluster = []
# Only check top k most similar entries
top_val_large, top_idx_large = cos_scores[i].topk(k=len(embeddings), largest=self.largest)
top_idx_large = top_idx_large.tolist()
top_val_large = top_val_large.tolist()
if top_val_large[-1] < self.min_threshold:
for idx, val in zip(top_idx_large, top_val_large):
if val < self.min_threshold:
break
new_cluster.append(idx)
else:
# Iterate over all entries (slow)
for idx, val in enumerate(cos_scores[i].tolist()):
if val >= self.min_threshold:
new_cluster.append(idx)
extracted_communities.append(new_cluster)
# Largest cluster first
extracted_communities = sorted(extracted_communities, key=lambda x: len(x), reverse=True)
# Step 2) Remove overlapping communities
unique_communities = []
extracted_ids = set()
for community in extracted_communities:
add_cluster = True
for idx in community:
if idx in extracted_ids:
add_cluster = False
break
if add_cluster:
unique_communities.append(community)
for idx in community:
extracted_ids.add(idx)
return unique_communities