From 0cfc7b20d29124478b41d8cc0a1ba764310374b8 Mon Sep 17 00:00:00 2001 From: Cristian Greco Date: Mon, 15 Dec 2025 11:57:03 +0100 Subject: [PATCH] feat: allow excluding databases or users names in `stat_statements` Add a couple of of options to the `stat_statements` collector that allow excluding certain database names or user names from metrics. This is useful if you don't care about certain preconfigured databases with admin activity. Signed-off-by: Cristian Greco --- README.md | 9 + collector/pg_stat_statements.go | 77 +++++++- collector/pg_stat_statements_test.go | 264 +++++++++++++++++++++++++-- 3 files changed, 333 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index beff2c793..35cc72276 100644 --- a/README.md +++ b/README.md @@ -156,6 +156,15 @@ This will build the docker image as `prometheuscommunity/postgres_exporter:${bra * `--collector.stat_statements.query_length` Maximum length of the statement text. Default is 120. +* `--collector.stat_statements.limit` + Maximum number of statements to return. Default is 100. + +* `--collector.stat_statements.exclude_databases` + Comma-separated list of database names to exclude from `pg_stat_statements` metrics. Default is none. + +* `--collector.stat_statements.exclude_users` + Comma-separated list of user names to exclude from `pg_stat_statements` metrics. Default is none. + * `[no-]collector.stat_user_tables` Enable the `stat_user_tables` collector (default: enabled). diff --git a/collector/pg_stat_statements.go b/collector/pg_stat_statements.go index 824605190..87e68424e 100644 --- a/collector/pg_stat_statements.go +++ b/collector/pg_stat_statements.go @@ -18,6 +18,7 @@ import ( "database/sql" "fmt" "log/slog" + "strings" "github.com/alecthomas/kingpin/v2" "github.com/blang/semver/v4" @@ -30,9 +31,11 @@ const ( ) var ( - includeQueryFlag *bool = nil - statementLengthFlag *uint = nil - statementLimitFlag *uint = nil + includeQueryFlag *bool = nil + statementLengthFlag *uint = nil + statementLimitFlag *uint = nil + excludedDatabasesFlag *string = nil + excludedUsersFlag *string = nil ) func init() { @@ -56,6 +59,16 @@ func init() { "Maximum number of statements to return."). Default(defaultStatementLimit). Uint() + excludedDatabasesFlag = kingpin.Flag( + fmt.Sprint(collectorFlagPrefix, statStatementsSubsystem, ".exclude_databases"), + "Comma-separated list of database names to exclude. (default: none)"). + Default(""). + String() + excludedUsersFlag = kingpin.Flag( + fmt.Sprint(collectorFlagPrefix, statStatementsSubsystem, ".exclude_users"), + "Comma-separated list of user names to exclude. (default: none)"). + Default(""). + String() } type PGStatStatementsCollector struct { @@ -63,14 +76,36 @@ type PGStatStatementsCollector struct { includeQueryStatement bool statementLength uint statementLimit uint + excludedDatabases []string + excludedUsers []string } func NewPGStatStatementsCollector(config collectorConfig) (Collector, error) { + var excludedDatabases []string + if *excludedDatabasesFlag != "" { + for db := range strings.SplitSeq(*excludedDatabasesFlag, ",") { + if trimmed := strings.TrimSpace(db); trimmed != "" { + excludedDatabases = append(excludedDatabases, trimmed) + } + } + } + + var excludedUsers []string + if *excludedUsersFlag != "" { + for user := range strings.SplitSeq(*excludedUsersFlag, ",") { + if trimmed := strings.TrimSpace(user); trimmed != "" { + excludedUsers = append(excludedUsers, trimmed) + } + } + } + return &PGStatStatementsCollector{ log: config.logger, includeQueryStatement: *includeQueryFlag, statementLength: *statementLengthFlag, statementLimit: *statementLimitFlag, + excludedDatabases: excludedDatabases, + excludedUsers: excludedUsers, }, nil } @@ -115,7 +150,9 @@ var ( ) const ( - pgStatStatementQuerySelect = `LEFT(pg_stat_statements.query, %d) as query,` + pgStatStatementQuerySelect = `LEFT(pg_stat_statements.query, %d) as query,` + pgStatStatementExcludeDatabases = `AND pg_database.datname NOT IN (%s) ` + pgStatStatementExcludeUsers = `AND pg_get_userbyid(userid) NOT IN (%s) ` pgStatStatementsQuery = `SELECT pg_get_userbyid(userid) as user, @@ -136,6 +173,7 @@ const ( WITHIN GROUP (ORDER BY total_time) FROM pg_stat_statements ) + %s ORDER BY seconds_total DESC LIMIT %s;` @@ -158,6 +196,7 @@ const ( WITHIN GROUP (ORDER BY total_exec_time) FROM pg_stat_statements ) + %s ORDER BY seconds_total DESC LIMIT %s;` @@ -180,6 +219,7 @@ const ( WITHIN GROUP (ORDER BY total_exec_time) FROM pg_stat_statements ) + %s ORDER BY seconds_total DESC LIMIT %s;` ) @@ -198,11 +238,13 @@ func (c PGStatStatementsCollector) Update(ctx context.Context, instance *instanc if c.includeQueryStatement { querySelect = fmt.Sprintf(pgStatStatementQuerySelect, c.statementLength) } + databaseFilter := c.buildExcludedDatabasesClause() + userFilter := c.buildExcludedUsersClause() statementLimit := defaultStatementLimit if c.statementLimit > 0 { statementLimit = fmt.Sprintf("%d", c.statementLimit) } - query := fmt.Sprintf(queryTemplate, querySelect, statementLimit) + query := fmt.Sprintf(queryTemplate, querySelect, databaseFilter+userFilter, statementLimit) db := instance.getDB() rows, err := db.QueryContext(ctx, query) @@ -319,3 +361,28 @@ func (c PGStatStatementsCollector) Update(ctx context.Context, instance *instanc } return nil } + +func (c PGStatStatementsCollector) buildExcludedDatabasesClause() string { + if len(c.excludedDatabases) == 0 { + return "" + } + + databases := make([]string, 0, len(c.excludedDatabases)) + for _, db := range c.excludedDatabases { + databases = append(databases, fmt.Sprintf("'%s'", strings.ReplaceAll(db, "'", "''"))) + } + + return fmt.Sprintf(pgStatStatementExcludeDatabases, strings.Join(databases, ", ")) +} + +func (c PGStatStatementsCollector) buildExcludedUsersClause() string { + if len(c.excludedUsers) == 0 { + return "" + } + + users := make([]string, 0, len(c.excludedUsers)) + for _, user := range c.excludedUsers { + users = append(users, fmt.Sprintf("'%s'", strings.ReplaceAll(user, "'", "''"))) + } + return fmt.Sprintf(pgStatStatementExcludeUsers, strings.Join(users, ", ")) +} diff --git a/collector/pg_stat_statements_test.go b/collector/pg_stat_statements_test.go index 8763f693d..a313f837f 100644 --- a/collector/pg_stat_statements_test.go +++ b/collector/pg_stat_statements_test.go @@ -36,7 +36,7 @@ func TestPGStatStatementsCollector(t *testing.T) { columns := []string{"user", "datname", "queryid", "calls_total", "seconds_total", "rows_total", "block_read_seconds_total", "block_write_seconds_total"} rows := sqlmock.NewRows(columns). AddRow("postgres", "postgres", 1500, 5, 0.4, 100, 0.1, 0.2) - mock.ExpectQuery(sanitizeQuery(fmt.Sprintf(pgStatStatementsQuery, "", defaultStatementLimit))).WillReturnRows(rows) + mock.ExpectQuery(sanitizeQuery(fmt.Sprintf(pgStatStatementsQuery, "", "", defaultStatementLimit))).WillReturnRows(rows) ch := make(chan prometheus.Metric) go func() { @@ -79,7 +79,7 @@ func TestPGStatStatementsCollectorWithStatement(t *testing.T) { columns := []string{"user", "datname", "queryid", "LEFT(pg_stat_statements.query, 100) as query", "calls_total", "seconds_total", "rows_total", "block_read_seconds_total", "block_write_seconds_total"} rows := sqlmock.NewRows(columns). AddRow("postgres", "postgres", 1500, "select 1 from foo", 5, 0.4, 100, 0.1, 0.2) - mock.ExpectQuery(sanitizeQuery(fmt.Sprintf(pgStatStatementsQuery, fmt.Sprintf(pgStatStatementQuerySelect, 100), defaultStatementLimit))).WillReturnRows(rows) + mock.ExpectQuery(sanitizeQuery(fmt.Sprintf(pgStatStatementsQuery, fmt.Sprintf(pgStatStatementQuerySelect, 100), "", defaultStatementLimit))).WillReturnRows(rows) ch := make(chan prometheus.Metric) go func() { @@ -123,7 +123,7 @@ func TestPGStatStatementsCollectorWithStatementAndLimit(t *testing.T) { columns := []string{"user", "datname", "queryid", "LEFT(pg_stat_statements.query, 100) as query", "calls_total", "seconds_total", "rows_total", "block_read_seconds_total", "block_write_seconds_total"} rows := sqlmock.NewRows(columns). AddRow("postgres", "postgres", 1500, "select 1 from foo", 5, 0.4, 100, 0.1, 0.2) - mock.ExpectQuery(sanitizeQuery(fmt.Sprintf(pgStatStatementsQuery, fmt.Sprintf(pgStatStatementQuerySelect, 100), "10"))).WillReturnRows(rows) + mock.ExpectQuery(sanitizeQuery(fmt.Sprintf(pgStatStatementsQuery, fmt.Sprintf(pgStatStatementQuerySelect, 100), "", "10"))).WillReturnRows(rows) ch := make(chan prometheus.Metric) go func() { @@ -167,7 +167,7 @@ func TestPGStatStatementsCollectorNull(t *testing.T) { columns := []string{"user", "datname", "queryid", "calls_total", "seconds_total", "rows_total", "block_read_seconds_total", "block_write_seconds_total"} rows := sqlmock.NewRows(columns). AddRow(nil, nil, nil, nil, nil, nil, nil, nil) - mock.ExpectQuery(sanitizeQuery(fmt.Sprintf(pgStatStatementsQuery_PG13, "", defaultStatementLimit))).WillReturnRows(rows) + mock.ExpectQuery(sanitizeQuery(fmt.Sprintf(pgStatStatementsQuery_PG13, "", "", defaultStatementLimit))).WillReturnRows(rows) ch := make(chan prometheus.Metric) go func() { @@ -210,7 +210,7 @@ func TestPGStatStatementsCollectorNullWithStatement(t *testing.T) { columns := []string{"user", "datname", "queryid", "LEFT(pg_stat_statements.query, 200) as query", "calls_total", "seconds_total", "rows_total", "block_read_seconds_total", "block_write_seconds_total"} rows := sqlmock.NewRows(columns). AddRow(nil, nil, nil, nil, nil, nil, nil, nil, nil) - mock.ExpectQuery(sanitizeQuery(fmt.Sprintf(pgStatStatementsQuery_PG13, fmt.Sprintf(pgStatStatementQuerySelect, 200), defaultStatementLimit))).WillReturnRows(rows) + mock.ExpectQuery(sanitizeQuery(fmt.Sprintf(pgStatStatementsQuery_PG13, fmt.Sprintf(pgStatStatementQuerySelect, 200), "", defaultStatementLimit))).WillReturnRows(rows) ch := make(chan prometheus.Metric) go func() { @@ -254,7 +254,7 @@ func TestPGStatStatementsCollectorNullWithStatementAndLimit(t *testing.T) { columns := []string{"user", "datname", "queryid", "LEFT(pg_stat_statements.query, 200) as query", "calls_total", "seconds_total", "rows_total", "block_read_seconds_total", "block_write_seconds_total"} rows := sqlmock.NewRows(columns). AddRow(nil, nil, nil, nil, nil, nil, nil, nil, nil) - mock.ExpectQuery(sanitizeQuery(fmt.Sprintf(pgStatStatementsQuery_PG13, fmt.Sprintf(pgStatStatementQuerySelect, 200), "10"))).WillReturnRows(rows) + mock.ExpectQuery(sanitizeQuery(fmt.Sprintf(pgStatStatementsQuery_PG13, fmt.Sprintf(pgStatStatementQuerySelect, 200), "", "10"))).WillReturnRows(rows) ch := make(chan prometheus.Metric) go func() { @@ -298,7 +298,7 @@ func TestPGStatStatementsCollector_PG13(t *testing.T) { columns := []string{"user", "datname", "queryid", "calls_total", "seconds_total", "rows_total", "block_read_seconds_total", "block_write_seconds_total"} rows := sqlmock.NewRows(columns). AddRow("postgres", "postgres", 1500, 5, 0.4, 100, 0.1, 0.2) - mock.ExpectQuery(sanitizeQuery(fmt.Sprintf(pgStatStatementsQuery_PG13, "", defaultStatementLimit))).WillReturnRows(rows) + mock.ExpectQuery(sanitizeQuery(fmt.Sprintf(pgStatStatementsQuery_PG13, "", "", defaultStatementLimit))).WillReturnRows(rows) ch := make(chan prometheus.Metric) go func() { @@ -341,7 +341,7 @@ func TestPGStatStatementsCollector_PG13_WithStatement(t *testing.T) { columns := []string{"user", "datname", "queryid", "LEFT(pg_stat_statements.query, 300) as query", "calls_total", "seconds_total", "rows_total", "block_read_seconds_total", "block_write_seconds_total"} rows := sqlmock.NewRows(columns). AddRow("postgres", "postgres", 1500, "select 1 from foo", 5, 0.4, 100, 0.1, 0.2) - mock.ExpectQuery(sanitizeQuery(fmt.Sprintf(pgStatStatementsQuery_PG13, fmt.Sprintf(pgStatStatementQuerySelect, 300), defaultStatementLimit))).WillReturnRows(rows) + mock.ExpectQuery(sanitizeQuery(fmt.Sprintf(pgStatStatementsQuery_PG13, fmt.Sprintf(pgStatStatementQuerySelect, 300), "", defaultStatementLimit))).WillReturnRows(rows) ch := make(chan prometheus.Metric) go func() { @@ -385,7 +385,7 @@ func TestPGStatStatementsCollector_PG13_WithStatementAndLimit(t *testing.T) { columns := []string{"user", "datname", "queryid", "LEFT(pg_stat_statements.query, 300) as query", "calls_total", "seconds_total", "rows_total", "block_read_seconds_total", "block_write_seconds_total"} rows := sqlmock.NewRows(columns). AddRow("postgres", "postgres", 1500, "select 1 from foo", 5, 0.4, 100, 0.1, 0.2) - mock.ExpectQuery(sanitizeQuery(fmt.Sprintf(pgStatStatementsQuery_PG13, fmt.Sprintf(pgStatStatementQuerySelect, 300), "10"))).WillReturnRows(rows) + mock.ExpectQuery(sanitizeQuery(fmt.Sprintf(pgStatStatementsQuery_PG13, fmt.Sprintf(pgStatStatementQuerySelect, 300), "", "10"))).WillReturnRows(rows) ch := make(chan prometheus.Metric) go func() { @@ -429,7 +429,7 @@ func TestPGStatStatementsCollector_PG17(t *testing.T) { columns := []string{"user", "datname", "queryid", "calls_total", "seconds_total", "rows_total", "block_read_seconds_total", "block_write_seconds_total"} rows := sqlmock.NewRows(columns). AddRow("postgres", "postgres", 1500, 5, 0.4, 100, 0.1, 0.2) - mock.ExpectQuery(sanitizeQuery(fmt.Sprintf(pgStatStatementsQuery_PG17, "", defaultStatementLimit))).WillReturnRows(rows) + mock.ExpectQuery(sanitizeQuery(fmt.Sprintf(pgStatStatementsQuery_PG17, "", "", defaultStatementLimit))).WillReturnRows(rows) ch := make(chan prometheus.Metric) go func() { @@ -472,7 +472,7 @@ func TestPGStatStatementsCollector_PG17_WithStatement(t *testing.T) { columns := []string{"user", "datname", "queryid", "LEFT(pg_stat_statements.query, 300) as query", "calls_total", "seconds_total", "rows_total", "block_read_seconds_total", "block_write_seconds_total"} rows := sqlmock.NewRows(columns). AddRow("postgres", "postgres", 1500, "select 1 from foo", 5, 0.4, 100, 0.1, 0.2) - mock.ExpectQuery(sanitizeQuery(fmt.Sprintf(pgStatStatementsQuery_PG17, fmt.Sprintf(pgStatStatementQuerySelect, 300), defaultStatementLimit))).WillReturnRows(rows) + mock.ExpectQuery(sanitizeQuery(fmt.Sprintf(pgStatStatementsQuery_PG17, fmt.Sprintf(pgStatStatementQuerySelect, 300), "", defaultStatementLimit))).WillReturnRows(rows) ch := make(chan prometheus.Metric) go func() { @@ -516,7 +516,7 @@ func TestPGStatStatementsCollector_PG17_WithStatementAndLimit(t *testing.T) { columns := []string{"user", "datname", "queryid", "LEFT(pg_stat_statements.query, 300) as query", "calls_total", "seconds_total", "rows_total", "block_read_seconds_total", "block_write_seconds_total"} rows := sqlmock.NewRows(columns). AddRow("postgres", "postgres", 1500, "select 1 from foo", 5, 0.4, 100, 0.1, 0.2) - mock.ExpectQuery(sanitizeQuery(fmt.Sprintf(pgStatStatementsQuery_PG17, fmt.Sprintf(pgStatStatementQuerySelect, 300), "10"))).WillReturnRows(rows) + mock.ExpectQuery(sanitizeQuery(fmt.Sprintf(pgStatStatementsQuery_PG17, fmt.Sprintf(pgStatStatementQuerySelect, 300), "", "10"))).WillReturnRows(rows) ch := make(chan prometheus.Metric) go func() { @@ -547,3 +547,243 @@ func TestPGStatStatementsCollector_PG17_WithStatementAndLimit(t *testing.T) { t.Errorf("there were unfulfilled exceptions: %s", err) } } + +func TestPGStatStatementsCollectorWithExcludedDatabases(t *testing.T) { + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("Error opening a stub db connection: %s", err) + } + defer db.Close() + + inst := &instance{db: db, version: semver.MustParse("12.0.0")} + + excludedDatabases := []string{"rdsadmin", "cloudsqladmin"} + + columns := []string{"user", "datname", "queryid", "calls_total", "seconds_total", "rows_total", "block_read_seconds_total", "block_write_seconds_total"} + rows := sqlmock.NewRows(columns). + AddRow("postgres", "postgres", 1500, 5, 0.4, 100, 0.1, 0.2) + + // Expected query should include database exclusion filters + expectedFilter := " AND pg_database.datname NOT IN ('rdsadmin', 'cloudsqladmin')" + mock.ExpectQuery(sanitizeQuery(fmt.Sprintf(pgStatStatementsQuery, "", expectedFilter, defaultStatementLimit))).WillReturnRows(rows) + + ch := make(chan prometheus.Metric) + go func() { + defer close(ch) + c := PGStatStatementsCollector{excludedDatabases: excludedDatabases} + + if err := c.Update(context.Background(), inst, ch); err != nil { + t.Errorf("Error calling PGStatStatementsCollector.Update: %s", err) + } + }() + + expected := []MetricResult{ + {labels: labelMap{"user": "postgres", "datname": "postgres", "queryid": "1500"}, metricType: dto.MetricType_COUNTER, value: 5}, + {labels: labelMap{"user": "postgres", "datname": "postgres", "queryid": "1500"}, metricType: dto.MetricType_COUNTER, value: 0.4}, + {labels: labelMap{"user": "postgres", "datname": "postgres", "queryid": "1500"}, metricType: dto.MetricType_COUNTER, value: 100}, + {labels: labelMap{"user": "postgres", "datname": "postgres", "queryid": "1500"}, metricType: dto.MetricType_COUNTER, value: 0.1}, + {labels: labelMap{"user": "postgres", "datname": "postgres", "queryid": "1500"}, metricType: dto.MetricType_COUNTER, value: 0.2}, + } + + convey.Convey("Metrics comparison with excluded databases", t, func() { + for _, expect := range expected { + m := readMetric(<-ch) + convey.So(expect, convey.ShouldResemble, m) + } + }) + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled exceptions: %s", err) + } +} + +func TestPGStatStatementsCollectorWithExcludedDatabasesSQLEscaping(t *testing.T) { + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("Error opening a stub db connection: %s", err) + } + defer db.Close() + + inst := &instance{db: db, version: semver.MustParse("13.0.0")} + + // exclude a database name with a single quote (SQL injection test) + excludedDatabases := []string{"test'db"} + + columns := []string{"user", "datname", "queryid", "calls_total", "seconds_total", "rows_total", "block_read_seconds_total", "block_write_seconds_total"} + rows := sqlmock.NewRows(columns). + AddRow("postgres", "postgres", 1500, 5, 0.4, 100, 0.1, 0.2) + + expectedFilter := " AND pg_database.datname NOT IN ('test''db')" + mock.ExpectQuery(sanitizeQuery(fmt.Sprintf(pgStatStatementsQuery_PG13, "", expectedFilter, defaultStatementLimit))).WillReturnRows(rows) + + ch := make(chan prometheus.Metric) + go func() { + defer close(ch) + c := PGStatStatementsCollector{excludedDatabases: excludedDatabases} + + if err := c.Update(context.Background(), inst, ch); err != nil { + t.Errorf("Error calling PGStatStatementsCollector.Update: %s", err) + } + }() + + expected := []MetricResult{ + {labels: labelMap{"user": "postgres", "datname": "postgres", "queryid": "1500"}, metricType: dto.MetricType_COUNTER, value: 5}, + {labels: labelMap{"user": "postgres", "datname": "postgres", "queryid": "1500"}, metricType: dto.MetricType_COUNTER, value: 0.4}, + {labels: labelMap{"user": "postgres", "datname": "postgres", "queryid": "1500"}, metricType: dto.MetricType_COUNTER, value: 100}, + {labels: labelMap{"user": "postgres", "datname": "postgres", "queryid": "1500"}, metricType: dto.MetricType_COUNTER, value: 0.1}, + {labels: labelMap{"user": "postgres", "datname": "postgres", "queryid": "1500"}, metricType: dto.MetricType_COUNTER, value: 0.2}, + } + + convey.Convey("Metrics comparison with SQL escaping", t, func() { + for _, expect := range expected { + m := readMetric(<-ch) + convey.So(expect, convey.ShouldResemble, m) + } + }) + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled exceptions: %s", err) + } +} + +func TestPGStatStatementsCollectorWithExcludedUsers(t *testing.T) { + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("Error opening a stub db connection: %s", err) + } + defer db.Close() + + inst := &instance{db: db, version: semver.MustParse("12.0.0")} + + excludedUsers := []string{"monitoring", "readonly"} + + columns := []string{"user", "datname", "queryid", "calls_total", "seconds_total", "rows_total", "block_read_seconds_total", "block_write_seconds_total"} + rows := sqlmock.NewRows(columns). + AddRow("postgres", "postgres", 1500, 5, 0.4, 100, 0.1, 0.2) + + expectedFilter := " AND pg_get_userbyid(userid) NOT IN ('monitoring', 'readonly')" + mock.ExpectQuery(sanitizeQuery(fmt.Sprintf(pgStatStatementsQuery, "", expectedFilter, defaultStatementLimit))).WillReturnRows(rows) + + ch := make(chan prometheus.Metric) + go func() { + defer close(ch) + c := PGStatStatementsCollector{excludedUsers: excludedUsers} + + if err := c.Update(context.Background(), inst, ch); err != nil { + t.Errorf("Error calling PGStatStatementsCollector.Update: %s", err) + } + }() + + expected := []MetricResult{ + {labels: labelMap{"user": "postgres", "datname": "postgres", "queryid": "1500"}, metricType: dto.MetricType_COUNTER, value: 5}, + {labels: labelMap{"user": "postgres", "datname": "postgres", "queryid": "1500"}, metricType: dto.MetricType_COUNTER, value: 0.4}, + {labels: labelMap{"user": "postgres", "datname": "postgres", "queryid": "1500"}, metricType: dto.MetricType_COUNTER, value: 100}, + {labels: labelMap{"user": "postgres", "datname": "postgres", "queryid": "1500"}, metricType: dto.MetricType_COUNTER, value: 0.1}, + {labels: labelMap{"user": "postgres", "datname": "postgres", "queryid": "1500"}, metricType: dto.MetricType_COUNTER, value: 0.2}, + } + + convey.Convey("Metrics comparison with excluded users", t, func() { + for _, expect := range expected { + m := readMetric(<-ch) + convey.So(expect, convey.ShouldResemble, m) + } + }) + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled exceptions: %s", err) + } +} + +func TestPGStatStatementsCollectorWithExcludedUsersSQLEscaping(t *testing.T) { + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("Error opening a stub db connection: %s", err) + } + defer db.Close() + + inst := &instance{db: db, version: semver.MustParse("13.0.0")} + + // exclude a user name with a single quote (SQL injection test) + excludedUsers := []string{"test'user"} + + columns := []string{"user", "datname", "queryid", "calls_total", "seconds_total", "rows_total", "block_read_seconds_total", "block_write_seconds_total"} + rows := sqlmock.NewRows(columns). + AddRow("postgres", "postgres", 1500, 5, 0.4, 100, 0.1, 0.2) + + expectedFilter := " AND pg_get_userbyid(userid) NOT IN ('test''user')" + mock.ExpectQuery(sanitizeQuery(fmt.Sprintf(pgStatStatementsQuery_PG13, "", expectedFilter, defaultStatementLimit))).WillReturnRows(rows) + + ch := make(chan prometheus.Metric) + go func() { + defer close(ch) + c := PGStatStatementsCollector{excludedUsers: excludedUsers} + + if err := c.Update(context.Background(), inst, ch); err != nil { + t.Errorf("Error calling PGStatStatementsCollector.Update: %s", err) + } + }() + + expected := []MetricResult{ + {labels: labelMap{"user": "postgres", "datname": "postgres", "queryid": "1500"}, metricType: dto.MetricType_COUNTER, value: 5}, + {labels: labelMap{"user": "postgres", "datname": "postgres", "queryid": "1500"}, metricType: dto.MetricType_COUNTER, value: 0.4}, + {labels: labelMap{"user": "postgres", "datname": "postgres", "queryid": "1500"}, metricType: dto.MetricType_COUNTER, value: 100}, + {labels: labelMap{"user": "postgres", "datname": "postgres", "queryid": "1500"}, metricType: dto.MetricType_COUNTER, value: 0.1}, + {labels: labelMap{"user": "postgres", "datname": "postgres", "queryid": "1500"}, metricType: dto.MetricType_COUNTER, value: 0.2}, + } + + convey.Convey("Metrics comparison with SQL escaping for users", t, func() { + for _, expect := range expected { + m := readMetric(<-ch) + convey.So(expect, convey.ShouldResemble, m) + } + }) + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled exceptions: %s", err) + } +} + +func TestPGStatStatementsCollectorWithExcludedDatabasesAndUsers(t *testing.T) { + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("Error opening a stub db connection: %s", err) + } + defer db.Close() + + inst := &instance{db: db, version: semver.MustParse("17.0.0")} + + // exclude both databases and users + excludedDatabases := []string{"rdsadmin", "cloudsqladmin"} + excludedUsers := []string{"monitoring", "readonly"} + + columns := []string{"user", "datname", "queryid", "calls_total", "seconds_total", "rows_total", "block_read_seconds_total", "block_write_seconds_total"} + rows := sqlmock.NewRows(columns). + AddRow("postgres", "postgres", 1500, 5, 0.4, 100, 0.1, 0.2) + + expectedFilter := " AND pg_database.datname NOT IN ('rdsadmin', 'cloudsqladmin') AND pg_get_userbyid(userid) NOT IN ('monitoring', 'readonly')" + mock.ExpectQuery(sanitizeQuery(fmt.Sprintf(pgStatStatementsQuery_PG17, "", expectedFilter, defaultStatementLimit))).WillReturnRows(rows) + + ch := make(chan prometheus.Metric) + go func() { + defer close(ch) + c := PGStatStatementsCollector{excludedDatabases: excludedDatabases, excludedUsers: excludedUsers} + + if err := c.Update(context.Background(), inst, ch); err != nil { + t.Errorf("Error calling PGStatStatementsCollector.Update: %s", err) + } + }() + + expected := []MetricResult{ + {labels: labelMap{"user": "postgres", "datname": "postgres", "queryid": "1500"}, metricType: dto.MetricType_COUNTER, value: 5}, + {labels: labelMap{"user": "postgres", "datname": "postgres", "queryid": "1500"}, metricType: dto.MetricType_COUNTER, value: 0.4}, + {labels: labelMap{"user": "postgres", "datname": "postgres", "queryid": "1500"}, metricType: dto.MetricType_COUNTER, value: 100}, + {labels: labelMap{"user": "postgres", "datname": "postgres", "queryid": "1500"}, metricType: dto.MetricType_COUNTER, value: 0.1}, + {labels: labelMap{"user": "postgres", "datname": "postgres", "queryid": "1500"}, metricType: dto.MetricType_COUNTER, value: 0.2}, + } + + convey.Convey("Metrics comparison with excluded databases and users", t, func() { + for _, expect := range expected { + m := readMetric(<-ch) + convey.So(expect, convey.ShouldResemble, m) + } + }) + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled exceptions: %s", err) + } +}