Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 29 additions & 21 deletions src/main/java/us/kbase/auth2/lib/Authentication.java
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
import us.kbase.auth2.lib.user.LocalUser;
import us.kbase.auth2.lib.user.NewUser;
import us.kbase.auth2.lib.token.IncomingToken;
import us.kbase.auth2.lib.token.MFAStatus;

/** The main class for the Authentication application.
*
Expand Down Expand Up @@ -512,7 +513,7 @@ public LocalLoginResult localLogin(
userName.getName());
return new LocalLoginResult(u.getUserName());
}
return new LocalLoginResult(login(u.getUserName(), tokenCtx));
return new LocalLoginResult(login(u.getUserName(), tokenCtx, MFAStatus.UNKNOWN));
}

private LocalUser getLocalUser(final UserName userName, final Password password)
Expand Down Expand Up @@ -744,13 +745,15 @@ public void forceResetAllPasswords(final IncomingToken token)
admin.getUserName().getName());
}

private NewToken login(final UserName userName, final TokenCreationContext tokenCtx)
private NewToken login(
final UserName userName, final TokenCreationContext tokenCtx, final MFAStatus mfa)
throws AuthStorageException {
final NewToken nt = new NewToken(StoredToken.getBuilder(
TokenType.LOGIN, randGen.randomUUID(), userName)
.withLifeTime(clock.instant(),
cfg.getAppConfig().getTokenLifetimeMS(TokenLifetimeType.LOGIN))
.withContext(tokenCtx)
.withMFA(mfa)
.build(),
randGen.getToken());
storage.storeToken(nt.getStoredToken(), nt.getTokenHash());
Expand Down Expand Up @@ -1795,9 +1798,11 @@ public LoginToken login( // enough args here to start considering a builder
final LoginState lstate = getLoginState(ipr.getIdentities(), Instant.MIN);
final ProviderConfig pc = cfg.getAppConfig().getProviderConfig(idp.getProviderName());
final LoginToken loginToken;
if (lstate.getUsers().size() == 1 &&
lstate.getIdentities().isEmpty() &&
!pc.isForceLoginChoice()) {
if (
lstate.getUsers().size() == 1
&& lstate.getIdentities().isEmpty()
&& !pc.isForceLoginChoice())
{
final UserName userName = lstate.getUsers().iterator().next();
final AuthUser user = lstate.getUser(userName);
/* Don't throw an error here since an auth UI may not be controlling the call -
Expand All @@ -1811,16 +1816,16 @@ public LoginToken login( // enough args here to start considering a builder
* so who cares.
*/
if (!cfg.getAppConfig().isLoginAllowed() && !Role.isAdmin(user.getRoles())) {
loginToken = storeIdentitiesTemporarily(lstate);
loginToken = storeIdentitiesTemporarily(lstate, ipr.getMFA());
} else if (user.isDisabled()) {
loginToken = storeIdentitiesTemporarily(lstate);
loginToken = storeIdentitiesTemporarily(lstate, ipr.getMFA());
} else {
loginToken = new LoginToken(login(user.getUserName(), tokenCtx));
loginToken = new LoginToken(login(user.getUserName(), tokenCtx, ipr.getMFA()));
}
} else {
// store the identities so the user can create an account or choose from more than one
// account
loginToken = storeIdentitiesTemporarily(lstate);
loginToken = storeIdentitiesTemporarily(lstate, ipr.getMFA());
}
return loginToken;
}
Expand All @@ -1834,13 +1839,13 @@ private void checkState(final TemporarySessionData tids, final String state)
}

