diff --git a/docs/postgres.md b/docs/postgres.md new file mode 100644 index 0000000..f03d357 --- /dev/null +++ b/docs/postgres.md @@ -0,0 +1,82 @@ +# GraphAlg in Postgres +The goal: Run GraphAlg programs in PostgreSQL. +This is accomplished by writing a PostgreSQL extension. + +## Building and Testing the Extension +Key steps, explained in more detail below: +- Build postgreSQL v18 from source and install it to the default path `/usr/local/pgsql`. +- Setup a dummy database in `~/pgdata` +- Install SuiteSparse:GraphBLAS +- Build the extension + + +## Build from source +Download the latest postgres release source code. +```bash +cd thirdparty/ +wget https://ftp.postgresql.org/pub/source/v18.1/postgresql-18.1.tar.bz2 +``` + +Following https://www.postgresql.org/docs/current/install-make.html + +Packages to install: +- bison +- flex +- libreadline-dev +- libicu-dev + +```bash +./configure +make +sudo make install +``` + +Now the binaries are located at `/usr/local/pgsql/bin`. + +## Setup a database +```bash +export LC_ALL=C +export LC_CTYPE=C + +# Create a new DB +/usr/local/pgsql/bin/initdb -D ~/pgdata + +# Start the server +/usr/local/pgsql/bin/postgres -D ~/pgdata + +``` + +## Load extension +First build with `pgext/build.sh`. + +Then connect to the server with `/usr/local/pgsql/bin/psql postgres` and run: + +```bash +CREATE FUNCTION add_one(integer) RETURNS integer + AS '/workspaces/graphalg/pgext/funcs', 'add_one' + LANGUAGE C STRICT; +``` + +## Foreign Data Wrapper +Resources: + +## GraphBLAS +We need to link to SuiteSparse:GraphBLAS. + +Install package `libgraphblas-dev`. + +Alternatively: +Download https://github.com/DrTimothyAldenDavis/SuiteSparse/archive/refs/tags/v7.12.1.tar.gz to `thirdparty/`. + +``` +make +sudo make install +``` + +## Resources +- https://www.postgresql.org/docs/current/xfunc-c.html +- https://www.pgedge.com/blog/introduction-to-postgres-extension-development +- https://stackoverflow.com/questions/76056209/postgresql-c-extension-function-table-as-argument-and-as-result +- https://www.postgresql.org/docs/current/fdwhandler.html +- https://www.dolthub.com/blog/2022-01-26-creating-a-postgres-foreign-data-wrapper/ +- https://github.com/Kentik-Archive/wdb_fdw diff --git a/pg_graphalg/.gitignore b/pg_graphalg/.gitignore new file mode 100644 index 0000000..80a6844 --- /dev/null +++ b/pg_graphalg/.gitignore @@ -0,0 +1,2 @@ +/build* +/.cache diff --git a/pg_graphalg/CMakeLists.txt b/pg_graphalg/CMakeLists.txt new file mode 100644 index 0000000..92a4eff --- /dev/null +++ b/pg_graphalg/CMakeLists.txt @@ -0,0 +1,36 @@ +# https://cliutils.gitlab.io/modern-cmake/chapters/basics.html#introduction-to-the-basics +cmake_minimum_required(VERSION 3.15...4.0) + +project( + pg_graphalg + VERSION 0.1 + DESCRIPTION "GraphAlg extension for PostgreSQL" + LANGUAGES C CXX) + +find_package(PostgreSQL COMPONENTS Server) + +include(FetchContent) +FetchContent_Declare( + GraphAlg + SOURCE_DIR ../../compiler +) +FetchContent_MakeAvailable(GraphAlg) + +# ================================= BEGIN LLVM ================================= + +find_package(LLVM REQUIRED CONFIG) +include_directories(${LLVM_INCLUDE_DIRS}) +separate_arguments(LLVM_DEFINITIONS_LIST NATIVE_COMMAND ${LLVM_DEFINITIONS}) +add_definitions(${LLVM_DEFINITIONS_LIST}) +list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") +include(AddLLVM) + +llvm_map_components_to_libnames(llvm_libs support) + +# ================================== END LLVM ================================== + + + +find_library(NAMES GraphBLAS) + +add_subdirectory(src) diff --git a/pg_graphalg/README.md b/pg_graphalg/README.md new file mode 100644 index 0000000..535ec60 --- /dev/null +++ b/pg_graphalg/README.md @@ -0,0 +1,67 @@ +# pg_graphalg: Run GraphAlg in PostgreSQL +This is an extension for PostgreSQL to execute GraphAlg programs. + +## Building +In addition to setting up the devcontainer for the project, you need to: +- Build postgreSQL v18 from source and install it to the default path `/usr/local/pgsql`. +- Setup a dummy database in `~/pgdata` +- Install SuiteSparse:GraphBLAS + +### PostgreSQL +Download the latest postgres release source code. + +```bash +# Dependencies for building PostgreSQL +sudo apt update +sudo apt install bison flex libreadline-dev libicu-dev + +# Download the sources +mkdir -p thirdparty/ +cd thirdparty/ +wget https://ftp.postgresql.org/pub/source/v18.1/postgresql-18.1.tar.bz2 +tar xf postgresql-18.1.tar.bz2 +cd postgresql-18.1/ + +# Build and install +./configure +make +sudo make install +``` + +Now PostgreSQL is installed at `/usr/local/pgsql`. + +### Setup a database +```bash +# Need to set or PostgreSQL will refuse to create the DB +export LC_ALL=C +export LC_CTYPE=C + +# Create a new DB +/usr/local/pgsql/bin/initdb -D ~/pgdata +``` + +### GraphBLAS +Install using APT: + +```bash +sudo apt install libgraphblas-dev +``` + +### Build the Extension +```bash +pg_graphalg/configure.sh + +# Extension is located at pg_graphalg/build/src/libpg_graphalg.so +cmake --build pg_graphalg/build +``` + +## Testing +After the extension has been built, start the server and then run the `test.sql` script: + +```bash +# Start the server +/usr/local/pgsql/bin/postgres -D ~/pgdata + +# Run the tests +/usr/local/pgsql/bin/psql postgres -f pg_graphalg/test/test.sql +``` diff --git a/pg_graphalg/configure.sh b/pg_graphalg/configure.sh new file mode 100755 index 0000000..ddf6cb0 --- /dev/null +++ b/pg_graphalg/configure.sh @@ -0,0 +1,13 @@ +#!/bin/bash +WORKSPACE_ROOT=pg_graphalg/ +BUILD_DIR=$WORKSPACE_ROOT/build +rm -rf $BUILD_DIR +cmake -S $WORKSPACE_ROOT -B $BUILD_DIR -G Ninja \ + -DCMAKE_BUILD_TYPE=Debug \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ + -DCMAKE_CXX_COMPILER=clang++-20 \ + -DCMAKE_POSITION_INDEPENDENT_CODE=TRUE \ + -DCMAKE_EXPORT_COMPILE_COMMANDS=1 \ + -DCMAKE_LINKER_TYPE=MOLD \ + -DPostgreSQL_ROOT=/usr/local/pgsql \ + -DLLVM_ROOT="/opt/llvm-debug" \ diff --git a/pg_graphalg/include/pg_graphalg/MatrixTable.h b/pg_graphalg/include/pg_graphalg/MatrixTable.h new file mode 100644 index 0000000..ad1fb02 --- /dev/null +++ b/pg_graphalg/include/pg_graphalg/MatrixTable.h @@ -0,0 +1,92 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace pg_graphalg { + +enum class MatrixValueType { + BOOL, + INT, + FLOAT, +}; + +struct MatrixTableDef { + std::string name; + std::size_t nRows; + std::size_t nCols; + MatrixValueType type; +}; + +class MatrixTable; + +struct MatrixTableScanState { + MatrixTable *table; + std::size_t row = 0; + std::size_t col = 0; + + MatrixTableScanState(MatrixTable *table) : table(table) {} + + void reset() { + row = 0; + col = 0; + } +}; + +class MatrixTable { +private: + std::string _name; + std::size_t _nRows; + std::size_t _nCols; + const MatrixValueType _type; + + using AnyValue = std::variant; + std::map, AnyValue> _values; + +public: + MatrixTable(const MatrixTableDef &def) + : _name(def.name), _nRows(def.nRows), _nCols(def.nCols), _type(def.type) { + } + + std::size_t nRows() const { return _nRows; } + std::size_t nCols() const { return _nCols; } + MatrixValueType getType() const { return _type; } + + std::size_t nValues() { return _values.size(); } + + void clear() { _values.clear(); } + + void setValue(std::size_t row, std::size_t col, AnyValue value) { + assert(getType() != MatrixValueType::BOOL || + std::holds_alternative(value)); + assert(getType() != MatrixValueType::INT || + std::holds_alternative(value)); + assert(getType() != MatrixValueType::FLOAT || + std::holds_alternative(value)); + _values[{row, col}] = value; + } + + const auto &values() const { return _values; } + + std::optional> + scan(MatrixTableScanState &state) { + auto it = _values.lower_bound({state.row, state.col}); + if (it == _values.end()) { + return std::nullopt; + } + + std::size_t row = it->first.first; + std::size_t col = it->first.second; + AnyValue val = it->second; + + state.row = row; + state.col = col + 1; + return std::make_tuple(row, col, val); + } +}; + +} // namespace pg_graphalg diff --git a/pg_graphalg/include/pg_graphalg/PgGraphAlg.h b/pg_graphalg/include/pg_graphalg/PgGraphAlg.h new file mode 100644 index 0000000..6323402 --- /dev/null +++ b/pg_graphalg/include/pg_graphalg/PgGraphAlg.h @@ -0,0 +1,39 @@ +#pragma once + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "pg_graphalg/MatrixTable.h" + +namespace pg_graphalg { + +using TableId = unsigned int; + +class PgGraphAlg { +private: + mlir::DialectRegistry _registry; + mlir::MLIRContext _ctx; + llvm::DenseMap> _tables; + +public: + PgGraphAlg(llvm::function_ref diagHandler); + + std::optional getOrCreateTable( + TableId tableId, + llvm::function_ref(TableId id)> createFunc); + + bool execute(llvm::StringRef programSource, llvm::StringRef function, + llvm::ArrayRef arguments, + MatrixTable &output); +}; + +} // namespace pg_graphalg diff --git a/pg_graphalg/src/CMakeLists.txt b/pg_graphalg/src/CMakeLists.txt new file mode 100644 index 0000000..2392f61 --- /dev/null +++ b/pg_graphalg/src/CMakeLists.txt @@ -0,0 +1,16 @@ +add_subdirectory(pg_graphalg) + +add_library(pg_graphalg SHARED + pg_graphalg.cpp +) +target_include_directories(pg_graphalg PUBLIC ../include) +target_link_libraries(pg_graphalg + PRIVATE + PgGraphAlg + PostgreSQL::PostgreSQL + ${llvm_libs} + GraphAlgEvaluate + GraphAlgIR + GraphAlgParse + GraphAlgPasses +) diff --git a/pg_graphalg/src/pg_graphalg.cpp b/pg_graphalg/src/pg_graphalg.cpp new file mode 100644 index 0000000..00e89c5 --- /dev/null +++ b/pg_graphalg/src/pg_graphalg.cpp @@ -0,0 +1,510 @@ +#include +#include +#include +#include + +#include +#include +#include + +#include "pg_graphalg/MatrixTable.h" +#include "pg_graphalg/PgGraphAlg.h" + +extern "C" { + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +PG_MODULE_MAGIC; + +static void diagHandler(mlir::Diagnostic &diag) { + std::string msg = diag.str(); + switch (diag.getSeverity()) { + case mlir::DiagnosticSeverity::Note: + case mlir::DiagnosticSeverity::Remark: + elog(INFO, "%s", msg.c_str()); + break; + case mlir::DiagnosticSeverity::Warning: + elog(WARNING, "%s", msg.c_str()); + break; + case mlir::DiagnosticSeverity::Error: + elog(ERROR, "%s", msg.c_str()); + break; + } +} + +static pg_graphalg::PgGraphAlg *SINGLETON = nullptr; +static pg_graphalg::PgGraphAlg &getInstance() { + if (!SINGLETON) { + SINGLETON = new pg_graphalg::PgGraphAlg(diagHandler); + } + + return *SINGLETON; +} + +PG_FUNCTION_INFO_V1(graphalg_fdw_handler); +PG_FUNCTION_INFO_V1(graphalg_fdw_validator); +PG_FUNCTION_INFO_V1(graphalg_pl_call_handler); +PG_FUNCTION_INFO_V1(graphalg_pl_inline_handler); +PG_FUNCTION_INFO_V1(graphalg_pl_validator); + +static std::optional parseDimension(std::string_view s) { + std::size_t v; + auto res = std::from_chars(s.data(), s.data() + s.size(), v); + if (res.ec == std::errc()) { + return v; + } else { + return std::nullopt; + } +} + +static bool validateOption(DefElem *def) { + std::string_view optName{def->defname}; + const char *optValue = defGetString(def); + + if (optName == "rows" || optName == "columns") { + // NOTE: foreign data wrapper options are always strings. + if (!parseDimension(optValue)) { + ereport(ERROR, (errcode(ERRCODE_FDW_ERROR), + errmsg("invalid value for option \"%s\": '%s' must be " + "a non-negative integer", + def->defname, optValue))); + return false; + } + } else { + ereport(ERROR, (errcode(ERRCODE_FDW_INVALID_OPTION_NAME), + errmsg("invalid option \"%s\"", def->defname), + errhint("Valid table options are \"rows\", and " + "\"columns\""))); + return false; + } + + return true; +} + +static std::optional mapValueType(Oid typeId) { + switch (typeId) { + case BOOLOID: + return pg_graphalg::MatrixValueType::BOOL; + case INT8OID: + return pg_graphalg::MatrixValueType::INT; + case FLOAT8OID: + return pg_graphalg::MatrixValueType::FLOAT; + default: + return std::nullopt; + } +} + +struct SysCacheTupleScope { + HeapTuple tup; + + SysCacheTupleScope(SysCacheIdentifier cacheId, Oid key) + : tup(SearchSysCache1(cacheId, key)) {} + ~SysCacheTupleScope() { + if (HeapTupleIsValid(tup)) { + ReleaseSysCache(tup); + } + } +}; + +struct RelationScope { + Relation rel; + + RelationScope(Oid relid) : rel(RelationIdGetRelation(relid)) {} + ~RelationScope() { RelationClose(rel); } +}; + +static std::optional lookupMatrixTable(Oid relid) { + // Must be a foreign table + // TODO: Check that it uses a graphalg server. + auto *fTable = GetForeignTable(relid); + if (!fTable) { + elog(ERROR, "relation with oid %u is not a foreign table", relid); + return std::nullopt; + } + + RelationScope rel(relid); + std::string tableName{NameStr(rel.rel->rd_rel->relname)}; + + // Validate the column types. + int nAttrs = RelationGetNumberOfAttributes(rel.rel); + if (nAttrs != 3) { + elog(ERROR, "matrix table must have 3 columns, got %d", nAttrs); + return std::nullopt; + } + + TupleDesc tupleDesc = rel.rel->rd_att; + auto rowAttr = TupleDescAttr(tupleDesc, 0); + if (rowAttr->atttypid != INT8OID) { + elog(ERROR, "first column (row index) must have type bigint"); + return std::nullopt; + } + + auto colAttr = TupleDescAttr(tupleDesc, 1); + if (colAttr->atttypid != INT8OID) { + elog(ERROR, "second column (column index) must have type bigint"); + return std::nullopt; + } + + auto valAttr = TupleDescAttr(tupleDesc, 2); + auto valType = mapValueType(valAttr->atttypid); + if (!valType) { + elog(ERROR, "third column (value) must have type boolean, bigint or double " + "precision"); + return std::nullopt; + } + + // Parse the options + ListCell *cell; + std::optional rows; + std::optional cols; + + foreach (cell, fTable->options) { + auto *def = lfirst_node(DefElem, cell); + std::string_view defName(def->defname); + + if (!validateOption(def)) { + return std::nullopt; + } else if (defName == "rows") { + rows = parseDimension(defGetString(def)); + } else if (defName == "columns") { + cols = parseDimension(defGetString(def)); + } + } + + if (!rows) { + ereport(ERROR, (errcode(ERRCODE_FDW_ERROR), + errmsg("missing required option \"rows\""))); + return std::nullopt; + } + + if (!cols) { + ereport(ERROR, (errcode(ERRCODE_FDW_ERROR), + errmsg("missing required option \"columns\""))); + return std::nullopt; + } + + return pg_graphalg::MatrixTableDef{ + tableName, + static_cast(*rows), + static_cast(*cols), + *valType, + }; +} + +static void GetForeignRelSize(PlannerInfo *root, RelOptInfo *baserel, + Oid foreigntableid) { + auto table = + getInstance().getOrCreateTable(foreigntableid, lookupMatrixTable); + if (table) { + baserel->rows = (*table)->nValues(); + } +} + +static void GetForeignPaths(PlannerInfo *root, RelOptInfo *baserel, + Oid foreigntableid) { + ForeignPath *path = create_foreignscan_path(root, baserel, + /*target=*/NULL, + /*rows=*/baserel->rows, + /*disabled_nodes=*/0, + /*startup_cost=*/1, + /*total_cost=*/1 + baserel->rows, + /*pathkeys=*/NIL, + /*required_outer=*/NULL, + /*fdw_outerpath=*/NULL, + /*fdw_restrictinfo=*/NULL, + /*fdw_private=*/NULL); + add_path(baserel, (Path *)path); +} + +static ForeignScan *GetForeignPlan(PlannerInfo *root, RelOptInfo *baserel, + Oid foreigntableid, ForeignPath *best_path, + List *tlist, List *scan_clauses, + Plan *outer_plan) { + // On extract_actual_clauses: + // https://www.postgresql.org/docs/current/fdw-planning.html + scan_clauses = extract_actual_clauses(scan_clauses, false); + return make_foreignscan( + tlist, scan_clauses, baserel->relid, + NIL, /* no expressions we will evaluate */ + NIL, /* no private data */ + NIL, /* no custom tlist; our scan tuple looks like tlist */ + NIL, /* no quals we will recheck */ + outer_plan); +} + +static void BeginForeignScan(ForeignScanState *node, int eflags) { + auto tableId = RelationGetRelid(node->ss.ss_currentRelation); + auto table = getInstance().getOrCreateTable(tableId, lookupMatrixTable); + if (!table) { + return; + } + + auto *state = palloc(sizeof(pg_graphalg::MatrixTableScanState)); + new (state) pg_graphalg::MatrixTableScanState(*table); + node->fdw_state = state; +} + +static Datum matrixValueGetDatum(std::variant v) { + if (auto *b = std::get_if(&v)) { + return BoolGetDatum(*b); + } else if (auto *i = std::get_if(&v)) { + return Int64GetDatum(*i); + } else { + return Float8GetDatum(std::get(v)); + } +} + +static std::variant +datumGetMatrixValue(pg_graphalg::MatrixValueType type, Datum v) { + switch (type) { + case pg_graphalg::MatrixValueType::BOOL: + return DatumGetBool(v); + case pg_graphalg::MatrixValueType::INT: + return DatumGetInt64(v); + case pg_graphalg::MatrixValueType::FLOAT: + return DatumGetFloat8(v); + } +} + +static TupleTableSlot *IterateForeignScan(ForeignScanState *node) { + TupleTableSlot *slot = node->ss.ss_ScanTupleSlot; + ExecClearTuple(slot); + + auto *scanState = + static_cast(node->fdw_state); + auto &table = *scanState->table; + if (auto res = table.scan(*scanState)) { + slot->tts_isnull[0] = false; + slot->tts_isnull[1] = false; + slot->tts_isnull[2] = false; + auto [row, col, val] = *res; + slot->tts_values[0] = UInt64GetDatum(row); + slot->tts_values[1] = UInt64GetDatum(col); + slot->tts_values[2] = matrixValueGetDatum(val); + ExecStoreVirtualTuple(slot); + } + + return slot; +} + +static void ReScanForeignScan(ForeignScanState *node) { + auto *scanState = + static_cast(node->fdw_state); + scanState->reset(); +} + +static void EndForeignScan(ForeignScanState *node) { + // No-Op +} + +static void BeginForeignModify(ModifyTableState *mtstate, ResultRelInfo *rinfo, + List *fdw_private, int subplan_index, + int eflags) { + // Ensure table exists before modifying it. + auto tableId = RelationGetRelid(rinfo->ri_RelationDesc); + getInstance().getOrCreateTable(tableId, lookupMatrixTable); +} + +static TupleTableSlot *ExecForeignInsert(EState *estate, ResultRelInfo *rinfo, + TupleTableSlot *slot, + TupleTableSlot *planSlot) { + auto tableId = RelationGetRelid(rinfo->ri_RelationDesc); + auto &table = **getInstance().getOrCreateTable(tableId, lookupMatrixTable); + + slot_getsomeattrs(slot, 3); + if (slot->tts_isnull[0] || slot->tts_isnull[1] || slot->tts_isnull[2]) { + // Ignore nulls + return nullptr; + } + + std::size_t row = DatumGetUInt64(slot->tts_values[0]); + std::size_t col = DatumGetUInt64(slot->tts_values[1]); + auto val = datumGetMatrixValue(table.getType(), slot->tts_values[2]); + table.setValue(row, col, val); + return slot; +} + +Datum graphalg_fdw_handler(PG_FUNCTION_ARGS) { + FdwRoutine *fdwRoutine = makeNode(FdwRoutine); + + fdwRoutine->GetForeignRelSize = GetForeignRelSize; + fdwRoutine->GetForeignPaths = GetForeignPaths; + fdwRoutine->GetForeignPlan = GetForeignPlan; + fdwRoutine->BeginForeignScan = BeginForeignScan; + fdwRoutine->IterateForeignScan = IterateForeignScan; + fdwRoutine->ReScanForeignScan = ReScanForeignScan; + fdwRoutine->EndForeignScan = EndForeignScan; + + fdwRoutine->BeginForeignModify = BeginForeignModify; + fdwRoutine->ExecForeignInsert = ExecForeignInsert; + + PG_RETURN_POINTER(fdwRoutine); +} + +Datum graphalg_fdw_validator(PG_FUNCTION_ARGS) { + List *options = untransformRelOptions(PG_GETARG_DATUM(0)); + + ListCell *cell; + foreach (cell, options) { + auto *def = static_cast(lfirst(cell)); + // TODO: Only allow options at the table level. + validateOption(def); + } + + // NOTE: Not checking that required options are set, because this validator is + // also called when checking options defined on the wrapper or the server. + + PG_RETURN_VOID(); +} + +// NOTE: Assumes a working SPI connection +static std::optional lookupForeignTable(Oid argType, Datum argValue) { + constexpr bool READ_ONLY = true; + /* + * Allow up to 2 results. + * n == 0: ERROR Table not found + * n > 1: ERROR Multiple oids for the given name + * n == 1: OK Name uniquely identifies table + */ + constexpr long TCOUNT = 2; + int execRes = SPI_execute_with_args( + "SELECT oid FROM pg_class WHERE relname=$1 AND relkind = 'f'", 1, + &argType, &argValue, nullptr, READ_ONLY, TCOUNT); + if (execRes < 0) { + elog(ERROR, "internal error finding argument tables"); + PG_RETURN_VOID(); + } + + // Expect exactly one result. + if (SPI_processed != 1) { + // TODO: We know this is a string, so can we make this simpler? + auto typeTuple = SearchSysCache1(TYPEOID, ObjectIdGetDatum(argType)); + auto typeStruct = (Form_pg_type)GETSTRUCT(typeTuple); + FmgrInfo typeInfo; + fmgr_info(typeStruct->typoutput, &typeInfo); + char *value = OutputFunctionCall(&typeInfo, argValue); + ReleaseSysCache(typeTuple); + + if (SPI_processed == 0) { + elog(ERROR, "no such matrix table '%s'", value); + } else { + elog(ERROR, "multiple tables named '%s'", value); + } + + PG_RETURN_VOID(); + } + + auto *tuptable = SPI_tuptable; + auto tupdesc = tuptable->tupdesc; + bool oidNull = false; + auto tableOidDatum = + SPI_getbinval(tuptable->vals[0], tuptable->tupdesc, 1, &oidNull); + return DatumGetObjectId(tableOidDatum); +} + +/** Automatically calls SPI_finish() when it goes out of scope. */ +class SPIConnection { +public: + SPIConnection() { SPI_connect(); } + ~SPIConnection() { SPI_finish(); } +}; + +static Datum executeCall(FunctionCallInfo fcinfo) { + auto procTuple = SearchSysCache( + PROCOID, ObjectIdGetDatum(fcinfo->flinfo->fn_oid), 0, 0, 0); + if (!HeapTupleIsValid(procTuple)) { + elog(ERROR, "cache lookup failed for function %s", fcinfo->flinfo->fn_oid); + PG_RETURN_VOID(); + } + + auto procStruct = (Form_pg_proc)GETSTRUCT(procTuple); + + // Extract program source code + bool sourceIsNull; + auto sourceDatum = + SysCacheGetAttr(PROCOID, procTuple, Anum_pg_proc_prosrc, &sourceIsNull); + if (sourceIsNull) { + elog(ERROR, "NULL procedure source"); + PG_RETURN_VOID(); + } + + char *procCode = DatumGetCString(DirectFunctionCall1(textout, sourceDatum)); + + // Get argument matrix tables. + SPIConnection spiConnection; + llvm::SmallVector arguments; + for (int i = 0; i < fcinfo->nargs; i++) { + auto arg = fcinfo->args[i]; + if (arg.isnull) { + elog(ERROR, "Argument %d is NULL", i); + PG_RETURN_VOID(); + } + + auto argType = procStruct->proargtypes.values[i]; + auto tableOid = lookupForeignTable(argType, arg.value); + if (!tableOid) { + PG_RETURN_VOID(); + } + + auto table = getInstance().getOrCreateTable(*tableOid, lookupMatrixTable); + if (!table) { + PG_RETURN_VOID(); + } + + arguments.push_back(*table); + } + + if (arguments.empty()) { + elog(ERROR, "must have at least one argument"); + PG_RETURN_VOID(); + } + + // Output is written to the final procedure argument. + auto *output = arguments.pop_back_val(); + + // No need to check the result here, postgres infers success based on + // diagnostics. + auto funcName = procStruct->proname.data; + getInstance().execute(procCode, funcName, arguments, *output); + + ReleaseSysCache(procTuple); + PG_RETURN_VOID(); +} + +Datum graphalg_pl_call_handler(PG_FUNCTION_ARGS) { return executeCall(fcinfo); } + +Datum graphalg_pl_inline_handler(PG_FUNCTION_ARGS) { + elog(ERROR, "inline handler not implemented"); + PG_RETURN_VOID(); +} + +Datum graphalg_pl_validator(PG_FUNCTION_ARGS) { + elog(INFO, "NOTE: language validator not implemented"); + PG_RETURN_VOID(); +} +} diff --git a/pg_graphalg/src/pg_graphalg/CMakeLists.txt b/pg_graphalg/src/pg_graphalg/CMakeLists.txt new file mode 100644 index 0000000..a5ecabd --- /dev/null +++ b/pg_graphalg/src/pg_graphalg/CMakeLists.txt @@ -0,0 +1,12 @@ +add_library(PgGraphAlg SHARED + PgGraphAlg.cpp +) +target_include_directories(PgGraphAlg PUBLIC ../../include) +target_link_libraries(PgGraphAlg + PRIVATE + ${llvm_libs} + GraphAlgEvaluate + GraphAlgIR + GraphAlgParse + GraphAlgPasses +) diff --git a/pg_graphalg/src/pg_graphalg/PgGraphAlg.cpp b/pg_graphalg/src/pg_graphalg/PgGraphAlg.cpp new file mode 100644 index 0000000..0d9fa21 --- /dev/null +++ b/pg_graphalg/src/pg_graphalg/PgGraphAlg.cpp @@ -0,0 +1,208 @@ +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace pg_graphalg { + +static mlir::DialectRegistry createDialectRegistry() { + mlir::DialectRegistry registry; + registry.insert(); + registry.insert(); + mlir::func::registerInlinerExtension(registry); + return registry; +} + +PgGraphAlg::PgGraphAlg(llvm::function_ref diagHandler) + : _registry(createDialectRegistry()), _ctx(_registry) { + auto &engine = _ctx.getDiagEngine(); + engine.registerHandler(diagHandler); +} + +std::optional PgGraphAlg::getOrCreateTable( + TableId tableId, + llvm::function_ref(TableId id)> createFunc) { + if (!_tables.contains(tableId)) { + auto def = createFunc(tableId); + if (!def) { + return std::nullopt; + } + + _tables[tableId] = std::make_unique(*def); + } + return _tables[tableId].get(); +} + +static mlir::TypedAttr +matrixValueToAttr(mlir::Type t, std::variant v) { + auto *ctx = t.getContext(); + auto intType = graphalg::SemiringTypes::forInt(ctx); + auto realType = graphalg::SemiringTypes::forReal(ctx); + auto tropIntType = graphalg::SemiringTypes::forTropInt(ctx); + auto tropRealType = graphalg::SemiringTypes::forTropReal(ctx); + auto tropMaxIntType = graphalg::SemiringTypes::forTropMaxInt(ctx); + if (t == graphalg::SemiringTypes::forBool(ctx)) { + assert(std::holds_alternative(v)); + return mlir::BoolAttr::get(ctx, std::get(v)); + } else if (t == intType) { + assert(std::holds_alternative(v)); + return mlir::IntegerAttr::get(intType, std::get(v)); + } else if (t == realType) { + assert(std::holds_alternative(v)); + return mlir::FloatAttr::get(realType, std::get(v)); + } else if (t == tropIntType) { + assert(std::holds_alternative(v)); + return graphalg::TropIntAttr::get( + ctx, tropIntType, + mlir::IntegerAttr::get(intType, std::get(v))); + } else if (t == tropRealType) { + assert(std::holds_alternative(v)); + return graphalg::TropFloatAttr::get( + ctx, tropRealType, mlir::FloatAttr::get(realType, std::get(v))); + } else if (t == tropMaxIntType) { + assert(std::holds_alternative(v)); + return graphalg::TropIntAttr::get( + ctx, tropMaxIntType, + mlir::IntegerAttr::get(intType, std::get(v))); + } else { + mlir::emitError(mlir::UnknownLoc::get(ctx)) + << "invalid target type for matrix value: " << t; + return nullptr; + } +} + +static std::variant +attrToMatrixValue(mlir::TypedAttr attr) { + if (auto b = llvm::dyn_cast(attr)) { + return b.getValue(); + } else if (auto i = llvm::dyn_cast(attr)) { + return i.getInt(); + } else if (auto f = llvm::dyn_cast(attr)) { + return f.getValueAsDouble(); + } else if (auto i = llvm::dyn_cast(attr)) { + return i.getValue().getInt(); + } else if (auto f = llvm::dyn_cast(attr)) { + return f.getValue().getValueAsDouble(); + } else { + mlir::emitError(mlir::UnknownLoc::get(attr.getContext())) + << "attribute cannot be converted to matrix value: " << attr; + std::abort(); + } +} + +bool PgGraphAlg::execute(llvm::StringRef programSource, + llvm::StringRef function, + llvm::ArrayRef arguments, + MatrixTable &output) { + // Parse + llvm::StringRef filename = ""; + auto loc = mlir::FileLineColLoc::get(&_ctx, filename, + /*line=*/1, /*column=*/1); + mlir::OwningOpRef moduleOp = + mlir::ModuleOp::create(loc, filename); + if (mlir::failed(graphalg::parse(programSource, *moduleOp))) { + return false; + } + + // Desugar + { + mlir::PassManager pm(&_ctx); + graphalg::GraphAlgToCorePipelineOptions toCoreOptions; + graphalg::buildGraphAlgToCorePipeline(pm, toCoreOptions); + if (mlir::failed(pm.run(*moduleOp))) { + return false; + } + } + + // Set dimensions + { + llvm::SmallVector argDims; + for (const auto *arg : arguments) { + argDims.push_back(graphalg::CallArgumentDimensions{ + .rows = arg->nRows(), + .cols = arg->nCols(), + }); + } + + graphalg::GraphAlgSetDimensionsOptions options{ + .functionName = function.str(), + .argDims = std::move(argDims), + }; + + mlir::PassManager pm(&_ctx); + pm.addNestedPass( + graphalg::createGraphAlgVerifyDimensions()); + pm.addPass(graphalg::createGraphAlgSetDimensions(options)); + pm.addPass(mlir::createCanonicalizerPass()); + if (mlir::failed(pm.run(*moduleOp))) { + return false; + } + } + + auto funcOp = + llvm::cast(moduleOp->lookupSymbol(function)); + + // TODO: Check semiring and value type are compatible + + // Build arguments + llvm::SmallVector argAttrs; + for (const auto &[arg, type] : + llvm::zip_equal(arguments, funcOp.getFunctionType().getInputs())) { + auto matType = llvm::cast(type); + graphalg::MatrixAttrBuilder builder(matType); + const auto &values = arg->values(); + for (auto [pos, val] : values) { + auto [row, col] = pos; + auto valAttr = matrixValueToAttr(matType.getSemiring(), val); + builder.set(row, col, valAttr); + } + + argAttrs.push_back(builder.build()); + } + + auto result = graphalg::evaluate(funcOp, argAttrs); + if (!result) { + return false; + } + + graphalg::MatrixAttrReader resultReader(result); + // TODO: Check rows/cols match. + // TODO: Check semiring is compatible with value type. + output.clear(); + + auto defaultValue = resultReader.ring().addIdentity(); + for (auto r : llvm::seq(resultReader.nRows())) { + for (auto c : llvm::seq(resultReader.nCols())) { + auto v = resultReader.at(r, c); + if (v != defaultValue) { + output.setValue(r, c, attrToMatrixValue(v)); + } + } + } + + return true; +} + +} // namespace pg_graphalg diff --git a/pg_graphalg/test/sssp.sql b/pg_graphalg/test/sssp.sql new file mode 100644 index 0000000..b1d7971 --- /dev/null +++ b/pg_graphalg/test/sssp.sql @@ -0,0 +1,52 @@ +-- Create the matrix table for the graph +CREATE FOREIGN TABLE graph(source bigint, target bigint, dist double precision) +SERVER graphalg_server OPTIONS (rows '10', columns '10'); +INSERT INTO graph VALUES + (0, 1, 0.5), + (0, 2, 5.0), + (0, 3, 5.0), + (1, 4, 0.5), + (2, 3, 2.0), + (4, 5, 0.5), + (5, 2, 0.5), + (5, 9, 23.0), + (6, 0, 1.0), + (6, 7, 3.2), + (7, 9, 0.2), + (8, 9, 0.1), + (9, 6, 8.0); + +-- Define the algorithm +CREATE PROCEDURE SSSP(text, text, text) +LANGUAGE graphalg +AS $$ +func sssp( + graph: Matrix, + source: Vector) -> Vector { + dist = source; + for i in graph.nrows { + dist += dist * graph; + } + return dist; +} +$$; + +-- Start from source node 0 +CREATE FOREIGN TABLE source(vertex_id bigint, nop bigint, init_dist double precision) +SERVER graphalg_server OPTIONS (rows '10', columns '1'); +INSERT INTO source VALUES (0, 0, 0.0); + +-- Output of the algorithm +CREATE FOREIGN TABLE dist_out(vertex_id bigint, nop bigint, dist double precision) +SERVER graphalg_server OPTIONS (rows '10', columns '1'); + +-- Run he algorithm +CALL SSSP('graph', 'source', 'dist_out'); + +-- Read the results. +SELECT vertex_id, dist FROM dist_out; + +DROP FOREIGN TABLE graph; +DROP FOREIGN TABLE source; +DROP FOREIGN TABLE dist_out; +DROP PROCEDURE SSSP; diff --git a/pg_graphalg/test/test.sql b/pg_graphalg/test/test.sql new file mode 100644 index 0000000..7e113e3 --- /dev/null +++ b/pg_graphalg/test/test.sql @@ -0,0 +1,119 @@ +DROP FOREIGN TABLE IF EXISTS mat1; +DROP FOREIGN TABLE IF EXISTS mat2; +DROP FOREIGN TABLE IF EXISTS lhs; +DROP FOREIGN TABLE IF EXISTS rhs; +DROP FOREIGN TABLE IF EXISTS matmul_out; +DROP SERVER IF EXISTS graphalg_server; +DROP FOREIGN DATA WRAPPER IF EXISTS graphalg_fdw; +DROP FUNCTION IF EXISTS graphalg_fdw_handler; +DROP FUNCTION IF EXISTS graphalg_fdw_validator; + +DROP PROCEDURE IF EXISTS matmul; +DROP LANGUAGE IF EXISTS graphalg; +DROP FUNCTION IF EXISTS graphalg_pl_call_handler; +DROP FUNCTION IF EXISTS graphalg_pl_inline_handler; +DROP FUNCTION IF EXISTS graphalg_pl_validator; + +-- Foreign data wrapper +CREATE FUNCTION graphalg_fdw_handler() +RETURNS fdw_handler +AS '/workspaces/graphalg/pg_graphalg/build/src/libpg_graphalg.so' +LANGUAGE C STRICT; + +CREATE FUNCTION graphalg_fdw_validator(text[], oid) +RETURNS void +AS '/workspaces/graphalg/pg_graphalg/build/src/libpg_graphalg.so' +LANGUAGE C STRICT; + +CREATE FOREIGN DATA WRAPPER graphalg_fdw + HANDLER graphalg_fdw_handler + VALIDATOR graphalg_fdw_validator; +CREATE SERVER graphalg_server FOREIGN DATA WRAPPER graphalg_fdw; + +-- Procedural Language +CREATE FUNCTION graphalg_pl_call_handler() +RETURNS language_handler +AS '/workspaces/graphalg/pg_graphalg/build/src/libpg_graphalg.so' +LANGUAGE C STRICT; + +CREATE FUNCTION graphalg_pl_validator(oid) RETURNS void +AS '/workspaces/graphalg/pg_graphalg/build/src/libpg_graphalg.so' +LANGUAGE C STRICT; + +CREATE FUNCTION graphalg_pl_inline_handler(internal) +RETURNS void +AS '/workspaces/graphalg/pg_graphalg/build/src/libpg_graphalg.so' +LANGUAGE C STRICT; + +CREATE TRUSTED LANGUAGE graphalg +HANDLER graphalg_pl_call_handler +INLINE graphalg_pl_inline_handler +VALIDATOR graphalg_pl_validator; + +CREATE FOREIGN TABLE mat1 ( row bigint, col bigint, val bigint ) SERVER graphalg_server OPTIONS (rows '10', columns '10'); +CREATE FOREIGN TABLE mat2 ( row bigint, col bigint, val bigint ) SERVER graphalg_server OPTIONS (rows '100', columns '100'); +SELECT * FROM mat1; +SELECT * FROM mat2; + +INSERT INTO mat1 VALUES + (0, 0, 42), + (0, 1, 43), + (1, 0, 44), + (1, 1, 45); + +INSERT INTO mat2 VALUES + (0, 0, 420), + (0, 1, 430), + (1, 0, 440), + (1, 1, 450); + +SELECT * FROM mat1; + +SELECT * FROM mat2; + +INSERT INTO mat1 VALUES + (0, 1, 4000); + +SELECT * FROM mat1; + +CREATE FOREIGN TABLE lhs(row bigint, col bigint, val bigint) +SERVER graphalg_server +OPTIONS (rows '2', columns '2'); + +INSERT INTO lhs VALUES + (0, 0, 42), + (0, 1, 43), + (1, 0, 44), + (1, 1, 45); +CREATE FOREIGN TABLE rhs(row bigint, col bigint, val bigint) +SERVER graphalg_server +OPTIONS (rows '2', columns '2'); +INSERT INTO rhs VALUES + (0, 0, 46), + (0, 1, 47), + (1, 0, 48), + (1, 1, 49); + +CREATE FOREIGN TABLE matmul_out(row bigint, col bigint, val bigint) +SERVER graphalg_server +OPTIONS (rows '2', columns '2'); + +CREATE PROCEDURE matmul(text, text, text) +LANGUAGE graphalg +AS $$ + func matmul( + lhs: Matrix, + rhs: Matrix) -> Matrix { + return lhs * rhs; + } +$$; + +CALL matmul('lhs', 'rhs', 'matmul_out'); +SELECT * FROM matmul_out; + +CREATE FOREIGN TABLE matbool ( row bigint, col bigint, val boolean ) SERVER graphalg_server OPTIONS (rows '10', columns '10'); +INSERT INTO matbool VALUES (0, 0, true); +CREATE FOREIGN TABLE matreal ( row bigint, col bigint, val double precision ) SERVER graphalg_server OPTIONS (rows '10', columns '10'); +INSERT INTO matreal VALUES (0, 0, 4.2); +DROP FOREIGN TABLE matbool; +DROP FOREIGN TABLE matreal;