feat(server): per-user rate limiting on protected sync endpoints

Adds a UserIdKeyExtractor that decodes the Authorization JWT to rate-limit
each user individually (falls back to client IP for unauthenticated
requests). Protected routes now throttle at 10-request burst / 1 token
per 10 s steady-state (6/min), matching the surface attack area of the
1 MB sync/push endpoint.

Also adds an integration test: sync_push_rate_limit_returns_429_on_11th_request.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
funman300
2026-05-12 13:55:07 -07:00
parent 549a817bb1
commit 6e6f3ef1ff
2 changed files with 129 additions and 2 deletions
+67 -2
View File
@@ -19,15 +19,61 @@ use axum::{
routing::{delete, get, post},
Router,
};
use jsonwebtoken::{decode, DecodingKey, Validation};
use sqlx::SqlitePool;
use std::sync::Arc;
use tower_governor::{
errors::GovernorError,
governor::GovernorConfigBuilder,
key_extractor::SmartIpKeyExtractor,
key_extractor::{KeyExtractor, SmartIpKeyExtractor},
GovernorLayer,
};
use tower_http::services::ServeDir;
/// Rate-limiting key extractor for authenticated endpoints.
///
/// Extracts the authenticated user's UUID from the `Authorization: Bearer` JWT
/// so each user gets their own bucket. Falls back to the client IP address when
/// the header is absent or the token fails signature verification — this
/// protects the server from unauthenticated request floods while ensuring
/// legitimate users are always identified by identity rather than IP.
///
/// Expiry is intentionally **not** checked here: `require_auth` validates the
/// full token (including `exp`) and returns 401. Counting an expired token
/// against the user's bucket is harmless and avoids returning 500 (the
/// `UnableToExtractKey` outcome) for a request that would get 401 anyway.
#[derive(Clone)]
struct UserIdKeyExtractor {
jwt_secret: String,
}
impl KeyExtractor for UserIdKeyExtractor {
type Key = String;
fn extract<T>(&self, req: &axum::http::Request<T>) -> Result<Self::Key, GovernorError> {
if let Some(user_id) = self.try_extract_user_id(req.headers()) {
return Ok(user_id);
}
// Fall back to IP so unauthenticated bursts don't bypass throttling.
SmartIpKeyExtractor
.extract(req)
.map(|ip| ip.to_string())
}
}
impl UserIdKeyExtractor {
fn try_extract_user_id(&self, headers: &axum::http::HeaderMap) -> Option<String> {
let value = headers.get("Authorization")?.to_str().ok()?;
let token = value.strip_prefix("Bearer ")?;
let key = DecodingKey::from_secret(self.jwt_secret.as_bytes());
let mut validation = Validation::default();
validation.validate_exp = false;
decode::<middleware::Claims>(token, &key, &validation)
.ok()
.map(|d| d.claims.sub)
}
}
/// Shared application state injected into every Axum handler via [`axum::extract::State`].
///
/// Loaded once at startup so a missing `JWT_SECRET` causes an immediate startup
@@ -64,7 +110,7 @@ pub fn build_test_router(pool: SqlitePool) -> Router {
fn build_router_inner(state: AppState, rate_limit: bool) -> Router {
// Protected routes require a valid JWT (injected by require_auth middleware).
let protected = Router::new()
let protected_base = Router::new()
.route("/api/sync/pull", get(sync::pull))
.route("/api/sync/push", post(sync::push))
.route("/api/replays", post(replays::upload))
@@ -77,6 +123,25 @@ fn build_router_inner(state: AppState, rate_limit: bool) -> Router {
middleware::require_auth,
));
// Per-user rate limit on protected endpoints: 10-request burst, then 1
// token replenished every 10 seconds (6/min steady-state). This prevents
// a single compromised account from hammering the 1 MB sync/push endpoint.
let protected = if rate_limit {
let governor_conf = Arc::new(
GovernorConfigBuilder::default()
.key_extractor(UserIdKeyExtractor {
jwt_secret: state.jwt_secret.clone(),
})
.per_second(10)
.burst_size(10)
.finish()
.expect("invalid sync governor config"),
);
protected_base.layer(GovernorLayer::new(governor_conf))
} else {
protected_base
};
// Auth endpoints — rate-limited in production, unrestricted in tests.
let auth_routes = Router::new()
.route("/api/auth/register", post(auth::register))