From 8c9c3e97f76396d05748bc94fec5ec2731293ba2 Mon Sep 17 00:00:00 2001 From: Mayur1009 Date: Tue, 7 Apr 2026 19:24:07 +0200 Subject: [PATCH] Add plotting functions --- PyHierarchicalTsetlinMachineCUDA/kernels.py | 28 ++++ PyHierarchicalTsetlinMachineCUDA/tm.py | 67 ++++++++++ PyHierarchicalTsetlinMachineCUDA/utils.py | 105 +++++++++++++++ examples/visualize_clauses.py | 139 ++++++++++++++++++++ setup.py | 3 +- 5 files changed, 341 insertions(+), 1 deletion(-) create mode 100644 PyHierarchicalTsetlinMachineCUDA/utils.py create mode 100644 examples/visualize_clauses.py diff --git a/PyHierarchicalTsetlinMachineCUDA/kernels.py b/PyHierarchicalTsetlinMachineCUDA/kernels.py index 22f11d0..25ca555 100644 --- a/PyHierarchicalTsetlinMachineCUDA/kernels.py +++ b/PyHierarchicalTsetlinMachineCUDA/kernels.py @@ -483,3 +483,31 @@ } } """ + +code_clauses = """ + extern "C" __global__ void get_ta_states(const unsigned int* global_ta_state, unsigned int* unpacked_states) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + + for (unsigned long long i = index; i < CLAUSES * COMPONENTS * LITERALS_PER_LEAF; i += stride) { + unsigned long long clause = i / (COMPONENTS * LITERALS_PER_LEAF); + unsigned long long comp = (i / LITERALS_PER_LEAF) % COMPONENTS; + unsigned long long ta_idx = i % LITERALS_PER_LEAF; + + int chunk = ta_idx / 32; + int bit_pos = ta_idx % 32; + + unsigned int state = 0; + for (int b = 0; b < STATE_BITS; ++b) { + unsigned int plane = global_ta_state[ + (clause * COMPONENTS * TA_CHUNKS_PER_LEAF * STATE_BITS) + + (comp * TA_CHUNKS_PER_LEAF * STATE_BITS) + + (chunk * STATE_BITS) + b + ]; + if (plane & (1U << bit_pos)) state |= (1U << b); + } + + unpacked_states[(clause * COMPONENTS * LITERALS_PER_LEAF) + (comp * LITERALS_PER_LEAF) + ta_idx] = state; + } + } +""" diff --git a/PyHierarchicalTsetlinMachineCUDA/tm.py b/PyHierarchicalTsetlinMachineCUDA/tm.py index 55d2629..8bceeb5 100644 --- a/PyHierarchicalTsetlinMachineCUDA/tm.py +++ b/PyHierarchicalTsetlinMachineCUDA/tm.py @@ -18,6 +18,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +from collections import deque import numpy as np import PyHierarchicalTsetlinMachineCUDA.kernels as kernels @@ -157,6 +158,9 @@ def cuda_modules(self): self.prepare_encode_hierarchy = mod_encode.get_function("prepare_encode_hierarchy") self.encode_hierarchy = mod_encode.get_function("encode_hierarchy") + mod_clauses = SourceModule(parameters + kernels.code_clauses, no_extern_c=True) + self.kernel_get_ta_states = mod_clauses.get_function("get_ta_states") + def encode_X(self, X, encoded_X_hierarchy_gpu): number_of_examples = X.shape[0] @@ -408,6 +412,69 @@ def _score(self, X): return class_sum + def get_ta_states(self) -> np.ndarray: + """ + Get state value for each TA. + Returns: Numpy array of shape (number_of_clauses, number_of_clause_components, number_of_literals_per_leaf) + """ + # Mem Allocation + ta_states_gpu = gpuarray.to_gpu( + np.zeros((self.number_of_clauses, self.hierarchy_size[1], self.number_of_literals_per_leaf), dtype=np.uint32) + ) + + # Calculate grid size based on the kernel + total = self.number_of_clauses * self.hierarchy_size[1] * self.number_of_literals_per_leaf + grid = (((total + self.block[0] - 1) // self.block[0]), 1, 1) + self.kernel_get_ta_states(self.ta_state_hierarchy_gpu, ta_states_gpu, block=self.block, grid=grid) + + # Copy back to CPU + return ta_states_gpu.get() + + def get_literals(self): + """ + Get included literals for each clause. + Returns: Numpy array of shape (number_of_clauses, number_of_clause_components, number_of_literals_per_leaf) + """ + return (self.get_ta_states() >= (1 << (self.number_of_state_bits - 1))).astype(np.uint8) + + def map_ta_id_to_feature_id(self): + """ + Return an array of shape(number_of_clause_components, number_of_literals_per_leaf). That is the total number of TAs in a clause. Maps each TA id to a feature_id in the input. In each component, the first half of the TAs correspond to the positive features, and the second half correspond to the negated features. + """ + # BFS top-down traversal + q = deque() + q.append((self.depth - 1, 0, 0)) # (level, node_id, group_id) + + comp_grps = -1 * np.ones(self.hierarchy_size[1], dtype=np.int32) + while q: + level, node_id, group_id = q.popleft() + + if level == 0: + # This is the leaf component + comp_grps[node_id] = group_id + continue + + n_children = self.hierarchy_structure[level][1] + is_alt = (self.hierarchy_structure[level][0] == OR_ALTERNATIVES) + for child_pos in range(n_children): + child_id = node_id * n_children + child_pos + if is_alt: + # All children share the same features + child_group_id = group_id + else: + # Features are partitioned among the children + child_group_id = group_id * n_children + child_pos + + q.append((level - 1, child_id, child_group_id)) + + # map each TA in a component to a feature + half = self.number_of_literals_per_leaf // 2 + lit_ids = np.arange(self.number_of_literals_per_leaf) + local_feats = lit_ids % half if self.append_negated else lit_ids + fmap = comp_grps[:, None] * (half if self.append_negated else self.number_of_literals_per_leaf) + local_feats[None, :] + + return fmap + def print_hierarchy(self, print_ta_state=False): for i in range(self.number_of_clauses): print("\nCLAUSE #%d: " % (i), end='') diff --git a/PyHierarchicalTsetlinMachineCUDA/utils.py b/PyHierarchicalTsetlinMachineCUDA/utils.py new file mode 100644 index 0000000..8d4fd2b --- /dev/null +++ b/PyHierarchicalTsetlinMachineCUDA/utils.py @@ -0,0 +1,105 @@ +from collections import deque +import numpy as np +import networkx as nx + +from .tm import OR_ALTERNATIVES, CommonTsetlinMachine + + +def make_hierarchy_graph(G: nx.Graph, hier: list[tuple[int, int]]): + """Create Graph from the hierarchy.""" + # BFS + q = deque() + q.append((len(hier), 0)) # (level, idx) + while q: + level, idx = q.popleft() + G.add_node((level, idx), label=hier[level - 1][0], op=hier[level - 1][0]) + if level == 1: + continue + branching = hier[level - 1][1] + for cid in range(idx * branching, (idx + 1) * branching): + q.append((level - 1, cid)) + G.add_edge((level, idx), (level - 1, cid)) + +def clause_to_nx( + tm: CommonTsetlinMachine, + clause_idx: int, + feature_names: list[str] | None = None, + negation_prefix: str = '¬', + clause_literals: np.ndarray | None = None, + ta_to_fid_mapping = None, +): + """ + Create a networkx Graph for a single clause. + Args: + `tm`: The model. + `clause_idx`: Index of the clause to visualize. + `feature_names`: Optional list of feature names for labeling. If None, defaults to 'x0', 'x1', ... + `negation_prefix`: Prefix for negated features (default: '¬'). + `clause_literals`: Optional pre-extracted literals for the clause, must have shape (n_components, literals_per_component). If None, literals will be extracted from the model. + `ta_to_fid_mapping`: Optional mapping from (component_id, literal_id) to feature_id. If None, it will be obtained from the model. + """ + + assert tm.hierarchy_structure is not None, "Hierarchy structure not defined in the model. Are you sure tm belongs to CommonTsetlinMachine or its subclasses?" + + if feature_names is None: + feature_names = [f'x{i}' for i in range(tm.number_of_features_hierarchy)] + + literals = clause_literals if clause_literals is not None else tm.get_literals()[clause_idx] + ta_to_fid = ta_to_fid_mapping if ta_to_fid_mapping is not None else tm.map_ta_id_to_feature_id() + n_comp = tm.hierarchy_size[1] + lits_per_comp = tm.number_of_literals_per_leaf + + G = nx.Graph() + make_hierarchy_graph(G, tm.hierarchy_structure) + + feat_per_comp = lits_per_comp // 2 if tm.append_negated else lits_per_comp + + # Add included literals to leaf components + for comp_id in range(n_comp): + for fid in range(feat_per_comp): + pos_lit = literals[comp_id, fid] + if pos_lit: + G.add_node( + (0, comp_id * lits_per_comp + fid), + label=feature_names[ta_to_fid[comp_id, fid]], + ) + G.add_edge((1, comp_id), (0, comp_id * lits_per_comp + fid)) + + if tm.append_negated: + neg_lit = literals[comp_id, feat_per_comp + fid] + if neg_lit: + lab = f'{negation_prefix}{feature_names[ta_to_fid[comp_id, fid]]}' + G.add_node( + (0, comp_id * lits_per_comp + feat_per_comp + fid), label=lab + ) + G.add_edge( + (1, comp_id), (0, comp_id * lits_per_comp + feat_per_comp + fid) + ) + + return G + + +def clause_bank_to_nx( + tm: CommonTsetlinMachine, + feature_names: list[str] | None = None, + negation_prefix: str = '¬', +): + """ + Create a networkx Graph for the entire clause bank. + Args: + `tm`: The model. + `feature_names`: Optional list of feature names for labeling. If None, defaults to 'x0', 'x1', ... + `negation_prefix`: Prefix for negated features (default: '¬'). + """ + literals = tm.get_literals() + ta_to_fid = tm.map_ta_id_to_feature_id() + G = nx.Graph() + cb_root = "CB_ROOT" + G.add_node(cb_root, label=OR_ALTERNATIVES, op=OR_ALTERNATIVES) + for ci in range(tm.number_of_clauses): + clause_G = clause_to_nx(tm, ci, feature_names, negation_prefix, literals[ci], ta_to_fid) + mapping = { node: (ci, node) for node in clause_G.nodes } + G = nx.compose(G, nx.relabel_nodes(clause_G, mapping)) + G.add_edge(cb_root, (ci, (tm.depth, 0))) + + return G diff --git a/examples/visualize_clauses.py b/examples/visualize_clauses.py new file mode 100644 index 0000000..927bd88 --- /dev/null +++ b/examples/visualize_clauses.py @@ -0,0 +1,139 @@ +import matplotlib as mpl +import matplotlib.pyplot as plt +from matplotlib.patches import Patch +import networkx as nx +import numpy as np + +import PyHierarchicalTsetlinMachineCUDA.tm as tm +from PyHierarchicalTsetlinMachineCUDA.utils import clause_bank_to_nx, clause_to_nx + + +def load_data(): + train_data = np.loadtxt('./examples/NoisyParityTrainingData.txt').astype(np.uint32) + X_train = train_data[:, 0:-1] + Y_train = train_data[:, -1] + + test_data = np.loadtxt('./examples/NoisyParityTestingData.txt').astype(np.uint32) + X_test = test_data[:, 0:-1] + Y_test = test_data[:, -1] + + return X_train, Y_train, X_test, Y_test + + +def train(tm, X_train, Y_train, X_test, Y_test, epochs=1): + for epoch in range(epochs): + tm.fit(X_train, Y_train, epochs=1, incremental=True) + result = 100 * (tm.predict(X_test) == Y_test).mean() + print('Epoch %d: Accuracy: %.2f%%' % (epoch + 1, result)) + + +def get_node_colors(G, op_colors, feat_colors): + node_colors = [] + for node in G.nodes(): + data = G.nodes[node] + op = data.get('op', None) + if op in op_colors: + node_colors.append(op_colors[op]) + else: + # Literal node — extract feature index from label (x3 or ¬x3) + label = data.get('label', '') + feat_str = label.lstrip('¬~') + if feat_str.startswith('x'): + feat_idx = int(feat_str[1:]) + node_colors.append(feat_colors[feat_idx]) + else: + node_colors.append('lightgray') + return node_colors + + +def visualize_clause(model, clause_idx): + G = clause_to_nx(model, clause_idx) + labels = nx.get_node_attributes(G, 'label') + pos = nx.nx_agraph.pygraphviz_layout(G, prog='twopi') + + # Coloring the nodes. + op_colors = { + tm.AND_GROUP: 'lightblue', + tm.OR_ALTERNATIVES: 'lightyellow', + } + cmap = mpl.colormaps["tab20"].resampled(model.number_of_features_hierarchy) + feat_colors = [mpl.colors.to_hex(cmap(i)) for i in range(model.number_of_features_hierarchy)] + node_colors = get_node_colors(G, op_colors, feat_colors) + + fig, ax = plt.subplots(figsize=(16, 16), dpi=150, layout="compressed") + nx.draw( + G, + pos, + with_labels=True, + labels=labels, + node_size=100, + node_color=node_colors, + font_size=4, + ax=ax, + ) + ax.axis('off') + fig.suptitle(f'Clause {clause_idx}') + + # Add legend + handles = [Patch(color=color, label=f"x{i}") for i, color in enumerate(feat_colors)] + ax.legend(handles=handles, title="Features", loc='center right', bbox_to_anchor=(1.01, 0.5), fontsize=6) + return fig, ax + + +def visualize_clause_bank(model): + G = clause_bank_to_nx(model) + labels = nx.get_node_attributes(G, 'label') + pos = nx.nx_agraph.pygraphviz_layout(G, prog='twopi') + + op_colors = { + tm.AND_GROUP: 'lightblue', + tm.OR_ALTERNATIVES: 'lightyellow', + } + cmap = mpl.colormaps["tab20"].resampled(model.number_of_features_hierarchy) + feat_colors = [mpl.colors.to_hex(cmap(i)) for i in range(model.number_of_features_hierarchy)] + node_colors = get_node_colors(G, op_colors, feat_colors) + + fig, ax = plt.subplots(figsize=(16, 16), dpi=150, layout="compressed") + nx.draw( + G, + pos, + with_labels=True, + labels=labels, + node_size=100, + node_color=node_colors, + font_size=4, + ax=ax, + ) + ax.axis('off') + fig.suptitle('Clause Bank') + return fig, ax + + +if __name__ == '__main__': + X_train, Y_train, X_test, Y_test = load_data() + + # Training the model + model = tm.TsetlinMachine( + number_of_clauses=16, + T=100, + s=32.1, + number_of_state_bits=8, + boost_true_positive_feedback=0, + hierarchy_structure=( + (tm.AND_GROUP, 3), + (tm.OR_ALTERNATIVES, 3), + (tm.AND_GROUP, 2), + (tm.OR_ALTERNATIVES, 3), + (tm.AND_GROUP, 2), + ), + seed=10, + ) + train(model, X_train, Y_train, X_test, Y_test, epochs=1) + + # Visualize a single clause + fig_clause, _ = visualize_clause(model, clause_idx=0) + + # Visualize the entire clause bank + fig_all, _ = visualize_clause_bank(model) + + plt.show() diff --git a/setup.py b/setup.py index ad3c692..2c29426 100644 --- a/setup.py +++ b/setup.py @@ -16,8 +16,9 @@ 'pycuda', 'scipy', 'scikit-learn', + 'networkx', ], extras_require={ - 'examples': ['tensorflow'], + 'examples': ['tensorflow', 'matplotlib', 'pygraphviz'], } )