|
7 | 7 | use error_stack::{Report, ResultExt}; |
8 | 8 | use fastly::{Request, Response}; |
9 | 9 | use serde::{Deserialize, Serialize}; |
| 10 | +use std::collections::HashSet; |
| 11 | +use url::Host; |
10 | 12 |
|
11 | 13 | use crate::error::TrustedServerError; |
12 | 14 |
|
@@ -64,6 +66,64 @@ fn default_pull_sync_rate_limit() -> u32 { |
64 | 66 | 10 |
65 | 67 | } |
66 | 68 |
|
| 69 | +fn bad_request(message: impl Into<String>) -> Report<TrustedServerError> { |
| 70 | + Report::new(TrustedServerError::BadRequest { |
| 71 | + message: message.into(), |
| 72 | + }) |
| 73 | +} |
| 74 | + |
| 75 | +fn normalize_required_text( |
| 76 | + value: &str, |
| 77 | + field_name: &str, |
| 78 | +) -> Result<String, Report<TrustedServerError>> { |
| 79 | + let trimmed = value.trim(); |
| 80 | + if trimmed.is_empty() { |
| 81 | + return Err(bad_request(format!("{field_name} is required"))); |
| 82 | + } |
| 83 | + Ok(trimmed.to_owned()) |
| 84 | +} |
| 85 | + |
| 86 | +fn normalize_hostname(value: &str, field_name: &str) -> Result<String, Report<TrustedServerError>> { |
| 87 | + let trimmed = value.trim().trim_end_matches('.'); |
| 88 | + if trimmed.is_empty() { |
| 89 | + return Err(bad_request(format!("{field_name} is required"))); |
| 90 | + } |
| 91 | + |
| 92 | + let normalized = trimmed.to_ascii_lowercase(); |
| 93 | + Host::parse(&normalized) |
| 94 | + .map_err(|_| bad_request(format!("{field_name} must be a valid hostname")))?; |
| 95 | + |
| 96 | + Ok(normalized) |
| 97 | +} |
| 98 | + |
| 99 | +fn normalize_hostname_list( |
| 100 | + values: Vec<String>, |
| 101 | + field_name: &str, |
| 102 | +) -> Result<Vec<String>, Report<TrustedServerError>> { |
| 103 | + let mut normalized_values = Vec::with_capacity(values.len()); |
| 104 | + let mut seen = HashSet::with_capacity(values.len()); |
| 105 | + |
| 106 | + for value in values { |
| 107 | + let trimmed = value.trim().trim_end_matches('.'); |
| 108 | + if trimmed.is_empty() { |
| 109 | + return Err(bad_request(format!( |
| 110 | + "{field_name} entries must not be empty" |
| 111 | + ))); |
| 112 | + } |
| 113 | + |
| 114 | + let normalized = trimmed.to_ascii_lowercase(); |
| 115 | + Host::parse(&normalized).map_err(|_| { |
| 116 | + bad_request(format!("{field_name} contains invalid hostname '{value}'")) |
| 117 | + })?; |
| 118 | + |
| 119 | + if seen.insert(normalized.clone()) { |
| 120 | + normalized_values.push(normalized); |
| 121 | + } |
| 122 | + } |
| 123 | + |
| 124 | + Ok(normalized_values) |
| 125 | +} |
| 126 | + |
67 | 127 | /// Response body for `POST /admin/partners/register`. |
68 | 128 | /// |
69 | 129 | /// Echoes key fields without exposing sensitive data (`api_key_hash`, |
@@ -97,54 +157,65 @@ pub fn handle_register_partner( |
97 | 157 | message: "Invalid JSON in request body".to_owned(), |
98 | 158 | })?; |
99 | 159 |
|
| 160 | + let RegisterPartnerRequest { |
| 161 | + id, |
| 162 | + name, |
| 163 | + allowed_return_domains, |
| 164 | + api_key, |
| 165 | + bidstream_enabled, |
| 166 | + source_domain, |
| 167 | + openrtb_atype, |
| 168 | + sync_rate_limit, |
| 169 | + batch_rate_limit, |
| 170 | + pull_sync_enabled, |
| 171 | + pull_sync_url, |
| 172 | + pull_sync_allowed_domains, |
| 173 | + pull_sync_ttl_sec, |
| 174 | + pull_sync_rate_limit, |
| 175 | + ts_pull_token, |
| 176 | + } = request; |
| 177 | + |
100 | 178 | // Validate partner ID. |
101 | | - validate_partner_id(&request.id) |
102 | | - .map_err(|msg| Report::new(TrustedServerError::BadRequest { message: msg }))?; |
103 | | - |
104 | | - // Validate required fields. |
105 | | - if request.name.is_empty() { |
106 | | - return Err(Report::new(TrustedServerError::BadRequest { |
107 | | - message: "name is required".to_owned(), |
108 | | - })); |
109 | | - } |
110 | | - if request.api_key.is_empty() { |
111 | | - return Err(Report::new(TrustedServerError::BadRequest { |
112 | | - message: "api_key is required".to_owned(), |
113 | | - })); |
114 | | - } |
115 | | - if request.source_domain.is_empty() { |
116 | | - return Err(Report::new(TrustedServerError::BadRequest { |
117 | | - message: "source_domain is required".to_owned(), |
118 | | - })); |
| 179 | + validate_partner_id(&id).map_err(bad_request)?; |
| 180 | + |
| 181 | + // Validate and normalize required fields. |
| 182 | + let name = normalize_required_text(&name, "name")?; |
| 183 | + if api_key.trim().is_empty() { |
| 184 | + return Err(bad_request("api_key is required")); |
119 | 185 | } |
120 | | - if request.allowed_return_domains.is_empty() { |
121 | | - return Err(Report::new(TrustedServerError::BadRequest { |
122 | | - message: "allowed_return_domains must have at least one entry".to_owned(), |
123 | | - })); |
| 186 | + let source_domain = normalize_hostname(&source_domain, "source_domain")?; |
| 187 | + |
| 188 | + if allowed_return_domains.is_empty() { |
| 189 | + return Err(bad_request( |
| 190 | + "allowed_return_domains must have at least one entry", |
| 191 | + )); |
124 | 192 | } |
| 193 | + let allowed_return_domains = |
| 194 | + normalize_hostname_list(allowed_return_domains, "allowed_return_domains")?; |
| 195 | + let pull_sync_allowed_domains = |
| 196 | + normalize_hostname_list(pull_sync_allowed_domains, "pull_sync_allowed_domains")?; |
125 | 197 |
|
126 | 198 | // Build the PartnerRecord with hashed API key. |
127 | 199 | let record = PartnerRecord { |
128 | | - id: request.id.clone(), |
129 | | - name: request.name.clone(), |
130 | | - allowed_return_domains: request.allowed_return_domains, |
131 | | - api_key_hash: hash_api_key(&request.api_key), |
132 | | - bidstream_enabled: request.bidstream_enabled, |
133 | | - source_domain: request.source_domain, |
134 | | - openrtb_atype: request.openrtb_atype, |
135 | | - sync_rate_limit: request.sync_rate_limit, |
136 | | - batch_rate_limit: request.batch_rate_limit, |
137 | | - pull_sync_enabled: request.pull_sync_enabled, |
138 | | - pull_sync_url: request.pull_sync_url, |
139 | | - pull_sync_allowed_domains: request.pull_sync_allowed_domains, |
140 | | - pull_sync_ttl_sec: request.pull_sync_ttl_sec, |
141 | | - pull_sync_rate_limit: request.pull_sync_rate_limit, |
142 | | - ts_pull_token: request.ts_pull_token, |
| 200 | + id, |
| 201 | + name, |
| 202 | + allowed_return_domains, |
| 203 | + api_key_hash: hash_api_key(&api_key), |
| 204 | + bidstream_enabled, |
| 205 | + source_domain, |
| 206 | + openrtb_atype, |
| 207 | + sync_rate_limit, |
| 208 | + batch_rate_limit, |
| 209 | + pull_sync_enabled, |
| 210 | + pull_sync_url, |
| 211 | + pull_sync_allowed_domains, |
| 212 | + pull_sync_ttl_sec, |
| 213 | + pull_sync_rate_limit, |
| 214 | + ts_pull_token, |
143 | 215 | }; |
144 | 216 |
|
145 | 217 | // Validate pull sync configuration. |
146 | | - validate_pull_sync_config(&record) |
147 | | - .map_err(|msg| Report::new(TrustedServerError::BadRequest { message: msg }))?; |
| 218 | + validate_pull_sync_config(&record).map_err(bad_request)?; |
148 | 219 |
|
149 | 220 | // Persist to KV store. |
150 | 221 | let created = partner_store.upsert(&record)?; |
@@ -258,4 +329,52 @@ mod tests { |
258 | 329 | Some("https://sync.example-ssp.com/pull") |
259 | 330 | ); |
260 | 331 | } |
| 332 | + |
| 333 | + #[test] |
| 334 | + fn normalize_required_text_rejects_whitespace_only() { |
| 335 | + let err = normalize_required_text(" ", "name") |
| 336 | + .expect_err("should reject whitespace-only required field"); |
| 337 | + assert!( |
| 338 | + err.to_string().contains("name is required"), |
| 339 | + "should mention required field" |
| 340 | + ); |
| 341 | + } |
| 342 | + |
| 343 | + #[test] |
| 344 | + fn normalize_hostname_normalizes_case_and_trailing_dot() { |
| 345 | + let normalized = normalize_hostname(" Sync.Example.COM. ", "source_domain") |
| 346 | + .expect("should parse host"); |
| 347 | + assert_eq!(normalized, "sync.example.com"); |
| 348 | + } |
| 349 | + |
| 350 | + #[test] |
| 351 | + fn normalize_hostname_list_rejects_empty_entry() { |
| 352 | + let err = normalize_hostname_list( |
| 353 | + vec!["sync.example.com".to_owned(), " ".to_owned()], |
| 354 | + "allowed_return_domains", |
| 355 | + ) |
| 356 | + .expect_err("should reject empty domain entries"); |
| 357 | + assert!( |
| 358 | + err.to_string() |
| 359 | + .contains("allowed_return_domains entries must not be empty"), |
| 360 | + "should surface empty-entry error" |
| 361 | + ); |
| 362 | + } |
| 363 | + |
| 364 | + #[test] |
| 365 | + fn normalize_hostname_list_deduplicates_normalized_values() { |
| 366 | + let normalized = normalize_hostname_list( |
| 367 | + vec![ |
| 368 | + "Sync.Example.com".to_owned(), |
| 369 | + "sync.example.com.".to_owned(), |
| 370 | + "cdn.example.com".to_owned(), |
| 371 | + ], |
| 372 | + "allowed_return_domains", |
| 373 | + ) |
| 374 | + .expect("should normalize hostnames"); |
| 375 | + assert_eq!( |
| 376 | + normalized, |
| 377 | + vec!["sync.example.com".to_owned(), "cdn.example.com".to_owned()] |
| 378 | + ); |
| 379 | + } |
261 | 380 | } |
0 commit comments