feat(server): per-user rate limiting on protected sync endpoints
Adds a UserIdKeyExtractor that decodes the Authorization JWT to rate-limit each user individually (falls back to client IP for unauthenticated requests). Protected routes now throttle at 10-request burst / 1 token per 10 s steady-state (6/min), matching the surface attack area of the 1 MB sync/push endpoint. Also adds an integration test: sync_push_rate_limit_returns_429_on_11th_request. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -19,15 +19,61 @@ use axum::{
|
|||||||
routing::{delete, get, post},
|
routing::{delete, get, post},
|
||||||
Router,
|
Router,
|
||||||
};
|
};
|
||||||
|
use jsonwebtoken::{decode, DecodingKey, Validation};
|
||||||
use sqlx::SqlitePool;
|
use sqlx::SqlitePool;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tower_governor::{
|
use tower_governor::{
|
||||||
|
errors::GovernorError,
|
||||||
governor::GovernorConfigBuilder,
|
governor::GovernorConfigBuilder,
|
||||||
key_extractor::SmartIpKeyExtractor,
|
key_extractor::{KeyExtractor, SmartIpKeyExtractor},
|
||||||
GovernorLayer,
|
GovernorLayer,
|
||||||
};
|
};
|
||||||
use tower_http::services::ServeDir;
|
use tower_http::services::ServeDir;
|
||||||
|
|
||||||
|
/// Rate-limiting key extractor for authenticated endpoints.
|
||||||
|
///
|
||||||
|
/// Extracts the authenticated user's UUID from the `Authorization: Bearer` JWT
|
||||||
|
/// so each user gets their own bucket. Falls back to the client IP address when
|
||||||
|
/// the header is absent or the token fails signature verification — this
|
||||||
|
/// protects the server from unauthenticated request floods while ensuring
|
||||||
|
/// legitimate users are always identified by identity rather than IP.
|
||||||
|
///
|
||||||
|
/// Expiry is intentionally **not** checked here: `require_auth` validates the
|
||||||
|
/// full token (including `exp`) and returns 401. Counting an expired token
|
||||||
|
/// against the user's bucket is harmless and avoids returning 500 (the
|
||||||
|
/// `UnableToExtractKey` outcome) for a request that would get 401 anyway.
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct UserIdKeyExtractor {
|
||||||
|
jwt_secret: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl KeyExtractor for UserIdKeyExtractor {
|
||||||
|
type Key = String;
|
||||||
|
|
||||||
|
fn extract<T>(&self, req: &axum::http::Request<T>) -> Result<Self::Key, GovernorError> {
|
||||||
|
if let Some(user_id) = self.try_extract_user_id(req.headers()) {
|
||||||
|
return Ok(user_id);
|
||||||
|
}
|
||||||
|
// Fall back to IP so unauthenticated bursts don't bypass throttling.
|
||||||
|
SmartIpKeyExtractor
|
||||||
|
.extract(req)
|
||||||
|
.map(|ip| ip.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UserIdKeyExtractor {
|
||||||
|
fn try_extract_user_id(&self, headers: &axum::http::HeaderMap) -> Option<String> {
|
||||||
|
let value = headers.get("Authorization")?.to_str().ok()?;
|
||||||
|
let token = value.strip_prefix("Bearer ")?;
|
||||||
|
let key = DecodingKey::from_secret(self.jwt_secret.as_bytes());
|
||||||
|
let mut validation = Validation::default();
|
||||||
|
validation.validate_exp = false;
|
||||||
|
decode::<middleware::Claims>(token, &key, &validation)
|
||||||
|
.ok()
|
||||||
|
.map(|d| d.claims.sub)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Shared application state injected into every Axum handler via [`axum::extract::State`].
|
/// Shared application state injected into every Axum handler via [`axum::extract::State`].
|
||||||
///
|
///
|
||||||
/// Loaded once at startup so a missing `JWT_SECRET` causes an immediate startup
|
/// Loaded once at startup so a missing `JWT_SECRET` causes an immediate startup
|
||||||
@@ -64,7 +110,7 @@ pub fn build_test_router(pool: SqlitePool) -> Router {
|
|||||||
|
|
||||||
fn build_router_inner(state: AppState, rate_limit: bool) -> Router {
|
fn build_router_inner(state: AppState, rate_limit: bool) -> Router {
|
||||||
// Protected routes require a valid JWT (injected by require_auth middleware).
|
// Protected routes require a valid JWT (injected by require_auth middleware).
|
||||||
let protected = Router::new()
|
let protected_base = Router::new()
|
||||||
.route("/api/sync/pull", get(sync::pull))
|
.route("/api/sync/pull", get(sync::pull))
|
||||||
.route("/api/sync/push", post(sync::push))
|
.route("/api/sync/push", post(sync::push))
|
||||||
.route("/api/replays", post(replays::upload))
|
.route("/api/replays", post(replays::upload))
|
||||||
@@ -77,6 +123,25 @@ fn build_router_inner(state: AppState, rate_limit: bool) -> Router {
|
|||||||
middleware::require_auth,
|
middleware::require_auth,
|
||||||
));
|
));
|
||||||
|
|
||||||
|
// Per-user rate limit on protected endpoints: 10-request burst, then 1
|
||||||
|
// token replenished every 10 seconds (6/min steady-state). This prevents
|
||||||
|
// a single compromised account from hammering the 1 MB sync/push endpoint.
|
||||||
|
let protected = if rate_limit {
|
||||||
|
let governor_conf = Arc::new(
|
||||||
|
GovernorConfigBuilder::default()
|
||||||
|
.key_extractor(UserIdKeyExtractor {
|
||||||
|
jwt_secret: state.jwt_secret.clone(),
|
||||||
|
})
|
||||||
|
.per_second(10)
|
||||||
|
.burst_size(10)
|
||||||
|
.finish()
|
||||||
|
.expect("invalid sync governor config"),
|
||||||
|
);
|
||||||
|
protected_base.layer(GovernorLayer::new(governor_conf))
|
||||||
|
} else {
|
||||||
|
protected_base
|
||||||
|
};
|
||||||
|
|
||||||
// Auth endpoints — rate-limited in production, unrestricted in tests.
|
// Auth endpoints — rate-limited in production, unrestricted in tests.
|
||||||
let auth_routes = Router::new()
|
let auth_routes = Router::new()
|
||||||
.route("/api/auth/register", post(auth::register))
|
.route("/api/auth/register", post(auth::register))
|
||||||
|
|||||||
@@ -1523,6 +1523,68 @@ async fn auth_rate_limit_returns_429_on_11th_request() {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// The 11th `POST /api/sync/push` from the same authenticated user within the
|
||||||
|
/// rate-limit window must return 429 Too Many Requests.
|
||||||
|
///
|
||||||
|
/// Uses [`solitaire_server::build_router`] (rate limiting ON) so the
|
||||||
|
/// GovernorLayer is applied. We register a fresh account, then send 10 pushes
|
||||||
|
/// (consuming the burst allowance), and assert the 11th is throttled.
|
||||||
|
///
|
||||||
|
/// Note: the push body deliberately omits valid `SyncPayload` structure —
|
||||||
|
/// that would return 422, but the rate limiter fires before deserialization,
|
||||||
|
/// so the response code for the first 10 is 422 and for the 11th is 429.
|
||||||
|
/// The test only asserts `!= 429` for requests 1–10 and `== 429` for request 11.
|
||||||
|
#[tokio::test]
|
||||||
|
async fn sync_push_rate_limit_returns_429_on_11th_request() {
|
||||||
|
let state = solitaire_server::AppState {
|
||||||
|
pool: test_pool().await,
|
||||||
|
jwt_secret: TEST_SECRET.to_string(),
|
||||||
|
};
|
||||||
|
let app = solitaire_server::build_router(state);
|
||||||
|
|
||||||
|
// Register a user to obtain a valid JWT for the UserIdKeyExtractor.
|
||||||
|
let (token, _) = register_user(app.clone(), "sync_ratelimit_user", "p4ssword!").await;
|
||||||
|
|
||||||
|
let stub_body = serde_json::to_vec(&serde_json::json!({})).unwrap();
|
||||||
|
|
||||||
|
// First 10 requests consume the burst allowance (burst_size = 10).
|
||||||
|
// The body is intentionally invalid — the rate limiter fires before
|
||||||
|
// deserialization, so we get 422 rather than 200. We only assert != 429.
|
||||||
|
for i in 0..10 {
|
||||||
|
let req = Request::builder()
|
||||||
|
.method("POST")
|
||||||
|
.uri("/api/sync/push")
|
||||||
|
.header("content-type", "application/json")
|
||||||
|
.header("Authorization", format!("Bearer {token}"))
|
||||||
|
.header("x-forwarded-for", TEST_CLIENT_IP)
|
||||||
|
.body(Body::from(stub_body.clone()))
|
||||||
|
.expect("failed to build request");
|
||||||
|
let resp = app.clone().oneshot(req).await.expect("oneshot failed");
|
||||||
|
assert_ne!(
|
||||||
|
resp.status(),
|
||||||
|
StatusCode::TOO_MANY_REQUESTS,
|
||||||
|
"request {} of 10 must not be rate-limited",
|
||||||
|
i + 1
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// The 11th request must be throttled.
|
||||||
|
let req = Request::builder()
|
||||||
|
.method("POST")
|
||||||
|
.uri("/api/sync/push")
|
||||||
|
.header("content-type", "application/json")
|
||||||
|
.header("Authorization", format!("Bearer {token}"))
|
||||||
|
.header("x-forwarded-for", TEST_CLIENT_IP)
|
||||||
|
.body(Body::from(stub_body))
|
||||||
|
.expect("failed to build 11th request");
|
||||||
|
let resp = app.clone().oneshot(req).await.expect("oneshot failed");
|
||||||
|
assert_eq!(
|
||||||
|
resp.status(),
|
||||||
|
StatusCode::TOO_MANY_REQUESTS,
|
||||||
|
"11th sync push must be rate-limited with 429"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
// Replay endpoints
|
// Replay endpoints
|
||||||
//
|
//
|
||||||
|
|||||||
Reference in New Issue
Block a user