Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
317 changes: 230 additions & 87 deletions sqlite/graph_net_sample_groups_insert.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,57 @@
import argparse
import sqlite3
import uuid as uuid_module
from collections import defaultdict, namedtuple
from datetime import datetime
from collections import namedtuple
from collections import defaultdict

from orm_models import (
get_session,
GraphNetSampleGroup,
from orm_models import get_session, GraphNetSampleGroup


# ── Types ──

BucketGroup = namedtuple(
"BucketGroup",
["head_uid", "op_seq", "shapes", "sample_type", "all_uids_csv"],
)

Candidate = namedtuple(
"Candidate",
["uid", "sample_type", "op_seq", "shapes", "dtypes"],
)


GraphNetSampleUid = str
GraphNetSampleType = str
BucketId = str
# ── Helpers ──


def _new_group_id():
return str(uuid_module.uuid4())


def _merge_stats(dst, src):
for key, val in src.items():
dst[key]["records"] += val["records"]
dst[key]["groups"].update(val["groups"])


def _print_stats(stats):
rule_order = ["rule1", "rule2", "rule4", "rule3"]
sample_types = sorted({st for st, _ in stats})
total_records = 0
total_groups = 0
for sample_type in sample_types:
print(f"\n [{sample_type}]")
for rule in rule_order:
key = (sample_type, rule)
if key in stats:
n_records = stats[key]["records"]
n_groups = len(stats[key]["groups"])
print(f" {rule}: {n_records} records, {n_groups} groups")
total_records += n_records
total_groups += n_groups
print(f"\n Total: {total_records} records, {total_groups} groups.")


# ── Database Queries ──


class DB:
Expand All @@ -23,117 +61,222 @@ def __init__(self, path):
def connect(self):
self.conn = sqlite3.connect(self.path)
self.conn.row_factory = sqlite3.Row
self.cur = self.conn.cursor()
self.cursor = self.conn.cursor()

def query(self, sql, params=None):
self.cur.execute(sql, params or ())
return self.cur.fetchall()

def exec(self, sql, params=None):
self.cur.execute(sql, params or ())
self.conn.commit()
self.cursor.execute(sql, params or ())
return self.cursor.fetchall()

def close(self):
self.conn.close()


SampleBucketInfo = namedtuple(
"SampleBucketInfo",
[
"sample_uid",
"op_seq_bucket_id",
"input_shapes_bucket_id",
"input_dtypes_bucket_id",
"sample_type",
"sample_uids",
],
)
def query_bucket_groups(db: DB) -> list[BucketGroup]:
sql = """
SELECT
MIN(sub.sample_uid) AS head_uid,
sub.op_seq_bucket_id,
sub.input_shapes_bucket_id,
sub.sample_type,
group_concat(sub.sample_uid, ',') AS all_uids
FROM (
SELECT
s.uuid AS sample_uid,
s.sample_type,
b.op_seq_bucket_id,
b.input_shapes_bucket_id
FROM graph_sample s
JOIN graph_net_sample_buckets b ON s.uuid = b.sample_uid
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

query_bucket_groups doesn’t filter out deleted samples or full_graph samples (unlike query_v2_candidates). If graph_net_sample_buckets contains rows for deleted/full_graph samples (e.g., from older runs), this script will generate groups for data that is supposed to be excluded. Add WHERE s.deleted = 0 AND s.sample_type != 'full_graph' (or equivalent) to keep v1/v2 selection consistent.

Suggested change
JOIN graph_net_sample_buckets b ON s.uuid = b.sample_uid
JOIN graph_net_sample_buckets b ON s.uuid = b.sample_uid
WHERE s.deleted = 0
AND s.sample_type != 'full_graph'

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

WHERE s.deleted = 0 AND s.sample_type != 'full_graph'
ORDER BY s.create_at ASC, s.uuid ASC
) sub
GROUP BY sub.sample_type, sub.op_seq_bucket_id, sub.input_shapes_bucket_id;
Comment on lines +76 to +93
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

query_bucket_groups selects sub.sample_uid as the bucket "head" without aggregating it or using a deterministic window function. In SQLite, selecting a non-GROUP BY column in an aggregate query can return an arbitrary row, so head_uid (and thus Rule 2 heads) may be nondeterministic. Consider selecting the head via MIN(...)/MAX(...) on a stable key, or using a window function to pick the first row by (create_at, uuid) and aggregating the rest separately.

Suggested change
SELECT
sub.sample_uid,
sub.op_seq_bucket_id,
sub.input_shapes_bucket_id,
sub.sample_type,
group_concat(sub.sample_uid, ',') AS all_uids
FROM (
SELECT
s.uuid AS sample_uid,
s.sample_type,
b.op_seq_bucket_id,
b.input_shapes_bucket_id
FROM graph_sample s
JOIN graph_net_sample_buckets b ON s.uuid = b.sample_uid
ORDER BY s.create_at ASC, s.uuid ASC
) sub
GROUP BY sub.sample_type, sub.op_seq_bucket_id, sub.input_shapes_bucket_id;
WITH buckets AS (
SELECT
s.uuid AS sample_uid,
s.sample_type,
b.op_seq_bucket_id,
b.input_shapes_bucket_id,
FIRST_VALUE(s.uuid) OVER (
PARTITION BY s.sample_type, b.op_seq_bucket_id, b.input_shapes_bucket_id
ORDER BY s.create_at ASC, s.uuid ASC
) AS head_uid
FROM graph_sample s
JOIN graph_net_sample_buckets b ON s.uuid = b.sample_uid
)
SELECT
MIN(head_uid) AS head_uid,
op_seq_bucket_id,
input_shapes_bucket_id,
sample_type,
group_concat(sample_uid, ',') AS all_uids
FROM buckets
GROUP BY sample_type, op_seq_bucket_id, input_shapes_bucket_id;

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"""
return [BucketGroup(*row) for row in db.query(sql)]


def get_ai4c_group_members(sample_bucket_infos: list[SampleBucketInfo]):
for bucket_info in sample_bucket_infos:
head_sample_uid = bucket_info.sample_uid
sample_uids = bucket_info.sample_uids.split(",")
selected_other_sample_uids = [
other_sample_uid
for other_sample_uid in sample_uids[::5]
if other_sample_uid != head_sample_uid
]
for sample_uid in selected_other_sample_uids:
new_uuid = str(uuid_module.uuid4())
yield sample_uid, new_uuid
def query_v2_candidates(db: DB) -> list[Candidate]:
sql = """
Comment thread
Honglei-Qiu marked this conversation as resolved.
SELECT
s.uuid,
s.sample_type,
b.op_seq_bucket_id,
b.input_shapes_bucket_id,
b.input_dtypes_bucket_id
FROM graph_sample s
JOIN graph_net_sample_buckets b ON s.uuid = b.sample_uid
WHERE s.deleted = 0
AND s.sample_type != 'full_graph'
AND s.uuid NOT IN (
SELECT g.sample_uid
FROM graph_net_sample_groups g
WHERE g.group_policy = 'bucket_policy_v1'
Comment thread
Honglei-Qiu marked this conversation as resolved.
AND g.deleted = 0
)
ORDER BY s.sample_type, b.op_seq_bucket_id, b.input_shapes_bucket_id,
b.input_dtypes_bucket_id, s.uuid;
"""
return [Candidate(*row) for row in db.query(sql)]

grouped = defaultdict(list)
for bucket_info in sample_bucket_infos:
key = (bucket_info.op_seq_bucket_id, bucket_info.input_dtypes_bucket_id)
grouped[key].append(bucket_info.sample_uid)

grouped = dict(grouped)
for key, sample_uids in grouped.items():
new_uuid = str(uuid_module.uuid4())
for sample_uid in sample_uids:
yield sample_uid, new_uuid
# ═══════════════════════════════════════════════════════════════════
# V1: Rule 1 (bucket-internal stride sampling) + Rule 2 (cross-shape)
# ═══════════════════════════════════════════════════════════════════


def main():
parser = argparse.ArgumentParser(
description="Generate graph_net_sample_groups from graph_net_sample_buckets"
)
parser.add_argument(
"--db_path",
type=str,
required=True,
help="Path to the SQLite database file",
)
def generate_v1_groups(bucket_groups: list[BucketGroup]):
"""Yields (sample_type, uid, group_id, rule_name).

args = parser.parse_args()
db = DB(args.db_path)
db.connect()
Rule 1: stride-16 sampling within each bucket, one group per sample.
Rule 2: aggregate all bucket heads sharing the same (sample_type, op_seq).
"""
# Rule 1
for bucket in bucket_groups:
members = bucket.all_uids_csv.split(",")
for uid in members[::16]:
if uid != bucket.head_uid:
yield bucket.sample_type, uid, _new_group_id(), "rule1"

query_str = """
SELECT b.sample_uid, b.op_seq_bucket_id as op_seq, b.input_shapes_bucket_id, b.input_dtypes_bucket_id, b.sample_type, group_concat(b.sample_uid, ',') as sample_uids
FROM (
SELECT
s.uuid AS sample_uid,
s.sample_type AS sample_type,
b.op_seq_bucket_id AS op_seq_bucket_id,
b.input_shapes_bucket_id AS input_shapes_bucket_id,
b.input_dtypes_bucket_id AS input_dtypes_bucket_id
FROM graph_sample s
JOIN graph_net_sample_buckets b
ON s.uuid = b.sample_uid
order by s.create_at asc, s.uuid asc
) b
GROUP BY b.sample_type, b.op_seq_bucket_id, b.input_shapes_bucket_id,b.input_dtypes_bucket_id;
# Rule 2
heads_by_type_op = defaultdict(list)
for bucket in bucket_groups:
heads_by_type_op[(bucket.sample_type, bucket.op_seq)].append(bucket.head_uid)
for (sample_type, _op), heads in heads_by_type_op.items():
gid = _new_group_id()
for uid in heads:
yield sample_type, uid, gid, "rule2"


# ═══════════════════════════════════════════════════════════════════
# V2: Rule 4 (dtype coverage) + Rule 3 (sparse sampling on remainder)
# ═══════════════════════════════════════════════════════════════════


def generate_v2_groups(candidates: list[Candidate], num_dtypes: int):
"""Yields (sample_type, uid, group_id, rule_name).

Rule 4 (first): per (sample_type, op_seq, shape), pick up to
num_dtypes samples with distinct dtypes.
Rule 3 (second): window-based sparse sampling on the remainder,
window_size = num_dtypes * 5, pick first num_dtypes.
"""
by_type_op = defaultdict(list)
for c in candidates:
by_type_op[(c.sample_type, c.op_seq)].append(c)

query_results = db.query(query_str)
print("Output:", len(query_results))
covered_uids = set()

query_results = [SampleBucketInfo(*row) for row in query_results]
# Rule 4: dtype coverage
for (sample_type, _op), group in by_type_op.items():
by_shape = defaultdict(list)
for c in group:
by_shape[c.shapes].append(c)

session = get_session(args.db_path)
picked = []
for _shape, shape_group in by_shape.items():
seen_dtypes = set()
for c in shape_group:
if c.dtypes not in seen_dtypes and len(seen_dtypes) < num_dtypes:
seen_dtypes.add(c.dtypes)
picked.append(c.uid)
covered_uids.add(c.uid)

try:
for sample_uid, group_uid in get_ai4c_group_members(query_results):
new_group = GraphNetSampleGroup(
sample_uid=sample_uid,
group_uid=group_uid,
if picked:
gid = _new_group_id()
for uid in picked:
yield sample_type, uid, gid, "rule4"

Comment thread
Honglei-Qiu marked this conversation as resolved.
# Rule 3: sparse sampling on remainder
window_size = num_dtypes * 5
for (sample_type, _op), group in by_type_op.items():
remaining = sorted(
(c for c in group if c.uid not in covered_uids),
key=lambda c: c.uid,
)
picked = [
c.uid for i, c in enumerate(remaining) if (i % window_size) < num_dtypes
]
if picked:
gid = _new_group_id()
for uid in picked:
yield sample_type, uid, gid, "rule3"


# ═══════════════════════════════════════════════════════════════════
# Insert
# ═══════════════════════════════════════════════════════════════════


def _insert_groups(session, rows, policy):
"""Consume a generator of (sample_type, uid, group_id, rule_name),
write to DB, and return per-(sample_type, rule) stats."""
stats = defaultdict(lambda: {"records": 0, "groups": set()})
for sample_type, uid, group_id, rule_name in rows:
session.add(
GraphNetSampleGroup(
sample_uid=uid,
group_uid=group_id,
group_type="ai4c",
group_policy="bucket_policy_v1",
group_policy=policy,
policy_version="1.0",
create_at=datetime.now(),
deleted=False,
)
)
stats[(sample_type, rule_name)]["records"] += 1
stats[(sample_type, rule_name)]["groups"].add(group_id)
session.commit()
return stats


session.add(new_group)
session.commit()
# ═══════════════════════════════════════════════════════════════════
# Main
# ═══════════════════════════════════════════════════════════════════


def main():
parser = argparse.ArgumentParser(
description="Generate graph_net_sample_groups (v1 + v2)"
)
parser.add_argument("--db_path", type=str, required=True)
parser.add_argument("--num_dtypes", type=int, default=3)
args = parser.parse_args()

Comment thread
Honglei-Qiu marked this conversation as resolved.
db = DB(args.db_path)
db.connect()
session = get_session(args.db_path)

all_stats = defaultdict(lambda: {"records": 0, "groups": set()})

try:
# V1
buckets = query_bucket_groups(db)
print(f"Bucket groups: {len(buckets)}")
v1 = _insert_groups(session, generate_v1_groups(buckets), "bucket_policy_v1")
_merge_stats(all_stats, v1)

# V2
candidates = query_v2_candidates(db)
print(f"V2 candidates: {len(candidates)}")
if candidates:
v2 = _insert_groups(
session,
generate_v2_groups(candidates, args.num_dtypes),
"bucket_policy_v2",
)
_merge_stats(all_stats, v2)
else:
print("No V2 candidates found. Skipping.")
except Exception:
session.rollback()
raise
finally:
session.close()
db.close()

print("=" * 60)
_print_stats(all_stats)
print("\nDone!")


if __name__ == "__main__":
Expand Down
Loading