// ignores expiration date of login state
private LoginToken storeIdentitiesTemporarily(final LoginState ls)
private LoginToken storeIdentitiesTemporarily(final LoginState ls, final MFAStatus mfa)
throws AuthStorageException {
final Set<RemoteIdentity> store = new HashSet<>(ls.getIdentities());
ls.getUsers().stream().forEach(u -> store.addAll(ls.getIdentities(u)));
final TemporarySessionData data = TemporarySessionData.create(
randGen.randomUUID(), clock.instant(), LOGIN_TOKEN_LIFETIME_MS)
.login(store);
.login(store, mfa);
final TemporaryToken tt = storeTemporarySessionData(data);
logInfo("Stored temporary token {} with {} login identities", tt.getId(), store.size());
return new LoginToken(tt);
Expand Down Expand Up @@ -1871,6 +1876,7 @@ private TemporaryToken storeTemporarySessionData(final TemporarySessionData data
public LoginState getLoginState(final IncomingToken token)
throws AuthStorageException, InvalidTokenException, IdentityProviderErrorException,
UnauthorizedException {
// TODO CODE this ignores the MFA state. May want to add it to LoginState in the future
final TemporarySessionData ids = getTemporarySessionData(
Optional.empty(), Operation.LOGINIDENTS, token);
logInfo("Accessed temporary login token {} with {} identities", ids.getId(),
Expand Down Expand Up @@ -1984,11 +1990,12 @@ public NewToken createUser(
if (!cfg.getAppConfig().isLoginAllowed()) {
throw new UnauthorizedException("Account creation is disabled");
}
// allow mutation of the identity set
final Set<RemoteIdentity> ids = new HashSet<>(
getTemporarySessionData(Optional.empty(), Operation.LOGINIDENTS, token)
.getIdentities().get());
final TemporarySessionData tsd = getTemporarySessionData(
Optional.empty(), Operation.LOGINIDENTS, token
);
storage.deleteTemporarySessionData(token.getHashedToken());
// allow mutation of the identity set
final Set<RemoteIdentity> ids = new HashSet<>(tsd.getIdentities().get());
final Optional<RemoteIdentity> match = getIdentity(identityID, ids);
if (!match.isPresent()) {
throw new UnauthorizedException(String.format(
Expand Down Expand Up @@ -2020,7 +2027,7 @@ public NewToken createUser(
linked, userName.getName());
}
}
return login(userName, tokenCtx);
return login(userName, tokenCtx, tsd.getMFA().get());
}

/** Create a test token. The token is entirely separate from standard tokens and is
Expand Down Expand Up @@ -2308,11 +2315,12 @@ public NewToken login(
requireNonNull(policyIDs, "policyIDs");
requireNonNull(tokenCtx, "tokenCtx");
noNulls(policyIDs, "null item in policyIDs");
// allow mutation of the identity set
final Set<RemoteIdentity> ids = new HashSet<>(
getTemporarySessionData(Optional.empty(), Operation.LOGINIDENTS, token)
.getIdentities().get());
final TemporarySessionData tsd = getTemporarySessionData(
Optional.empty(), Operation.LOGINIDENTS, token
);
storage.deleteTemporarySessionData(token.getHashedToken());
// allow mutation of the identity set
final Set<RemoteIdentity> ids = new HashSet<>(tsd.getIdentities().get());
final Optional<RemoteIdentity> ri = getIdentity(identityID, ids);
if (!ri.isPresent()) {
throw new UnauthorizedException(String.format(
Expand Down Expand Up @@ -2342,7 +2350,7 @@ public NewToken login(
linked, u.get().getUserName().getName());
}
}
return login(u.get().getUserName(), tokenCtx);
return login(u.get().getUserName(), tokenCtx, tsd.getMFA().get());
}

private Optional<RemoteIdentity> getIdentity(
Expand Down
72 changes: 51 additions & 21 deletions src/main/java/us/kbase/auth2/lib/TemporarySessionData.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import us.kbase.auth2.lib.exceptions.ErrorType;
import us.kbase.auth2.lib.identity.RemoteIdentity;
import us.kbase.auth2.lib.token.MFAStatus;

/** Temporary session data that may include a set of temporary identities and / or an associated
* user, or an error that was stored instead of the identities.
Expand All @@ -33,6 +34,7 @@ public class TemporarySessionData {
private final String error;
private final ErrorType errorType;
private final UserName user;
private final MFAStatus mfa;

private TemporarySessionData(
final Operation op,
Expand All @@ -44,7 +46,9 @@ private TemporarySessionData(
final Set<RemoteIdentity> identities,
final UserName user,
final String error,
final ErrorType errorType) {
final ErrorType errorType,
final MFAStatus mfa
) {
this.op = op;
this.id = id;
this.created = created;
Expand All @@ -55,6 +59,7 @@ private TemporarySessionData(
this.user = user;
this.error = error;
this.errorType = errorType;
this.mfa = mfa;
}

/** Get the operation this temporary session data supports.
Expand All @@ -77,6 +82,13 @@ public UUID getId() {
public Optional<Set<RemoteIdentity>> getIdentities() {
return Optional.ofNullable(identities);
}

/** Get the MFA status, if any.
* @return the MFA status.
*/
public Optional<MFAStatus> getMFA() {
return Optional.ofNullable(mfa);
}

/** Get the date of creation for the session data.
* @return the creation date.
Expand Down Expand Up @@ -139,27 +151,31 @@ public boolean hasError() {

@Override
public int hashCode() {
return Objects.hash(created, error, errorType, expires, id, identities, oauth2State, op, pkceCodeVerifier,
user);
return Objects.hash(created, error, errorType, expires, id, identities, mfa, oauth2State,
op, pkceCodeVerifier, user
);
}

@Override
public boolean equals(Object obj) {
if (this == obj) {
if (this == obj)
return true;
}
if (obj == null) {
if (obj == null)
return false;
}
if (getClass() != obj.getClass()) {
if (getClass() != obj.getClass())
return false;
}
TemporarySessionData other = (TemporarySessionData) obj;
return Objects.equals(created, other.created) && Objects.equals(error, other.error)
&& errorType == other.errorType && Objects.equals(expires, other.expires)
&& Objects.equals(id, other.id) && Objects.equals(identities, other.identities)
&& Objects.equals(oauth2State, other.oauth2State) && op == other.op
&& Objects.equals(pkceCodeVerifier, other.pkceCodeVerifier) && Objects.equals(user, other.user);
return Objects.equals(created, other.created)
&& Objects.equals(error, other.error)
&& errorType == other.errorType
&& Objects.equals(expires, other.expires)
&& Objects.equals(id, other.id)
&& Objects.equals(identities, other.identities)
&& mfa == other.mfa
&& Objects.equals(oauth2State, other.oauth2State)
&& op == other.op
&& Objects.equals(pkceCodeVerifier, other.pkceCodeVerifier)
&& Objects.equals(user, other.user);
}

/** The operation this session data is associated with.
Expand Down Expand Up @@ -242,7 +258,7 @@ public TemporarySessionData error(final String error, final ErrorType errorType)
requireNonNull(errorType, "errorType");
return new TemporarySessionData(
Operation.ERROR, id, created, expires,
null, null, null, null, error, errorType);
null, null, null, null, error, errorType, null);
}

/** Create temporary session data for the start of a login operation.
Expand All @@ -258,18 +274,32 @@ public TemporarySessionData login(
checkStringNoCheckedException(pkceCodeVerifier, "pkceCodeVerifier");
return new TemporarySessionData(
Operation.LOGINSTART, id, created, expires,
oauth2State, pkceCodeVerifier, null, null, null, null);
oauth2State, pkceCodeVerifier, null, null, null, null, null);
}

/** Create temporary session data for a login operation where remote identities are
* involved.
* @param identities the remote identities involved in the login.
* @param mfa the MFA state from the login.
* @return the temporary session data.
*/
public TemporarySessionData login(final Set<RemoteIdentity> identities) {
public TemporarySessionData login(
final Set<RemoteIdentity> identities,
final MFAStatus mfa
) {
return new TemporarySessionData(
Operation.LOGINIDENTS, id, created, expires,
null, null, checkIdents(identities), null, null, null);
Operation.LOGINIDENTS,
id,
created,
expires,
null,
null,
checkIdents(identities),
null,
null,
null,
requireNonNull(mfa, "mfa")
);
}

private Set<RemoteIdentity> checkIdents(final Set<RemoteIdentity> identities) {
Expand Down Expand Up @@ -299,7 +329,7 @@ public TemporarySessionData link(
requireNonNull(userName, "userName");
return new TemporarySessionData(
Operation.LINKSTART, id, created, expires,
oauth2State, pkceCodeVerifier, null, userName, null, null);
oauth2State, pkceCodeVerifier, null, userName, null, null, null);
}

/** Create temporary session data for a linking operation when remote identities are
Expand All @@ -314,7 +344,7 @@ public TemporarySessionData link(
requireNonNull(userName, "userName");
return new TemporarySessionData(
Operation.LINKIDENTS, id, created, expires,
null, null, checkIdents(identities), userName, null, null);
null, null, checkIdents(identities), userName, null, null, null);
}
}
}
2 changes: 2 additions & 0 deletions src/main/java/us/kbase/auth2/lib/storage/mongo/Fields.java
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ public class Fields {
public static final String TEMP_SESSION_USER = "user";
/** The remote identities associated with the temporary token. */
public static final String TEMP_SESSION_IDENTITIES = "idents";
/** The MFA status associated with the temporary token. */
public static final String TEMP_SESSION_MFA = "mfa";
/** The error associated with the temporary token. */
public static final String TEMP_SESSION_ERROR = "err";
/** The type of the error associated with the temporary token. */
Expand Down
12 changes: 7 additions & 5 deletions src/main/java/us/kbase/auth2/lib/storage/mongo/MongoStorage.java
Original file line number Diff line number Diff line change
Expand Up @@ -1698,13 +1698,14 @@ public void storeTemporarySessionData(final TemporarySessionData data, final Str
.append(Fields.TEMP_SESSION_OAUTH2STATE, data.getOAuth2State().orElse(null))
.append(Fields.TEMP_SESSION_PKCE_CODE_VERIFIER,
data.getPKCECodeVerifier().orElse(null))
.append(Fields.TEMP_SESSION_ERROR,
data.getError().isPresent() ? data.getError().get() : null)
.append(Fields.TEMP_SESSION_ERROR, data.getError().orElse(null))
.append(Fields.TEMP_SESSION_ERROR_TYPE, data.getErrorType().isPresent() ?
data.getErrorType().get().getErrorCode() : null)
.append(Fields.TEMP_SESSION_IDENTITIES, ids)
.append(Fields.TEMP_SESSION_USER,
data.getUser().isPresent() ? data.getUser().get().getName() : null);
.append(Fields.TEMP_SESSION_USER, data.getUser().isPresent() ?
data.getUser().get().getName() : null)
.append(Fields.TEMP_SESSION_MFA, data.getMFA().isPresent() ?
data.getMFA().get().getID() : null);
storeTemporarySessionData(td);
}

Expand Down Expand Up @@ -1772,7 +1773,8 @@ public TemporarySessionData getTemporarySessionData(
d.getString(Fields.TEMP_SESSION_PKCE_CODE_VERIFIER)
);
} else if (op.equals(Operation.LOGINIDENTS)) {
tis = b.login(toIdentities(ids));
final MFAStatus mfa = MFAStatus.fromID(d.getString(Fields.TEMP_SESSION_MFA));
tis = b.login(toIdentities(ids), mfa);
} else if (op.equals(Operation.LINKSTART)) {
tis = b.link(
d.getString(Fields.TEMP_SESSION_OAUTH2STATE),
Expand Down
4 changes: 4 additions & 0 deletions src/test/java/us/kbase/test/auth2/TestCommon.java
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,10 @@ public static <T> List<T> list(T... objects) {

public static final Optional<String> ES = Optional.empty();

public static <T> Optional<T> opt() {
return Optional.empty();
}

public static <T> Optional<T> opt(final T obj) {
return Optional.of(obj);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
import us.kbase.auth2.lib.storage.AuthStorage;
import us.kbase.auth2.lib.storage.exceptions.AuthStorageException;
import us.kbase.auth2.lib.token.IncomingToken;
import us.kbase.auth2.lib.token.MFAStatus;
import us.kbase.auth2.lib.token.StoredToken;
import us.kbase.auth2.lib.token.TemporaryToken;
import us.kbase.auth2.lib.token.TokenType;
Expand Down Expand Up @@ -885,7 +886,7 @@ public void linkWithTokenFailBadTokenOp() throws Exception {
final UUID tid = UUID.randomUUID();
when(storage.getTemporarySessionData(token.getHashedToken())).thenReturn(
TemporarySessionData.create(tid, Instant.now(), Instant.now())
.login(set(REMOTE)))
.login(set(REMOTE), MFAStatus.UNKNOWN))
.thenReturn(null);

failLinkWithToken(auth, token, "prov", "foo", null, "state", new InvalidTokenException(
Expand Down Expand Up @@ -1573,7 +1574,7 @@ public void getLinkStateFailBadTokenOp() throws Exception {

when(storage.getTemporarySessionData(tempToken.getHashedToken())).thenReturn(
TemporarySessionData.create(tempTokenID, NOW, NOW)
.login(set(REMOTE)))
.login(set(REMOTE), MFAStatus.UNKNOWN))
.thenReturn(null);

failGetLinkState(auth, userToken, tempToken, new InvalidTokenException(
Expand Down Expand Up @@ -1867,7 +1868,7 @@ public void linkIdentityFailBadTokenOp() throws Exception {
final UUID id = UUID.randomUUID();
when(storage.getTemporarySessionData(tempToken.getHashedToken())).thenReturn(
TemporarySessionData.create(id, NOW, NOW)
.login(set(REMOTE)))
.login(set(REMOTE), MFAStatus.UNKNOWN))
.thenReturn(null);

failLinkIdentity(auth, userToken, tempToken, "fakeid", new InvalidTokenException(
Expand Down Expand Up @@ -2384,7 +2385,7 @@ public void linkAllFailLinkFailBadTokenOp() throws Exception {
final UUID id = UUID.randomUUID();
when(storage.getTemporarySessionData(tempToken.getHashedToken())).thenReturn(
TemporarySessionData.create(id, NOW, NOW)
.login(set(REMOTE)))
.login(set(REMOTE), MFAStatus.UNKNOWN))
.thenReturn(null);

failLinkAll(auth, userToken, tempToken, new InvalidTokenException(
Expand Down
Loading
Loading