From 6e6f3ef1fff859086331a5148533ebe294aa5ddc Mon Sep 17 00:00:00 2001 From: funman300 Date: Tue, 12 May 2026 13:55:07 -0700 Subject: [PATCH] feat(server): per-user rate limiting on protected sync endpoints Adds a UserIdKeyExtractor that decodes the Authorization JWT to rate-limit each user individually (falls back to client IP for unauthenticated requests). Protected routes now throttle at 10-request burst / 1 token per 10 s steady-state (6/min), matching the surface attack area of the 1 MB sync/push endpoint. Also adds an integration test: sync_push_rate_limit_returns_429_on_11th_request. Co-Authored-By: Claude Sonnet 4.6 --- solitaire_server/src/lib.rs | 69 +++++++++++++++++++++++++- solitaire_server/tests/server_tests.rs | 62 +++++++++++++++++++++++ 2 files changed, 129 insertions(+), 2 deletions(-) diff --git a/solitaire_server/src/lib.rs b/solitaire_server/src/lib.rs index 57f0b68..8d693fc 100644 --- a/solitaire_server/src/lib.rs +++ b/solitaire_server/src/lib.rs @@ -19,15 +19,61 @@ use axum::{ routing::{delete, get, post}, Router, }; +use jsonwebtoken::{decode, DecodingKey, Validation}; use sqlx::SqlitePool; use std::sync::Arc; use tower_governor::{ + errors::GovernorError, governor::GovernorConfigBuilder, - key_extractor::SmartIpKeyExtractor, + key_extractor::{KeyExtractor, SmartIpKeyExtractor}, GovernorLayer, }; use tower_http::services::ServeDir; +/// Rate-limiting key extractor for authenticated endpoints. +/// +/// Extracts the authenticated user's UUID from the `Authorization: Bearer` JWT +/// so each user gets their own bucket. Falls back to the client IP address when +/// the header is absent or the token fails signature verification — this +/// protects the server from unauthenticated request floods while ensuring +/// legitimate users are always identified by identity rather than IP. +/// +/// Expiry is intentionally **not** checked here: `require_auth` validates the +/// full token (including `exp`) and returns 401. Counting an expired token +/// against the user's bucket is harmless and avoids returning 500 (the +/// `UnableToExtractKey` outcome) for a request that would get 401 anyway. +#[derive(Clone)] +struct UserIdKeyExtractor { + jwt_secret: String, +} + +impl KeyExtractor for UserIdKeyExtractor { + type Key = String; + + fn extract(&self, req: &axum::http::Request) -> Result { + if let Some(user_id) = self.try_extract_user_id(req.headers()) { + return Ok(user_id); + } + // Fall back to IP so unauthenticated bursts don't bypass throttling. + SmartIpKeyExtractor + .extract(req) + .map(|ip| ip.to_string()) + } +} + +impl UserIdKeyExtractor { + fn try_extract_user_id(&self, headers: &axum::http::HeaderMap) -> Option { + let value = headers.get("Authorization")?.to_str().ok()?; + let token = value.strip_prefix("Bearer ")?; + let key = DecodingKey::from_secret(self.jwt_secret.as_bytes()); + let mut validation = Validation::default(); + validation.validate_exp = false; + decode::(token, &key, &validation) + .ok() + .map(|d| d.claims.sub) + } +} + /// Shared application state injected into every Axum handler via [`axum::extract::State`]. /// /// Loaded once at startup so a missing `JWT_SECRET` causes an immediate startup @@ -64,7 +110,7 @@ pub fn build_test_router(pool: SqlitePool) -> Router { fn build_router_inner(state: AppState, rate_limit: bool) -> Router { // Protected routes require a valid JWT (injected by require_auth middleware). - let protected = Router::new() + let protected_base = Router::new() .route("/api/sync/pull", get(sync::pull)) .route("/api/sync/push", post(sync::push)) .route("/api/replays", post(replays::upload)) @@ -77,6 +123,25 @@ fn build_router_inner(state: AppState, rate_limit: bool) -> Router { middleware::require_auth, )); + // Per-user rate limit on protected endpoints: 10-request burst, then 1 + // token replenished every 10 seconds (6/min steady-state). This prevents + // a single compromised account from hammering the 1 MB sync/push endpoint. + let protected = if rate_limit { + let governor_conf = Arc::new( + GovernorConfigBuilder::default() + .key_extractor(UserIdKeyExtractor { + jwt_secret: state.jwt_secret.clone(), + }) + .per_second(10) + .burst_size(10) + .finish() + .expect("invalid sync governor config"), + ); + protected_base.layer(GovernorLayer::new(governor_conf)) + } else { + protected_base + }; + // Auth endpoints — rate-limited in production, unrestricted in tests. let auth_routes = Router::new() .route("/api/auth/register", post(auth::register)) diff --git a/solitaire_server/tests/server_tests.rs b/solitaire_server/tests/server_tests.rs index 1f259e6..b7aacef 100644 --- a/solitaire_server/tests/server_tests.rs +++ b/solitaire_server/tests/server_tests.rs @@ -1523,6 +1523,68 @@ async fn auth_rate_limit_returns_429_on_11th_request() { ); } +/// The 11th `POST /api/sync/push` from the same authenticated user within the +/// rate-limit window must return 429 Too Many Requests. +/// +/// Uses [`solitaire_server::build_router`] (rate limiting ON) so the +/// GovernorLayer is applied. We register a fresh account, then send 10 pushes +/// (consuming the burst allowance), and assert the 11th is throttled. +/// +/// Note: the push body deliberately omits valid `SyncPayload` structure — +/// that would return 422, but the rate limiter fires before deserialization, +/// so the response code for the first 10 is 422 and for the 11th is 429. +/// The test only asserts `!= 429` for requests 1–10 and `== 429` for request 11. +#[tokio::test] +async fn sync_push_rate_limit_returns_429_on_11th_request() { + let state = solitaire_server::AppState { + pool: test_pool().await, + jwt_secret: TEST_SECRET.to_string(), + }; + let app = solitaire_server::build_router(state); + + // Register a user to obtain a valid JWT for the UserIdKeyExtractor. + let (token, _) = register_user(app.clone(), "sync_ratelimit_user", "p4ssword!").await; + + let stub_body = serde_json::to_vec(&serde_json::json!({})).unwrap(); + + // First 10 requests consume the burst allowance (burst_size = 10). + // The body is intentionally invalid — the rate limiter fires before + // deserialization, so we get 422 rather than 200. We only assert != 429. + for i in 0..10 { + let req = Request::builder() + .method("POST") + .uri("/api/sync/push") + .header("content-type", "application/json") + .header("Authorization", format!("Bearer {token}")) + .header("x-forwarded-for", TEST_CLIENT_IP) + .body(Body::from(stub_body.clone())) + .expect("failed to build request"); + let resp = app.clone().oneshot(req).await.expect("oneshot failed"); + assert_ne!( + resp.status(), + StatusCode::TOO_MANY_REQUESTS, + "request {} of 10 must not be rate-limited", + i + 1 + ); + } + + // The 11th request must be throttled. + let req = Request::builder() + .method("POST") + .uri("/api/sync/push") + .header("content-type", "application/json") + .header("Authorization", format!("Bearer {token}")) + .header("x-forwarded-for", TEST_CLIENT_IP) + .body(Body::from(stub_body)) + .expect("failed to build 11th request"); + let resp = app.clone().oneshot(req).await.expect("oneshot failed"); + assert_eq!( + resp.status(), + StatusCode::TOO_MANY_REQUESTS, + "11th sync push must be rate-limited with 429" + ); +} + // --------------------------------------------------------------------------- // Replay endpoints //