Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,12 @@ public SuccessResponse<AuthResponse> signupAndLogin(
return SuccessResponse.of(
UserHttpResponseCode.SIGNUP_SUCCESS, new AuthResponse(authResult.tokenPair()));
}

@PreAuthorize("permitAll()")
@PostMapping("/refresh")
public SuccessResponse<AuthResponse> refresh(@Valid @RequestBody AuthRefreshRequest request) {
AuthResult authResult = userAuthService.refresh(request.getRefreshToken());
return SuccessResponse.of(
UserHttpResponseCode.REFRESH_SUCCESS, new AuthResponse(authResult.tokenPair()));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package com.run_us.server.domains.user.controller.model.request;

import jakarta.validation.constraints.NotBlank;
import lombok.Getter;
import lombok.NoArgsConstructor;

@Getter
@NoArgsConstructor
public class AuthRefreshRequest {
@NotBlank
private String refreshToken;
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ public enum UserHttpResponseCode implements CustomResponseCode {
MY_PAGE_DATA_FETCHED("USH2001", "마이페이지 데이터 조회 성공", "마이페이지 데이터 조회 성공"),
SIGNUP_SUCCESS("USH2002", "회원가입 성공", "회원가입 성공"),
LOGIN_SUCCESS("USH2003", "로그인 성공", "로그인 성공"),
REFRESH_SUCCESS("USH2004", "토큰 재발급 성공", "토큰 재발급 성공"),
;

private final String code;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

public enum AuthResultType {
LOGIN_SUCCESS,
REFRESH_SUCCESS,
SIGNUP_REQUIRED,
AUTH_FAILED,
SIGNUP_FAILED
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ public enum UserErrorCode implements CustomResponseCode {
JWT_NOT_FOUND("UEH4012", HttpStatus.UNAUTHORIZED, "JWT 토큰이 존재하지 않습니다.", "JWT 토큰이 존재하지 않습니다."),
JWT_EXPIRED("UEH4013", HttpStatus.UNAUTHORIZED, "JWT 토큰이 만료되었습니다.", "JWT 토큰이 만료되었습니다."),
JWT_BROKEN("UEH4014", HttpStatus.UNAUTHORIZED, "JWT 토큰이 손상되었습니다", "JWT 토큰이 손상되었습니다"),
REFRESH_FAILED("UEH4015", HttpStatus.UNAUTHORIZED, "리프레시 토큰이 만료되었습니다.", "리프레시 토큰이 만료되었습니다."),

// 404
USER_NOT_FOUND("UEH4041", HttpStatus.NOT_FOUND, "사용자를 찾을 수 없음", "사용자를 찾을 수 없음"),;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,27 @@
import com.run_us.server.domains.user.domain.TokenPair;
import com.run_us.server.domains.user.domain.User;
import com.run_us.server.domains.user.service.verifier.TokenVerifierFactory;
import com.run_us.server.global.common.cache.InMemoryCache;
import java.time.Duration;
import java.util.Date;
import lombok.RequiredArgsConstructor;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;

@Service
@RequiredArgsConstructor
public class JwtService {

private static final String ISSUER = "RunUSAuthService";
private final TokenVerifierFactory tokenVerifierFactory;
private final InMemoryCache<String, String> refreshTokenCache;
@Value("${jwt.secret}")
private String jwtSecret;
@Value("${jwt.expiration}")
private long jwtExpiration;
@Value("${jwt.refresh.expiration}")
private long jwtRefreshExpiration;

public JwtService(TokenVerifierFactory tokenVerifierFactory) {
this.tokenVerifierFactory = tokenVerifierFactory;
}

public String generateAccessToken(User user) {
Date now = new Date();
Date expiryDate = new Date(now.getTime() + jwtExpiration);
Expand All @@ -43,13 +44,22 @@ public String generateRefreshToken(User user) {
Date now = new Date();
Date expiryDate = new Date(now.getTime() + jwtRefreshExpiration);

return JWT.create()
.withSubject(user.getPublicId())
.withIssuedAt(now)
.withExpiresAt(expiryDate)
.withIssuer(ISSUER)
.withClaim("tokenType", "refresh")
.sign(Algorithm.HMAC256(jwtSecret));
String refreshToken = JWT.create()
.withSubject(user.getPublicId())
.withIssuedAt(now)
.withExpiresAt(expiryDate)
.withIssuer(ISSUER)
.withClaim("tokenType", "refresh")
.sign(Algorithm.HMAC256(jwtSecret));

refreshTokenCache.put("auth:refresh:"+user.getPublicId(),
refreshToken, Duration.ofSeconds(jwtExpiration));
return refreshToken;
}

public boolean nonceRefreshToken(String refreshToken) {
String userPublicId = getUserIdFromAccessToken(refreshToken);
return refreshTokenCache.remove("auth:refresh:"+userPublicId, refreshToken);
}

public TokenPair generateTokenPair(User user) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ public class UserAuthService {
private final OAuthInfoRepository oAuthInfoRepository;
private final OAuthTokenRepository oAuthTokenRepository;
private final JwtService jwtService;
private final UserService userService;

@Transactional(readOnly = true)
public AuthResult authenticateOAuth(String rawToken, SocialProvider provider) {
Expand Down Expand Up @@ -54,6 +55,22 @@ public AuthResult signupAndLogin(String rawToken, SocialProvider provider, Profi
}
}

@Transactional(readOnly = true)
public AuthResult refresh(String refreshToken) {
if (!jwtService.nonceRefreshToken(refreshToken)) {
throw UserAuthException.of(UserErrorCode.REFRESH_FAILED);
}

String userPublicId = jwtService.getUserIdFromAccessToken(refreshToken);

User user = userService.getUserByPublicId(userPublicId);
if (user == null) {
throw UserAuthException.of(UserErrorCode.USER_NOT_FOUND);
}

return new AuthResult(AuthResultType.REFRESH_SUCCESS, login(user));
}

private TokenPair login(User user) {
return jwtService.generateTokenPair(user);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ public interface InMemoryCache<K, V> {

Optional<V> get(K key);
Optional<CacheEntry<V>> getEntry(K key);
void remove(K key);
void cleanup();

boolean remove(K key);
boolean remove(K key, V value);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package com.run_us.server.global.common.cache;

import java.time.Duration;
import java.util.Optional;
import lombok.RequiredArgsConstructor;
import org.springframework.data.redis.core.RedisTemplate;

@RequiredArgsConstructor
public class RedisInMemoryCache<K, V> implements InMemoryCache<K, V> {
private final RedisTemplate<K, V> cache;

@Override
public void put(K key, V value) {
cache.opsForValue().set(key, value);
}

@Override
public void put(K key, V value, Duration ttl) {
cache.opsForValue().set(key, value, ttl);
}

@Override
public boolean putIfAbsent(K key, V value) {
return Boolean.TRUE.equals(
cache.opsForValue().setIfAbsent(key, value));
}

@Override
public boolean putIfAbsent(K key, V value, Duration ttl) {
return Boolean.TRUE.equals(
cache.opsForValue().setIfAbsent(key, value, ttl));
}

@Override
public Optional<V> get(K key) {
V value = cache.opsForValue().get(key);
if (value == null) {
return Optional.empty();
}
return Optional.of(value);
}

@Override
public Optional<CacheEntry<V>> getEntry(K key) {
V value = cache.opsForValue().get(key);
if (value == null) {
return Optional.empty();
}
Long ttl = cache.getExpire(key);
return Optional.of(
CacheEntry.withTtl(value, Duration.ofSeconds(ttl)));
}

@Override
public boolean remove(K key) {
return Boolean.TRUE.equals(
cache.delete(key));
}

@Override
public boolean remove(K key, V value) {
V currentValue = cache.opsForValue().get(key);
if (value == null || !value.equals(currentValue)) {
return false;
}
return Boolean.TRUE.equals(
cache.delete(key));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ public Optional<V> get(K key) {
return Optional.of(entry.value());
}

@Override
public Optional<CacheEntry<V>> getEntry(K key) {
CacheEntry<V> entry = cache.get(key);
if (entry == null || entry.isExpired()) {
Expand All @@ -78,11 +77,19 @@ public Optional<CacheEntry<V>> getEntry(K key) {
}

@Override
public void remove(K key) {
cache.remove(key);
public boolean remove(K key) {
return cache.remove(key) != null;
}

@Override
public boolean remove(K key, V value) {
CacheEntry<V> entry = cache.get(key);
if(entry == null || !entry.value().equals(value)) {
return false;
}
return cache.remove(key) != null;
}

public void cleanup() {
cache.entrySet().removeIf(entry ->
entry.getValue().expiresAt() != null &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import com.run_us.server.domains.crew.domain.CrewPrincipal;
import com.run_us.server.domains.user.domain.UserPrincipal;
import com.run_us.server.global.common.cache.InMemoryCache;
import com.run_us.server.global.common.cache.RedisInMemoryCache;
import com.run_us.server.global.common.cache.SpringInMemoryCache;
import com.run_us.server.domains.user.domain.TokenStatus;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.scheduling.TaskScheduler;
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;

Expand Down Expand Up @@ -60,4 +62,11 @@ public InMemoryCache<String, CrewPrincipal> crewPrincipalCache(
Duration.ofSeconds(cleanupIntervalSeconds)
);
}

@Bean
public InMemoryCache<String, String> generalStringCache(
RedisTemplate<String, String> redisTemplate
) {
return new RedisInMemoryCache<>(redisTemplate);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ public RedisConnectionFactory redisConnectionFactory() {
}

@Bean
public RedisTemplate<String, String> redisTemplate() {
RedisTemplate<String, String> redisTemplate = new RedisTemplate<>();
public <K, V> RedisTemplate<K, V> cacheTemplate() {
RedisTemplate<K, V> redisTemplate = new RedisTemplate<>();
redisTemplate.setConnectionFactory(redisConnectionFactory());
redisTemplate.setKeySerializer(new StringRedisSerializer());
redisTemplate.setValueSerializer(new StringRedisSerializer());
Expand Down