From 02f7a8c8ae32a1d176cc0f1973bffb1d8ac361f1 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 19 Nov 2025 16:43:46 +0000 Subject: [PATCH 01/20] Initial postgres loading done. --- docs/postgres.md | 52 ++++++++++++++++++++++++ pgext/.gitignore | 2 + pgext/build.sh | 7 ++++ pgext/funcs.c | 103 +++++++++++++++++++++++++++++++++++++++++++++++ pgext/load.sql | 20 +++++++++ 5 files changed, 184 insertions(+) create mode 100644 docs/postgres.md create mode 100644 pgext/.gitignore create mode 100755 pgext/build.sh create mode 100644 pgext/funcs.c create mode 100644 pgext/load.sql diff --git a/docs/postgres.md b/docs/postgres.md new file mode 100644 index 0000000..e5112d6 --- /dev/null +++ b/docs/postgres.md @@ -0,0 +1,52 @@ +# GraphAlg in Postgres +We need to write a postgres C extension to add GraphAlg support to it. +Resources: +- https://www.postgresql.org/docs/current/xfunc-c.html +- https://www.pgedge.com/blog/introduction-to-postgres-extension-development +- https://stackoverflow.com/questions/76056209/postgresql-c-extension-function-table-as-argument-and-as-result + +## Build from source +Download the latest postgres release source code. +```bash +cd thirdparty/ +wget https://ftp.postgresql.org/pub/source/v18.1/postgresql-18.1.tar.bz2 +``` + +Following https://www.postgresql.org/docs/current/install-make.html + +Packages to install: +- bison +- flex +- libreadline-dev + +```bash +./configure +make +sudo make install +``` + +Now the binaries are located at `/usr/local/pgsql/bin`. + +## Setup a database +```bash +export LC_ALL=C +export LC_CTYPE=C + +# Create a new DB +/usr/local/pgsql/bin/initdb -D ~/pgdata + +# Start the server +/usr/local/pgsql/bin/postgres -D ~/pgdata + +``` + +## Load extension +First build with `pgext/build.sh`. + +Then connect to the server with `/usr/local/pgsql/bin/psql postgres` and run: + +```bash +CREATE FUNCTION add_one(integer) RETURNS integer + AS '/workspaces/graphalg/pgext/funcs', 'add_one' + LANGUAGE C STRICT; +``` diff --git a/pgext/.gitignore b/pgext/.gitignore new file mode 100644 index 0000000..9d22eb4 --- /dev/null +++ b/pgext/.gitignore @@ -0,0 +1,2 @@ +*.o +*.so diff --git a/pgext/build.sh b/pgext/build.sh new file mode 100755 index 0000000..048359c --- /dev/null +++ b/pgext/build.sh @@ -0,0 +1,7 @@ +#!/bin/bash + +PG_SERVER_INCLUDE_DIR=$(/usr/local/pgsql/bin/pg_config --includedir-server) + +gcc -I $PG_SERVER_INCLUDE_DIR -fPIC -c pgext/funcs.c -o pgext/funcs.o +gcc -shared -o pgext/funcs.so pgext/funcs.o + diff --git a/pgext/funcs.c b/pgext/funcs.c new file mode 100644 index 0000000..37a9174 --- /dev/null +++ b/pgext/funcs.c @@ -0,0 +1,103 @@ +#include "postgres.h" + +#include "fmgr.h" +#include "utils/geo_decls.h" +#include "varatt.h" + +#include + +PG_MODULE_MAGIC; + +/* by value */ + +PG_FUNCTION_INFO_V1(add_one); + +Datum add_one(PG_FUNCTION_ARGS) { + int32 arg = PG_GETARG_INT32(0); + + PG_RETURN_INT32(arg + 1); +} + +/* by reference, fixed length */ + +PG_FUNCTION_INFO_V1(add_one_float8); + +Datum add_one_float8(PG_FUNCTION_ARGS) { + /* The macros for FLOAT8 hide its pass-by-reference nature. */ + float8 arg = PG_GETARG_FLOAT8(0); + + PG_RETURN_FLOAT8(arg + 1.0); +} + +PG_FUNCTION_INFO_V1(makepoint); + +Datum makepoint(PG_FUNCTION_ARGS) { + /* Here, the pass-by-reference nature of Point is not hidden. */ + Point *pointx = PG_GETARG_POINT_P(0); + Point *pointy = PG_GETARG_POINT_P(1); + Point *new_point = (Point *)palloc(sizeof(Point)); + + new_point->x = pointx->x; + new_point->y = pointy->y; + + PG_RETURN_POINT_P(new_point); +} + +/* by reference, variable length */ + +PG_FUNCTION_INFO_V1(copytext); + +Datum copytext(PG_FUNCTION_ARGS) { + text *t = PG_GETARG_TEXT_PP(0); + + /* + * VARSIZE_ANY_EXHDR is the size of the struct in bytes, minus the + * VARHDRSZ or VARHDRSZ_SHORT of its header. Construct the copy with a + * full-length header. + */ + text *new_t = (text *)palloc(VARSIZE_ANY_EXHDR(t) + VARHDRSZ); + SET_VARSIZE(new_t, VARSIZE_ANY_EXHDR(t) + VARHDRSZ); + + /* + * VARDATA is a pointer to the data region of the new struct. The source + * could be a short datum, so retrieve its data through VARDATA_ANY. + */ + memcpy(VARDATA(new_t), /* destination */ + VARDATA_ANY(t), /* source */ + VARSIZE_ANY_EXHDR(t)); /* how many bytes */ + PG_RETURN_TEXT_P(new_t); +} + +PG_FUNCTION_INFO_V1(concat_text); + +Datum concat_text(PG_FUNCTION_ARGS) { + text *arg1 = PG_GETARG_TEXT_PP(0); + text *arg2 = PG_GETARG_TEXT_PP(1); + int32 arg1_size = VARSIZE_ANY_EXHDR(arg1); + int32 arg2_size = VARSIZE_ANY_EXHDR(arg2); + int32 new_text_size = arg1_size + arg2_size + VARHDRSZ; + text *new_text = (text *)palloc(new_text_size); + + SET_VARSIZE(new_text, new_text_size); + memcpy(VARDATA(new_text), VARDATA_ANY(arg1), arg1_size); + memcpy(VARDATA(new_text) + arg1_size, VARDATA_ANY(arg2), arg2_size); + PG_RETURN_TEXT_P(new_text); +} + +/* A wrapper around starts_with(text, text) */ +/* + +PG_FUNCTION_INFO_V1(t_starts_with); + +Datum t_starts_with(PG_FUNCTION_ARGS) { + text *t1 = PG_GETARG_TEXT_PP(0); + text *t2 = PG_GETARG_TEXT_PP(1); + Oid collid = PG_GET_COLLATION(); + bool result; + + result = DatumGetBool(DirectFunctionCall2Coll( + text_starts_with, collid, PointerGetDatum(t1), PointerGetDatum(t2))); + PG_RETURN_BOOL(result); +} + +*/ diff --git a/pgext/load.sql b/pgext/load.sql new file mode 100644 index 0000000..1cbc3a9 --- /dev/null +++ b/pgext/load.sql @@ -0,0 +1,20 @@ +CREATE FUNCTION add_one(integer) RETURNS integer + AS '/workspaces/graphalg/pgext/funcs', 'add_one' + LANGUAGE C STRICT; + +-- note overloading of SQL function name "add_one" +CREATE FUNCTION add_one(double precision) RETURNS double precision + AS '/workspaces/graphalg/pgext/funcs', 'add_one_float8' + LANGUAGE C STRICT; + +CREATE FUNCTION makepoint(point, point) RETURNS point + AS '/workspaces/graphalg/pgext/funcs', 'makepoint' + LANGUAGE C STRICT; + +CREATE FUNCTION copytext(text) RETURNS text + AS '/workspaces/graphalg/pgext/funcs', 'copytext' + LANGUAGE C STRICT; + +CREATE FUNCTION concat_text(text, text) RETURNS text + AS '/workspaces/graphalg/pgext/funcs', 'concat_text' + LANGUAGE C STRICT; From 6435d2451295c4edd5879661d76331e9ee44752e Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sat, 22 Nov 2025 19:11:52 +0000 Subject: [PATCH 02/20] tutorial POC. --- docs/postgres.md | 12 ++ pgext/tutorial_fdw/Makefile | 9 ++ pgext/tutorial_fdw/smoke_test.sh | 14 ++ pgext/tutorial_fdw/smoke_test.sql | 4 + pgext/tutorial_fdw/tutorial_fdw--1.0.sql | 7 + pgext/tutorial_fdw/tutorial_fdw.c | 159 +++++++++++++++++++++++ pgext/tutorial_fdw/tutorial_fdw.control | 5 + 7 files changed, 210 insertions(+) create mode 100644 pgext/tutorial_fdw/Makefile create mode 100755 pgext/tutorial_fdw/smoke_test.sh create mode 100644 pgext/tutorial_fdw/smoke_test.sql create mode 100644 pgext/tutorial_fdw/tutorial_fdw--1.0.sql create mode 100644 pgext/tutorial_fdw/tutorial_fdw.c create mode 100644 pgext/tutorial_fdw/tutorial_fdw.control diff --git a/docs/postgres.md b/docs/postgres.md index e5112d6..d9b6bf2 100644 --- a/docs/postgres.md +++ b/docs/postgres.md @@ -18,6 +18,7 @@ Packages to install: - bison - flex - libreadline-dev +- libicu-dev ```bash ./configure @@ -50,3 +51,14 @@ CREATE FUNCTION add_one(integer) RETURNS integer AS '/workspaces/graphalg/pgext/funcs', 'add_one' LANGUAGE C STRICT; ``` + +## Foreign Data Wrapper +Resources: +- https://www.postgresql.org/docs/current/fdwhandler.html +- https://www.dolthub.com/blog/2022-01-26-creating-a-postgres-foreign-data-wrapper/ + +```bash +export PATH="/usr/local/pgsql/bin:$PATH" +cd pgext/tutorial_fdw +./smoke_test.sh +``` \ No newline at end of file diff --git a/pgext/tutorial_fdw/Makefile b/pgext/tutorial_fdw/Makefile new file mode 100644 index 0000000..7ba44af --- /dev/null +++ b/pgext/tutorial_fdw/Makefile @@ -0,0 +1,9 @@ +MODULE_big = tutorial_fdw +OBJS = tutorial_fdw.o + +EXTENSION = tutorial_fdw +DATA = tutorial_fdw--1.0.sql + +PG_CONFIG = /usr/local/pgsql/bin/pg_config +PGXS := $(shell $(PG_CONFIG) --pgxs) +include $(PGXS) \ No newline at end of file diff --git a/pgext/tutorial_fdw/smoke_test.sh b/pgext/tutorial_fdw/smoke_test.sh new file mode 100755 index 0000000..f3ee9cb --- /dev/null +++ b/pgext/tutorial_fdw/smoke_test.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +set -eo pipefail + +make +#sudo make install + +PGDATA=`mktemp -d -t tfdw-XXXXXXXXXXX` + +trap "PGDATA=\"$PGDATA\" pg_ctl stop >/dev/null || true; rm -rf \"$PGDATA\"" EXIT + +PGDATA="$PGDATA" pg_ctl initdb > /dev/null +PGDATA="$PGDATA" pg_ctl start +psql postgres -f smoke_test.sql \ No newline at end of file diff --git a/pgext/tutorial_fdw/smoke_test.sql b/pgext/tutorial_fdw/smoke_test.sql new file mode 100644 index 0000000..f989cb3 --- /dev/null +++ b/pgext/tutorial_fdw/smoke_test.sql @@ -0,0 +1,4 @@ +CREATE EXTENSION tutorial_fdw; +CREATE SERVER tutorial_server FOREIGN DATA WRAPPER tutorial_fdw; +CREATE FOREIGN TABLE sequential_ints ( val int ) SERVER tutorial_server; +SELECT * FROM sequential_ints; \ No newline at end of file diff --git a/pgext/tutorial_fdw/tutorial_fdw--1.0.sql b/pgext/tutorial_fdw/tutorial_fdw--1.0.sql new file mode 100644 index 0000000..9dc7f66 --- /dev/null +++ b/pgext/tutorial_fdw/tutorial_fdw--1.0.sql @@ -0,0 +1,7 @@ +CREATE FUNCTION tutorial_fdw_handler() +RETURNS fdw_handler +AS '/workspaces/graphalg/pgext/tutorial_fdw/tutorial_fdw' +LANGUAGE C STRICT; + +CREATE FOREIGN DATA WRAPPER tutorial_fdw + HANDLER tutorial_fdw_handler; diff --git a/pgext/tutorial_fdw/tutorial_fdw.c b/pgext/tutorial_fdw/tutorial_fdw.c new file mode 100644 index 0000000..f384436 --- /dev/null +++ b/pgext/tutorial_fdw/tutorial_fdw.c @@ -0,0 +1,159 @@ +#include "postgres.h" + +#include "access/table.h" +#include "fmgr.h" +#include "foreign/fdwapi.h" +#include "optimizer/pathnode.h" +#include "optimizer/planmain.h" +#include "optimizer/restrictinfo.h" +#include "utils/rel.h" + +Datum tutorial_fdw_handler(PG_FUNCTION_ARGS); + +PG_FUNCTION_INFO_V1(tutorial_fdw_handler); + +void tutorial_fdw_GetForeignRelSize(PlannerInfo *root, RelOptInfo *baserel, + Oid foreigntableid); + +void tutorial_fdw_GetForeignPaths(PlannerInfo *root, RelOptInfo *baserel, + Oid foreigntableid); + +ForeignScan *tutorial_fdw_GetForeignPlan(PlannerInfo *root, RelOptInfo *baserel, + Oid foreigntableid, + ForeignPath *best_path, List *tlist, + List *scan_clauses, Plan *outer_plan); + +void tutorial_fdw_BeginForeignScan(ForeignScanState *node, int eflags); + +TupleTableSlot *tutorial_fdw_IterateForeignScan(ForeignScanState *node); + +void tutorial_fdw_ReScanForeignScan(ForeignScanState *node); + +void tutorial_fdw_EndForeignScan(ForeignScanState *node); + +Datum tutorial_fdw_handler(PG_FUNCTION_ARGS) { + FdwRoutine *fdwroutine = makeNode(FdwRoutine); + fdwroutine->GetForeignRelSize = tutorial_fdw_GetForeignRelSize; + + fdwroutine->GetForeignPaths = tutorial_fdw_GetForeignPaths; + + fdwroutine->GetForeignPlan = tutorial_fdw_GetForeignPlan; + + fdwroutine->BeginForeignScan = tutorial_fdw_BeginForeignScan; + + fdwroutine->IterateForeignScan = tutorial_fdw_IterateForeignScan; + + fdwroutine->ReScanForeignScan = tutorial_fdw_ReScanForeignScan; + + fdwroutine->EndForeignScan = tutorial_fdw_EndForeignScan; + + PG_RETURN_POINTER(fdwroutine); +} + +void tutorial_fdw_GetForeignRelSize(PlannerInfo *root, RelOptInfo *baserel, + Oid foreigntableid) { + Relation rel = table_open(foreigntableid, NoLock); + + if (rel->rd_att->natts != 1) { + + ereport(ERROR, + + errcode(ERRCODE_FDW_INVALID_COLUMN_NUMBER), + + errmsg("incorrect schema for tutorial_fdw table %s: table must " + "have exactly one column", + NameStr(rel->rd_rel->relname))); + } + + /* + Oid typid = rel->rd_att->attrs[0].atttypid; + + if (typid != INT4OID) { + + ereport(ERROR, + + errcode(ERRCODE_FDW_INVALID_DATA_TYPE), + + errmsg("incorrect schema for tutorial_fdw table %s: table column " + "must have type int", + NameStr(rel->rd_rel->relname))); + } + */ + + table_close(rel, NoLock); +} + +void tutorial_fdw_GetForeignPaths(PlannerInfo *root, RelOptInfo *baserel, + Oid foreigntableid) { + Path *path = (Path *)create_foreignscan_path( + root, baserel, NULL, /* default pathtarget */ + baserel->rows, /* rows */ + 0, /* disabled_nodes */ + 1, /* startup cost */ + 1 + baserel->rows, /* total cost */ + NIL, /* no pathkeys */ + NULL, /* no required outer relids */ + NULL, /* no fdw_outerpath */ + NULL, /* no fdw_restrictinfo */ + NIL); /* no fdw_private */ + + add_path(baserel, path); +} + +ForeignScan *tutorial_fdw_GetForeignPlan(PlannerInfo *root, RelOptInfo *baserel, + Oid foreigntableid, + + ForeignPath *best_path, List *tlist, + List *scan_clauses, Plan *outer_plan) { + scan_clauses = extract_actual_clauses(scan_clauses, false); + return make_foreignscan( + tlist, scan_clauses, baserel->relid, + NIL, /* no expressions we will evaluate */ + NIL, /* no private data */ + NIL, /* no custom tlist; our scan tuple looks like tlist */ + NIL, /* no quals we will recheck */ + outer_plan); +} + +typedef struct tutorial_fdw_state { + + int current; + +} tutorial_fdw_state; + +void tutorial_fdw_BeginForeignScan(ForeignScanState *node, int eflags) { + tutorial_fdw_state *state = palloc0(sizeof(tutorial_fdw_state)); + + node->fdw_state = state; +} + +TupleTableSlot *tutorial_fdw_IterateForeignScan(ForeignScanState *node) { + TupleTableSlot *slot = node->ss.ss_ScanTupleSlot; + + ExecClearTuple(slot); + + tutorial_fdw_state *state = node->fdw_state; + + if (state->current < 64) { + + slot->tts_isnull[0] = false; + + slot->tts_values[0] = Int32GetDatum(state->current); + + ExecStoreVirtualTuple(slot); + + state->current++; + } + + return slot; +} + +void tutorial_fdw_ReScanForeignScan(ForeignScanState *node) { + tutorial_fdw_state *state = node->fdw_state; + + state->current = 0; +} + +void tutorial_fdw_EndForeignScan(ForeignScanState *node) {} + +PG_MODULE_MAGIC; diff --git a/pgext/tutorial_fdw/tutorial_fdw.control b/pgext/tutorial_fdw/tutorial_fdw.control new file mode 100644 index 0000000..28b88da --- /dev/null +++ b/pgext/tutorial_fdw/tutorial_fdw.control @@ -0,0 +1,5 @@ +comment = 'Tutorial FDW.' +default_version = '1.0' +module_pathname = '/tutorial_fdw' +relocatable = true + From c180fbd693b60e6acda64df9c199a442af72756e Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sat, 22 Nov 2025 19:28:08 +0000 Subject: [PATCH 03/20] Stub pg_graphalg. --- pg_graphalg/.gitignore | 2 ++ pg_graphalg/CMakeLists.txt | 12 ++++++++++++ pg_graphalg/configure.sh | 11 +++++++++++ pg_graphalg/src/CMakeLists.txt | 5 +++++ pg_graphalg/src/pg_graphalg.cpp | 16 ++++++++++++++++ pg_graphalg/test.sql | 5 +++++ 6 files changed, 51 insertions(+) create mode 100644 pg_graphalg/.gitignore create mode 100644 pg_graphalg/CMakeLists.txt create mode 100755 pg_graphalg/configure.sh create mode 100644 pg_graphalg/src/CMakeLists.txt create mode 100644 pg_graphalg/src/pg_graphalg.cpp create mode 100644 pg_graphalg/test.sql diff --git a/pg_graphalg/.gitignore b/pg_graphalg/.gitignore new file mode 100644 index 0000000..80a6844 --- /dev/null +++ b/pg_graphalg/.gitignore @@ -0,0 +1,2 @@ +/build* +/.cache diff --git a/pg_graphalg/CMakeLists.txt b/pg_graphalg/CMakeLists.txt new file mode 100644 index 0000000..700d125 --- /dev/null +++ b/pg_graphalg/CMakeLists.txt @@ -0,0 +1,12 @@ +# https://cliutils.gitlab.io/modern-cmake/chapters/basics.html#introduction-to-the-basics +cmake_minimum_required(VERSION 3.15...4.0) + +project( + pg_graphalg + VERSION 0.1 + DESCRIPTION "GraphAlg extension for PostgreSQL" + LANGUAGES CXX) + +find_package(PostgreSQL COMPONENTS Server) + +add_subdirectory(src) diff --git a/pg_graphalg/configure.sh b/pg_graphalg/configure.sh new file mode 100755 index 0000000..d2e7779 --- /dev/null +++ b/pg_graphalg/configure.sh @@ -0,0 +1,11 @@ +#!/bin/bash +WORKSPACE_ROOT=pg_graphalg/ +BUILD_DIR=$WORKSPACE_ROOT/build +rm -rf $BUILD_DIR +cmake -S $WORKSPACE_ROOT -B $BUILD_DIR -G Ninja \ + -DCMAKE_BUILD_TYPE=Debug \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ + -DCMAKE_CXX_COMPILER=clang++-20 \ + -DCMAKE_EXPORT_COMPILE_COMMANDS=1 \ + -DCMAKE_LINKER_TYPE=MOLD \ + -DPostgreSQL_ROOT=/usr/local/pgsql diff --git a/pg_graphalg/src/CMakeLists.txt b/pg_graphalg/src/CMakeLists.txt new file mode 100644 index 0000000..4c6f0b7 --- /dev/null +++ b/pg_graphalg/src/CMakeLists.txt @@ -0,0 +1,5 @@ +add_library(pg_graphalg SHARED + pg_graphalg.cpp +) + +target_link_libraries(pg_graphalg PRIVATE PostgreSQL::PostgreSQL) diff --git a/pg_graphalg/src/pg_graphalg.cpp b/pg_graphalg/src/pg_graphalg.cpp new file mode 100644 index 0000000..9e23caa --- /dev/null +++ b/pg_graphalg/src/pg_graphalg.cpp @@ -0,0 +1,16 @@ +#include "postgres.h" + +#include "fmgr.h" + +extern "C" { + +PG_MODULE_MAGIC; + +PG_FUNCTION_INFO_V1(add_one); + +Datum add_one(PG_FUNCTION_ARGS) { + int32 arg = PG_GETARG_INT32(0); + + PG_RETURN_INT32(arg + 1); +} +} diff --git a/pg_graphalg/test.sql b/pg_graphalg/test.sql new file mode 100644 index 0000000..2170f22 --- /dev/null +++ b/pg_graphalg/test.sql @@ -0,0 +1,5 @@ +CREATE OR REPLACE FUNCTION add_one(integer) RETURNS integer + AS '/workspaces/graphalg/pg_graphalg/build/src/libpg_graphalg.so', 'add_one' + LANGUAGE C STRICT; + +SELECT add_one(42); From 82263bd1ecd6a654c4f858f5818628ded74fedce Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sat, 22 Nov 2025 21:15:33 +0000 Subject: [PATCH 04/20] Port over tutorial impl. --- pg_graphalg/CMakeLists.txt | 2 +- pg_graphalg/src/CMakeLists.txt | 2 +- pg_graphalg/src/pg_graphalg.c | 105 ++++++++++++++++++++++++++++++++ pg_graphalg/src/pg_graphalg.cpp | 16 ----- pg_graphalg/test.sql | 19 ++++-- 5 files changed, 122 insertions(+), 22 deletions(-) create mode 100644 pg_graphalg/src/pg_graphalg.c delete mode 100644 pg_graphalg/src/pg_graphalg.cpp diff --git a/pg_graphalg/CMakeLists.txt b/pg_graphalg/CMakeLists.txt index 700d125..990686e 100644 --- a/pg_graphalg/CMakeLists.txt +++ b/pg_graphalg/CMakeLists.txt @@ -5,7 +5,7 @@ project( pg_graphalg VERSION 0.1 DESCRIPTION "GraphAlg extension for PostgreSQL" - LANGUAGES CXX) + LANGUAGES C CXX) find_package(PostgreSQL COMPONENTS Server) diff --git a/pg_graphalg/src/CMakeLists.txt b/pg_graphalg/src/CMakeLists.txt index 4c6f0b7..062098f 100644 --- a/pg_graphalg/src/CMakeLists.txt +++ b/pg_graphalg/src/CMakeLists.txt @@ -1,5 +1,5 @@ add_library(pg_graphalg SHARED - pg_graphalg.cpp + pg_graphalg.c ) target_link_libraries(pg_graphalg PRIVATE PostgreSQL::PostgreSQL) diff --git a/pg_graphalg/src/pg_graphalg.c b/pg_graphalg/src/pg_graphalg.c new file mode 100644 index 0000000..40b211e --- /dev/null +++ b/pg_graphalg/src/pg_graphalg.c @@ -0,0 +1,105 @@ +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +PG_MODULE_MAGIC; + +PG_FUNCTION_INFO_V1(add_one); + +PG_FUNCTION_INFO_V1(graphalg_fdw_handler); + +Datum add_one(PG_FUNCTION_ARGS) { + int32 arg = PG_GETARG_INT32(0); + + PG_RETURN_INT32(arg + 1); +} + +static void GetForeignRelSize(PlannerInfo *root, RelOptInfo *baserel, + Oid foreigntableid) { + baserel->rows = 42; +} + +static void GetForeignPaths(PlannerInfo *root, RelOptInfo *baserel, + Oid foreigntableid) { + ForeignPath *path = create_foreignscan_path(root, baserel, + /*target=*/NULL, + /*rows=*/baserel->rows, + /*disabled_nodes=*/0, + /*startup_cost=*/1, + /*total_cost=*/1 + baserel->rows, + /*pathkeys=*/NIL, + /*required_outer=*/NULL, + /*fdw_outerpath=*/NULL, + /*fdw_restrictinfo=*/NULL, + /*fdw_private=*/NULL); + add_path(baserel, (Path *)path); +} + +static ForeignScan *GetForeignPlan(PlannerInfo *root, RelOptInfo *baserel, + Oid foreigntableid, ForeignPath *best_path, + List *tlist, List *scan_clauses, + Plan *outer_plan) { + // On extract_actual_clauses: + // https://www.postgresql.org/docs/current/fdw-planning.html + scan_clauses = extract_actual_clauses(scan_clauses, false); + return make_foreignscan( + tlist, scan_clauses, baserel->relid, + NIL, /* no expressions we will evaluate */ + NIL, /* no private data */ + NIL, /* no custom tlist; our scan tuple looks like tlist */ + NIL, /* no quals we will recheck */ + outer_plan); +} + +typedef struct { + int current; +} GaScanState; + +static void BeginForeignScan(ForeignScanState *node, int eflags) { + node->fdw_state = palloc0(sizeof(GaScanState)); +} + +static TupleTableSlot *IterateForeignScan(ForeignScanState *node) { + TupleTableSlot *slot = node->ss.ss_ScanTupleSlot; + ExecClearTuple(slot); + + GaScanState *state = (GaScanState *)node->fdw_state; + if (state->current < 64) { + slot->tts_isnull[0] = false; + slot->tts_values[0] = Int32GetDatum(state->current); + ExecStoreVirtualTuple(slot); + state->current++; + } + + return slot; +} + +static void ReScanForeignScan(ForeignScanState *node) { + GaScanState *state = (GaScanState *)node->fdw_state; + state->current = 0; +} + +static void EndForeignScan(ForeignScanState *node) { + // No-Op +} + +Datum graphalg_fdw_handler(PG_FUNCTION_ARGS) { + FdwRoutine *fdwRoutine = makeNode(FdwRoutine); + + fdwRoutine->GetForeignRelSize = GetForeignRelSize; + fdwRoutine->GetForeignPaths = GetForeignPaths; + fdwRoutine->GetForeignPlan = GetForeignPlan; + fdwRoutine->BeginForeignScan = BeginForeignScan; + fdwRoutine->IterateForeignScan = IterateForeignScan; + fdwRoutine->ReScanForeignScan = ReScanForeignScan; + fdwRoutine->EndForeignScan = EndForeignScan; + + PG_RETURN_POINTER(fdwRoutine); +} diff --git a/pg_graphalg/src/pg_graphalg.cpp b/pg_graphalg/src/pg_graphalg.cpp deleted file mode 100644 index 9e23caa..0000000 --- a/pg_graphalg/src/pg_graphalg.cpp +++ /dev/null @@ -1,16 +0,0 @@ -#include "postgres.h" - -#include "fmgr.h" - -extern "C" { - -PG_MODULE_MAGIC; - -PG_FUNCTION_INFO_V1(add_one); - -Datum add_one(PG_FUNCTION_ARGS) { - int32 arg = PG_GETARG_INT32(0); - - PG_RETURN_INT32(arg + 1); -} -} diff --git a/pg_graphalg/test.sql b/pg_graphalg/test.sql index 2170f22..1bd11f0 100644 --- a/pg_graphalg/test.sql +++ b/pg_graphalg/test.sql @@ -1,5 +1,16 @@ -CREATE OR REPLACE FUNCTION add_one(integer) RETURNS integer - AS '/workspaces/graphalg/pg_graphalg/build/src/libpg_graphalg.so', 'add_one' - LANGUAGE C STRICT; +DROP FOREIGN TABLE IF EXISTS sequential_ints; +DROP SERVER IF EXISTS graphalg_server; +DROP FOREIGN DATA WRAPPER IF EXISTS graphalg_fdw; +DROP FUNCTION IF EXISTS graphalg_fdw_handler; -SELECT add_one(42); +CREATE FUNCTION graphalg_fdw_handler() +RETURNS fdw_handler +AS '/workspaces/graphalg/pg_graphalg/build/src/libpg_graphalg.so' +LANGUAGE C STRICT; + +CREATE FOREIGN DATA WRAPPER graphalg_fdw + HANDLER graphalg_fdw_handler; +CREATE SERVER graphalg_server FOREIGN DATA WRAPPER graphalg_fdw; + +CREATE FOREIGN TABLE sequential_ints ( val int ) SERVER graphalg_server; +SELECT * FROM sequential_ints; From e7b4d25a0b96706d5cb18787050eeb271d311a8d Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sat, 22 Nov 2025 22:03:03 +0000 Subject: [PATCH 05/20] Getting somewhere. --- docs/postgres.md | 7 +------ pg_graphalg/src/pg_graphalg.c | 20 +++++++++++++++++++- pg_graphalg/test.sql | 2 ++ 3 files changed, 22 insertions(+), 7 deletions(-) diff --git a/docs/postgres.md b/docs/postgres.md index d9b6bf2..dbb348b 100644 --- a/docs/postgres.md +++ b/docs/postgres.md @@ -56,9 +56,4 @@ CREATE FUNCTION add_one(integer) RETURNS integer Resources: - https://www.postgresql.org/docs/current/fdwhandler.html - https://www.dolthub.com/blog/2022-01-26-creating-a-postgres-foreign-data-wrapper/ - -```bash -export PATH="/usr/local/pgsql/bin:$PATH" -cd pgext/tutorial_fdw -./smoke_test.sh -``` \ No newline at end of file +- https://github.com/Kentik-Archive/wdb_fdw diff --git a/pg_graphalg/src/pg_graphalg.c b/pg_graphalg/src/pg_graphalg.c index 40b211e..00b73dd 100644 --- a/pg_graphalg/src/pg_graphalg.c +++ b/pg_graphalg/src/pg_graphalg.c @@ -71,7 +71,7 @@ static TupleTableSlot *IterateForeignScan(ForeignScanState *node) { ExecClearTuple(slot); GaScanState *state = (GaScanState *)node->fdw_state; - if (state->current < 64) { + if (state->current < 10) { slot->tts_isnull[0] = false; slot->tts_values[0] = Int32GetDatum(state->current); ExecStoreVirtualTuple(slot); @@ -90,6 +90,22 @@ static void EndForeignScan(ForeignScanState *node) { // No-Op } +static TupleTableSlot *ExecForeignInsert(EState *estate, ResultRelInfo *rinfo, + TupleTableSlot *slot, + TupleTableSlot *planSlot) { + // TODO: Actually save it somewhere. + bool isnull; + Datum datum = slot_getattr(slot, 1, &isnull); + if (isnull) { + printf("NULL\n"); + } else { + int x = DatumGetInt32(datum); + printf("value: %d\n", x); + } + + return slot; +} + Datum graphalg_fdw_handler(PG_FUNCTION_ARGS) { FdwRoutine *fdwRoutine = makeNode(FdwRoutine); @@ -101,5 +117,7 @@ Datum graphalg_fdw_handler(PG_FUNCTION_ARGS) { fdwRoutine->ReScanForeignScan = ReScanForeignScan; fdwRoutine->EndForeignScan = EndForeignScan; + fdwRoutine->ExecForeignInsert = ExecForeignInsert; + PG_RETURN_POINTER(fdwRoutine); } diff --git a/pg_graphalg/test.sql b/pg_graphalg/test.sql index 1bd11f0..01d95d2 100644 --- a/pg_graphalg/test.sql +++ b/pg_graphalg/test.sql @@ -14,3 +14,5 @@ CREATE SERVER graphalg_server FOREIGN DATA WRAPPER graphalg_fdw; CREATE FOREIGN TABLE sequential_ints ( val int ) SERVER graphalg_server; SELECT * FROM sequential_ints; + +INSERT INTO sequential_ints VALUES (1), (2); From 56e748092ec1e2ba87838c5249217b4c9e5eeaa7 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sun, 23 Nov 2025 10:43:04 +0000 Subject: [PATCH 06/20] Pull in GraphBLAS. --- docs/postgres.md | 13 ++++ pg_graphalg/CMakeLists.txt | 2 + pg_graphalg/configure.sh | 2 +- pg_graphalg/include/pg_graphalg/PgGraphAlg.h | 38 ++++++++++++ pg_graphalg/src/CMakeLists.txt | 11 +++- .../src/{pg_graphalg.c => pg_graphalg.cpp} | 62 +++++++++++-------- pg_graphalg/src/pg_graphalg/CMakeLists.txt | 5 ++ pg_graphalg/src/pg_graphalg/PgGraphAlg.cpp | 37 +++++++++++ pg_graphalg/test.sql | 19 ++++-- 9 files changed, 154 insertions(+), 35 deletions(-) create mode 100644 pg_graphalg/include/pg_graphalg/PgGraphAlg.h rename pg_graphalg/src/{pg_graphalg.c => pg_graphalg.cpp} (72%) create mode 100644 pg_graphalg/src/pg_graphalg/CMakeLists.txt create mode 100644 pg_graphalg/src/pg_graphalg/PgGraphAlg.cpp diff --git a/docs/postgres.md b/docs/postgres.md index dbb348b..6c0fb7e 100644 --- a/docs/postgres.md +++ b/docs/postgres.md @@ -57,3 +57,16 @@ Resources: - https://www.postgresql.org/docs/current/fdwhandler.html - https://www.dolthub.com/blog/2022-01-26-creating-a-postgres-foreign-data-wrapper/ - https://github.com/Kentik-Archive/wdb_fdw + +## GraphBLAS +We need to link to SuiteSparse:GraphBLAS. + +Install package `libgraphblas-dev`. + +Alternatively: +Download https://github.com/DrTimothyAldenDavis/SuiteSparse/archive/refs/tags/v7.12.1.tar.gz to `thirdparty/`. + +``` +make +sudo make install +``` diff --git a/pg_graphalg/CMakeLists.txt b/pg_graphalg/CMakeLists.txt index 990686e..424e614 100644 --- a/pg_graphalg/CMakeLists.txt +++ b/pg_graphalg/CMakeLists.txt @@ -9,4 +9,6 @@ project( find_package(PostgreSQL COMPONENTS Server) +find_library(NAMES GraphBLAS) + add_subdirectory(src) diff --git a/pg_graphalg/configure.sh b/pg_graphalg/configure.sh index d2e7779..5512c75 100755 --- a/pg_graphalg/configure.sh +++ b/pg_graphalg/configure.sh @@ -8,4 +8,4 @@ cmake -S $WORKSPACE_ROOT -B $BUILD_DIR -G Ninja \ -DCMAKE_CXX_COMPILER=clang++-20 \ -DCMAKE_EXPORT_COMPILE_COMMANDS=1 \ -DCMAKE_LINKER_TYPE=MOLD \ - -DPostgreSQL_ROOT=/usr/local/pgsql + -DPostgreSQL_ROOT=/usr/local/pgsql \ diff --git a/pg_graphalg/include/pg_graphalg/PgGraphAlg.h b/pg_graphalg/include/pg_graphalg/PgGraphAlg.h new file mode 100644 index 0000000..aa5101f --- /dev/null +++ b/pg_graphalg/include/pg_graphalg/PgGraphAlg.h @@ -0,0 +1,38 @@ +#pragma once + +#include +#include +#include +#include +#include + +extern "C" { +#include "GraphBLAS.h" +} + +namespace pg_graphalg { + +struct ScanState { + std::size_t row = 0; + std::size_t col = 0; + + void reset() { + row = 0; + col = 0; + } +}; + +class PgGraphAlg { +private: + std::map, std::int64_t> _values; + +public: + PgGraphAlg(); + + std::size_t size() { return _values.size(); } + void addTuple(std::size_t row, std::size_t col, std::int64_t value); + std::optional> + scan(ScanState &state); +}; + +} // namespace pg_graphalg diff --git a/pg_graphalg/src/CMakeLists.txt b/pg_graphalg/src/CMakeLists.txt index 062098f..69bab63 100644 --- a/pg_graphalg/src/CMakeLists.txt +++ b/pg_graphalg/src/CMakeLists.txt @@ -1,5 +1,10 @@ +add_subdirectory(pg_graphalg) + add_library(pg_graphalg SHARED - pg_graphalg.c + pg_graphalg.cpp ) - -target_link_libraries(pg_graphalg PRIVATE PostgreSQL::PostgreSQL) +target_include_directories(pg_graphalg PUBLIC ../include) +target_link_libraries(pg_graphalg + PRIVATE + PgGraphAlg + PostgreSQL::PostgreSQL) diff --git a/pg_graphalg/src/pg_graphalg.c b/pg_graphalg/src/pg_graphalg.cpp similarity index 72% rename from pg_graphalg/src/pg_graphalg.c rename to pg_graphalg/src/pg_graphalg.cpp index 00b73dd..894256c 100644 --- a/pg_graphalg/src/pg_graphalg.c +++ b/pg_graphalg/src/pg_graphalg.cpp @@ -1,3 +1,7 @@ +#include "pg_graphalg/PgGraphAlg.h" + +extern "C" { + #include #include @@ -11,19 +15,20 @@ PG_MODULE_MAGIC; -PG_FUNCTION_INFO_V1(add_one); - -PG_FUNCTION_INFO_V1(graphalg_fdw_handler); - -Datum add_one(PG_FUNCTION_ARGS) { - int32 arg = PG_GETARG_INT32(0); +static pg_graphalg::PgGraphAlg *SINGLETON = nullptr; +static pg_graphalg::PgGraphAlg &getInstance() { + if (!SINGLETON) { + SINGLETON = new pg_graphalg::PgGraphAlg(); + } - PG_RETURN_INT32(arg + 1); + return *SINGLETON; } +PG_FUNCTION_INFO_V1(graphalg_fdw_handler); + static void GetForeignRelSize(PlannerInfo *root, RelOptInfo *baserel, Oid foreigntableid) { - baserel->rows = 42; + baserel->rows = getInstance().size(); } static void GetForeignPaths(PlannerInfo *root, RelOptInfo *baserel, @@ -58,32 +63,34 @@ static ForeignScan *GetForeignPlan(PlannerInfo *root, RelOptInfo *baserel, outer_plan); } -typedef struct { - int current; -} GaScanState; - static void BeginForeignScan(ForeignScanState *node, int eflags) { - node->fdw_state = palloc0(sizeof(GaScanState)); + auto *state = palloc(sizeof(pg_graphalg::ScanState)); + new (state) pg_graphalg::ScanState(); + node->fdw_state = state; } static TupleTableSlot *IterateForeignScan(ForeignScanState *node) { TupleTableSlot *slot = node->ss.ss_ScanTupleSlot; ExecClearTuple(slot); - GaScanState *state = (GaScanState *)node->fdw_state; - if (state->current < 10) { + auto *scanState = static_cast(node->fdw_state); + if (auto res = getInstance().scan(*scanState)) { slot->tts_isnull[0] = false; - slot->tts_values[0] = Int32GetDatum(state->current); + slot->tts_isnull[1] = false; + slot->tts_isnull[2] = false; + auto [row, col, val] = *res; + slot->tts_values[0] = UInt64GetDatum(row); + slot->tts_values[1] = UInt64GetDatum(col); + slot->tts_values[2] = UInt64GetDatum(val); ExecStoreVirtualTuple(slot); - state->current++; } return slot; } static void ReScanForeignScan(ForeignScanState *node) { - GaScanState *state = (GaScanState *)node->fdw_state; - state->current = 0; + auto *scanState = static_cast(node->fdw_state); + scanState->reset(); } static void EndForeignScan(ForeignScanState *node) { @@ -93,16 +100,16 @@ static void EndForeignScan(ForeignScanState *node) { static TupleTableSlot *ExecForeignInsert(EState *estate, ResultRelInfo *rinfo, TupleTableSlot *slot, TupleTableSlot *planSlot) { - // TODO: Actually save it somewhere. - bool isnull; - Datum datum = slot_getattr(slot, 1, &isnull); - if (isnull) { - printf("NULL\n"); - } else { - int x = DatumGetInt32(datum); - printf("value: %d\n", x); + slot_getsomeattrs(slot, 3); + if (slot->tts_isnull[0] || slot->tts_isnull[1] || slot->tts_isnull[2]) { + // Ignore nulls + return nullptr; } + std::size_t row = DatumGetUInt64(slot->tts_values[0]); + std::size_t col = DatumGetUInt64(slot->tts_values[1]); + std::int64_t val = DatumGetInt64(slot->tts_values[2]); + getInstance().addTuple(row, col, val); return slot; } @@ -121,3 +128,4 @@ Datum graphalg_fdw_handler(PG_FUNCTION_ARGS) { PG_RETURN_POINTER(fdwRoutine); } +} diff --git a/pg_graphalg/src/pg_graphalg/CMakeLists.txt b/pg_graphalg/src/pg_graphalg/CMakeLists.txt new file mode 100644 index 0000000..cd03ea5 --- /dev/null +++ b/pg_graphalg/src/pg_graphalg/CMakeLists.txt @@ -0,0 +1,5 @@ +add_library(PgGraphAlg SHARED + PgGraphAlg.cpp +) +target_include_directories(PgGraphAlg PUBLIC ../../include) +target_link_libraries(PgGraphAlg graphblas) diff --git a/pg_graphalg/src/pg_graphalg/PgGraphAlg.cpp b/pg_graphalg/src/pg_graphalg/PgGraphAlg.cpp new file mode 100644 index 0000000..47c215a --- /dev/null +++ b/pg_graphalg/src/pg_graphalg/PgGraphAlg.cpp @@ -0,0 +1,37 @@ +#include +#include + +extern "C" { +#include "GraphBLAS.h" +} + +#include "pg_graphalg/PgGraphAlg.h" + +namespace pg_graphalg { + +PgGraphAlg::PgGraphAlg() { + // TODO: init graphblas +} + +void PgGraphAlg::addTuple(std::size_t row, std::size_t col, + std::int64_t value) { + _values[{row, col}] = value; +} + +std::optional> +PgGraphAlg::scan(ScanState &state) { + auto it = _values.lower_bound({state.row, state.col}); + if (it == _values.end()) { + return std::nullopt; + } + + std::size_t row = it->first.first; + std::size_t col = it->first.second; + std::int64_t val = it->second; + + state.row = row; + state.col = col + 1; + return std::make_tuple(row, col, val); +} + +} // namespace pg_graphalg diff --git a/pg_graphalg/test.sql b/pg_graphalg/test.sql index 01d95d2..6177cad 100644 --- a/pg_graphalg/test.sql +++ b/pg_graphalg/test.sql @@ -1,4 +1,4 @@ -DROP FOREIGN TABLE IF EXISTS sequential_ints; +DROP FOREIGN TABLE IF EXISTS mat; DROP SERVER IF EXISTS graphalg_server; DROP FOREIGN DATA WRAPPER IF EXISTS graphalg_fdw; DROP FUNCTION IF EXISTS graphalg_fdw_handler; @@ -12,7 +12,18 @@ CREATE FOREIGN DATA WRAPPER graphalg_fdw HANDLER graphalg_fdw_handler; CREATE SERVER graphalg_server FOREIGN DATA WRAPPER graphalg_fdw; -CREATE FOREIGN TABLE sequential_ints ( val int ) SERVER graphalg_server; -SELECT * FROM sequential_ints; +CREATE FOREIGN TABLE mat ( row bigint, col bigint, val bigint ) SERVER graphalg_server; +SELECT * FROM mat; -INSERT INTO sequential_ints VALUES (1), (2); +INSERT INTO mat VALUES + (0, 0, 42), + (0, 1, 43), + (1, 0, 44), + (1, 1, 45); + +SELECT * FROM mat; + +INSERT INTO mat VALUES + (0, 1, 4000); + +SELECT * FROM mat; From 612b4afd891ecba0ff2bb98813f13dc6ed7cd327 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sat, 29 Nov 2025 22:02:01 +0000 Subject: [PATCH 07/20] Support multiple tables. --- docs/postgres.md | 20 +++-- pg_graphalg/README.md | 65 ++++++++++++++ pg_graphalg/include/pg_graphalg/PgGraphAlg.h | 36 +++++++- pg_graphalg/src/pg_graphalg.cpp | 92 +++++++++++++++++++- pg_graphalg/src/pg_graphalg/PgGraphAlg.cpp | 30 +++++-- pg_graphalg/test.sql | 25 ++++-- 6 files changed, 242 insertions(+), 26 deletions(-) create mode 100644 pg_graphalg/README.md diff --git a/docs/postgres.md b/docs/postgres.md index 6c0fb7e..2aa9084 100644 --- a/docs/postgres.md +++ b/docs/postgres.md @@ -1,9 +1,14 @@ # GraphAlg in Postgres -We need to write a postgres C extension to add GraphAlg support to it. -Resources: -- https://www.postgresql.org/docs/current/xfunc-c.html -- https://www.pgedge.com/blog/introduction-to-postgres-extension-development -- https://stackoverflow.com/questions/76056209/postgresql-c-extension-function-table-as-argument-and-as-result +The goal: Run GraphAlg programs in PostgreSQL. +This is accomplished by writing a PostgreSQL extension. + +## Building and Testing the Extension +Key steps, explained in more detail below: +- Build postgreSQL v18 from source and install it to the default path `/usr/local/pgsql`. +- Setup a dummy database in `~/pgdata` +- Install SuiteSparse:GraphBLAS +- Build the extension + ## Build from source Download the latest postgres release source code. @@ -70,3 +75,8 @@ Download https://github.com/DrTimothyAldenDavis/SuiteSparse/archive/refs/tags/v7 make sudo make install ``` + +## Resources +- https://www.postgresql.org/docs/current/xfunc-c.html +- https://www.pgedge.com/blog/introduction-to-postgres-extension-development +- https://stackoverflow.com/questions/76056209/postgresql-c-extension-function-table-as-argument-and-as-result diff --git a/pg_graphalg/README.md b/pg_graphalg/README.md new file mode 100644 index 0000000..c0ae37b --- /dev/null +++ b/pg_graphalg/README.md @@ -0,0 +1,65 @@ +# pg_graphalg: Run GraphAlg in PostgreSQL +This is an extension for PostgreSQL to execute GraphAlg programs. + +## Building +In addition to setting up the devcontainer for the project, you need to: +- Build postgreSQL v18 from source and install it to the default path `/usr/local/pgsql`. +- Setup a dummy database in `~/pgdata` +- Install SuiteSparse:GraphBLAS + +### PostgreSQL +Download the latest postgres release source code. + +```bash +# Dependencies for building PostgreSQL +sudo apt install bison flex libreadline-dev libicu-dev + +# Download the sources +cd thirdparty/ +wget https://ftp.postgresql.org/pub/source/v18.1/postgresql-18.1.tar.bz2 +tar xf postgresql-18.1.tar.bz2 +cd postgresql-18.1/ + +# Build and install +./configure +make +sudo make install +``` + +Now PostgreSQL is installed at `/usr/local/pgsql`. + +### Setup a database +```bash +# Need to set or PostgreSQL will refuse to create the DB +export LC_ALL=C +export LC_CTYPE=C + +# Create a new DB +/usr/local/pgsql/bin/initdb -D ~/pgdata +``` + +### GraphBLAS +Install using APT: + +```bash +sudo apt install libgraphblas-dev +``` + +### Build the Extension +```bash +pg_graphalg/configure.sh + +# Extension is located at pg_graphalg/build/src/libpg_graphalg.so +cmake --build pg_graphalg/build +``` + +## Testing +After the extension has been built, start the server and then run the `test.sql` script: + +```bash +# Start the server +/usr/local/pgsql/bin/postgres -D ~/pgdata + +# Run the tests +/usr/local/pgsql/bin/psql postgres -f pg_graphalg/test.sql +``` diff --git a/pg_graphalg/include/pg_graphalg/PgGraphAlg.h b/pg_graphalg/include/pg_graphalg/PgGraphAlg.h index aa5101f..627dae5 100644 --- a/pg_graphalg/include/pg_graphalg/PgGraphAlg.h +++ b/pg_graphalg/include/pg_graphalg/PgGraphAlg.h @@ -4,6 +4,7 @@ #include #include #include +#include #include extern "C" { @@ -12,27 +13,54 @@ extern "C" { namespace pg_graphalg { +using TableId = unsigned int; + +class MatrixTable; + struct ScanState { + MatrixTable *table; std::size_t row = 0; std::size_t col = 0; + ScanState(MatrixTable *table) : table(table) {} + void reset() { row = 0; col = 0; } }; -class PgGraphAlg { +struct MatrixTableDef { + std::size_t nRows; + std::size_t nCols; + // TODO: data type. +}; + +class MatrixTable { private: + std::size_t _nRows; + std::size_t _nCols; std::map, std::int64_t> _values; public: - PgGraphAlg(); + MatrixTable(const MatrixTableDef &def); - std::size_t size() { return _values.size(); } - void addTuple(std::size_t row, std::size_t col, std::int64_t value); + void setValue(std::size_t row, std::size_t col, std::int64_t value); std::optional> scan(ScanState &state); + + std::size_t nValues() { return _values.size(); } +}; + +class PgGraphAlg { +private: + std::unordered_map _tables; + +public: + PgGraphAlg(); + + MatrixTable &getTable(TableId tableId); + MatrixTable &getOrCreateTable(TableId tableId, const MatrixTableDef &def); }; } // namespace pg_graphalg diff --git a/pg_graphalg/src/pg_graphalg.cpp b/pg_graphalg/src/pg_graphalg.cpp index 894256c..dd5d297 100644 --- a/pg_graphalg/src/pg_graphalg.cpp +++ b/pg_graphalg/src/pg_graphalg.cpp @@ -1,17 +1,25 @@ +#include +#include + #include "pg_graphalg/PgGraphAlg.h" extern "C" { #include +#include #include #include #include +#include +#include #include #include #include #include +#include #include +#include PG_MODULE_MAGIC; @@ -26,9 +34,66 @@ static pg_graphalg::PgGraphAlg &getInstance() { PG_FUNCTION_INFO_V1(graphalg_fdw_handler); +static std::optional +parseOptions(ForeignTable *table) { + ListCell *cell; + std::optional rows = 0; + std::optional cols = 0; + + foreach (cell, table->options) { + auto *def = lfirst_node(DefElem, cell); + std::string_view defName(def->defname); + if (defName == "rows") { + // TODO: Check type of option + rows = defGetInt64(def); + if (rows < 0) { + ereport(ERROR, (errcode(ERRCODE_FDW_ERROR), + errmsg("invalid value for option \"rows\": %d must be " + "a positive integer", + *rows))); + return std::nullopt; + } + } else if (defName == "columns") { + // TODO: Check type of option + cols = defGetInt64(def); + if (rows < 0) { + ereport(ERROR, (errcode(ERRCODE_FDW_ERROR), + errmsg("invalid value for option \"cols\": %d must be " + "a positive integer", + *cols))); + return std::nullopt; + } + } else { + ereport(ERROR, + (errcode(ERRCODE_FDW_INVALID_OPTION_NAME), + errmsg("invalid option \"%s\"", def->defname), + errhint("Valid table options for graphalg are \"rows\", and " + "\"columns\""))); + return std::nullopt; + } + } + + if (!rows) { + ereport(ERROR, (errcode(ERRCODE_FDW_ERROR), + errmsg("missing required option \"rows\""))); + return std::nullopt; + } + + return pg_graphalg::MatrixTableDef{ + static_cast(*rows), + static_cast(*cols), + }; +} + static void GetForeignRelSize(PlannerInfo *root, RelOptInfo *baserel, Oid foreigntableid) { - baserel->rows = getInstance().size(); + auto tableDef = parseOptions(GetForeignTable(foreigntableid)); + if (!tableDef) { + return; + } + + auto &table = getInstance().getOrCreateTable(foreigntableid, *tableDef); + baserel->rows = table.nValues(); } static void GetForeignPaths(PlannerInfo *root, RelOptInfo *baserel, @@ -64,8 +129,17 @@ static ForeignScan *GetForeignPlan(PlannerInfo *root, RelOptInfo *baserel, } static void BeginForeignScan(ForeignScanState *node, int eflags) { + auto tableId = RelationGetRelid(node->ss.ss_currentRelation); + // TODO: Avoid parsing options multiple times? + auto tableDef = parseOptions(GetForeignTable(tableId)); + if (!tableDef) { + return; + } + + auto &table = getInstance().getOrCreateTable(tableId, *tableDef); + auto *state = palloc(sizeof(pg_graphalg::ScanState)); - new (state) pg_graphalg::ScanState(); + new (state) pg_graphalg::ScanState(&table); node->fdw_state = state; } @@ -74,7 +148,8 @@ static TupleTableSlot *IterateForeignScan(ForeignScanState *node) { ExecClearTuple(slot); auto *scanState = static_cast(node->fdw_state); - if (auto res = getInstance().scan(*scanState)) { + auto &table = *scanState->table; + if (auto res = table.scan(*scanState)) { slot->tts_isnull[0] = false; slot->tts_isnull[1] = false; slot->tts_isnull[2] = false; @@ -100,6 +175,15 @@ static void EndForeignScan(ForeignScanState *node) { static TupleTableSlot *ExecForeignInsert(EState *estate, ResultRelInfo *rinfo, TupleTableSlot *slot, TupleTableSlot *planSlot) { + auto tableId = RelationGetRelid(rinfo->ri_RelationDesc); + // TODO: Avoid parsing options multiple times? + auto tableDef = parseOptions(GetForeignTable(tableId)); + if (!tableDef) { + return NULL; + } + + auto &table = getInstance().getOrCreateTable(tableId, *tableDef); + slot_getsomeattrs(slot, 3); if (slot->tts_isnull[0] || slot->tts_isnull[1] || slot->tts_isnull[2]) { // Ignore nulls @@ -109,7 +193,7 @@ static TupleTableSlot *ExecForeignInsert(EState *estate, ResultRelInfo *rinfo, std::size_t row = DatumGetUInt64(slot->tts_values[0]); std::size_t col = DatumGetUInt64(slot->tts_values[1]); std::int64_t val = DatumGetInt64(slot->tts_values[2]); - getInstance().addTuple(row, col, val); + table.setValue(row, col, val); return slot; } diff --git a/pg_graphalg/src/pg_graphalg/PgGraphAlg.cpp b/pg_graphalg/src/pg_graphalg/PgGraphAlg.cpp index 47c215a..261da36 100644 --- a/pg_graphalg/src/pg_graphalg/PgGraphAlg.cpp +++ b/pg_graphalg/src/pg_graphalg/PgGraphAlg.cpp @@ -1,3 +1,4 @@ +#include #include #include @@ -9,17 +10,16 @@ extern "C" { namespace pg_graphalg { -PgGraphAlg::PgGraphAlg() { - // TODO: init graphblas -} +MatrixTable::MatrixTable(const MatrixTableDef &def) + : _nRows(def.nRows), _nCols(def.nCols) {} -void PgGraphAlg::addTuple(std::size_t row, std::size_t col, - std::int64_t value) { +void MatrixTable::setValue(std::size_t row, std::size_t col, + std::int64_t value) { _values[{row, col}] = value; } std::optional> -PgGraphAlg::scan(ScanState &state) { +MatrixTable::scan(ScanState &state) { auto it = _values.lower_bound({state.row, state.col}); if (it == _values.end()) { return std::nullopt; @@ -34,4 +34,22 @@ PgGraphAlg::scan(ScanState &state) { return std::make_tuple(row, col, val); } +PgGraphAlg::PgGraphAlg() { + // TODO: init graphblas +} + +MatrixTable &PgGraphAlg::getTable(TableId tableId) { + assert(_tables.count(tableId) && "getTable called before getOrCreateTable"); + return _tables.at(tableId); +} + +MatrixTable &PgGraphAlg::getOrCreateTable(TableId tableId, + const MatrixTableDef &def) { + if (!_tables.count(tableId)) { + _tables.emplace(tableId, def); + } + + return getTable(tableId); +} + } // namespace pg_graphalg diff --git a/pg_graphalg/test.sql b/pg_graphalg/test.sql index 6177cad..f6d6145 100644 --- a/pg_graphalg/test.sql +++ b/pg_graphalg/test.sql @@ -1,4 +1,5 @@ -DROP FOREIGN TABLE IF EXISTS mat; +DROP FOREIGN TABLE IF EXISTS mat1; +DROP FOREIGN TABLE IF EXISTS mat2; DROP SERVER IF EXISTS graphalg_server; DROP FOREIGN DATA WRAPPER IF EXISTS graphalg_fdw; DROP FUNCTION IF EXISTS graphalg_fdw_handler; @@ -12,18 +13,28 @@ CREATE FOREIGN DATA WRAPPER graphalg_fdw HANDLER graphalg_fdw_handler; CREATE SERVER graphalg_server FOREIGN DATA WRAPPER graphalg_fdw; -CREATE FOREIGN TABLE mat ( row bigint, col bigint, val bigint ) SERVER graphalg_server; -SELECT * FROM mat; +CREATE FOREIGN TABLE mat1 ( row bigint, col bigint, val bigint ) SERVER graphalg_server; +CREATE FOREIGN TABLE mat2 ( row bigint, col bigint, val bigint ) SERVER graphalg_server; +SELECT * FROM mat1; +SELECT * FROM mat2; -INSERT INTO mat VALUES +INSERT INTO mat1 VALUES (0, 0, 42), (0, 1, 43), (1, 0, 44), (1, 1, 45); -SELECT * FROM mat; +INSERT INTO mat2 VALUES + (0, 0, 420), + (0, 1, 430), + (1, 0, 440), + (1, 1, 450); -INSERT INTO mat VALUES +SELECT * FROM mat1; + +SELECT * FROM mat2; + +INSERT INTO mat1 VALUES (0, 1, 4000); -SELECT * FROM mat; +SELECT * FROM mat1; From 6f5a4035c2a6352995651468542716c07ed41585 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sat, 29 Nov 2025 22:17:46 +0000 Subject: [PATCH 08/20] Validate options. --- pg_graphalg/src/pg_graphalg.cpp | 48 ++++++++++++++++++++++----------- pg_graphalg/test.sql | 4 +-- 2 files changed, 34 insertions(+), 18 deletions(-) diff --git a/pg_graphalg/src/pg_graphalg.cpp b/pg_graphalg/src/pg_graphalg.cpp index dd5d297..3e20c72 100644 --- a/pg_graphalg/src/pg_graphalg.cpp +++ b/pg_graphalg/src/pg_graphalg.cpp @@ -1,3 +1,4 @@ +#include #include #include @@ -34,33 +35,42 @@ static pg_graphalg::PgGraphAlg &getInstance() { PG_FUNCTION_INFO_V1(graphalg_fdw_handler); +static std::optional parseDimension(const char *c) { + auto v = std::atoll(c); + if (v <= 0) { + return std::nullopt; + } else { + return v; + } +} + static std::optional parseOptions(ForeignTable *table) { ListCell *cell; - std::optional rows = 0; - std::optional cols = 0; + std::optional rows; + std::optional cols; foreach (cell, table->options) { auto *def = lfirst_node(DefElem, cell); std::string_view defName(def->defname); if (defName == "rows") { - // TODO: Check type of option - rows = defGetInt64(def); - if (rows < 0) { - ereport(ERROR, (errcode(ERRCODE_FDW_ERROR), - errmsg("invalid value for option \"rows\": %d must be " - "a positive integer", - *rows))); + rows = parseDimension(defGetString(def)); + if (!rows) { + ereport(ERROR, + (errcode(ERRCODE_FDW_ERROR), + errmsg("invalid value for option \"rows\": '%s' must be " + "a positive integer", + defGetString(def)))); return std::nullopt; } } else if (defName == "columns") { - // TODO: Check type of option - cols = defGetInt64(def); - if (rows < 0) { - ereport(ERROR, (errcode(ERRCODE_FDW_ERROR), - errmsg("invalid value for option \"cols\": %d must be " - "a positive integer", - *cols))); + cols = parseDimension(defGetString(def)); + if (!cols) { + ereport(ERROR, + (errcode(ERRCODE_FDW_ERROR), + errmsg("invalid value for option \"columns\": '%s' must be " + "a positive integer", + defGetString(def)))); return std::nullopt; } } else { @@ -79,6 +89,12 @@ parseOptions(ForeignTable *table) { return std::nullopt; } + if (!cols) { + ereport(ERROR, (errcode(ERRCODE_FDW_ERROR), + errmsg("missing required option \"columns\""))); + return std::nullopt; + } + return pg_graphalg::MatrixTableDef{ static_cast(*rows), static_cast(*cols), diff --git a/pg_graphalg/test.sql b/pg_graphalg/test.sql index f6d6145..d8f39c2 100644 --- a/pg_graphalg/test.sql +++ b/pg_graphalg/test.sql @@ -13,8 +13,8 @@ CREATE FOREIGN DATA WRAPPER graphalg_fdw HANDLER graphalg_fdw_handler; CREATE SERVER graphalg_server FOREIGN DATA WRAPPER graphalg_fdw; -CREATE FOREIGN TABLE mat1 ( row bigint, col bigint, val bigint ) SERVER graphalg_server; -CREATE FOREIGN TABLE mat2 ( row bigint, col bigint, val bigint ) SERVER graphalg_server; +CREATE FOREIGN TABLE mat1 ( row bigint, col bigint, val bigint ) SERVER graphalg_server OPTIONS (rows '10', columns '10'); +CREATE FOREIGN TABLE mat2 ( row bigint, col bigint, val bigint ) SERVER graphalg_server OPTIONS (rows '100', columns '100'); SELECT * FROM mat1; SELECT * FROM mat2; From 4e6837321954c09f613114c2fb2cdb8e4c23eb39 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sun, 30 Nov 2025 17:12:00 +0000 Subject: [PATCH 09/20] Avoid parsing options repeatedly. --- docs/postgres.md | 6 +- pg_graphalg/src/pg_graphalg.cpp | 119 ++++++++++++++++++++------------ pg_graphalg/test.sql | 9 ++- 3 files changed, 87 insertions(+), 47 deletions(-) diff --git a/docs/postgres.md b/docs/postgres.md index 2aa9084..f03d357 100644 --- a/docs/postgres.md +++ b/docs/postgres.md @@ -59,9 +59,6 @@ CREATE FUNCTION add_one(integer) RETURNS integer ## Foreign Data Wrapper Resources: -- https://www.postgresql.org/docs/current/fdwhandler.html -- https://www.dolthub.com/blog/2022-01-26-creating-a-postgres-foreign-data-wrapper/ -- https://github.com/Kentik-Archive/wdb_fdw ## GraphBLAS We need to link to SuiteSparse:GraphBLAS. @@ -80,3 +77,6 @@ sudo make install - https://www.postgresql.org/docs/current/xfunc-c.html - https://www.pgedge.com/blog/introduction-to-postgres-extension-development - https://stackoverflow.com/questions/76056209/postgresql-c-extension-function-table-as-argument-and-as-result +- https://www.postgresql.org/docs/current/fdwhandler.html +- https://www.dolthub.com/blog/2022-01-26-creating-a-postgres-foreign-data-wrapper/ +- https://github.com/Kentik-Archive/wdb_fdw diff --git a/pg_graphalg/src/pg_graphalg.cpp b/pg_graphalg/src/pg_graphalg.cpp index 3e20c72..6eb3819 100644 --- a/pg_graphalg/src/pg_graphalg.cpp +++ b/pg_graphalg/src/pg_graphalg.cpp @@ -1,6 +1,7 @@ -#include +#include #include #include +#include #include "pg_graphalg/PgGraphAlg.h" @@ -8,6 +9,7 @@ extern "C" { #include +#include #include #include #include @@ -34,14 +36,41 @@ static pg_graphalg::PgGraphAlg &getInstance() { } PG_FUNCTION_INFO_V1(graphalg_fdw_handler); +PG_FUNCTION_INFO_V1(graphalg_fdw_validator); -static std::optional parseDimension(const char *c) { - auto v = std::atoll(c); - if (v <= 0) { +static std::optional parseDimension(std::string_view s) { + std::size_t v; + auto res = std::from_chars(s.data(), s.data() + s.size(), v); + if (res.ec == std::errc()) { + return v; + } else { return std::nullopt; + } +} + +static bool validateOption(DefElem *def) { + std::string_view optName{def->defname}; + const char *optValue = defGetString(def); + + bool isValid = false; + if (optName == "rows" || optName == "columns") { + // NOTE: foreign data wrapper options are always strings. + if (!parseDimension(optValue)) { + ereport(ERROR, (errcode(ERRCODE_FDW_ERROR), + errmsg("invalid value for option \"%s\": '%s' must be " + "a non-negative integer", + def->defname, optValue))); + isValid = true; + } } else { - return v; + ereport(ERROR, (errcode(ERRCODE_FDW_INVALID_OPTION_NAME), + errmsg("invalid option \"%s\"", def->defname), + errhint("Valid table options are \"rows\", and " + "\"columns\""))); + isValid = true; } + + return isValid; } static std::optional @@ -53,33 +82,13 @@ parseOptions(ForeignTable *table) { foreach (cell, table->options) { auto *def = lfirst_node(DefElem, cell); std::string_view defName(def->defname); - if (defName == "rows") { + + if (!validateOption(def)) { + return std::nullopt; + } else if (defName == "rows") { rows = parseDimension(defGetString(def)); - if (!rows) { - ereport(ERROR, - (errcode(ERRCODE_FDW_ERROR), - errmsg("invalid value for option \"rows\": '%s' must be " - "a positive integer", - defGetString(def)))); - return std::nullopt; - } } else if (defName == "columns") { cols = parseDimension(defGetString(def)); - if (!cols) { - ereport(ERROR, - (errcode(ERRCODE_FDW_ERROR), - errmsg("invalid value for option \"columns\": '%s' must be " - "a positive integer", - defGetString(def)))); - return std::nullopt; - } - } else { - ereport(ERROR, - (errcode(ERRCODE_FDW_INVALID_OPTION_NAME), - errmsg("invalid option \"%s\"", def->defname), - errhint("Valid table options for graphalg are \"rows\", and " - "\"columns\""))); - return std::nullopt; } } @@ -132,6 +141,15 @@ static ForeignScan *GetForeignPlan(PlannerInfo *root, RelOptInfo *baserel, Oid foreigntableid, ForeignPath *best_path, List *tlist, List *scan_clauses, Plan *outer_plan) { + // Resolve to a matrix table. + auto tableDef = parseOptions(GetForeignTable(foreigntableid)); + if (!tableDef) { + NULL; + } + + // Create table if it does not exist yet. + getInstance().getOrCreateTable(foreigntableid, *tableDef); + // On extract_actual_clauses: // https://www.postgresql.org/docs/current/fdw-planning.html scan_clauses = extract_actual_clauses(scan_clauses, false); @@ -146,13 +164,7 @@ static ForeignScan *GetForeignPlan(PlannerInfo *root, RelOptInfo *baserel, static void BeginForeignScan(ForeignScanState *node, int eflags) { auto tableId = RelationGetRelid(node->ss.ss_currentRelation); - // TODO: Avoid parsing options multiple times? - auto tableDef = parseOptions(GetForeignTable(tableId)); - if (!tableDef) { - return; - } - - auto &table = getInstance().getOrCreateTable(tableId, *tableDef); + auto &table = getInstance().getTable(tableId); auto *state = palloc(sizeof(pg_graphalg::ScanState)); new (state) pg_graphalg::ScanState(&table); @@ -188,17 +200,22 @@ static void EndForeignScan(ForeignScanState *node) { // No-Op } -static TupleTableSlot *ExecForeignInsert(EState *estate, ResultRelInfo *rinfo, - TupleTableSlot *slot, - TupleTableSlot *planSlot) { +static void BeginForeignModify(ModifyTableState *mtstate, ResultRelInfo *rinfo, + List *fdw_private, int subplan_index, + int eflags) { + // Ensure table exists before modifying it. auto tableId = RelationGetRelid(rinfo->ri_RelationDesc); - // TODO: Avoid parsing options multiple times? auto tableDef = parseOptions(GetForeignTable(tableId)); - if (!tableDef) { - return NULL; + if (tableDef) { + getInstance().getOrCreateTable(tableId, *tableDef); } +} - auto &table = getInstance().getOrCreateTable(tableId, *tableDef); +static TupleTableSlot *ExecForeignInsert(EState *estate, ResultRelInfo *rinfo, + TupleTableSlot *slot, + TupleTableSlot *planSlot) { + auto tableId = RelationGetRelid(rinfo->ri_RelationDesc); + auto &table = getInstance().getTable(tableId); slot_getsomeattrs(slot, 3); if (slot->tts_isnull[0] || slot->tts_isnull[1] || slot->tts_isnull[2]) { @@ -224,8 +241,24 @@ Datum graphalg_fdw_handler(PG_FUNCTION_ARGS) { fdwRoutine->ReScanForeignScan = ReScanForeignScan; fdwRoutine->EndForeignScan = EndForeignScan; + fdwRoutine->BeginForeignModify = BeginForeignModify; fdwRoutine->ExecForeignInsert = ExecForeignInsert; PG_RETURN_POINTER(fdwRoutine); } + +Datum graphalg_fdw_validator(PG_FUNCTION_ARGS) { + List *options = untransformRelOptions(PG_GETARG_DATUM(0)); + + ListCell *cell; + foreach (cell, options) { + auto *def = static_cast(lfirst(cell)); + validateOption(def); + } + + // NOTE: Not checking that required options are set, because this validator is + // also called when checking options defined on the wrapper or the server. + + PG_RETURN_VOID(); +} } diff --git a/pg_graphalg/test.sql b/pg_graphalg/test.sql index d8f39c2..78aa8be 100644 --- a/pg_graphalg/test.sql +++ b/pg_graphalg/test.sql @@ -3,14 +3,21 @@ DROP FOREIGN TABLE IF EXISTS mat2; DROP SERVER IF EXISTS graphalg_server; DROP FOREIGN DATA WRAPPER IF EXISTS graphalg_fdw; DROP FUNCTION IF EXISTS graphalg_fdw_handler; +DROP FUNCTION IF EXISTS graphalg_fdw_validator; CREATE FUNCTION graphalg_fdw_handler() RETURNS fdw_handler AS '/workspaces/graphalg/pg_graphalg/build/src/libpg_graphalg.so' LANGUAGE C STRICT; +CREATE FUNCTION graphalg_fdw_validator(text[], oid) +RETURNS void +AS '/workspaces/graphalg/pg_graphalg/build/src/libpg_graphalg.so' +LANGUAGE C STRICT; + CREATE FOREIGN DATA WRAPPER graphalg_fdw - HANDLER graphalg_fdw_handler; + HANDLER graphalg_fdw_handler + VALIDATOR graphalg_fdw_validator; CREATE SERVER graphalg_server FOREIGN DATA WRAPPER graphalg_fdw; CREATE FOREIGN TABLE mat1 ( row bigint, col bigint, val bigint ) SERVER graphalg_server OPTIONS (rows '10', columns '10'); From 64ef41871ce2b82e5d6c56a5ee33686c310a2fd6 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 4 Dec 2025 16:49:39 +0000 Subject: [PATCH 10/20] All the data we need to run algorithms. --- pg_graphalg/src/pg_graphalg.cpp | 72 +++++++++++++++++++++++++++++++++ pg_graphalg/test.sql | 54 +++++++++++++++++++++++++ 2 files changed, 126 insertions(+) diff --git a/pg_graphalg/src/pg_graphalg.cpp b/pg_graphalg/src/pg_graphalg.cpp index 6eb3819..2b2dddb 100644 --- a/pg_graphalg/src/pg_graphalg.cpp +++ b/pg_graphalg/src/pg_graphalg.cpp @@ -1,4 +1,5 @@ #include +#include #include #include #include @@ -9,7 +10,10 @@ extern "C" { #include +#include #include +#include +#include #include #include #include @@ -21,8 +25,10 @@ extern "C" { #include #include #include +#include #include #include +#include PG_MODULE_MAGIC; @@ -37,6 +43,9 @@ static pg_graphalg::PgGraphAlg &getInstance() { PG_FUNCTION_INFO_V1(graphalg_fdw_handler); PG_FUNCTION_INFO_V1(graphalg_fdw_validator); +PG_FUNCTION_INFO_V1(graphalg_pl_call_handler); +PG_FUNCTION_INFO_V1(graphalg_pl_inline_handler); +PG_FUNCTION_INFO_V1(graphalg_pl_validator); static std::optional parseDimension(std::string_view s) { std::size_t v; @@ -261,4 +270,67 @@ Datum graphalg_fdw_validator(PG_FUNCTION_ARGS) { PG_RETURN_VOID(); } + +static Datum executeCall(FunctionCallInfo fcinfo) { + auto procTuple = SearchSysCache( + PROCOID, ObjectIdGetDatum(fcinfo->flinfo->fn_oid), 0, 0, 0); + if (!HeapTupleIsValid(procTuple)) { + elog(ERROR, "cache lookup failed for function %s", fcinfo->flinfo->fn_oid); + PG_RETURN_VOID(); + } + + auto procStruct = (Form_pg_proc)GETSTRUCT(procTuple); + + bool isnull; + auto sourceDatum = + SysCacheGetAttr(PROCOID, procTuple, Anum_pg_proc_prosrc, &isnull); + if (isnull) { + elog(ERROR, "NULL procedure source"); + PG_RETURN_VOID(); + } + + char *procCode = DatumGetCString(DirectFunctionCall1(textout, sourceDatum)); + std::cerr << "GraphAlg source: " << procCode << "\n"; + + auto funcName = procStruct->proname.data; + std::cerr << "function name: " << funcName << "\n"; + + // TODO: Get string values of arguments. + for (int i = 0; i < fcinfo->nargs; i++) { + auto arg = fcinfo->args[i]; + if (arg.isnull) { + elog(ERROR, "Argument %d is NULL", i); + PG_RETURN_VOID(); + } + + auto argType = procStruct->proargtypes.values[i]; + auto typeTuple = SearchSysCache1(TYPEOID, ObjectIdGetDatum(argType)); + auto typeStruct = (Form_pg_type)GETSTRUCT(typeTuple); + FmgrInfo typeInfo; + fmgr_info(typeStruct->typoutput, &typeInfo); + + char *value = OutputFunctionCall(&typeInfo, arg.value); + std::cerr << "arg value: " << value << "\n"; + } + + // TODO: release sys cache. + + elog(ERROR, "execute not implemented"); + PG_RETURN_VOID(); +} + +Datum graphalg_pl_call_handler(PG_FUNCTION_ARGS) { + std::cerr << "call handler!\n"; + return executeCall(fcinfo); +} + +Datum graphalg_pl_inline_handler(PG_FUNCTION_ARGS) { + std::cerr << "inline handler!\n"; + PG_RETURN_VOID(); +} + +Datum graphalg_pl_validator(PG_FUNCTION_ARGS) { + std::cerr << "validator!\n"; + PG_RETURN_VOID(); +} } diff --git a/pg_graphalg/test.sql b/pg_graphalg/test.sql index 78aa8be..495ca0f 100644 --- a/pg_graphalg/test.sql +++ b/pg_graphalg/test.sql @@ -1,10 +1,20 @@ DROP FOREIGN TABLE IF EXISTS mat1; DROP FOREIGN TABLE IF EXISTS mat2; +DROP FOREIGN TABLE IF EXISTS lhs; +DROP FOREIGN TABLE IF EXISTS rhs; +DROP FOREIGN TABLE IF EXISTS matmul_out; DROP SERVER IF EXISTS graphalg_server; DROP FOREIGN DATA WRAPPER IF EXISTS graphalg_fdw; DROP FUNCTION IF EXISTS graphalg_fdw_handler; DROP FUNCTION IF EXISTS graphalg_fdw_validator; +DROP PROCEDURE IF EXISTS matmul; +DROP FUNCTION IF EXISTS graphalg_pl_call_handler; +DROP FUNCTION IF EXISTS graphalg_pl_inline_handler; +DROP FUNCTION IF EXISTS graphalg_pl_validator; +DROP LANGUAGE IF EXISTS graphalg; + +-- Foreign data wrapper CREATE FUNCTION graphalg_fdw_handler() RETURNS fdw_handler AS '/workspaces/graphalg/pg_graphalg/build/src/libpg_graphalg.so' @@ -20,6 +30,26 @@ CREATE FOREIGN DATA WRAPPER graphalg_fdw VALIDATOR graphalg_fdw_validator; CREATE SERVER graphalg_server FOREIGN DATA WRAPPER graphalg_fdw; +-- Procedural Language +CREATE FUNCTION graphalg_pl_call_handler() +RETURNS language_handler +AS '/workspaces/graphalg/pg_graphalg/build/src/libpg_graphalg.so' +LANGUAGE C STRICT; + +CREATE FUNCTION graphalg_pl_validator(oid) RETURNS void +AS '/workspaces/graphalg/pg_graphalg/build/src/libpg_graphalg.so' +LANGUAGE C STRICT; + +CREATE FUNCTION graphalg_pl_inline_handler(internal) +RETURNS void +AS '/workspaces/graphalg/pg_graphalg/build/src/libpg_graphalg.so' +LANGUAGE C STRICT; + +CREATE TRUSTED LANGUAGE graphalg +HANDLER graphalg_pl_call_handler +INLINE graphalg_pl_inline_handler +VALIDATOR graphalg_pl_validator; + CREATE FOREIGN TABLE mat1 ( row bigint, col bigint, val bigint ) SERVER graphalg_server OPTIONS (rows '10', columns '10'); CREATE FOREIGN TABLE mat2 ( row bigint, col bigint, val bigint ) SERVER graphalg_server OPTIONS (rows '100', columns '100'); SELECT * FROM mat1; @@ -45,3 +75,27 @@ INSERT INTO mat1 VALUES (0, 1, 4000); SELECT * FROM mat1; + +CREATE FOREIGN TABLE lhs(row bigint, col bigint, val bigint) +SERVER graphalg_server +OPTIONS (rows '10', columns '10'); + +CREATE FOREIGN TABLE rhs(row bigint, col bigint, val bigint) +SERVER graphalg_server +OPTIONS (rows '10', columns '10'); + +CREATE FOREIGN TABLE matmul_out(row bigint, col bigint, val bigint) +SERVER graphalg_server +OPTIONS (rows '10', columns '10'); + +CREATE PROCEDURE matmul(text, text, text) +LANGUAGE graphalg +AS $$ + func matmul( + lhs: Matrix, + rhs: Matrix) -> Matrix { + return lhs * rhs; + } +$$; + +CALL matmul('lhs', 'rhs', 'matmul_out'); From 5d5c645ef05e33c615ecee3ae6f5105f17e345b0 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 9 Dec 2025 10:02:36 +0000 Subject: [PATCH 11/20] Pull up graphalg compiler. --- pg_graphalg/CMakeLists.txt | 22 ++++++++++++++++++++ pg_graphalg/configure.sh | 2 ++ pg_graphalg/include/pg_graphalg/PgGraphAlg.h | 17 ++++++++++++--- pg_graphalg/src/CMakeLists.txt | 8 ++++++- pg_graphalg/src/pg_graphalg/CMakeLists.txt | 9 +++++++- pg_graphalg/src/pg_graphalg/PgGraphAlg.cpp | 17 ++++++++------- 6 files changed, 63 insertions(+), 12 deletions(-) diff --git a/pg_graphalg/CMakeLists.txt b/pg_graphalg/CMakeLists.txt index 424e614..92a4eff 100644 --- a/pg_graphalg/CMakeLists.txt +++ b/pg_graphalg/CMakeLists.txt @@ -9,6 +9,28 @@ project( find_package(PostgreSQL COMPONENTS Server) +include(FetchContent) +FetchContent_Declare( + GraphAlg + SOURCE_DIR ../../compiler +) +FetchContent_MakeAvailable(GraphAlg) + +# ================================= BEGIN LLVM ================================= + +find_package(LLVM REQUIRED CONFIG) +include_directories(${LLVM_INCLUDE_DIRS}) +separate_arguments(LLVM_DEFINITIONS_LIST NATIVE_COMMAND ${LLVM_DEFINITIONS}) +add_definitions(${LLVM_DEFINITIONS_LIST}) +list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") +include(AddLLVM) + +llvm_map_components_to_libnames(llvm_libs support) + +# ================================== END LLVM ================================== + + + find_library(NAMES GraphBLAS) add_subdirectory(src) diff --git a/pg_graphalg/configure.sh b/pg_graphalg/configure.sh index 5512c75..ddf6cb0 100755 --- a/pg_graphalg/configure.sh +++ b/pg_graphalg/configure.sh @@ -6,6 +6,8 @@ cmake -S $WORKSPACE_ROOT -B $BUILD_DIR -G Ninja \ -DCMAKE_BUILD_TYPE=Debug \ -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ -DCMAKE_CXX_COMPILER=clang++-20 \ + -DCMAKE_POSITION_INDEPENDENT_CODE=TRUE \ -DCMAKE_EXPORT_COMPILE_COMMANDS=1 \ -DCMAKE_LINKER_TYPE=MOLD \ -DPostgreSQL_ROOT=/usr/local/pgsql \ + -DLLVM_ROOT="/opt/llvm-debug" \ diff --git a/pg_graphalg/include/pg_graphalg/PgGraphAlg.h b/pg_graphalg/include/pg_graphalg/PgGraphAlg.h index 627dae5..fe2f5c2 100644 --- a/pg_graphalg/include/pg_graphalg/PgGraphAlg.h +++ b/pg_graphalg/include/pg_graphalg/PgGraphAlg.h @@ -7,9 +7,18 @@ #include #include -extern "C" { -#include "GraphBLAS.h" -} +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include namespace pg_graphalg { @@ -54,6 +63,8 @@ class MatrixTable { class PgGraphAlg { private: + mlir::DialectRegistry _registry; + mlir::MLIRContext _ctx; std::unordered_map _tables; public: diff --git a/pg_graphalg/src/CMakeLists.txt b/pg_graphalg/src/CMakeLists.txt index 69bab63..2392f61 100644 --- a/pg_graphalg/src/CMakeLists.txt +++ b/pg_graphalg/src/CMakeLists.txt @@ -7,4 +7,10 @@ target_include_directories(pg_graphalg PUBLIC ../include) target_link_libraries(pg_graphalg PRIVATE PgGraphAlg - PostgreSQL::PostgreSQL) + PostgreSQL::PostgreSQL + ${llvm_libs} + GraphAlgEvaluate + GraphAlgIR + GraphAlgParse + GraphAlgPasses +) diff --git a/pg_graphalg/src/pg_graphalg/CMakeLists.txt b/pg_graphalg/src/pg_graphalg/CMakeLists.txt index cd03ea5..a5ecabd 100644 --- a/pg_graphalg/src/pg_graphalg/CMakeLists.txt +++ b/pg_graphalg/src/pg_graphalg/CMakeLists.txt @@ -2,4 +2,11 @@ add_library(PgGraphAlg SHARED PgGraphAlg.cpp ) target_include_directories(PgGraphAlg PUBLIC ../../include) -target_link_libraries(PgGraphAlg graphblas) +target_link_libraries(PgGraphAlg + PRIVATE + ${llvm_libs} + GraphAlgEvaluate + GraphAlgIR + GraphAlgParse + GraphAlgPasses +) diff --git a/pg_graphalg/src/pg_graphalg/PgGraphAlg.cpp b/pg_graphalg/src/pg_graphalg/PgGraphAlg.cpp index 261da36..0d44bef 100644 --- a/pg_graphalg/src/pg_graphalg/PgGraphAlg.cpp +++ b/pg_graphalg/src/pg_graphalg/PgGraphAlg.cpp @@ -2,11 +2,7 @@ #include #include -extern "C" { -#include "GraphBLAS.h" -} - -#include "pg_graphalg/PgGraphAlg.h" +#include namespace pg_graphalg { @@ -34,10 +30,17 @@ MatrixTable::scan(ScanState &state) { return std::make_tuple(row, col, val); } -PgGraphAlg::PgGraphAlg() { - // TODO: init graphblas +static mlir::DialectRegistry createDialectRegistry() { + mlir::DialectRegistry registry; + registry.insert(); + registry.insert(); + mlir::func::registerInlinerExtension(registry); + return registry; } +PgGraphAlg::PgGraphAlg() + : _registry(createDialectRegistry()), _ctx(_registry) {} + MatrixTable &PgGraphAlg::getTable(TableId tableId) { assert(_tables.count(tableId) && "getTable called before getOrCreateTable"); return _tables.at(tableId); From 8e594b06de0fe77e044519b693f492ec085e405c Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 9 Dec 2025 10:35:01 +0000 Subject: [PATCH 12/20] WIP: parse foreign table name. --- pg_graphalg/include/pg_graphalg/PgGraphAlg.h | 1 + pg_graphalg/src/pg_graphalg.cpp | 21 +++++++++++++++++++- pg_graphalg/test.sql | 2 +- 3 files changed, 22 insertions(+), 2 deletions(-) diff --git a/pg_graphalg/include/pg_graphalg/PgGraphAlg.h b/pg_graphalg/include/pg_graphalg/PgGraphAlg.h index fe2f5c2..756e3f5 100644 --- a/pg_graphalg/include/pg_graphalg/PgGraphAlg.h +++ b/pg_graphalg/include/pg_graphalg/PgGraphAlg.h @@ -40,6 +40,7 @@ struct ScanState { }; struct MatrixTableDef { + std::string name; std::size_t nRows; std::size_t nCols; // TODO: data type. diff --git a/pg_graphalg/src/pg_graphalg.cpp b/pg_graphalg/src/pg_graphalg.cpp index 2b2dddb..824010d 100644 --- a/pg_graphalg/src/pg_graphalg.cpp +++ b/pg_graphalg/src/pg_graphalg.cpp @@ -84,6 +84,21 @@ static bool validateOption(DefElem *def) { static std::optional parseOptions(ForeignTable *table) { + std::cerr << "parsing options\n"; + + // Get the name of the table. + auto relTuple = SearchSysCache1(RELOID, table->relid); + if (!HeapTupleIsValid(relTuple)) { + elog(ERROR, "Cannot retrieve table name for oid"); + return std::nullopt; + } + + auto relStruct = (Form_pg_class)GETSTRUCT(relTuple); + std::string tableName{NameStr(relStruct->relname)}; + std::cerr << "table name: " << tableName << "\n"; + ReleaseSysCache(relTuple); + + // Parse the options ListCell *cell; std::optional rows; std::optional cols; @@ -114,6 +129,7 @@ parseOptions(ForeignTable *table) { } return pg_graphalg::MatrixTableDef{ + tableName, static_cast(*rows), static_cast(*cols), }; @@ -153,7 +169,7 @@ static ForeignScan *GetForeignPlan(PlannerInfo *root, RelOptInfo *baserel, // Resolve to a matrix table. auto tableDef = parseOptions(GetForeignTable(foreigntableid)); if (!tableDef) { - NULL; + return nullptr; } // Create table if it does not exist yet. @@ -311,9 +327,12 @@ static Datum executeCall(FunctionCallInfo fcinfo) { char *value = OutputFunctionCall(&typeInfo, arg.value); std::cerr << "arg value: " << value << "\n"; + + // ReleaseSysCache(typeTuple); } // TODO: release sys cache. + // ReleaseSysCache(procTuple); elog(ERROR, "execute not implemented"); PG_RETURN_VOID(); diff --git a/pg_graphalg/test.sql b/pg_graphalg/test.sql index 495ca0f..6fa83bb 100644 --- a/pg_graphalg/test.sql +++ b/pg_graphalg/test.sql @@ -9,10 +9,10 @@ DROP FUNCTION IF EXISTS graphalg_fdw_handler; DROP FUNCTION IF EXISTS graphalg_fdw_validator; DROP PROCEDURE IF EXISTS matmul; +DROP LANGUAGE IF EXISTS graphalg; DROP FUNCTION IF EXISTS graphalg_pl_call_handler; DROP FUNCTION IF EXISTS graphalg_pl_inline_handler; DROP FUNCTION IF EXISTS graphalg_pl_validator; -DROP LANGUAGE IF EXISTS graphalg; -- Foreign data wrapper CREATE FUNCTION graphalg_fdw_handler() From bb3dab5570b0546bf534462d957597e9de9924ef Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 9 Dec 2025 11:12:09 +0000 Subject: [PATCH 13/20] Fix the bugs. --- pg_graphalg/src/pg_graphalg.cpp | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/pg_graphalg/src/pg_graphalg.cpp b/pg_graphalg/src/pg_graphalg.cpp index 824010d..7312f04 100644 --- a/pg_graphalg/src/pg_graphalg.cpp +++ b/pg_graphalg/src/pg_graphalg.cpp @@ -61,7 +61,6 @@ static bool validateOption(DefElem *def) { std::string_view optName{def->defname}; const char *optValue = defGetString(def); - bool isValid = false; if (optName == "rows" || optName == "columns") { // NOTE: foreign data wrapper options are always strings. if (!parseDimension(optValue)) { @@ -69,33 +68,30 @@ static bool validateOption(DefElem *def) { errmsg("invalid value for option \"%s\": '%s' must be " "a non-negative integer", def->defname, optValue))); - isValid = true; + return false; } } else { ereport(ERROR, (errcode(ERRCODE_FDW_INVALID_OPTION_NAME), errmsg("invalid option \"%s\"", def->defname), errhint("Valid table options are \"rows\", and " "\"columns\""))); - isValid = true; + return false; } - return isValid; + return true; } static std::optional parseOptions(ForeignTable *table) { - std::cerr << "parsing options\n"; - // Get the name of the table. auto relTuple = SearchSysCache1(RELOID, table->relid); if (!HeapTupleIsValid(relTuple)) { - elog(ERROR, "Cannot retrieve table name for oid"); + elog(ERROR, "cannot retrieve table name for oid"); return std::nullopt; } auto relStruct = (Form_pg_class)GETSTRUCT(relTuple); std::string tableName{NameStr(relStruct->relname)}; - std::cerr << "table name: " << tableName << "\n"; ReleaseSysCache(relTuple); // Parse the options @@ -328,11 +324,10 @@ static Datum executeCall(FunctionCallInfo fcinfo) { char *value = OutputFunctionCall(&typeInfo, arg.value); std::cerr << "arg value: " << value << "\n"; - // ReleaseSysCache(typeTuple); + ReleaseSysCache(typeTuple); } - // TODO: release sys cache. - // ReleaseSysCache(procTuple); + ReleaseSysCache(procTuple); elog(ERROR, "execute not implemented"); PG_RETURN_VOID(); From 75cd946cbcd09f1195d8b79f50a3b195be44d87e Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 9 Dec 2025 11:33:54 +0000 Subject: [PATCH 14/20] Fix matrix tables for args. --- pg_graphalg/include/pg_graphalg/PgGraphAlg.h | 7 +++++++ pg_graphalg/src/pg_graphalg.cpp | 9 +++++++-- pg_graphalg/src/pg_graphalg/PgGraphAlg.cpp | 13 ++++++++++++- pg_graphalg/test.sql | 12 ++++++++++++ 4 files changed, 38 insertions(+), 3 deletions(-) diff --git a/pg_graphalg/include/pg_graphalg/PgGraphAlg.h b/pg_graphalg/include/pg_graphalg/PgGraphAlg.h index 756e3f5..b15d304 100644 --- a/pg_graphalg/include/pg_graphalg/PgGraphAlg.h +++ b/pg_graphalg/include/pg_graphalg/PgGraphAlg.h @@ -1,11 +1,15 @@ #pragma once +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" #include #include #include +#include #include #include #include +#include #include #include @@ -48,6 +52,7 @@ struct MatrixTableDef { class MatrixTable { private: + std::string _name; std::size_t _nRows; std::size_t _nCols; std::map, std::int64_t> _values; @@ -67,12 +72,14 @@ class PgGraphAlg { mlir::DialectRegistry _registry; mlir::MLIRContext _ctx; std::unordered_map _tables; + llvm::StringMap _nameToId; public: PgGraphAlg(); MatrixTable &getTable(TableId tableId); MatrixTable &getOrCreateTable(TableId tableId, const MatrixTableDef &def); + MatrixTable *lookupTable(llvm::StringRef tableName); }; } // namespace pg_graphalg diff --git a/pg_graphalg/src/pg_graphalg.cpp b/pg_graphalg/src/pg_graphalg.cpp index 7312f04..9621fb9 100644 --- a/pg_graphalg/src/pg_graphalg.cpp +++ b/pg_graphalg/src/pg_graphalg.cpp @@ -307,7 +307,6 @@ static Datum executeCall(FunctionCallInfo fcinfo) { auto funcName = procStruct->proname.data; std::cerr << "function name: " << funcName << "\n"; - // TODO: Get string values of arguments. for (int i = 0; i < fcinfo->nargs; i++) { auto arg = fcinfo->args[i]; if (arg.isnull) { @@ -315,6 +314,7 @@ static Datum executeCall(FunctionCallInfo fcinfo) { PG_RETURN_VOID(); } + // TODO: We know this is a string, so can we make this simpler? auto argType = procStruct->proargtypes.values[i]; auto typeTuple = SearchSysCache1(TYPEOID, ObjectIdGetDatum(argType)); auto typeStruct = (Form_pg_type)GETSTRUCT(typeTuple); @@ -322,7 +322,12 @@ static Datum executeCall(FunctionCallInfo fcinfo) { fmgr_info(typeStruct->typoutput, &typeInfo); char *value = OutputFunctionCall(&typeInfo, arg.value); - std::cerr << "arg value: " << value << "\n"; + + auto *argTable = getInstance().lookupTable(value); + if (!argTable) { + elog(ERROR, "No such matrix table '%s'", value); + PG_RETURN_VOID(); + } ReleaseSysCache(typeTuple); } diff --git a/pg_graphalg/src/pg_graphalg/PgGraphAlg.cpp b/pg_graphalg/src/pg_graphalg/PgGraphAlg.cpp index 0d44bef..cec1dfa 100644 --- a/pg_graphalg/src/pg_graphalg/PgGraphAlg.cpp +++ b/pg_graphalg/src/pg_graphalg/PgGraphAlg.cpp @@ -1,4 +1,5 @@ #include +#include #include #include @@ -7,7 +8,7 @@ namespace pg_graphalg { MatrixTable::MatrixTable(const MatrixTableDef &def) - : _nRows(def.nRows), _nCols(def.nCols) {} + : _name(def.name), _nRows(def.nRows), _nCols(def.nCols) {} void MatrixTable::setValue(std::size_t row, std::size_t col, std::int64_t value) { @@ -50,9 +51,19 @@ MatrixTable &PgGraphAlg::getOrCreateTable(TableId tableId, const MatrixTableDef &def) { if (!_tables.count(tableId)) { _tables.emplace(tableId, def); + _nameToId[def.name] = tableId; + std::cerr << "registered table " << def.name << "\n"; } return getTable(tableId); } +MatrixTable *PgGraphAlg::lookupTable(llvm::StringRef tableName) { + if (_nameToId.contains(tableName)) { + return &getTable(_nameToId[tableName]); + } else { + return nullptr; + } +} + } // namespace pg_graphalg diff --git a/pg_graphalg/test.sql b/pg_graphalg/test.sql index 6fa83bb..efb69f8 100644 --- a/pg_graphalg/test.sql +++ b/pg_graphalg/test.sql @@ -80,13 +80,25 @@ CREATE FOREIGN TABLE lhs(row bigint, col bigint, val bigint) SERVER graphalg_server OPTIONS (rows '10', columns '10'); +INSERT INTO lhs VALUES + (0, 0, 42), + (0, 1, 43), + (1, 0, 44), + (1, 1, 45); CREATE FOREIGN TABLE rhs(row bigint, col bigint, val bigint) SERVER graphalg_server OPTIONS (rows '10', columns '10'); +INSERT INTO rhs VALUES + (0, 0, 46), + (0, 1, 47), + (1, 0, 48), + (1, 1, 49); CREATE FOREIGN TABLE matmul_out(row bigint, col bigint, val bigint) SERVER graphalg_server OPTIONS (rows '10', columns '10'); +-- HACK: Necessary because we don't get a callback from CREATE +SELECT * FROM matmul_out; CREATE PROCEDURE matmul(text, text, text) LANGUAGE graphalg From 9d32b2084c0fd8952ca6d9fc183e7f1fabc0504e Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 9 Dec 2025 16:30:26 +0000 Subject: [PATCH 15/20] Run in postgres. --- pg_graphalg/include/pg_graphalg/PgGraphAlg.h | 15 ++- pg_graphalg/src/pg_graphalg.cpp | 27 ++++- pg_graphalg/src/pg_graphalg/PgGraphAlg.cpp | 115 ++++++++++++++++++- pg_graphalg/test.sql | 7 +- 4 files changed, 154 insertions(+), 10 deletions(-) diff --git a/pg_graphalg/include/pg_graphalg/PgGraphAlg.h b/pg_graphalg/include/pg_graphalg/PgGraphAlg.h index b15d304..857f22b 100644 --- a/pg_graphalg/include/pg_graphalg/PgGraphAlg.h +++ b/pg_graphalg/include/pg_graphalg/PgGraphAlg.h @@ -1,5 +1,9 @@ #pragma once +#include "mlir/IR/Diagnostics.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLFunctionalExtras.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringRef.h" #include @@ -60,6 +64,11 @@ class MatrixTable { public: MatrixTable(const MatrixTableDef &def); + std::size_t nRows() const { return _nRows; } + std::size_t nCols() const { return _nCols; } + const auto &values() const { return _values; } + + void clear(); void setValue(std::size_t row, std::size_t col, std::int64_t value); std::optional> scan(ScanState &state); @@ -75,11 +84,15 @@ class PgGraphAlg { llvm::StringMap _nameToId; public: - PgGraphAlg(); + PgGraphAlg(llvm::function_ref diagHandler); MatrixTable &getTable(TableId tableId); MatrixTable &getOrCreateTable(TableId tableId, const MatrixTableDef &def); MatrixTable *lookupTable(llvm::StringRef tableName); + + bool execute(llvm::StringRef programSource, llvm::StringRef function, + llvm::ArrayRef arguments, + MatrixTable &output); }; } // namespace pg_graphalg diff --git a/pg_graphalg/src/pg_graphalg.cpp b/pg_graphalg/src/pg_graphalg.cpp index 9621fb9..2325f74 100644 --- a/pg_graphalg/src/pg_graphalg.cpp +++ b/pg_graphalg/src/pg_graphalg.cpp @@ -4,6 +4,9 @@ #include #include +#include +#include + #include "pg_graphalg/PgGraphAlg.h" extern "C" { @@ -32,10 +35,15 @@ extern "C" { PG_MODULE_MAGIC; +static void diagHandler(mlir::Diagnostic &diag) { + std::string msg = diag.str(); + elog(ERROR, "%s", msg.c_str()); +} + static pg_graphalg::PgGraphAlg *SINGLETON = nullptr; static pg_graphalg::PgGraphAlg &getInstance() { if (!SINGLETON) { - SINGLETON = new pg_graphalg::PgGraphAlg(); + SINGLETON = new pg_graphalg::PgGraphAlg(diagHandler); } return *SINGLETON; @@ -307,6 +315,7 @@ static Datum executeCall(FunctionCallInfo fcinfo) { auto funcName = procStruct->proname.data; std::cerr << "function name: " << funcName << "\n"; + llvm::SmallVector arguments; for (int i = 0; i < fcinfo->nargs; i++) { auto arg = fcinfo->args[i]; if (arg.isnull) { @@ -322,6 +331,7 @@ static Datum executeCall(FunctionCallInfo fcinfo) { fmgr_info(typeStruct->typoutput, &typeInfo); char *value = OutputFunctionCall(&typeInfo, arg.value); + ReleaseSysCache(typeTuple); auto *argTable = getInstance().lookupTable(value); if (!argTable) { @@ -329,12 +339,21 @@ static Datum executeCall(FunctionCallInfo fcinfo) { PG_RETURN_VOID(); } - ReleaseSysCache(typeTuple); + arguments.push_back(argTable); } - ReleaseSysCache(procTuple); + if (arguments.empty()) { + elog(ERROR, "must have at least one argument"); + PG_RETURN_VOID(); + } + + auto *output = arguments.pop_back_val(); - elog(ERROR, "execute not implemented"); + // No need to check the result here, postgres infers success based on + // diagnostics. + getInstance().execute(procCode, funcName, arguments, *output); + + ReleaseSysCache(procTuple); PG_RETURN_VOID(); } diff --git a/pg_graphalg/src/pg_graphalg/PgGraphAlg.cpp b/pg_graphalg/src/pg_graphalg/PgGraphAlg.cpp index cec1dfa..f87d35b 100644 --- a/pg_graphalg/src/pg_graphalg/PgGraphAlg.cpp +++ b/pg_graphalg/src/pg_graphalg/PgGraphAlg.cpp @@ -1,8 +1,20 @@ +#include "graphalg/GraphAlgTypes.h" +#include "mlir/IR/BuiltinAttributes.h" #include #include #include #include +#include +#include +#include +#include +#include +#include + +#include +#include + #include namespace pg_graphalg { @@ -10,6 +22,8 @@ namespace pg_graphalg { MatrixTable::MatrixTable(const MatrixTableDef &def) : _name(def.name), _nRows(def.nRows), _nCols(def.nCols) {} +void MatrixTable::clear() { _values.clear(); } + void MatrixTable::setValue(std::size_t row, std::size_t col, std::int64_t value) { _values[{row, col}] = value; @@ -39,8 +53,11 @@ static mlir::DialectRegistry createDialectRegistry() { return registry; } -PgGraphAlg::PgGraphAlg() - : _registry(createDialectRegistry()), _ctx(_registry) {} +PgGraphAlg::PgGraphAlg(llvm::function_ref diagHandler) + : _registry(createDialectRegistry()), _ctx(_registry) { + auto &engine = _ctx.getDiagEngine(); + engine.registerHandler(diagHandler); +} MatrixTable &PgGraphAlg::getTable(TableId tableId) { assert(_tables.count(tableId) && "getTable called before getOrCreateTable"); @@ -66,4 +83,98 @@ MatrixTable *PgGraphAlg::lookupTable(llvm::StringRef tableName) { } } +bool PgGraphAlg::execute(llvm::StringRef programSource, + llvm::StringRef function, + llvm::ArrayRef arguments, + MatrixTable &output) { + // Parse + llvm::StringRef filename = ""; + auto loc = mlir::FileLineColLoc::get(&_ctx, filename, + /*line=*/1, /*column=*/1); + mlir::OwningOpRef moduleOp = + mlir::ModuleOp::create(loc, filename); + if (mlir::failed(graphalg::parse(programSource, *moduleOp))) { + return false; + } + + // Desugar + { + mlir::PassManager pm(&_ctx); + graphalg::GraphAlgToCorePipelineOptions toCoreOptions; + graphalg::buildGraphAlgToCorePipeline(pm, toCoreOptions); + if (mlir::failed(pm.run(*moduleOp))) { + return false; + } + } + + // Set dimensions + { + llvm::SmallVector argDims; + for (const auto *arg : arguments) { + argDims.push_back(graphalg::CallArgumentDimensions{ + .rows = arg->nRows(), + .cols = arg->nCols(), + }); + } + + graphalg::GraphAlgSetDimensionsOptions options{ + .functionName = function.str(), + .argDims = std::move(argDims), + }; + + mlir::PassManager pm(&_ctx); + pm.addNestedPass( + graphalg::createGraphAlgVerifyDimensions()); + pm.addPass(graphalg::createGraphAlgSetDimensions(options)); + pm.addPass(mlir::createCanonicalizerPass()); + if (mlir::failed(pm.run(*moduleOp))) { + return false; + } + } + + auto funcOp = + llvm::cast(moduleOp->lookupSymbol(function)); + + // TODO: Check semiring and value type are compatible + + // Build arguments + llvm::SmallVector argAttrs; + for (const auto &[arg, type] : + llvm::zip_equal(arguments, funcOp.getFunctionType().getInputs())) { + auto matType = llvm::cast(type); + graphalg::MatrixAttrBuilder builder(matType); + for (auto [pos, val] : arg->values()) { + auto [row, col] = pos; + // TODO: support more value types + auto valAttr = mlir::IntegerAttr::get(matType.getSemiring(), val); + builder.set(row, col, valAttr); + } + + argAttrs.push_back(builder.build()); + } + + auto result = graphalg::evaluate(funcOp, argAttrs); + if (!result) { + return false; + } + + graphalg::MatrixAttrReader resultReader(result); + // TODO: Check rows/cols match. + // TODO: Check semiring is compatible with value type. + output.clear(); + auto defaultValue = resultReader.ring().addIdentity(); + for (auto r : llvm::seq(resultReader.nRows())) { + for (auto c : llvm::seq(resultReader.nCols())) { + auto v = resultReader.at(r, c); + if (v != defaultValue) { + // TODO: Support bool/real value types + auto vInt = llvm::cast(v).getInt(); + output.setValue(r, c, vInt); + } + } + } + + return true; +} + } // namespace pg_graphalg diff --git a/pg_graphalg/test.sql b/pg_graphalg/test.sql index efb69f8..5a49429 100644 --- a/pg_graphalg/test.sql +++ b/pg_graphalg/test.sql @@ -78,7 +78,7 @@ SELECT * FROM mat1; CREATE FOREIGN TABLE lhs(row bigint, col bigint, val bigint) SERVER graphalg_server -OPTIONS (rows '10', columns '10'); +OPTIONS (rows '2', columns '2'); INSERT INTO lhs VALUES (0, 0, 42), @@ -87,7 +87,7 @@ INSERT INTO lhs VALUES (1, 1, 45); CREATE FOREIGN TABLE rhs(row bigint, col bigint, val bigint) SERVER graphalg_server -OPTIONS (rows '10', columns '10'); +OPTIONS (rows '2', columns '2'); INSERT INTO rhs VALUES (0, 0, 46), (0, 1, 47), @@ -96,7 +96,7 @@ INSERT INTO rhs VALUES CREATE FOREIGN TABLE matmul_out(row bigint, col bigint, val bigint) SERVER graphalg_server -OPTIONS (rows '10', columns '10'); +OPTIONS (rows '2', columns '2'); -- HACK: Necessary because we don't get a callback from CREATE SELECT * FROM matmul_out; @@ -111,3 +111,4 @@ AS $$ $$; CALL matmul('lhs', 'rhs', 'matmul_out'); +SELECT * FROM matmul_out; From fb6d76808f188e31311e56692fc8ca96aae4c3c2 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 9 Dec 2025 16:43:54 +0000 Subject: [PATCH 16/20] Cleanup. --- pg_graphalg/include/pg_graphalg/MatrixTable.h | 54 ++++++++++++++ pg_graphalg/include/pg_graphalg/PgGraphAlg.h | 71 ++----------------- pg_graphalg/src/pg_graphalg.cpp | 11 +-- pg_graphalg/src/pg_graphalg/CMakeLists.txt | 1 + pg_graphalg/src/pg_graphalg/MatrixTable.cpp | 31 ++++++++ pg_graphalg/src/pg_graphalg/PgGraphAlg.cpp | 36 ++-------- 6 files changed, 103 insertions(+), 101 deletions(-) create mode 100644 pg_graphalg/include/pg_graphalg/MatrixTable.h create mode 100644 pg_graphalg/src/pg_graphalg/MatrixTable.cpp diff --git a/pg_graphalg/include/pg_graphalg/MatrixTable.h b/pg_graphalg/include/pg_graphalg/MatrixTable.h new file mode 100644 index 0000000..0916e55 --- /dev/null +++ b/pg_graphalg/include/pg_graphalg/MatrixTable.h @@ -0,0 +1,54 @@ +#pragma once + +#include +#include +#include +#include + +namespace pg_graphalg { + +struct MatrixTableDef { + std::string name; + std::size_t nRows; + std::size_t nCols; + // TODO: data type. +}; + +class MatrixTable; + +struct MatrixTableScanState { + MatrixTable *table; + std::size_t row = 0; + std::size_t col = 0; + + MatrixTableScanState(MatrixTable *table) : table(table) {} + + void reset() { + row = 0; + col = 0; + } +}; + +class MatrixTable { +private: + std::string _name; + std::size_t _nRows; + std::size_t _nCols; + std::map, std::int64_t> _values; + +public: + MatrixTable(const MatrixTableDef &def); + + std::size_t nRows() const { return _nRows; } + std::size_t nCols() const { return _nCols; } + const auto &values() const { return _values; } + + void clear(); + void setValue(std::size_t row, std::size_t col, std::int64_t value); + std::optional> + scan(MatrixTableScanState &state); + + std::size_t nValues() { return _values.size(); } +}; + +} // namespace pg_graphalg diff --git a/pg_graphalg/include/pg_graphalg/PgGraphAlg.h b/pg_graphalg/include/pg_graphalg/PgGraphAlg.h index 857f22b..682c6c3 100644 --- a/pg_graphalg/include/pg_graphalg/PgGraphAlg.h +++ b/pg_graphalg/include/pg_graphalg/PgGraphAlg.h @@ -1,81 +1,20 @@ #pragma once -#include "mlir/IR/Diagnostics.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/STLFunctionalExtras.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringMap.h" -#include "llvm/ADT/StringRef.h" -#include -#include -#include -#include -#include #include -#include -#include -#include -#include +#include +#include +#include +#include #include #include -#include -#include -#include -#include -#include -#include -#include +#include "pg_graphalg/MatrixTable.h" namespace pg_graphalg { using TableId = unsigned int; -class MatrixTable; - -struct ScanState { - MatrixTable *table; - std::size_t row = 0; - std::size_t col = 0; - - ScanState(MatrixTable *table) : table(table) {} - - void reset() { - row = 0; - col = 0; - } -}; - -struct MatrixTableDef { - std::string name; - std::size_t nRows; - std::size_t nCols; - // TODO: data type. -}; - -class MatrixTable { -private: - std::string _name; - std::size_t _nRows; - std::size_t _nCols; - std::map, std::int64_t> _values; - -public: - MatrixTable(const MatrixTableDef &def); - - std::size_t nRows() const { return _nRows; } - std::size_t nCols() const { return _nCols; } - const auto &values() const { return _values; } - - void clear(); - void setValue(std::size_t row, std::size_t col, std::int64_t value); - std::optional> - scan(ScanState &state); - - std::size_t nValues() { return _values.size(); } -}; - class PgGraphAlg { private: mlir::DialectRegistry _registry; diff --git a/pg_graphalg/src/pg_graphalg.cpp b/pg_graphalg/src/pg_graphalg.cpp index 2325f74..edc7a2b 100644 --- a/pg_graphalg/src/pg_graphalg.cpp +++ b/pg_graphalg/src/pg_graphalg.cpp @@ -7,6 +7,7 @@ #include #include +#include "pg_graphalg/MatrixTable.h" #include "pg_graphalg/PgGraphAlg.h" extern "C" { @@ -195,8 +196,8 @@ static void BeginForeignScan(ForeignScanState *node, int eflags) { auto tableId = RelationGetRelid(node->ss.ss_currentRelation); auto &table = getInstance().getTable(tableId); - auto *state = palloc(sizeof(pg_graphalg::ScanState)); - new (state) pg_graphalg::ScanState(&table); + auto *state = palloc(sizeof(pg_graphalg::MatrixTableScanState)); + new (state) pg_graphalg::MatrixTableScanState(&table); node->fdw_state = state; } @@ -204,7 +205,8 @@ static TupleTableSlot *IterateForeignScan(ForeignScanState *node) { TupleTableSlot *slot = node->ss.ss_ScanTupleSlot; ExecClearTuple(slot); - auto *scanState = static_cast(node->fdw_state); + auto *scanState = + static_cast(node->fdw_state); auto &table = *scanState->table; if (auto res = table.scan(*scanState)) { slot->tts_isnull[0] = false; @@ -221,7 +223,8 @@ static TupleTableSlot *IterateForeignScan(ForeignScanState *node) { } static void ReScanForeignScan(ForeignScanState *node) { - auto *scanState = static_cast(node->fdw_state); + auto *scanState = + static_cast(node->fdw_state); scanState->reset(); } diff --git a/pg_graphalg/src/pg_graphalg/CMakeLists.txt b/pg_graphalg/src/pg_graphalg/CMakeLists.txt index a5ecabd..73be6e7 100644 --- a/pg_graphalg/src/pg_graphalg/CMakeLists.txt +++ b/pg_graphalg/src/pg_graphalg/CMakeLists.txt @@ -1,4 +1,5 @@ add_library(PgGraphAlg SHARED + MatrixTable.cpp PgGraphAlg.cpp ) target_include_directories(PgGraphAlg PUBLIC ../../include) diff --git a/pg_graphalg/src/pg_graphalg/MatrixTable.cpp b/pg_graphalg/src/pg_graphalg/MatrixTable.cpp new file mode 100644 index 0000000..1b8704d --- /dev/null +++ b/pg_graphalg/src/pg_graphalg/MatrixTable.cpp @@ -0,0 +1,31 @@ +#include + +namespace pg_graphalg { + +MatrixTable::MatrixTable(const MatrixTableDef &def) + : _name(def.name), _nRows(def.nRows), _nCols(def.nCols) {} + +void MatrixTable::clear() { _values.clear(); } + +void MatrixTable::setValue(std::size_t row, std::size_t col, + std::int64_t value) { + _values[{row, col}] = value; +} + +std::optional> +MatrixTable::scan(MatrixTableScanState &state) { + auto it = _values.lower_bound({state.row, state.col}); + if (it == _values.end()) { + return std::nullopt; + } + + std::size_t row = it->first.first; + std::size_t col = it->first.second; + std::int64_t val = it->second; + + state.row = row; + state.col = col + 1; + return std::make_tuple(row, col, val); +} + +} // namespace pg_graphalg diff --git a/pg_graphalg/src/pg_graphalg/PgGraphAlg.cpp b/pg_graphalg/src/pg_graphalg/PgGraphAlg.cpp index f87d35b..737dac4 100644 --- a/pg_graphalg/src/pg_graphalg/PgGraphAlg.cpp +++ b/pg_graphalg/src/pg_graphalg/PgGraphAlg.cpp @@ -1,11 +1,9 @@ -#include "graphalg/GraphAlgTypes.h" -#include "mlir/IR/BuiltinAttributes.h" #include -#include #include -#include #include +#include +#include #include #include #include @@ -14,37 +12,14 @@ #include #include +#include +#include +#include #include namespace pg_graphalg { -MatrixTable::MatrixTable(const MatrixTableDef &def) - : _name(def.name), _nRows(def.nRows), _nCols(def.nCols) {} - -void MatrixTable::clear() { _values.clear(); } - -void MatrixTable::setValue(std::size_t row, std::size_t col, - std::int64_t value) { - _values[{row, col}] = value; -} - -std::optional> -MatrixTable::scan(ScanState &state) { - auto it = _values.lower_bound({state.row, state.col}); - if (it == _values.end()) { - return std::nullopt; - } - - std::size_t row = it->first.first; - std::size_t col = it->first.second; - std::int64_t val = it->second; - - state.row = row; - state.col = col + 1; - return std::make_tuple(row, col, val); -} - static mlir::DialectRegistry createDialectRegistry() { mlir::DialectRegistry registry; registry.insert(); @@ -69,7 +44,6 @@ MatrixTable &PgGraphAlg::getOrCreateTable(TableId tableId, if (!_tables.count(tableId)) { _tables.emplace(tableId, def); _nameToId[def.name] = tableId; - std::cerr << "registered table " << def.name << "\n"; } return getTable(tableId); From e7515d98d43e519c04068dbab188e705b4d5b9ed Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 9 Dec 2025 17:14:23 +0000 Subject: [PATCH 17/20] Specialize MatrixTable for different types. --- pg_graphalg/include/pg_graphalg/MatrixTable.h | 44 ++++++++++++++++--- pg_graphalg/include/pg_graphalg/PgGraphAlg.h | 5 ++- pg_graphalg/src/pg_graphalg.cpp | 9 +++- pg_graphalg/src/pg_graphalg/MatrixTable.cpp | 12 +---- pg_graphalg/src/pg_graphalg/PgGraphAlg.cpp | 15 +++++-- 5 files changed, 59 insertions(+), 26 deletions(-) diff --git a/pg_graphalg/include/pg_graphalg/MatrixTable.h b/pg_graphalg/include/pg_graphalg/MatrixTable.h index 0916e55..29996fc 100644 --- a/pg_graphalg/include/pg_graphalg/MatrixTable.h +++ b/pg_graphalg/include/pg_graphalg/MatrixTable.h @@ -7,11 +7,17 @@ namespace pg_graphalg { +enum class MatrixValueType { + BOOL, + INT, + FLOAT, +}; + struct MatrixTableDef { std::string name; std::size_t nRows; std::size_t nCols; - // TODO: data type. + MatrixValueType type; }; class MatrixTable; @@ -34,21 +40,45 @@ class MatrixTable { std::string _name; std::size_t _nRows; std::size_t _nCols; - std::map, std::int64_t> _values; + const MatrixValueType _type; public: - MatrixTable(const MatrixTableDef &def); + MatrixTable(const MatrixTableDef &def) + : _name(def.name), _nRows(def.nRows), _nCols(def.nCols), _type(def.type) { + } + + virtual ~MatrixTable() = default; std::size_t nRows() const { return _nRows; } std::size_t nCols() const { return _nCols; } + MatrixValueType getType() const { return _type; } + + virtual std::size_t nValues() = 0; + virtual void clear() = 0; +}; + +class MatrixTableInt : public MatrixTable { +private: + std::map, std::int64_t> _values; + +public: + MatrixTableInt(const MatrixTableDef &def) : MatrixTable(def) {} + + static bool classof(const MatrixTable *t) { + return t->getType() == MatrixValueType::INT; + } + + std::size_t nValues() override { return _values.size(); } + void clear() override { _values.clear(); } + + void setValue(std::size_t row, std::size_t col, std::int64_t value) { + _values[{row, col}] = value; + } + const auto &values() const { return _values; } - void clear(); - void setValue(std::size_t row, std::size_t col, std::int64_t value); std::optional> scan(MatrixTableScanState &state); - - std::size_t nValues() { return _values.size(); } }; } // namespace pg_graphalg diff --git a/pg_graphalg/include/pg_graphalg/PgGraphAlg.h b/pg_graphalg/include/pg_graphalg/PgGraphAlg.h index 682c6c3..1fc0114 100644 --- a/pg_graphalg/include/pg_graphalg/PgGraphAlg.h +++ b/pg_graphalg/include/pg_graphalg/PgGraphAlg.h @@ -1,8 +1,9 @@ #pragma once -#include +#include #include +#include #include #include #include @@ -19,7 +20,7 @@ class PgGraphAlg { private: mlir::DialectRegistry _registry; mlir::MLIRContext _ctx; - std::unordered_map _tables; + llvm::DenseMap> _tables; llvm::StringMap _nameToId; public: diff --git a/pg_graphalg/src/pg_graphalg.cpp b/pg_graphalg/src/pg_graphalg.cpp index edc7a2b..5baf9df 100644 --- a/pg_graphalg/src/pg_graphalg.cpp +++ b/pg_graphalg/src/pg_graphalg.cpp @@ -137,6 +137,8 @@ parseOptions(ForeignTable *table) { tableName, static_cast(*rows), static_cast(*cols), + // TODO: Determine from column type. + pg_graphalg::MatrixValueType::INT, }; } @@ -207,7 +209,8 @@ static TupleTableSlot *IterateForeignScan(ForeignScanState *node) { auto *scanState = static_cast(node->fdw_state); - auto &table = *scanState->table; + // TODO: Check type + auto &table = llvm::cast(*scanState->table); if (auto res = table.scan(*scanState)) { slot->tts_isnull[0] = false; slot->tts_isnull[1] = false; @@ -247,7 +250,9 @@ static TupleTableSlot *ExecForeignInsert(EState *estate, ResultRelInfo *rinfo, TupleTableSlot *slot, TupleTableSlot *planSlot) { auto tableId = RelationGetRelid(rinfo->ri_RelationDesc); - auto &table = getInstance().getTable(tableId); + // TODO: More types + auto &table = + llvm::cast(getInstance().getTable(tableId)); slot_getsomeattrs(slot, 3); if (slot->tts_isnull[0] || slot->tts_isnull[1] || slot->tts_isnull[2]) { diff --git a/pg_graphalg/src/pg_graphalg/MatrixTable.cpp b/pg_graphalg/src/pg_graphalg/MatrixTable.cpp index 1b8704d..41cc8ee 100644 --- a/pg_graphalg/src/pg_graphalg/MatrixTable.cpp +++ b/pg_graphalg/src/pg_graphalg/MatrixTable.cpp @@ -2,18 +2,8 @@ namespace pg_graphalg { -MatrixTable::MatrixTable(const MatrixTableDef &def) - : _name(def.name), _nRows(def.nRows), _nCols(def.nCols) {} - -void MatrixTable::clear() { _values.clear(); } - -void MatrixTable::setValue(std::size_t row, std::size_t col, - std::int64_t value) { - _values[{row, col}] = value; -} - std::optional> -MatrixTable::scan(MatrixTableScanState &state) { +MatrixTableInt::scan(MatrixTableScanState &state) { auto it = _values.lower_bound({state.row, state.col}); if (it == _values.end()) { return std::nullopt; diff --git a/pg_graphalg/src/pg_graphalg/PgGraphAlg.cpp b/pg_graphalg/src/pg_graphalg/PgGraphAlg.cpp index 737dac4..d319f4a 100644 --- a/pg_graphalg/src/pg_graphalg/PgGraphAlg.cpp +++ b/pg_graphalg/src/pg_graphalg/PgGraphAlg.cpp @@ -16,6 +16,7 @@ #include #include +#include #include namespace pg_graphalg { @@ -36,13 +37,14 @@ PgGraphAlg::PgGraphAlg(llvm::function_ref diagHandler) MatrixTable &PgGraphAlg::getTable(TableId tableId) { assert(_tables.count(tableId) && "getTable called before getOrCreateTable"); - return _tables.at(tableId); + return *_tables[tableId]; } MatrixTable &PgGraphAlg::getOrCreateTable(TableId tableId, const MatrixTableDef &def) { if (!_tables.count(tableId)) { - _tables.emplace(tableId, def); + // TODO: More types + _tables[tableId] = std::make_unique(def); _nameToId[def.name] = tableId; } @@ -117,7 +119,9 @@ bool PgGraphAlg::execute(llvm::StringRef programSource, llvm::zip_equal(arguments, funcOp.getFunctionType().getInputs())) { auto matType = llvm::cast(type); graphalg::MatrixAttrBuilder builder(matType); - for (auto [pos, val] : arg->values()) { + // TODO: support more value types. + const auto &values = llvm::cast(arg)->values(); + for (auto [pos, val] : values) { auto [row, col] = pos; // TODO: support more value types auto valAttr = mlir::IntegerAttr::get(matType.getSemiring(), val); @@ -136,6 +140,9 @@ bool PgGraphAlg::execute(llvm::StringRef programSource, // TODO: Check rows/cols match. // TODO: Check semiring is compatible with value type. output.clear(); + + // TODO: more value types + auto &outputInt = llvm::cast(output); auto defaultValue = resultReader.ring().addIdentity(); for (auto r : llvm::seq(resultReader.nRows())) { for (auto c : llvm::seq(resultReader.nCols())) { @@ -143,7 +150,7 @@ bool PgGraphAlg::execute(llvm::StringRef programSource, if (v != defaultValue) { // TODO: Support bool/real value types auto vInt = llvm::cast(v).getInt(); - output.setValue(r, c, vInt); + outputInt.setValue(r, c, vInt); } } } From 988077a37bee48bff1d5711e978edf5113d9567f Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 10 Dec 2025 14:28:14 +0000 Subject: [PATCH 18/20] Lookup foreign tables in catalog. --- pg_graphalg/README.md | 2 + pg_graphalg/include/pg_graphalg/PgGraphAlg.h | 9 +- pg_graphalg/src/pg_graphalg.cpp | 135 +++++++++++++------ pg_graphalg/src/pg_graphalg/PgGraphAlg.cpp | 32 ++--- pg_graphalg/test.sql | 2 - 5 files changed, 111 insertions(+), 69 deletions(-) diff --git a/pg_graphalg/README.md b/pg_graphalg/README.md index c0ae37b..33439f6 100644 --- a/pg_graphalg/README.md +++ b/pg_graphalg/README.md @@ -12,9 +12,11 @@ Download the latest postgres release source code. ```bash # Dependencies for building PostgreSQL +sudo apt update sudo apt install bison flex libreadline-dev libicu-dev # Download the sources +mkdir -p thirdparty/ cd thirdparty/ wget https://ftp.postgresql.org/pub/source/v18.1/postgresql-18.1.tar.bz2 tar xf postgresql-18.1.tar.bz2 diff --git a/pg_graphalg/include/pg_graphalg/PgGraphAlg.h b/pg_graphalg/include/pg_graphalg/PgGraphAlg.h index 1fc0114..6323402 100644 --- a/pg_graphalg/include/pg_graphalg/PgGraphAlg.h +++ b/pg_graphalg/include/pg_graphalg/PgGraphAlg.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include @@ -9,6 +10,7 @@ #include #include #include +#include #include "pg_graphalg/MatrixTable.h" @@ -21,14 +23,13 @@ class PgGraphAlg { mlir::DialectRegistry _registry; mlir::MLIRContext _ctx; llvm::DenseMap> _tables; - llvm::StringMap _nameToId; public: PgGraphAlg(llvm::function_ref diagHandler); - MatrixTable &getTable(TableId tableId); - MatrixTable &getOrCreateTable(TableId tableId, const MatrixTableDef &def); - MatrixTable *lookupTable(llvm::StringRef tableName); + std::optional getOrCreateTable( + TableId tableId, + llvm::function_ref(TableId id)> createFunc); bool execute(llvm::StringRef programSource, llvm::StringRef function, llvm::ArrayRef arguments, diff --git a/pg_graphalg/src/pg_graphalg.cpp b/pg_graphalg/src/pg_graphalg.cpp index 5baf9df..ddcb333 100644 --- a/pg_graphalg/src/pg_graphalg.cpp +++ b/pg_graphalg/src/pg_graphalg.cpp @@ -19,6 +19,7 @@ extern "C" { #include #include #include +#include #include #include #include @@ -28,6 +29,7 @@ extern "C" { #include #include #include +#include #include #include #include @@ -142,15 +144,23 @@ parseOptions(ForeignTable *table) { }; } +static std::optional lookupMatrixTable(Oid relid) { + auto *fTable = GetForeignTable(relid); + if (!fTable) { + elog(ERROR, "relation with oid %u is not a foreign table", relid); + return std::nullopt; + } + + return parseOptions(fTable); +} + static void GetForeignRelSize(PlannerInfo *root, RelOptInfo *baserel, Oid foreigntableid) { - auto tableDef = parseOptions(GetForeignTable(foreigntableid)); - if (!tableDef) { - return; + auto table = + getInstance().getOrCreateTable(foreigntableid, lookupMatrixTable); + if (table) { + baserel->rows = (*table)->nValues(); } - - auto &table = getInstance().getOrCreateTable(foreigntableid, *tableDef); - baserel->rows = table.nValues(); } static void GetForeignPaths(PlannerInfo *root, RelOptInfo *baserel, @@ -173,15 +183,6 @@ static ForeignScan *GetForeignPlan(PlannerInfo *root, RelOptInfo *baserel, Oid foreigntableid, ForeignPath *best_path, List *tlist, List *scan_clauses, Plan *outer_plan) { - // Resolve to a matrix table. - auto tableDef = parseOptions(GetForeignTable(foreigntableid)); - if (!tableDef) { - return nullptr; - } - - // Create table if it does not exist yet. - getInstance().getOrCreateTable(foreigntableid, *tableDef); - // On extract_actual_clauses: // https://www.postgresql.org/docs/current/fdw-planning.html scan_clauses = extract_actual_clauses(scan_clauses, false); @@ -196,10 +197,13 @@ static ForeignScan *GetForeignPlan(PlannerInfo *root, RelOptInfo *baserel, static void BeginForeignScan(ForeignScanState *node, int eflags) { auto tableId = RelationGetRelid(node->ss.ss_currentRelation); - auto &table = getInstance().getTable(tableId); + auto table = getInstance().getOrCreateTable(tableId, lookupMatrixTable); + if (!table) { + return; + } auto *state = palloc(sizeof(pg_graphalg::MatrixTableScanState)); - new (state) pg_graphalg::MatrixTableScanState(&table); + new (state) pg_graphalg::MatrixTableScanState(*table); node->fdw_state = state; } @@ -240,19 +244,16 @@ static void BeginForeignModify(ModifyTableState *mtstate, ResultRelInfo *rinfo, int eflags) { // Ensure table exists before modifying it. auto tableId = RelationGetRelid(rinfo->ri_RelationDesc); - auto tableDef = parseOptions(GetForeignTable(tableId)); - if (tableDef) { - getInstance().getOrCreateTable(tableId, *tableDef); - } + getInstance().getOrCreateTable(tableId, lookupMatrixTable); } static TupleTableSlot *ExecForeignInsert(EState *estate, ResultRelInfo *rinfo, TupleTableSlot *slot, TupleTableSlot *planSlot) { auto tableId = RelationGetRelid(rinfo->ri_RelationDesc); + auto &table = **getInstance().getOrCreateTable(tableId, lookupMatrixTable); // TODO: More types - auto &table = - llvm::cast(getInstance().getTable(tableId)); + auto &tableInt = llvm::cast(table); slot_getsomeattrs(slot, 3); if (slot->tts_isnull[0] || slot->tts_isnull[1] || slot->tts_isnull[2]) { @@ -263,7 +264,7 @@ static TupleTableSlot *ExecForeignInsert(EState *estate, ResultRelInfo *rinfo, std::size_t row = DatumGetUInt64(slot->tts_values[0]); std::size_t col = DatumGetUInt64(slot->tts_values[1]); std::int64_t val = DatumGetInt64(slot->tts_values[2]); - table.setValue(row, col, val); + tableInt.setValue(row, col, val); return slot; } @@ -299,6 +300,58 @@ Datum graphalg_fdw_validator(PG_FUNCTION_ARGS) { PG_RETURN_VOID(); } +// NOTE: Assumes a working SPI connection +static std::optional lookupForeignTable(Oid argType, Datum argValue) { + constexpr bool READ_ONLY = true; + /* + * Allow up to 2 results. + * n == 0: ERROR Table not found + * n > 1: ERROR Multiple oids for the given name + * n == 1: OK Name uniquely identifies table + */ + constexpr long TCOUNT = 2; + int execRes = SPI_execute_with_args( + "SELECT oid FROM pg_class WHERE relname=$1 AND relkind = 'f'", 1, + &argType, &argValue, nullptr, READ_ONLY, TCOUNT); + if (execRes < 0) { + elog(ERROR, "internal error finding argument tables"); + PG_RETURN_VOID(); + } + + // Expect exactly one result. + if (SPI_processed != 1) { + // TODO: We know this is a string, so can we make this simpler? + auto typeTuple = SearchSysCache1(TYPEOID, ObjectIdGetDatum(argType)); + auto typeStruct = (Form_pg_type)GETSTRUCT(typeTuple); + FmgrInfo typeInfo; + fmgr_info(typeStruct->typoutput, &typeInfo); + char *value = OutputFunctionCall(&typeInfo, argValue); + ReleaseSysCache(typeTuple); + + if (SPI_processed == 0) { + elog(ERROR, "no such matrix table '%s'", value); + } else { + elog(ERROR, "multiple tables named '%s'", value); + } + + PG_RETURN_VOID(); + } + + auto *tuptable = SPI_tuptable; + auto tupdesc = tuptable->tupdesc; + bool oidNull = false; + auto tableOidDatum = + SPI_getbinval(tuptable->vals[0], tuptable->tupdesc, 1, &oidNull); + return DatumGetObjectId(tableOidDatum); +} + +/** Automatically calls SPI_finish() when it goes out of scope. */ +class SPIConnection { +public: + SPIConnection() { SPI_connect(); } + ~SPIConnection() { SPI_finish(); } +}; + static Datum executeCall(FunctionCallInfo fcinfo) { auto procTuple = SearchSysCache( PROCOID, ObjectIdGetDatum(fcinfo->flinfo->fn_oid), 0, 0, 0); @@ -309,20 +362,19 @@ static Datum executeCall(FunctionCallInfo fcinfo) { auto procStruct = (Form_pg_proc)GETSTRUCT(procTuple); - bool isnull; + // Extract program source code + bool sourceIsNull; auto sourceDatum = - SysCacheGetAttr(PROCOID, procTuple, Anum_pg_proc_prosrc, &isnull); - if (isnull) { + SysCacheGetAttr(PROCOID, procTuple, Anum_pg_proc_prosrc, &sourceIsNull); + if (sourceIsNull) { elog(ERROR, "NULL procedure source"); PG_RETURN_VOID(); } char *procCode = DatumGetCString(DirectFunctionCall1(textout, sourceDatum)); - std::cerr << "GraphAlg source: " << procCode << "\n"; - - auto funcName = procStruct->proname.data; - std::cerr << "function name: " << funcName << "\n"; + // Get argument matrix tables. + SPIConnection spiConnection; llvm::SmallVector arguments; for (int i = 0; i < fcinfo->nargs; i++) { auto arg = fcinfo->args[i]; @@ -331,23 +383,18 @@ static Datum executeCall(FunctionCallInfo fcinfo) { PG_RETURN_VOID(); } - // TODO: We know this is a string, so can we make this simpler? auto argType = procStruct->proargtypes.values[i]; - auto typeTuple = SearchSysCache1(TYPEOID, ObjectIdGetDatum(argType)); - auto typeStruct = (Form_pg_type)GETSTRUCT(typeTuple); - FmgrInfo typeInfo; - fmgr_info(typeStruct->typoutput, &typeInfo); - - char *value = OutputFunctionCall(&typeInfo, arg.value); - ReleaseSysCache(typeTuple); + auto tableOid = lookupForeignTable(argType, arg.value); + if (!tableOid) { + PG_RETURN_VOID(); + } - auto *argTable = getInstance().lookupTable(value); - if (!argTable) { - elog(ERROR, "No such matrix table '%s'", value); + auto table = getInstance().getOrCreateTable(*tableOid, lookupMatrixTable); + if (!table) { PG_RETURN_VOID(); } - arguments.push_back(argTable); + arguments.push_back(*table); } if (arguments.empty()) { @@ -355,10 +402,12 @@ static Datum executeCall(FunctionCallInfo fcinfo) { PG_RETURN_VOID(); } + // Output is written to the final procedure argument. auto *output = arguments.pop_back_val(); // No need to check the result here, postgres infers success based on // diagnostics. + auto funcName = procStruct->proname.data; getInstance().execute(procCode, funcName, arguments, *output); ReleaseSysCache(procTuple); diff --git a/pg_graphalg/src/pg_graphalg/PgGraphAlg.cpp b/pg_graphalg/src/pg_graphalg/PgGraphAlg.cpp index d319f4a..4758f0d 100644 --- a/pg_graphalg/src/pg_graphalg/PgGraphAlg.cpp +++ b/pg_graphalg/src/pg_graphalg/PgGraphAlg.cpp @@ -1,4 +1,5 @@ #include +#include #include #include @@ -35,28 +36,19 @@ PgGraphAlg::PgGraphAlg(llvm::function_ref diagHandler) engine.registerHandler(diagHandler); } -MatrixTable &PgGraphAlg::getTable(TableId tableId) { - assert(_tables.count(tableId) && "getTable called before getOrCreateTable"); - return *_tables[tableId]; -} - -MatrixTable &PgGraphAlg::getOrCreateTable(TableId tableId, - const MatrixTableDef &def) { - if (!_tables.count(tableId)) { - // TODO: More types - _tables[tableId] = std::make_unique(def); - _nameToId[def.name] = tableId; - } - - return getTable(tableId); -} +std::optional PgGraphAlg::getOrCreateTable( + TableId tableId, + llvm::function_ref(TableId id)> createFunc) { + if (!_tables.contains(tableId)) { + auto def = createFunc(tableId); + if (!def) { + return std::nullopt; + } -MatrixTable *PgGraphAlg::lookupTable(llvm::StringRef tableName) { - if (_nameToId.contains(tableName)) { - return &getTable(_nameToId[tableName]); - } else { - return nullptr; + // TODO: Different types. + _tables[tableId] = std::make_unique(*def); } + return _tables[tableId].get(); } bool PgGraphAlg::execute(llvm::StringRef programSource, diff --git a/pg_graphalg/test.sql b/pg_graphalg/test.sql index 5a49429..90217c7 100644 --- a/pg_graphalg/test.sql +++ b/pg_graphalg/test.sql @@ -97,8 +97,6 @@ INSERT INTO rhs VALUES CREATE FOREIGN TABLE matmul_out(row bigint, col bigint, val bigint) SERVER graphalg_server OPTIONS (rows '2', columns '2'); --- HACK: Necessary because we don't get a callback from CREATE -SELECT * FROM matmul_out; CREATE PROCEDURE matmul(text, text, text) LANGUAGE graphalg From 070a3a90ad63af17e3f3a8810622f8c57a0b9edc Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 10 Dec 2025 19:31:33 +0000 Subject: [PATCH 19/20] Validate column types. --- pg_graphalg/src/pg_graphalg.cpp | 107 ++++++++++----- pg_graphalg/test.sql | 10 ++ pgext/.gitignore | 2 - pgext/build.sh | 7 - pgext/funcs.c | 103 --------------- pgext/load.sql | 20 --- pgext/tutorial_fdw/Makefile | 9 -- pgext/tutorial_fdw/smoke_test.sh | 14 -- pgext/tutorial_fdw/smoke_test.sql | 4 - pgext/tutorial_fdw/tutorial_fdw--1.0.sql | 7 - pgext/tutorial_fdw/tutorial_fdw.c | 159 ----------------------- pgext/tutorial_fdw/tutorial_fdw.control | 5 - 12 files changed, 88 insertions(+), 359 deletions(-) delete mode 100644 pgext/.gitignore delete mode 100755 pgext/build.sh delete mode 100644 pgext/funcs.c delete mode 100644 pgext/load.sql delete mode 100644 pgext/tutorial_fdw/Makefile delete mode 100755 pgext/tutorial_fdw/smoke_test.sh delete mode 100644 pgext/tutorial_fdw/smoke_test.sql delete mode 100644 pgext/tutorial_fdw/tutorial_fdw--1.0.sql delete mode 100644 pgext/tutorial_fdw/tutorial_fdw.c delete mode 100644 pgext/tutorial_fdw/tutorial_fdw.control diff --git a/pg_graphalg/src/pg_graphalg.cpp b/pg_graphalg/src/pg_graphalg.cpp index ddcb333..5e8a785 100644 --- a/pg_graphalg/src/pg_graphalg.cpp +++ b/pg_graphalg/src/pg_graphalg.cpp @@ -1,9 +1,8 @@ #include -#include #include #include -#include +#include #include #include @@ -14,8 +13,10 @@ extern "C" { #include +#include #include #include +#include #include #include #include @@ -34,6 +35,7 @@ extern "C" { #include #include #include +#include #include PG_MODULE_MAGIC; @@ -92,25 +94,84 @@ static bool validateOption(DefElem *def) { return true; } -static std::optional -parseOptions(ForeignTable *table) { - // Get the name of the table. - auto relTuple = SearchSysCache1(RELOID, table->relid); - if (!HeapTupleIsValid(relTuple)) { - elog(ERROR, "cannot retrieve table name for oid"); +static std::optional mapValueType(Oid typeId) { + switch (typeId) { + case BOOLOID: + return pg_graphalg::MatrixValueType::BOOL; + case INT8OID: + return pg_graphalg::MatrixValueType::INT; + case FLOAT8OID: + return pg_graphalg::MatrixValueType::FLOAT; + default: + return std::nullopt; + } +} + +struct SysCacheTupleScope { + HeapTuple tup; + + SysCacheTupleScope(SysCacheIdentifier cacheId, Oid key) + : tup(SearchSysCache1(cacheId, key)) {} + ~SysCacheTupleScope() { + if (HeapTupleIsValid(tup)) { + ReleaseSysCache(tup); + } + } +}; + +struct RelationScope { + Relation rel; + + RelationScope(Oid relid) : rel(RelationIdGetRelation(relid)) {} + ~RelationScope() { RelationClose(rel); } +}; + +static std::optional lookupMatrixTable(Oid relid) { + // Must be a foreign table + // TODO: Check that it uses a graphalg server. + auto *fTable = GetForeignTable(relid); + if (!fTable) { + elog(ERROR, "relation with oid %u is not a foreign table", relid); + return std::nullopt; + } + + RelationScope rel(relid); + std::string tableName{NameStr(rel.rel->rd_rel->relname)}; + + // Validate the column types. + int nAttrs = RelationGetNumberOfAttributes(rel.rel); + if (nAttrs != 3) { + elog(ERROR, "matrix table must have 3 columns, got %d", nAttrs); + return std::nullopt; + } + + TupleDesc tupleDesc = rel.rel->rd_att; + auto rowAttr = TupleDescAttr(tupleDesc, 0); + if (rowAttr->atttypid != INT8OID) { + elog(ERROR, "first column (row index) must have type bigint"); + return std::nullopt; + } + + auto colAttr = TupleDescAttr(tupleDesc, 1); + if (colAttr->atttypid != INT8OID) { + elog(ERROR, "second column (column index) must have type bigint"); return std::nullopt; } - auto relStruct = (Form_pg_class)GETSTRUCT(relTuple); - std::string tableName{NameStr(relStruct->relname)}; - ReleaseSysCache(relTuple); + auto valAttr = TupleDescAttr(tupleDesc, 2); + auto valType = mapValueType(valAttr->atttypid); + if (!valType) { + elog(ERROR, "third column (value) must have type boolean, bigint or double " + "precision"); + return std::nullopt; + } // Parse the options ListCell *cell; std::optional rows; std::optional cols; - foreach (cell, table->options) { + foreach (cell, fTable->options) { auto *def = lfirst_node(DefElem, cell); std::string_view defName(def->defname); @@ -140,20 +201,10 @@ parseOptions(ForeignTable *table) { static_cast(*rows), static_cast(*cols), // TODO: Determine from column type. - pg_graphalg::MatrixValueType::INT, + *valType, }; } -static std::optional lookupMatrixTable(Oid relid) { - auto *fTable = GetForeignTable(relid); - if (!fTable) { - elog(ERROR, "relation with oid %u is not a foreign table", relid); - return std::nullopt; - } - - return parseOptions(fTable); -} - static void GetForeignRelSize(PlannerInfo *root, RelOptInfo *baserel, Oid foreigntableid) { auto table = @@ -291,6 +342,7 @@ Datum graphalg_fdw_validator(PG_FUNCTION_ARGS) { ListCell *cell; foreach (cell, options) { auto *def = static_cast(lfirst(cell)); + // TODO: Only allow options at the table level. validateOption(def); } @@ -414,18 +466,15 @@ static Datum executeCall(FunctionCallInfo fcinfo) { PG_RETURN_VOID(); } -Datum graphalg_pl_call_handler(PG_FUNCTION_ARGS) { - std::cerr << "call handler!\n"; - return executeCall(fcinfo); -} +Datum graphalg_pl_call_handler(PG_FUNCTION_ARGS) { return executeCall(fcinfo); } Datum graphalg_pl_inline_handler(PG_FUNCTION_ARGS) { - std::cerr << "inline handler!\n"; + elog(ERROR, "inline handler not implemented"); PG_RETURN_VOID(); } Datum graphalg_pl_validator(PG_FUNCTION_ARGS) { - std::cerr << "validator!\n"; + elog(INFO, "NOTE: language validator not implemented"); PG_RETURN_VOID(); } } diff --git a/pg_graphalg/test.sql b/pg_graphalg/test.sql index 90217c7..59f4ed5 100644 --- a/pg_graphalg/test.sql +++ b/pg_graphalg/test.sql @@ -110,3 +110,13 @@ $$; CALL matmul('lhs', 'rhs', 'matmul_out'); SELECT * FROM matmul_out; + +CREATE FOREIGN TABLE matbool ( row bigint, col bigint, val boolean ) SERVER graphalg_server OPTIONS (rows '10', columns '10'); +INSERT INTO matbool VALUES (0, 0, true); +CREATE FOREIGN TABLE matreal ( row bigint, col bigint, val double precision ) SERVER graphalg_server OPTIONS (rows '10', columns '10'); +INSERT INTO matreal VALUES (0, 0, 4.2); +DROP FOREIGN TABLE matbool; +DROP FOREIGN TABLE matreal; +CREATE FOREIGN TABLE matbad ( row bigint, col bigint, val text ) SERVER graphalg_server OPTIONS (rows '10', columns '10'); +INSERT INTO matbad VALUES (0, 0, '42'); +DROP FOREIGN TABLE matbad; diff --git a/pgext/.gitignore b/pgext/.gitignore deleted file mode 100644 index 9d22eb4..0000000 --- a/pgext/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -*.o -*.so diff --git a/pgext/build.sh b/pgext/build.sh deleted file mode 100755 index 048359c..0000000 --- a/pgext/build.sh +++ /dev/null @@ -1,7 +0,0 @@ -#!/bin/bash - -PG_SERVER_INCLUDE_DIR=$(/usr/local/pgsql/bin/pg_config --includedir-server) - -gcc -I $PG_SERVER_INCLUDE_DIR -fPIC -c pgext/funcs.c -o pgext/funcs.o -gcc -shared -o pgext/funcs.so pgext/funcs.o - diff --git a/pgext/funcs.c b/pgext/funcs.c deleted file mode 100644 index 37a9174..0000000 --- a/pgext/funcs.c +++ /dev/null @@ -1,103 +0,0 @@ -#include "postgres.h" - -#include "fmgr.h" -#include "utils/geo_decls.h" -#include "varatt.h" - -#include - -PG_MODULE_MAGIC; - -/* by value */ - -PG_FUNCTION_INFO_V1(add_one); - -Datum add_one(PG_FUNCTION_ARGS) { - int32 arg = PG_GETARG_INT32(0); - - PG_RETURN_INT32(arg + 1); -} - -/* by reference, fixed length */ - -PG_FUNCTION_INFO_V1(add_one_float8); - -Datum add_one_float8(PG_FUNCTION_ARGS) { - /* The macros for FLOAT8 hide its pass-by-reference nature. */ - float8 arg = PG_GETARG_FLOAT8(0); - - PG_RETURN_FLOAT8(arg + 1.0); -} - -PG_FUNCTION_INFO_V1(makepoint); - -Datum makepoint(PG_FUNCTION_ARGS) { - /* Here, the pass-by-reference nature of Point is not hidden. */ - Point *pointx = PG_GETARG_POINT_P(0); - Point *pointy = PG_GETARG_POINT_P(1); - Point *new_point = (Point *)palloc(sizeof(Point)); - - new_point->x = pointx->x; - new_point->y = pointy->y; - - PG_RETURN_POINT_P(new_point); -} - -/* by reference, variable length */ - -PG_FUNCTION_INFO_V1(copytext); - -Datum copytext(PG_FUNCTION_ARGS) { - text *t = PG_GETARG_TEXT_PP(0); - - /* - * VARSIZE_ANY_EXHDR is the size of the struct in bytes, minus the - * VARHDRSZ or VARHDRSZ_SHORT of its header. Construct the copy with a - * full-length header. - */ - text *new_t = (text *)palloc(VARSIZE_ANY_EXHDR(t) + VARHDRSZ); - SET_VARSIZE(new_t, VARSIZE_ANY_EXHDR(t) + VARHDRSZ); - - /* - * VARDATA is a pointer to the data region of the new struct. The source - * could be a short datum, so retrieve its data through VARDATA_ANY. - */ - memcpy(VARDATA(new_t), /* destination */ - VARDATA_ANY(t), /* source */ - VARSIZE_ANY_EXHDR(t)); /* how many bytes */ - PG_RETURN_TEXT_P(new_t); -} - -PG_FUNCTION_INFO_V1(concat_text); - -Datum concat_text(PG_FUNCTION_ARGS) { - text *arg1 = PG_GETARG_TEXT_PP(0); - text *arg2 = PG_GETARG_TEXT_PP(1); - int32 arg1_size = VARSIZE_ANY_EXHDR(arg1); - int32 arg2_size = VARSIZE_ANY_EXHDR(arg2); - int32 new_text_size = arg1_size + arg2_size + VARHDRSZ; - text *new_text = (text *)palloc(new_text_size); - - SET_VARSIZE(new_text, new_text_size); - memcpy(VARDATA(new_text), VARDATA_ANY(arg1), arg1_size); - memcpy(VARDATA(new_text) + arg1_size, VARDATA_ANY(arg2), arg2_size); - PG_RETURN_TEXT_P(new_text); -} - -/* A wrapper around starts_with(text, text) */ -/* - -PG_FUNCTION_INFO_V1(t_starts_with); - -Datum t_starts_with(PG_FUNCTION_ARGS) { - text *t1 = PG_GETARG_TEXT_PP(0); - text *t2 = PG_GETARG_TEXT_PP(1); - Oid collid = PG_GET_COLLATION(); - bool result; - - result = DatumGetBool(DirectFunctionCall2Coll( - text_starts_with, collid, PointerGetDatum(t1), PointerGetDatum(t2))); - PG_RETURN_BOOL(result); -} - -*/ diff --git a/pgext/load.sql b/pgext/load.sql deleted file mode 100644 index 1cbc3a9..0000000 --- a/pgext/load.sql +++ /dev/null @@ -1,20 +0,0 @@ -CREATE FUNCTION add_one(integer) RETURNS integer - AS '/workspaces/graphalg/pgext/funcs', 'add_one' - LANGUAGE C STRICT; - --- note overloading of SQL function name "add_one" -CREATE FUNCTION add_one(double precision) RETURNS double precision - AS '/workspaces/graphalg/pgext/funcs', 'add_one_float8' - LANGUAGE C STRICT; - -CREATE FUNCTION makepoint(point, point) RETURNS point - AS '/workspaces/graphalg/pgext/funcs', 'makepoint' - LANGUAGE C STRICT; - -CREATE FUNCTION copytext(text) RETURNS text - AS '/workspaces/graphalg/pgext/funcs', 'copytext' - LANGUAGE C STRICT; - -CREATE FUNCTION concat_text(text, text) RETURNS text - AS '/workspaces/graphalg/pgext/funcs', 'concat_text' - LANGUAGE C STRICT; diff --git a/pgext/tutorial_fdw/Makefile b/pgext/tutorial_fdw/Makefile deleted file mode 100644 index 7ba44af..0000000 --- a/pgext/tutorial_fdw/Makefile +++ /dev/null @@ -1,9 +0,0 @@ -MODULE_big = tutorial_fdw -OBJS = tutorial_fdw.o - -EXTENSION = tutorial_fdw -DATA = tutorial_fdw--1.0.sql - -PG_CONFIG = /usr/local/pgsql/bin/pg_config -PGXS := $(shell $(PG_CONFIG) --pgxs) -include $(PGXS) \ No newline at end of file diff --git a/pgext/tutorial_fdw/smoke_test.sh b/pgext/tutorial_fdw/smoke_test.sh deleted file mode 100755 index f3ee9cb..0000000 --- a/pgext/tutorial_fdw/smoke_test.sh +++ /dev/null @@ -1,14 +0,0 @@ -#!/bin/bash - -set -eo pipefail - -make -#sudo make install - -PGDATA=`mktemp -d -t tfdw-XXXXXXXXXXX` - -trap "PGDATA=\"$PGDATA\" pg_ctl stop >/dev/null || true; rm -rf \"$PGDATA\"" EXIT - -PGDATA="$PGDATA" pg_ctl initdb > /dev/null -PGDATA="$PGDATA" pg_ctl start -psql postgres -f smoke_test.sql \ No newline at end of file diff --git a/pgext/tutorial_fdw/smoke_test.sql b/pgext/tutorial_fdw/smoke_test.sql deleted file mode 100644 index f989cb3..0000000 --- a/pgext/tutorial_fdw/smoke_test.sql +++ /dev/null @@ -1,4 +0,0 @@ -CREATE EXTENSION tutorial_fdw; -CREATE SERVER tutorial_server FOREIGN DATA WRAPPER tutorial_fdw; -CREATE FOREIGN TABLE sequential_ints ( val int ) SERVER tutorial_server; -SELECT * FROM sequential_ints; \ No newline at end of file diff --git a/pgext/tutorial_fdw/tutorial_fdw--1.0.sql b/pgext/tutorial_fdw/tutorial_fdw--1.0.sql deleted file mode 100644 index 9dc7f66..0000000 --- a/pgext/tutorial_fdw/tutorial_fdw--1.0.sql +++ /dev/null @@ -1,7 +0,0 @@ -CREATE FUNCTION tutorial_fdw_handler() -RETURNS fdw_handler -AS '/workspaces/graphalg/pgext/tutorial_fdw/tutorial_fdw' -LANGUAGE C STRICT; - -CREATE FOREIGN DATA WRAPPER tutorial_fdw - HANDLER tutorial_fdw_handler; diff --git a/pgext/tutorial_fdw/tutorial_fdw.c b/pgext/tutorial_fdw/tutorial_fdw.c deleted file mode 100644 index f384436..0000000 --- a/pgext/tutorial_fdw/tutorial_fdw.c +++ /dev/null @@ -1,159 +0,0 @@ -#include "postgres.h" - -#include "access/table.h" -#include "fmgr.h" -#include "foreign/fdwapi.h" -#include "optimizer/pathnode.h" -#include "optimizer/planmain.h" -#include "optimizer/restrictinfo.h" -#include "utils/rel.h" - -Datum tutorial_fdw_handler(PG_FUNCTION_ARGS); - -PG_FUNCTION_INFO_V1(tutorial_fdw_handler); - -void tutorial_fdw_GetForeignRelSize(PlannerInfo *root, RelOptInfo *baserel, - Oid foreigntableid); - -void tutorial_fdw_GetForeignPaths(PlannerInfo *root, RelOptInfo *baserel, - Oid foreigntableid); - -ForeignScan *tutorial_fdw_GetForeignPlan(PlannerInfo *root, RelOptInfo *baserel, - Oid foreigntableid, - ForeignPath *best_path, List *tlist, - List *scan_clauses, Plan *outer_plan); - -void tutorial_fdw_BeginForeignScan(ForeignScanState *node, int eflags); - -TupleTableSlot *tutorial_fdw_IterateForeignScan(ForeignScanState *node); - -void tutorial_fdw_ReScanForeignScan(ForeignScanState *node); - -void tutorial_fdw_EndForeignScan(ForeignScanState *node); - -Datum tutorial_fdw_handler(PG_FUNCTION_ARGS) { - FdwRoutine *fdwroutine = makeNode(FdwRoutine); - fdwroutine->GetForeignRelSize = tutorial_fdw_GetForeignRelSize; - - fdwroutine->GetForeignPaths = tutorial_fdw_GetForeignPaths; - - fdwroutine->GetForeignPlan = tutorial_fdw_GetForeignPlan; - - fdwroutine->BeginForeignScan = tutorial_fdw_BeginForeignScan; - - fdwroutine->IterateForeignScan = tutorial_fdw_IterateForeignScan; - - fdwroutine->ReScanForeignScan = tutorial_fdw_ReScanForeignScan; - - fdwroutine->EndForeignScan = tutorial_fdw_EndForeignScan; - - PG_RETURN_POINTER(fdwroutine); -} - -void tutorial_fdw_GetForeignRelSize(PlannerInfo *root, RelOptInfo *baserel, - Oid foreigntableid) { - Relation rel = table_open(foreigntableid, NoLock); - - if (rel->rd_att->natts != 1) { - - ereport(ERROR, - - errcode(ERRCODE_FDW_INVALID_COLUMN_NUMBER), - - errmsg("incorrect schema for tutorial_fdw table %s: table must " - "have exactly one column", - NameStr(rel->rd_rel->relname))); - } - - /* - Oid typid = rel->rd_att->attrs[0].atttypid; - - if (typid != INT4OID) { - - ereport(ERROR, - - errcode(ERRCODE_FDW_INVALID_DATA_TYPE), - - errmsg("incorrect schema for tutorial_fdw table %s: table column " - "must have type int", - NameStr(rel->rd_rel->relname))); - } - */ - - table_close(rel, NoLock); -} - -void tutorial_fdw_GetForeignPaths(PlannerInfo *root, RelOptInfo *baserel, - Oid foreigntableid) { - Path *path = (Path *)create_foreignscan_path( - root, baserel, NULL, /* default pathtarget */ - baserel->rows, /* rows */ - 0, /* disabled_nodes */ - 1, /* startup cost */ - 1 + baserel->rows, /* total cost */ - NIL, /* no pathkeys */ - NULL, /* no required outer relids */ - NULL, /* no fdw_outerpath */ - NULL, /* no fdw_restrictinfo */ - NIL); /* no fdw_private */ - - add_path(baserel, path); -} - -ForeignScan *tutorial_fdw_GetForeignPlan(PlannerInfo *root, RelOptInfo *baserel, - Oid foreigntableid, - - ForeignPath *best_path, List *tlist, - List *scan_clauses, Plan *outer_plan) { - scan_clauses = extract_actual_clauses(scan_clauses, false); - return make_foreignscan( - tlist, scan_clauses, baserel->relid, - NIL, /* no expressions we will evaluate */ - NIL, /* no private data */ - NIL, /* no custom tlist; our scan tuple looks like tlist */ - NIL, /* no quals we will recheck */ - outer_plan); -} - -typedef struct tutorial_fdw_state { - - int current; - -} tutorial_fdw_state; - -void tutorial_fdw_BeginForeignScan(ForeignScanState *node, int eflags) { - tutorial_fdw_state *state = palloc0(sizeof(tutorial_fdw_state)); - - node->fdw_state = state; -} - -TupleTableSlot *tutorial_fdw_IterateForeignScan(ForeignScanState *node) { - TupleTableSlot *slot = node->ss.ss_ScanTupleSlot; - - ExecClearTuple(slot); - - tutorial_fdw_state *state = node->fdw_state; - - if (state->current < 64) { - - slot->tts_isnull[0] = false; - - slot->tts_values[0] = Int32GetDatum(state->current); - - ExecStoreVirtualTuple(slot); - - state->current++; - } - - return slot; -} - -void tutorial_fdw_ReScanForeignScan(ForeignScanState *node) { - tutorial_fdw_state *state = node->fdw_state; - - state->current = 0; -} - -void tutorial_fdw_EndForeignScan(ForeignScanState *node) {} - -PG_MODULE_MAGIC; diff --git a/pgext/tutorial_fdw/tutorial_fdw.control b/pgext/tutorial_fdw/tutorial_fdw.control deleted file mode 100644 index 28b88da..0000000 --- a/pgext/tutorial_fdw/tutorial_fdw.control +++ /dev/null @@ -1,5 +0,0 @@ -comment = 'Tutorial FDW.' -default_version = '1.0' -module_pathname = '/tutorial_fdw' -relocatable = true - From 4f14f1f59de77e39371e404f420defc6834e5104 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 11 Dec 2025 14:28:39 +0000 Subject: [PATCH 20/20] Works with different types. --- pg_graphalg/README.md | 2 +- pg_graphalg/include/pg_graphalg/MatrixTable.h | 50 +++++++----- pg_graphalg/src/pg_graphalg.cpp | 48 +++++++++--- pg_graphalg/src/pg_graphalg/CMakeLists.txt | 1 - pg_graphalg/src/pg_graphalg/MatrixTable.cpp | 21 ----- pg_graphalg/src/pg_graphalg/PgGraphAlg.cpp | 77 ++++++++++++++++--- pg_graphalg/test/sssp.sql | 52 +++++++++++++ pg_graphalg/{ => test}/test.sql | 3 - 8 files changed, 187 insertions(+), 67 deletions(-) delete mode 100644 pg_graphalg/src/pg_graphalg/MatrixTable.cpp create mode 100644 pg_graphalg/test/sssp.sql rename pg_graphalg/{ => test}/test.sql (94%) diff --git a/pg_graphalg/README.md b/pg_graphalg/README.md index 33439f6..535ec60 100644 --- a/pg_graphalg/README.md +++ b/pg_graphalg/README.md @@ -63,5 +63,5 @@ After the extension has been built, start the server and then run the `test.sql` /usr/local/pgsql/bin/postgres -D ~/pgdata # Run the tests -/usr/local/pgsql/bin/psql postgres -f pg_graphalg/test.sql +/usr/local/pgsql/bin/psql postgres -f pg_graphalg/test/test.sql ``` diff --git a/pg_graphalg/include/pg_graphalg/MatrixTable.h b/pg_graphalg/include/pg_graphalg/MatrixTable.h index 29996fc..ad1fb02 100644 --- a/pg_graphalg/include/pg_graphalg/MatrixTable.h +++ b/pg_graphalg/include/pg_graphalg/MatrixTable.h @@ -1,9 +1,11 @@ #pragma once +#include #include #include #include #include +#include namespace pg_graphalg { @@ -42,43 +44,49 @@ class MatrixTable { std::size_t _nCols; const MatrixValueType _type; + using AnyValue = std::variant; + std::map, AnyValue> _values; + public: MatrixTable(const MatrixTableDef &def) : _name(def.name), _nRows(def.nRows), _nCols(def.nCols), _type(def.type) { } - virtual ~MatrixTable() = default; - std::size_t nRows() const { return _nRows; } std::size_t nCols() const { return _nCols; } MatrixValueType getType() const { return _type; } - virtual std::size_t nValues() = 0; - virtual void clear() = 0; -}; - -class MatrixTableInt : public MatrixTable { -private: - std::map, std::int64_t> _values; - -public: - MatrixTableInt(const MatrixTableDef &def) : MatrixTable(def) {} - - static bool classof(const MatrixTable *t) { - return t->getType() == MatrixValueType::INT; - } + std::size_t nValues() { return _values.size(); } - std::size_t nValues() override { return _values.size(); } - void clear() override { _values.clear(); } + void clear() { _values.clear(); } - void setValue(std::size_t row, std::size_t col, std::int64_t value) { + void setValue(std::size_t row, std::size_t col, AnyValue value) { + assert(getType() != MatrixValueType::BOOL || + std::holds_alternative(value)); + assert(getType() != MatrixValueType::INT || + std::holds_alternative(value)); + assert(getType() != MatrixValueType::FLOAT || + std::holds_alternative(value)); _values[{row, col}] = value; } const auto &values() const { return _values; } - std::optional> - scan(MatrixTableScanState &state); + std::optional> + scan(MatrixTableScanState &state) { + auto it = _values.lower_bound({state.row, state.col}); + if (it == _values.end()) { + return std::nullopt; + } + + std::size_t row = it->first.first; + std::size_t col = it->first.second; + AnyValue val = it->second; + + state.row = row; + state.col = col + 1; + return std::make_tuple(row, col, val); + } }; } // namespace pg_graphalg diff --git a/pg_graphalg/src/pg_graphalg.cpp b/pg_graphalg/src/pg_graphalg.cpp index 5e8a785..00e89c5 100644 --- a/pg_graphalg/src/pg_graphalg.cpp +++ b/pg_graphalg/src/pg_graphalg.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include @@ -42,7 +43,18 @@ PG_MODULE_MAGIC; static void diagHandler(mlir::Diagnostic &diag) { std::string msg = diag.str(); - elog(ERROR, "%s", msg.c_str()); + switch (diag.getSeverity()) { + case mlir::DiagnosticSeverity::Note: + case mlir::DiagnosticSeverity::Remark: + elog(INFO, "%s", msg.c_str()); + break; + case mlir::DiagnosticSeverity::Warning: + elog(WARNING, "%s", msg.c_str()); + break; + case mlir::DiagnosticSeverity::Error: + elog(ERROR, "%s", msg.c_str()); + break; + } } static pg_graphalg::PgGraphAlg *SINGLETON = nullptr; @@ -200,7 +212,6 @@ static std::optional lookupMatrixTable(Oid relid) { tableName, static_cast(*rows), static_cast(*cols), - // TODO: Determine from column type. *valType, }; } @@ -258,14 +269,35 @@ static void BeginForeignScan(ForeignScanState *node, int eflags) { node->fdw_state = state; } +static Datum matrixValueGetDatum(std::variant v) { + if (auto *b = std::get_if(&v)) { + return BoolGetDatum(*b); + } else if (auto *i = std::get_if(&v)) { + return Int64GetDatum(*i); + } else { + return Float8GetDatum(std::get(v)); + } +} + +static std::variant +datumGetMatrixValue(pg_graphalg::MatrixValueType type, Datum v) { + switch (type) { + case pg_graphalg::MatrixValueType::BOOL: + return DatumGetBool(v); + case pg_graphalg::MatrixValueType::INT: + return DatumGetInt64(v); + case pg_graphalg::MatrixValueType::FLOAT: + return DatumGetFloat8(v); + } +} + static TupleTableSlot *IterateForeignScan(ForeignScanState *node) { TupleTableSlot *slot = node->ss.ss_ScanTupleSlot; ExecClearTuple(slot); auto *scanState = static_cast(node->fdw_state); - // TODO: Check type - auto &table = llvm::cast(*scanState->table); + auto &table = *scanState->table; if (auto res = table.scan(*scanState)) { slot->tts_isnull[0] = false; slot->tts_isnull[1] = false; @@ -273,7 +305,7 @@ static TupleTableSlot *IterateForeignScan(ForeignScanState *node) { auto [row, col, val] = *res; slot->tts_values[0] = UInt64GetDatum(row); slot->tts_values[1] = UInt64GetDatum(col); - slot->tts_values[2] = UInt64GetDatum(val); + slot->tts_values[2] = matrixValueGetDatum(val); ExecStoreVirtualTuple(slot); } @@ -303,8 +335,6 @@ static TupleTableSlot *ExecForeignInsert(EState *estate, ResultRelInfo *rinfo, TupleTableSlot *planSlot) { auto tableId = RelationGetRelid(rinfo->ri_RelationDesc); auto &table = **getInstance().getOrCreateTable(tableId, lookupMatrixTable); - // TODO: More types - auto &tableInt = llvm::cast(table); slot_getsomeattrs(slot, 3); if (slot->tts_isnull[0] || slot->tts_isnull[1] || slot->tts_isnull[2]) { @@ -314,8 +344,8 @@ static TupleTableSlot *ExecForeignInsert(EState *estate, ResultRelInfo *rinfo, std::size_t row = DatumGetUInt64(slot->tts_values[0]); std::size_t col = DatumGetUInt64(slot->tts_values[1]); - std::int64_t val = DatumGetInt64(slot->tts_values[2]); - tableInt.setValue(row, col, val); + auto val = datumGetMatrixValue(table.getType(), slot->tts_values[2]); + table.setValue(row, col, val); return slot; } diff --git a/pg_graphalg/src/pg_graphalg/CMakeLists.txt b/pg_graphalg/src/pg_graphalg/CMakeLists.txt index 73be6e7..a5ecabd 100644 --- a/pg_graphalg/src/pg_graphalg/CMakeLists.txt +++ b/pg_graphalg/src/pg_graphalg/CMakeLists.txt @@ -1,5 +1,4 @@ add_library(PgGraphAlg SHARED - MatrixTable.cpp PgGraphAlg.cpp ) target_include_directories(PgGraphAlg PUBLIC ../../include) diff --git a/pg_graphalg/src/pg_graphalg/MatrixTable.cpp b/pg_graphalg/src/pg_graphalg/MatrixTable.cpp deleted file mode 100644 index 41cc8ee..0000000 --- a/pg_graphalg/src/pg_graphalg/MatrixTable.cpp +++ /dev/null @@ -1,21 +0,0 @@ -#include - -namespace pg_graphalg { - -std::optional> -MatrixTableInt::scan(MatrixTableScanState &state) { - auto it = _values.lower_bound({state.row, state.col}); - if (it == _values.end()) { - return std::nullopt; - } - - std::size_t row = it->first.first; - std::size_t col = it->first.second; - std::int64_t val = it->second; - - state.row = row; - state.col = col + 1; - return std::make_tuple(row, col, val); -} - -} // namespace pg_graphalg diff --git a/pg_graphalg/src/pg_graphalg/PgGraphAlg.cpp b/pg_graphalg/src/pg_graphalg/PgGraphAlg.cpp index 4758f0d..0d9fa21 100644 --- a/pg_graphalg/src/pg_graphalg/PgGraphAlg.cpp +++ b/pg_graphalg/src/pg_graphalg/PgGraphAlg.cpp @@ -1,19 +1,24 @@ #include #include #include +#include #include +#include #include +#include #include #include #include #include +#include #include #include #include #include #include +#include #include #include @@ -45,12 +50,68 @@ std::optional PgGraphAlg::getOrCreateTable( return std::nullopt; } - // TODO: Different types. - _tables[tableId] = std::make_unique(*def); + _tables[tableId] = std::make_unique(*def); } return _tables[tableId].get(); } +static mlir::TypedAttr +matrixValueToAttr(mlir::Type t, std::variant v) { + auto *ctx = t.getContext(); + auto intType = graphalg::SemiringTypes::forInt(ctx); + auto realType = graphalg::SemiringTypes::forReal(ctx); + auto tropIntType = graphalg::SemiringTypes::forTropInt(ctx); + auto tropRealType = graphalg::SemiringTypes::forTropReal(ctx); + auto tropMaxIntType = graphalg::SemiringTypes::forTropMaxInt(ctx); + if (t == graphalg::SemiringTypes::forBool(ctx)) { + assert(std::holds_alternative(v)); + return mlir::BoolAttr::get(ctx, std::get(v)); + } else if (t == intType) { + assert(std::holds_alternative(v)); + return mlir::IntegerAttr::get(intType, std::get(v)); + } else if (t == realType) { + assert(std::holds_alternative(v)); + return mlir::FloatAttr::get(realType, std::get(v)); + } else if (t == tropIntType) { + assert(std::holds_alternative(v)); + return graphalg::TropIntAttr::get( + ctx, tropIntType, + mlir::IntegerAttr::get(intType, std::get(v))); + } else if (t == tropRealType) { + assert(std::holds_alternative(v)); + return graphalg::TropFloatAttr::get( + ctx, tropRealType, mlir::FloatAttr::get(realType, std::get(v))); + } else if (t == tropMaxIntType) { + assert(std::holds_alternative(v)); + return graphalg::TropIntAttr::get( + ctx, tropMaxIntType, + mlir::IntegerAttr::get(intType, std::get(v))); + } else { + mlir::emitError(mlir::UnknownLoc::get(ctx)) + << "invalid target type for matrix value: " << t; + return nullptr; + } +} + +static std::variant +attrToMatrixValue(mlir::TypedAttr attr) { + if (auto b = llvm::dyn_cast(attr)) { + return b.getValue(); + } else if (auto i = llvm::dyn_cast(attr)) { + return i.getInt(); + } else if (auto f = llvm::dyn_cast(attr)) { + return f.getValueAsDouble(); + } else if (auto i = llvm::dyn_cast(attr)) { + return i.getValue().getInt(); + } else if (auto f = llvm::dyn_cast(attr)) { + return f.getValue().getValueAsDouble(); + } else { + mlir::emitError(mlir::UnknownLoc::get(attr.getContext())) + << "attribute cannot be converted to matrix value: " << attr; + std::abort(); + } +} + bool PgGraphAlg::execute(llvm::StringRef programSource, llvm::StringRef function, llvm::ArrayRef arguments, @@ -111,12 +172,10 @@ bool PgGraphAlg::execute(llvm::StringRef programSource, llvm::zip_equal(arguments, funcOp.getFunctionType().getInputs())) { auto matType = llvm::cast(type); graphalg::MatrixAttrBuilder builder(matType); - // TODO: support more value types. - const auto &values = llvm::cast(arg)->values(); + const auto &values = arg->values(); for (auto [pos, val] : values) { auto [row, col] = pos; - // TODO: support more value types - auto valAttr = mlir::IntegerAttr::get(matType.getSemiring(), val); + auto valAttr = matrixValueToAttr(matType.getSemiring(), val); builder.set(row, col, valAttr); } @@ -133,16 +192,12 @@ bool PgGraphAlg::execute(llvm::StringRef programSource, // TODO: Check semiring is compatible with value type. output.clear(); - // TODO: more value types - auto &outputInt = llvm::cast(output); auto defaultValue = resultReader.ring().addIdentity(); for (auto r : llvm::seq(resultReader.nRows())) { for (auto c : llvm::seq(resultReader.nCols())) { auto v = resultReader.at(r, c); if (v != defaultValue) { - // TODO: Support bool/real value types - auto vInt = llvm::cast(v).getInt(); - outputInt.setValue(r, c, vInt); + output.setValue(r, c, attrToMatrixValue(v)); } } } diff --git a/pg_graphalg/test/sssp.sql b/pg_graphalg/test/sssp.sql new file mode 100644 index 0000000..b1d7971 --- /dev/null +++ b/pg_graphalg/test/sssp.sql @@ -0,0 +1,52 @@ +-- Create the matrix table for the graph +CREATE FOREIGN TABLE graph(source bigint, target bigint, dist double precision) +SERVER graphalg_server OPTIONS (rows '10', columns '10'); +INSERT INTO graph VALUES + (0, 1, 0.5), + (0, 2, 5.0), + (0, 3, 5.0), + (1, 4, 0.5), + (2, 3, 2.0), + (4, 5, 0.5), + (5, 2, 0.5), + (5, 9, 23.0), + (6, 0, 1.0), + (6, 7, 3.2), + (7, 9, 0.2), + (8, 9, 0.1), + (9, 6, 8.0); + +-- Define the algorithm +CREATE PROCEDURE SSSP(text, text, text) +LANGUAGE graphalg +AS $$ +func sssp( + graph: Matrix, + source: Vector) -> Vector { + dist = source; + for i in graph.nrows { + dist += dist * graph; + } + return dist; +} +$$; + +-- Start from source node 0 +CREATE FOREIGN TABLE source(vertex_id bigint, nop bigint, init_dist double precision) +SERVER graphalg_server OPTIONS (rows '10', columns '1'); +INSERT INTO source VALUES (0, 0, 0.0); + +-- Output of the algorithm +CREATE FOREIGN TABLE dist_out(vertex_id bigint, nop bigint, dist double precision) +SERVER graphalg_server OPTIONS (rows '10', columns '1'); + +-- Run he algorithm +CALL SSSP('graph', 'source', 'dist_out'); + +-- Read the results. +SELECT vertex_id, dist FROM dist_out; + +DROP FOREIGN TABLE graph; +DROP FOREIGN TABLE source; +DROP FOREIGN TABLE dist_out; +DROP PROCEDURE SSSP; diff --git a/pg_graphalg/test.sql b/pg_graphalg/test/test.sql similarity index 94% rename from pg_graphalg/test.sql rename to pg_graphalg/test/test.sql index 59f4ed5..7e113e3 100644 --- a/pg_graphalg/test.sql +++ b/pg_graphalg/test/test.sql @@ -117,6 +117,3 @@ CREATE FOREIGN TABLE matreal ( row bigint, col bigint, val double precision ) SE INSERT INTO matreal VALUES (0, 0, 4.2); DROP FOREIGN TABLE matbool; DROP FOREIGN TABLE matreal; -CREATE FOREIGN TABLE matbad ( row bigint, col bigint, val text ) SERVER graphalg_server OPTIONS (rows '10', columns '10'); -INSERT INTO matbad VALUES (0, 0, '42'); -DROP FOREIGN TABLE matbad;