Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions PyHierarchicalTsetlinMachineCUDA/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,3 +483,31 @@
}
}
"""

code_clauses = """
extern "C" __global__ void get_ta_states(const unsigned int* global_ta_state, unsigned int* unpacked_states) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;

for (unsigned long long i = index; i < CLAUSES * COMPONENTS * LITERALS_PER_LEAF; i += stride) {
unsigned long long clause = i / (COMPONENTS * LITERALS_PER_LEAF);
unsigned long long comp = (i / LITERALS_PER_LEAF) % COMPONENTS;
unsigned long long ta_idx = i % LITERALS_PER_LEAF;

int chunk = ta_idx / 32;
int bit_pos = ta_idx % 32;

unsigned int state = 0;
for (int b = 0; b < STATE_BITS; ++b) {
unsigned int plane = global_ta_state[
(clause * COMPONENTS * TA_CHUNKS_PER_LEAF * STATE_BITS) +
(comp * TA_CHUNKS_PER_LEAF * STATE_BITS) +
(chunk * STATE_BITS) + b
];
if (plane & (1U << bit_pos)) state |= (1U << b);
}

unpacked_states[(clause * COMPONENTS * LITERALS_PER_LEAF) + (comp * LITERALS_PER_LEAF) + ta_idx] = state;
}
}
"""
67 changes: 67 additions & 0 deletions PyHierarchicalTsetlinMachineCUDA/tm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from collections import deque
import numpy as np

import PyHierarchicalTsetlinMachineCUDA.kernels as kernels
Expand Down Expand Up @@ -157,6 +158,9 @@ def cuda_modules(self):
self.prepare_encode_hierarchy = mod_encode.get_function("prepare_encode_hierarchy")
self.encode_hierarchy = mod_encode.get_function("encode_hierarchy")

mod_clauses = SourceModule(parameters + kernels.code_clauses, no_extern_c=True)
self.kernel_get_ta_states = mod_clauses.get_function("get_ta_states")

def encode_X(self, X, encoded_X_hierarchy_gpu):
number_of_examples = X.shape[0]

Expand Down Expand Up @@ -408,6 +412,69 @@ def _score(self, X):

return class_sum

def get_ta_states(self) -> np.ndarray:
"""
Get state value for each TA.
Returns: Numpy array of shape (number_of_clauses, number_of_clause_components, number_of_literals_per_leaf)
"""
# Mem Allocation
ta_states_gpu = gpuarray.to_gpu(
np.zeros((self.number_of_clauses, self.hierarchy_size[1], self.number_of_literals_per_leaf), dtype=np.uint32)
)

# Calculate grid size based on the kernel
total = self.number_of_clauses * self.hierarchy_size[1] * self.number_of_literals_per_leaf
grid = (((total + self.block[0] - 1) // self.block[0]), 1, 1)
self.kernel_get_ta_states(self.ta_state_hierarchy_gpu, ta_states_gpu, block=self.block, grid=grid)

# Copy back to CPU
return ta_states_gpu.get()

def get_literals(self):
"""
Get included literals for each clause.
Returns: Numpy array of shape (number_of_clauses, number_of_clause_components, number_of_literals_per_leaf)
"""
return (self.get_ta_states() >= (1 << (self.number_of_state_bits - 1))).astype(np.uint8)

def map_ta_id_to_feature_id(self):
"""
Return an array of shape(number_of_clause_components, number_of_literals_per_leaf). That is the total number of TAs in a clause. Maps each TA id to a feature_id in the input. In each component, the first half of the TAs correspond to the positive features, and the second half correspond to the negated features.
"""
# BFS top-down traversal
q = deque()
q.append((self.depth - 1, 0, 0)) # (level, node_id, group_id)

comp_grps = -1 * np.ones(self.hierarchy_size[1], dtype=np.int32)
while q:
level, node_id, group_id = q.popleft()

if level == 0:
# This is the leaf component
comp_grps[node_id] = group_id
continue

n_children = self.hierarchy_structure[level][1]
is_alt = (self.hierarchy_structure[level][0] == OR_ALTERNATIVES)
for child_pos in range(n_children):
child_id = node_id * n_children + child_pos
if is_alt:
# All children share the same features
child_group_id = group_id
else:
# Features are partitioned among the children
child_group_id = group_id * n_children + child_pos

q.append((level - 1, child_id, child_group_id))

# map each TA in a component to a feature
half = self.number_of_literals_per_leaf // 2
lit_ids = np.arange(self.number_of_literals_per_leaf)
local_feats = lit_ids % half if self.append_negated else lit_ids
fmap = comp_grps[:, None] * (half if self.append_negated else self.number_of_literals_per_leaf) + local_feats[None, :]

return fmap

def print_hierarchy(self, print_ta_state=False):
for i in range(self.number_of_clauses):
print("\nCLAUSE #%d: " % (i), end='')
Expand Down
105 changes: 105 additions & 0 deletions PyHierarchicalTsetlinMachineCUDA/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from collections import deque
import numpy as np
import networkx as nx

from .tm import OR_ALTERNATIVES, CommonTsetlinMachine


def make_hierarchy_graph(G: nx.Graph, hier: list[tuple[int, int]]):
"""Create Graph from the hierarchy."""
# BFS
q = deque()
q.append((len(hier), 0)) # (level, idx)
while q:
level, idx = q.popleft()
G.add_node((level, idx), label=hier[level - 1][0], op=hier[level - 1][0])
if level == 1:
continue
branching = hier[level - 1][1]
for cid in range(idx * branching, (idx + 1) * branching):
q.append((level - 1, cid))
G.add_edge((level, idx), (level - 1, cid))

def clause_to_nx(
tm: CommonTsetlinMachine,
clause_idx: int,
feature_names: list[str] | None = None,
negation_prefix: str = '¬',
clause_literals: np.ndarray | None = None,
ta_to_fid_mapping = None,
):
"""
Create a networkx Graph for a single clause.
Args:
`tm`: The model.
`clause_idx`: Index of the clause to visualize.
`feature_names`: Optional list of feature names for labeling. If None, defaults to 'x0', 'x1', ...
`negation_prefix`: Prefix for negated features (default: '¬').
`clause_literals`: Optional pre-extracted literals for the clause, must have shape (n_components, literals_per_component). If None, literals will be extracted from the model.
`ta_to_fid_mapping`: Optional mapping from (component_id, literal_id) to feature_id. If None, it will be obtained from the model.
"""

assert tm.hierarchy_structure is not None, "Hierarchy structure not defined in the model. Are you sure tm belongs to CommonTsetlinMachine or its subclasses?"

if feature_names is None:
feature_names = [f'x{i}' for i in range(tm.number_of_features_hierarchy)]

literals = clause_literals if clause_literals is not None else tm.get_literals()[clause_idx]
ta_to_fid = ta_to_fid_mapping if ta_to_fid_mapping is not None else tm.map_ta_id_to_feature_id()
n_comp = tm.hierarchy_size[1]
lits_per_comp = tm.number_of_literals_per_leaf

G = nx.Graph()
make_hierarchy_graph(G, tm.hierarchy_structure)

feat_per_comp = lits_per_comp // 2 if tm.append_negated else lits_per_comp

# Add included literals to leaf components
for comp_id in range(n_comp):
for fid in range(feat_per_comp):
pos_lit = literals[comp_id, fid]
if pos_lit:
G.add_node(
(0, comp_id * lits_per_comp + fid),
label=feature_names[ta_to_fid[comp_id, fid]],
)
G.add_edge((1, comp_id), (0, comp_id * lits_per_comp + fid))

if tm.append_negated:
neg_lit = literals[comp_id, feat_per_comp + fid]
if neg_lit:
lab = f'{negation_prefix}{feature_names[ta_to_fid[comp_id, fid]]}'
G.add_node(
(0, comp_id * lits_per_comp + feat_per_comp + fid), label=lab
)
G.add_edge(
(1, comp_id), (0, comp_id * lits_per_comp + feat_per_comp + fid)
)

return G


def clause_bank_to_nx(
tm: CommonTsetlinMachine,
feature_names: list[str] | None = None,
negation_prefix: str = '¬',
):
"""
Create a networkx Graph for the entire clause bank.
Args:
`tm`: The model.
`feature_names`: Optional list of feature names for labeling. If None, defaults to 'x0', 'x1', ...
`negation_prefix`: Prefix for negated features (default: '¬').
"""
literals = tm.get_literals()
ta_to_fid = tm.map_ta_id_to_feature_id()
G = nx.Graph()
cb_root = "CB_ROOT"
G.add_node(cb_root, label=OR_ALTERNATIVES, op=OR_ALTERNATIVES)
for ci in range(tm.number_of_clauses):
clause_G = clause_to_nx(tm, ci, feature_names, negation_prefix, literals[ci], ta_to_fid)
mapping = { node: (ci, node) for node in clause_G.nodes }
G = nx.compose(G, nx.relabel_nodes(clause_G, mapping))
G.add_edge(cb_root, (ci, (tm.depth, 0)))

return G
139 changes: 139 additions & 0 deletions examples/visualize_clauses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
import networkx as nx
import numpy as np

import PyHierarchicalTsetlinMachineCUDA.tm as tm
from PyHierarchicalTsetlinMachineCUDA.utils import clause_bank_to_nx, clause_to_nx


def load_data():
train_data = np.loadtxt('./examples/NoisyParityTrainingData.txt').astype(np.uint32)
X_train = train_data[:, 0:-1]
Y_train = train_data[:, -1]

test_data = np.loadtxt('./examples/NoisyParityTestingData.txt').astype(np.uint32)
X_test = test_data[:, 0:-1]
Y_test = test_data[:, -1]

return X_train, Y_train, X_test, Y_test


def train(tm, X_train, Y_train, X_test, Y_test, epochs=1):
for epoch in range(epochs):
tm.fit(X_train, Y_train, epochs=1, incremental=True)
result = 100 * (tm.predict(X_test) == Y_test).mean()
print('Epoch %d: Accuracy: %.2f%%' % (epoch + 1, result))


def get_node_colors(G, op_colors, feat_colors):
node_colors = []
for node in G.nodes():
data = G.nodes[node]
op = data.get('op', None)
if op in op_colors:
node_colors.append(op_colors[op])
else:
# Literal node — extract feature index from label (x3 or ¬x3)
label = data.get('label', '')
feat_str = label.lstrip('¬~')
if feat_str.startswith('x'):
feat_idx = int(feat_str[1:])
node_colors.append(feat_colors[feat_idx])
else:
node_colors.append('lightgray')
return node_colors


def visualize_clause(model, clause_idx):
G = clause_to_nx(model, clause_idx)
labels = nx.get_node_attributes(G, 'label')
pos = nx.nx_agraph.pygraphviz_layout(G, prog='twopi')

# Coloring the nodes.
op_colors = {
tm.AND_GROUP: 'lightblue',
tm.OR_ALTERNATIVES: 'lightyellow',
}
cmap = mpl.colormaps["tab20"].resampled(model.number_of_features_hierarchy)
feat_colors = [mpl.colors.to_hex(cmap(i)) for i in range(model.number_of_features_hierarchy)]
node_colors = get_node_colors(G, op_colors, feat_colors)

fig, ax = plt.subplots(figsize=(16, 16), dpi=150, layout="compressed")
nx.draw(
G,
pos,
with_labels=True,
labels=labels,
node_size=100,
node_color=node_colors,
font_size=4,
ax=ax,
)
ax.axis('off')
fig.suptitle(f'Clause {clause_idx}')

# Add legend
handles = [Patch(color=color, label=f"x{i}") for i, color in enumerate(feat_colors)]
ax.legend(handles=handles, title="Features", loc='center right', bbox_to_anchor=(1.01, 0.5), fontsize=6)
return fig, ax


def visualize_clause_bank(model):
G = clause_bank_to_nx(model)
labels = nx.get_node_attributes(G, 'label')
pos = nx.nx_agraph.pygraphviz_layout(G, prog='twopi')

op_colors = {
tm.AND_GROUP: 'lightblue',
tm.OR_ALTERNATIVES: 'lightyellow',
}
cmap = mpl.colormaps["tab20"].resampled(model.number_of_features_hierarchy)
feat_colors = [mpl.colors.to_hex(cmap(i)) for i in range(model.number_of_features_hierarchy)]
node_colors = get_node_colors(G, op_colors, feat_colors)

fig, ax = plt.subplots(figsize=(16, 16), dpi=150, layout="compressed")
nx.draw(
G,
pos,
with_labels=True,
labels=labels,
node_size=100,
node_color=node_colors,
font_size=4,
ax=ax,
)
ax.axis('off')
fig.suptitle('Clause Bank')
return fig, ax


if __name__ == '__main__':
X_train, Y_train, X_test, Y_test = load_data()

# Training the model
model = tm.TsetlinMachine(
number_of_clauses=16,
T=100,
s=32.1,
number_of_state_bits=8,
boost_true_positive_feedback=0,
hierarchy_structure=(
(tm.AND_GROUP, 3),
(tm.OR_ALTERNATIVES, 3),
(tm.AND_GROUP, 2),
(tm.OR_ALTERNATIVES, 3),
(tm.AND_GROUP, 2),
),
seed=10,
)
train(model, X_train, Y_train, X_test, Y_test, epochs=1)

# Visualize a single clause
fig_clause, _ = visualize_clause(model, clause_idx=0)

# Visualize the entire clause bank
fig_all, _ = visualize_clause_bank(model)

plt.show()
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
'pycuda',
'scipy',
'scikit-learn',
'networkx',
],
extras_require={
'examples': ['tensorflow'],
'examples': ['tensorflow', 'matplotlib', 'pygraphviz'],
}
)