diff --git a/pom.xml b/pom.xml index a948486..c107f83 100644 --- a/pom.xml +++ b/pom.xml @@ -61,6 +61,13 @@ runtime + + + com.bucket4j + bucket4j-core + 8.10.1 + + org.springframework.boot diff --git a/src/main/java/com/example/config/DefaultSecurityConfig.java b/src/main/java/com/example/config/DefaultSecurityConfig.java index 8d83f44..4a0c009 100644 --- a/src/main/java/com/example/config/DefaultSecurityConfig.java +++ b/src/main/java/com/example/config/DefaultSecurityConfig.java @@ -11,6 +11,8 @@ import org.springframework.security.crypto.password.PasswordEncoder; import org.springframework.security.provisioning.JdbcUserDetailsManager; import org.springframework.security.web.SecurityFilterChain; +import org.springframework.security.web.authentication.UsernamePasswordAuthenticationFilter; +import org.springframework.security.web.header.writers.ReferrerPolicyHeaderWriter; import org.springframework.web.cors.CorsConfiguration; import org.springframework.web.cors.CorsConfigurationSource; import org.springframework.web.cors.UrlBasedCorsConfigurationSource; @@ -29,6 +31,12 @@ public class DefaultSecurityConfig { @Value("${cors.allowed-origins:}") private String corsAllowedOrigins; + private final RateLimitingFilter rateLimitingFilter; + + public DefaultSecurityConfig(RateLimitingFilter rateLimitingFilter) { + this.rateLimitingFilter = rateLimitingFilter; + } + @Bean public SecurityFilterChain defaultSecurityFilterChain(HttpSecurity http) throws Exception { http @@ -42,9 +50,18 @@ public SecurityFilterChain defaultSecurityFilterChain(HttpSecurity http) throws ) .formLogin(withDefaults()) .headers(headers -> headers - .frameOptions(frameOptions -> frameOptions.sameOrigin()) // For H2 console + .frameOptions(frameOptions -> frameOptions.sameOrigin()) // For H2 console + .httpStrictTransportSecurity(hsts -> hsts + .includeSubDomains(true) + .maxAgeInSeconds(31536000)) + .contentTypeOptions(withDefaults()) + .referrerPolicy(referrer -> referrer + .policy(ReferrerPolicyHeaderWriter.ReferrerPolicy.STRICT_ORIGIN_WHEN_CROSS_ORIGIN)) + .contentSecurityPolicy(csp -> csp + .policyDirectives("default-src 'self'; frame-ancestors 'self'")) ) - .csrf(csrf -> csrf.ignoringRequestMatchers("/h2-console/**")); // For H2 console + .csrf(csrf -> csrf.ignoringRequestMatchers("/h2-console/**")) + .addFilterBefore(rateLimitingFilter, UsernamePasswordAuthenticationFilter.class); return http.build(); } diff --git a/src/main/java/com/example/config/ProviderConfig.java b/src/main/java/com/example/config/ProviderConfig.java index 98551c9..0cf0b1d 100644 --- a/src/main/java/com/example/config/ProviderConfig.java +++ b/src/main/java/com/example/config/ProviderConfig.java @@ -1,15 +1,20 @@ package com.example.config; +import org.springframework.beans.factory.annotation.Value; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings; @Configuration public class ProviderConfig { + + @Value("${oauth2.issuer-uri:http://localhost:9000}") + private String issuerUri; + @Bean public AuthorizationServerSettings authorizationServerSettings() { return AuthorizationServerSettings.builder() - .issuer("http://localhost:9000") + .issuer(issuerUri) .build(); } -} \ No newline at end of file +} diff --git a/src/main/java/com/example/config/RateLimitingFilter.java b/src/main/java/com/example/config/RateLimitingFilter.java new file mode 100644 index 0000000..0637e5c --- /dev/null +++ b/src/main/java/com/example/config/RateLimitingFilter.java @@ -0,0 +1,52 @@ +package com.example.config; + +import io.github.bucket4j.Bandwidth; +import io.github.bucket4j.Bucket; +import io.github.bucket4j.Refill; +import jakarta.servlet.FilterChain; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import org.springframework.http.MediaType; +import org.springframework.stereotype.Component; +import org.springframework.web.filter.OncePerRequestFilter; + +import java.io.IOException; +import java.time.Duration; +import java.util.concurrent.ConcurrentHashMap; + +@Component +public class RateLimitingFilter extends OncePerRequestFilter { + + private final ConcurrentHashMap buckets = new ConcurrentHashMap<>(); + + @Override + protected boolean shouldNotFilter(HttpServletRequest request) { + return !"POST".equalsIgnoreCase(request.getMethod()) + || !"/oauth2/token".equals(request.getServletPath()); + } + + @Override + protected void doFilterInternal(HttpServletRequest request, + HttpServletResponse response, + FilterChain filterChain) throws ServletException, IOException { + String clientIp = request.getRemoteAddr(); + Bucket bucket = buckets.computeIfAbsent(clientIp, ip -> newBucket()); + + if (bucket.tryConsume(1)) { + filterChain.doFilter(request, response); + } else { + response.setStatus(429); + response.setContentType(MediaType.APPLICATION_JSON_VALUE); + response.getWriter().write( + "{\"error\":\"too_many_requests\",\"error_description\":\"Rate limit exceeded\"}" + ); + } + } + + private Bucket newBucket() { + // 20 requests per minute per IP + Bandwidth limit = Bandwidth.classic(20, Refill.greedy(20, Duration.ofMinutes(1))); + return Bucket.builder().addLimit(limit).build(); + } +} diff --git a/src/main/resources/application.properties b/src/main/resources/application.properties index 55d001a..0b2275b 100644 --- a/src/main/resources/application.properties +++ b/src/main/resources/application.properties @@ -14,6 +14,7 @@ jwt.key-path=${JWT_KEY_PATH:keys/jwt} # OAuth2 client credentials (override in production) oauth2.client.secret=${OAUTH2_CLIENT_SECRET:secret} oauth2.extra-redirect-uri=${OAUTH2_EXTRA_REDIRECT_URI:} +oauth2.issuer-uri=${OAUTH2_ISSUER_URI:http://localhost:9000} # CORS — comma-separated list of additional allowed origin patterns (e.g. https://*.example.com) cors.allowed-origins=${CORS_ALLOWED_ORIGINS:}