Skip to content

Commit fd1901e

Browse files
committed
chore: integrate splinter runtime with analyse
1 parent 0379759 commit fd1901e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+1134
-253
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

PLAN.md

Lines changed: 457 additions & 74 deletions
Large diffs are not rendered by default.

crates/pgls_configuration/src/rules/configuration.rs

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
1+
pub mod linter;
2+
pub mod splinter;
3+
4+
pub use crate::analyser::linter::*;
15
use biome_deserialize::Merge;
26
use biome_deserialize_macros::Deserializable;
37
use pgls_analyser::RuleOptions;
48
use pgls_diagnostics::Severity;
59
#[cfg(feature = "schema")]
610
use schemars::JsonSchema;
711
use serde::{Deserialize, Serialize};
12+
use std::str::FromStr;
813

914
#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
1015
#[cfg_attr(feature = "schema", derive(JsonSchema))]
@@ -295,3 +300,127 @@ impl<T: Default> Merge for RuleWithFixOptions<T> {
295300
self.options = other.options;
296301
}
297302
}
303+
304+
#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
305+
pub enum RuleSelector {
306+
LinterGroup(linter::RuleGroup),
307+
LinterRule(linter::RuleGroup, &'static str),
308+
SplinterGroup(splinter::RuleGroup),
309+
SplinterRule(splinter::RuleGroup, &'static str),
310+
}
311+
312+
impl From<RuleSelector> for RuleFilter<'static> {
313+
fn from(value: RuleSelector) -> Self {
314+
match value {
315+
RuleSelector::LinterGroup(group) => RuleFilter::Group(group.as_str()),
316+
RuleSelector::LinterRule(group, name) => RuleFilter::Rule(group.as_str(), name),
317+
RuleSelector::SplinterGroup(group) => RuleFilter::Group(group.as_str()),
318+
RuleSelector::SplinterRule(group, name) => RuleFilter::Rule(group.as_str(), name),
319+
}
320+
}
321+
}
322+
323+
impl<'a> From<&'a RuleSelector> for RuleFilter<'static> {
324+
fn from(value: &'a RuleSelector) -> Self {
325+
match value {
326+
RuleSelector::LinterGroup(group) => RuleFilter::Group(group.as_str()),
327+
RuleSelector::LinterRule(group, name) => RuleFilter::Rule(group.as_str(), name),
328+
RuleSelector::SplinterGroup(group) => RuleFilter::Group(group.as_str()),
329+
RuleSelector::SplinterRule(group, name) => RuleFilter::Rule(group.as_str(), name),
330+
}
331+
}
332+
}
333+
334+
impl FromStr for RuleSelector {
335+
type Err = &'static str;
336+
fn from_str(selector: &str) -> Result<Self, Self::Err> {
337+
// Check for explicit prefixes
338+
if let Some(linter_selector) = selector.strip_prefix("lint/") {
339+
return parse_linter_selector(linter_selector);
340+
}
341+
if let Some(splinter_selector) = selector.strip_prefix("splinter/") {
342+
return parse_splinter_selector(splinter_selector);
343+
}
344+
345+
// No prefix: try linter first (for backward compatibility), then splinter
346+
parse_linter_selector(selector)
347+
.or_else(|_| parse_splinter_selector(selector))
348+
.map_err(|_| "This rule or group doesn't exist in linter or splinter.")
349+
}
350+
}
351+
352+
fn parse_linter_selector(selector: &str) -> Result<RuleSelector, &'static str> {
353+
if let Some((group_name, rule_name)) = selector.split_once('/') {
354+
let group = linter::RuleGroup::from_str(group_name)?;
355+
if let Some(rule_name) = linter::Rules::has_rule(group, rule_name) {
356+
Ok(RuleSelector::LinterRule(group, rule_name))
357+
} else {
358+
Err("This linter rule doesn't exist.")
359+
}
360+
} else {
361+
let group = linter::RuleGroup::from_str(selector)?;
362+
Ok(RuleSelector::LinterGroup(group))
363+
}
364+
}
365+
366+
fn parse_splinter_selector(selector: &str) -> Result<RuleSelector, &'static str> {
367+
if let Some((group_name, rule_name)) = selector.split_once('/') {
368+
let group = splinter::RuleGroup::from_str(group_name)?;
369+
if let Some(rule_name) = splinter::Rules::has_rule(group, rule_name) {
370+
Ok(RuleSelector::SplinterRule(group, rule_name))
371+
} else {
372+
Err("This splinter rule doesn't exist.")
373+
}
374+
} else {
375+
let group = splinter::RuleGroup::from_str(selector)?;
376+
Ok(RuleSelector::SplinterGroup(group))
377+
}
378+
}
379+
380+
impl serde::Serialize for RuleSelector {
381+
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
382+
match self {
383+
RuleSelector::LinterGroup(group) => serializer.serialize_str(group.as_str()),
384+
RuleSelector::LinterRule(group, rule_name) => {
385+
let group_name = group.as_str();
386+
serializer.serialize_str(&format!("{group_name}/{rule_name}"))
387+
}
388+
RuleSelector::SplinterGroup(group) => {
389+
serializer.serialize_str(&format!("splinter/{}", group.as_str()))
390+
}
391+
RuleSelector::SplinterRule(group, rule_name) => {
392+
let group_name = group.as_str();
393+
serializer.serialize_str(&format!("splinter/{group_name}/{rule_name}"))
394+
}
395+
}
396+
}
397+
}
398+
399+
impl<'de> serde::Deserialize<'de> for RuleSelector {
400+
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
401+
struct Visitor;
402+
impl serde::de::Visitor<'_> for Visitor {
403+
type Value = RuleSelector;
404+
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
405+
formatter.write_str("<group>/<ruyle_name>")
406+
}
407+
fn visit_str<E: serde::de::Error>(self, v: &str) -> Result<Self::Value, E> {
408+
match RuleSelector::from_str(v) {
409+
Ok(result) => Ok(result),
410+
Err(error) => Err(serde::de::Error::custom(error)),
411+
}
412+
}
413+
}
414+
deserializer.deserialize_str(Visitor)
415+
}
416+
}
417+
418+
#[cfg(feature = "schema")]
419+
impl schemars::JsonSchema for RuleSelector {
420+
fn schema_name() -> String {
421+
"RuleCode".to_string()
422+
}
423+
fn json_schema(r#gen: &mut schemars::r#gen::SchemaGenerator) -> schemars::schema::Schema {
424+
String::json_schema(r#gen)
425+
}
426+
}

crates/pgls_splinter/Cargo.toml

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@ repository.workspace = true
1111
version = "0.0.0"
1212

1313
[dependencies]
14-
pgls_analyse.workspace = true
15-
pgls_diagnostics.workspace = true
16-
serde.workspace = true
17-
serde_json.workspace = true
18-
sqlx.workspace = true
14+
pgls_analyse.workspace = true
15+
pgls_diagnostics.workspace = true
16+
pgls_schema_cache.workspace = true
17+
serde.workspace = true
18+
serde_json.workspace = true
19+
sqlx.workspace = true
1920

2021
[build-dependencies]
2122
ureq = "2.10"

crates/pgls_splinter/src/convert.rs

Lines changed: 4 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,12 @@ impl From<SplinterQueryResult> for SplinterDiagnostic {
1111
let (schema, object_name, object_type, additional_metadata) =
1212
extract_metadata_fields(&result.metadata);
1313

14-
// for now, we just take the first category as the group
15-
let group = result
16-
.categories
17-
.first()
18-
.map(|s| s.to_lowercase())
19-
.unwrap_or_else(|| "unknown".to_string());
14+
// Look up category from generated registry function
15+
let category = crate::registry::get_rule_category(&result.name)
16+
.expect("Rule name should map to a valid category");
2017

2118
SplinterDiagnostic {
22-
category: rule_name_to_category(&result.name, &group),
19+
category,
2320
message: result.detail.into(),
2421
severity,
2522
db_object: object_name.as_ref().map(|name| DatabaseObjectOwned {
@@ -49,68 +46,6 @@ fn parse_severity(level: &str) -> Severity {
4946
}
5047
}
5148

52-
/// Convert rule name and group to a Category
53-
/// Note: Rule names use snake_case, but categories use camelCase
54-
fn rule_name_to_category(name: &str, group: &str) -> &'static Category {
55-
// we cannot use convert_case here because category! macro requires a string literal
56-
match (group, name) {
57-
("performance", "unindexed_foreign_keys") => {
58-
category!("splinter/performance/unindexedForeignKeys")
59-
}
60-
("performance", "auth_rls_initplan") => {
61-
category!("splinter/performance/authRlsInitplan")
62-
}
63-
("performance", "no_primary_key") => category!("splinter/performance/noPrimaryKey"),
64-
("performance", "unused_index") => category!("splinter/performance/unusedIndex"),
65-
("performance", "duplicate_index") => category!("splinter/performance/duplicateIndex"),
66-
("performance", "table_bloat") => category!("splinter/performance/tableBloat"),
67-
("performance", "multiple_permissive_policies") => {
68-
category!("splinter/performance/multiplePermissivePolicies")
69-
}
70-
("security", "auth_users_exposed") => category!("splinter/security/authUsersExposed"),
71-
("security", "extension_versions_outdated") => {
72-
category!("splinter/security/extensionVersionsOutdated")
73-
}
74-
("security", "policy_exists_rls_disabled") => {
75-
category!("splinter/security/policyExistsRlsDisabled")
76-
}
77-
("security", "rls_enabled_no_policy") => {
78-
category!("splinter/security/rlsEnabledNoPolicy")
79-
}
80-
("security", "security_definer_view") => {
81-
category!("splinter/security/securityDefinerView")
82-
}
83-
("security", "function_search_path_mutable") => {
84-
category!("splinter/security/functionSearchPathMutable")
85-
}
86-
("security", "rls_disabled_in_public") => {
87-
category!("splinter/security/rlsDisabledInPublic")
88-
}
89-
("security", "extension_in_public") => category!("splinter/security/extensionInPublic"),
90-
("security", "rls_references_user_metadata") => {
91-
category!("splinter/security/rlsReferencesUserMetadata")
92-
}
93-
("security", "materialized_view_in_api") => {
94-
category!("splinter/security/materializedViewInApi")
95-
}
96-
("security", "foreign_table_in_api") => {
97-
category!("splinter/security/foreignTableInApi")
98-
}
99-
("security", "unsupported_reg_types") => {
100-
category!("splinter/security/unsupportedRegTypes")
101-
}
102-
("security", "insecure_queue_exposed_in_api") => {
103-
category!("splinter/security/insecureQueueExposedInApi")
104-
}
105-
("security", "fkey_to_auth_unique") => category!("splinter/security/fkeyToAuthUnique"),
106-
_ => {
107-
// Log a warning for unknown rules but provide a fallback
108-
eprintln!("Warning: Unknown splinter rule: {}/{}", group, name);
109-
category!("splinter/performance/unindexedForeignKeys") // Fallback to first rule
110-
}
111-
}
112-
}
113-
11449
/// Extract common metadata fields and return the rest as additional_metadata
11550
fn extract_metadata_fields(
11651
metadata: &Value,

crates/pgls_splinter/src/lib.rs

Lines changed: 85 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ pub mod registry;
55
pub mod rule;
66
pub mod rules;
77

8+
use pgls_analyse::{AnalysisFilter, RegistryVisitor, RuleMeta};
9+
use pgls_schema_cache::SchemaCache;
810
use sqlx::PgPool;
911

1012
pub use diagnostics::{SplinterAdvices, SplinterDiagnostic};
@@ -14,41 +16,103 @@ pub use rule::SplinterRule;
1416
#[derive(Debug)]
1517
pub struct SplinterParams<'a> {
1618
pub conn: &'a PgPool,
19+
pub schema_cache: Option<&'a SchemaCache>,
1720
}
1821

19-
async fn check_supabase_roles(conn: &PgPool) -> Result<bool, sqlx::Error> {
20-
let required_roles = ["anon", "authenticated", "service_role"];
22+
/// Visitor that collects enabled splinter rules based on filter
23+
struct SplinterRuleCollector<'a> {
24+
filter: &'a AnalysisFilter<'a>,
25+
enabled_rules: Vec<String>, // rule names in camelCase
26+
}
2127

22-
let existing_roles: Vec<String> =
23-
sqlx::query_scalar("SELECT rolname FROM pg_roles WHERE rolname = ANY($1)")
24-
.bind(&required_roles[..])
25-
.fetch_all(conn)
26-
.await?;
28+
impl<'a> RegistryVisitor for SplinterRuleCollector<'a> {
29+
fn record_category<C: pgls_analyse::GroupCategory>(&mut self) {
30+
if self.filter.match_category::<C>() {
31+
C::record_groups(self);
32+
}
33+
}
2734

28-
// Check if all required roles exist
29-
let all_exist = required_roles
30-
.iter()
31-
.all(|role| existing_roles.contains(&(*role).to_string()));
35+
fn record_group<G: pgls_analyse::RuleGroup>(&mut self) {
36+
if self.filter.match_group::<G>() {
37+
G::record_rules(self);
38+
}
39+
}
3240

33-
Ok(all_exist)
41+
fn record_rule<R: RuleMeta>(&mut self) {
42+
if self.filter.match_rule::<R>() {
43+
self.enabled_rules.push(R::METADATA.name.to_string());
44+
}
45+
}
3446
}
3547

3648
pub async fn run_splinter(
3749
params: SplinterParams<'_>,
50+
filter: &AnalysisFilter<'_>,
3851
) -> Result<Vec<SplinterDiagnostic>, sqlx::Error> {
39-
let mut all_results = Vec::new();
52+
// Use visitor pattern to collect enabled rules
53+
let mut collector = SplinterRuleCollector {
54+
filter,
55+
enabled_rules: Vec::new(),
56+
};
57+
crate::registry::visit_registry(&mut collector);
58+
59+
// If no rules are enabled, return early
60+
if collector.enabled_rules.is_empty() {
61+
return Ok(Vec::new());
62+
}
63+
64+
// Check if Supabase roles exist (anon, authenticated, service_role)
65+
let has_supabase_roles = params.schema_cache.map_or(false, |cache| {
66+
let required_roles = ["anon", "authenticated", "service_role"];
67+
required_roles.iter().all(|role_name| {
68+
cache
69+
.roles
70+
.iter()
71+
.any(|role| role.name.as_str() == *role_name)
72+
})
73+
});
74+
75+
// Build dynamic SQL query from enabled rules
76+
// Filter out Supabase-specific rules if Supabase roles don't exist
77+
// SQL content is embedded at compile time using include_str! for performance
78+
let mut sql_queries = Vec::new();
4079

41-
let generic_results = query::load_generic_splinter_results(params.conn).await?;
42-
all_results.extend(generic_results);
80+
for rule_name in &collector.enabled_rules {
81+
// Skip Supabase-specific rules if Supabase roles don't exist
82+
if !has_supabase_roles && crate::registry::rule_requires_supabase(rule_name) {
83+
continue;
84+
}
4385

44-
// Only run Supabase-specific rules if the required roles exist
45-
let has_supabase_roles = check_supabase_roles(params.conn).await?;
46-
if has_supabase_roles {
47-
let supabase_results = query::load_supabase_splinter_results(params.conn).await?;
48-
all_results.extend(supabase_results);
86+
// Get embedded SQL content (compile-time included)
87+
if let Some(sql) = crate::registry::get_sql_content(rule_name) {
88+
sql_queries.push(sql);
89+
}
4990
}
5091

51-
let diagnostics: Vec<SplinterDiagnostic> = all_results.into_iter().map(Into::into).collect();
92+
// If no SQL files could be read, return early
93+
if sql_queries.is_empty() {
94+
return Ok(Vec::new());
95+
}
96+
97+
// Combine SQL queries with UNION ALL
98+
let combined_sql = sql_queries.join("\n\nUNION ALL\n\n");
99+
100+
// Execute the combined query
101+
let mut tx = params.conn.begin().await?;
102+
103+
// Set search path as done in the original implementation
104+
sqlx::query("set local search_path = ''")
105+
.execute(&mut *tx)
106+
.await?;
107+
108+
let results = sqlx::query_as::<_, SplinterQueryResult>(&combined_sql)
109+
.fetch_all(&mut *tx)
110+
.await?;
111+
112+
tx.commit().await?;
113+
114+
// Convert results to diagnostics
115+
let diagnostics: Vec<SplinterDiagnostic> = results.into_iter().map(Into::into).collect();
52116

53117
Ok(diagnostics)
54118
}

0 commit comments

Comments
 (0)