Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@
<scope>runtime</scope>
</dependency>

<!-- Rate limiting -->
<dependency>
<groupId>com.bucket4j</groupId>
<artifactId>bucket4j-core</artifactId>
<version>8.10.1</version>
</dependency>

<!-- Test dependencies -->
<dependency>
<groupId>org.springframework.boot</groupId>
Expand Down
21 changes: 19 additions & 2 deletions src/main/java/com/example/config/DefaultSecurityConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import org.springframework.security.crypto.password.PasswordEncoder;
import org.springframework.security.provisioning.JdbcUserDetailsManager;
import org.springframework.security.web.SecurityFilterChain;
import org.springframework.security.web.authentication.UsernamePasswordAuthenticationFilter;
import org.springframework.security.web.header.writers.ReferrerPolicyHeaderWriter;
import org.springframework.web.cors.CorsConfiguration;
import org.springframework.web.cors.CorsConfigurationSource;
import org.springframework.web.cors.UrlBasedCorsConfigurationSource;
Expand All @@ -29,6 +31,12 @@ public class DefaultSecurityConfig {
@Value("${cors.allowed-origins:}")
private String corsAllowedOrigins;

private final RateLimitingFilter rateLimitingFilter;

public DefaultSecurityConfig(RateLimitingFilter rateLimitingFilter) {
this.rateLimitingFilter = rateLimitingFilter;
}

@Bean
public SecurityFilterChain defaultSecurityFilterChain(HttpSecurity http) throws Exception {
http
Expand All @@ -42,9 +50,18 @@ public SecurityFilterChain defaultSecurityFilterChain(HttpSecurity http) throws
)
.formLogin(withDefaults())
.headers(headers -> headers
.frameOptions(frameOptions -> frameOptions.sameOrigin()) // For H2 console
.frameOptions(frameOptions -> frameOptions.sameOrigin()) // For H2 console
.httpStrictTransportSecurity(hsts -> hsts
.includeSubDomains(true)
.maxAgeInSeconds(31536000))
.contentTypeOptions(withDefaults())
.referrerPolicy(referrer -> referrer
.policy(ReferrerPolicyHeaderWriter.ReferrerPolicy.STRICT_ORIGIN_WHEN_CROSS_ORIGIN))
.contentSecurityPolicy(csp -> csp
.policyDirectives("default-src 'self'; frame-ancestors 'self'"))
)
.csrf(csrf -> csrf.ignoringRequestMatchers("/h2-console/**")); // For H2 console
.csrf(csrf -> csrf.ignoringRequestMatchers("/h2-console/**"))
.addFilterBefore(rateLimitingFilter, UsernamePasswordAuthenticationFilter.class);

return http.build();
}
Expand Down
9 changes: 7 additions & 2 deletions src/main/java/com/example/config/ProviderConfig.java
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
package com.example.config;

import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings;

@Configuration
public class ProviderConfig {

@Value("${oauth2.issuer-uri:http://localhost:9000}")
private String issuerUri;

@Bean
public AuthorizationServerSettings authorizationServerSettings() {
return AuthorizationServerSettings.builder()
.issuer("http://localhost:9000")
.issuer(issuerUri)
.build();
}
}
}
52 changes: 52 additions & 0 deletions src/main/java/com/example/config/RateLimitingFilter.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package com.example.config;

import io.github.bucket4j.Bandwidth;
import io.github.bucket4j.Bucket;
import io.github.bucket4j.Refill;
import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.springframework.http.MediaType;
import org.springframework.stereotype.Component;
import org.springframework.web.filter.OncePerRequestFilter;

import java.io.IOException;
import java.time.Duration;
import java.util.concurrent.ConcurrentHashMap;

@Component
public class RateLimitingFilter extends OncePerRequestFilter {

private final ConcurrentHashMap<String, Bucket> buckets = new ConcurrentHashMap<>();

@Override
protected boolean shouldNotFilter(HttpServletRequest request) {
return !"POST".equalsIgnoreCase(request.getMethod())
|| !"/oauth2/token".equals(request.getServletPath());
}

@Override
protected void doFilterInternal(HttpServletRequest request,
HttpServletResponse response,
FilterChain filterChain) throws ServletException, IOException {
String clientIp = request.getRemoteAddr();
Bucket bucket = buckets.computeIfAbsent(clientIp, ip -> newBucket());

if (bucket.tryConsume(1)) {
filterChain.doFilter(request, response);
} else {
response.setStatus(429);
response.setContentType(MediaType.APPLICATION_JSON_VALUE);
response.getWriter().write(
"{\"error\":\"too_many_requests\",\"error_description\":\"Rate limit exceeded\"}"
);
}
}

private Bucket newBucket() {
// 20 requests per minute per IP
Bandwidth limit = Bandwidth.classic(20, Refill.greedy(20, Duration.ofMinutes(1)));
return Bucket.builder().addLimit(limit).build();
}
}
1 change: 1 addition & 0 deletions src/main/resources/application.properties
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ jwt.key-path=${JWT_KEY_PATH:keys/jwt}
# OAuth2 client credentials (override in production)
oauth2.client.secret=${OAUTH2_CLIENT_SECRET:secret}
oauth2.extra-redirect-uri=${OAUTH2_EXTRA_REDIRECT_URI:}
oauth2.issuer-uri=${OAUTH2_ISSUER_URI:http://localhost:9000}

# CORS — comma-separated list of additional allowed origin patterns (e.g. https://*.example.com)
cors.allowed-origins=${CORS_ALLOWED_ORIGINS:}
Expand Down
Loading