Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.concurrent.locks.ReentrantLock;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -37,10 +38,12 @@
public class DefaultCmabService implements CmabService {
public static final int DEFAULT_CMAB_CACHE_SIZE = 10000;
public static final int DEFAULT_CMAB_CACHE_TIMEOUT_SECS = 30*60; // 30 minutes
private static final int NUM_LOCK_STRIPES = 1000;

private final Cache<CmabCacheValue> cmabCache;
private final CmabClient cmabClient;
private final Logger logger;
private final ReentrantLock[] locks;

public DefaultCmabService(CmabClient cmabClient, Cache<CmabCacheValue> cmabCache) {
this(cmabClient, cmabCache, null);
Expand All @@ -50,52 +53,64 @@ public DefaultCmabService(CmabClient cmabClient, Cache<CmabCacheValue> cmabCache
this.cmabCache = cmabCache;
this.cmabClient = cmabClient;
this.logger = logger != null ? logger : LoggerFactory.getLogger(DefaultCmabService.class);
this.locks = new ReentrantLock[NUM_LOCK_STRIPES];
for (int i = 0; i < NUM_LOCK_STRIPES; i++) {
this.locks[i] = new ReentrantLock();
}
}

@Override
public CmabDecision getDecision(ProjectConfig projectConfig, OptimizelyUserContext userContext, String ruleId, List<OptimizelyDecideOption> options) {
options = options == null ? Collections.emptyList() : options;
String userId = userContext.getUserId();
Map<String, Object> filteredAttributes = filterAttributes(projectConfig, userContext, ruleId);

if (options.contains(OptimizelyDecideOption.IGNORE_CMAB_CACHE)) {
logger.debug("Ignoring CMAB cache for user '{}' and rule '{}'", userId, ruleId);
return fetchDecision(ruleId, userId, filteredAttributes);
}
int lockIndex = getLockIndex(userId, ruleId);
ReentrantLock lock = locks[lockIndex];
lock.lock();
try {
Map<String, Object> filteredAttributes = filterAttributes(projectConfig, userContext, ruleId);

if (options.contains(OptimizelyDecideOption.RESET_CMAB_CACHE)) {
logger.debug("Resetting CMAB cache for user '{}' and rule '{}'", userId, ruleId);
cmabCache.reset();
}
if (options.contains(OptimizelyDecideOption.IGNORE_CMAB_CACHE)) {
logger.debug("Ignoring CMAB cache for user '{}' and rule '{}'", userId, ruleId);
return fetchDecision(ruleId, userId, filteredAttributes);
}

String cacheKey = getCacheKey(userContext.getUserId(), ruleId);
if (options.contains(OptimizelyDecideOption.INVALIDATE_USER_CMAB_CACHE)) {
logger.debug("Invalidating CMAB cache for user '{}' and rule '{}'", userId, ruleId);
cmabCache.remove(cacheKey);
}
if (options.contains(OptimizelyDecideOption.RESET_CMAB_CACHE)) {
logger.debug("Resetting CMAB cache for user '{}' and rule '{}'", userId, ruleId);
cmabCache.reset();
}

String cacheKey = getCacheKey(userContext.getUserId(), ruleId);
if (options.contains(OptimizelyDecideOption.INVALIDATE_USER_CMAB_CACHE)) {
logger.debug("Invalidating CMAB cache for user '{}' and rule '{}'", userId, ruleId);
cmabCache.remove(cacheKey);
}

CmabCacheValue cachedValue = cmabCache.lookup(cacheKey);
CmabCacheValue cachedValue = cmabCache.lookup(cacheKey);

String attributesHash = hashAttributes(filteredAttributes);
String attributesHash = hashAttributes(filteredAttributes);

if (cachedValue != null) {
if (cachedValue.getAttributesHash().equals(attributesHash)) {
logger.debug("CMAB cache hit for user '{}' and rule '{}'", userId, ruleId);
return new CmabDecision(cachedValue.getVariationId(), cachedValue.getCmabUuid());
if (cachedValue != null) {
if (cachedValue.getAttributesHash().equals(attributesHash)) {
logger.debug("CMAB cache hit for user '{}' and rule '{}'", userId, ruleId);
return new CmabDecision(cachedValue.getVariationId(), cachedValue.getCmabUuid());
} else {
logger.debug("CMAB cache attributes mismatch for user '{}' and rule '{}', fetching new decision", userId, ruleId);
cmabCache.remove(cacheKey);
}
} else {
logger.debug("CMAB cache attributes mismatch for user '{}' and rule '{}', fetching new decision", userId, ruleId);
cmabCache.remove(cacheKey);
logger.debug("CMAB cache miss for user '{}' and rule '{}'", userId, ruleId);
}
} else {
logger.debug("CMAB cache miss for user '{}' and rule '{}'", userId, ruleId);
}

CmabDecision cmabDecision = fetchDecision(ruleId, userId, filteredAttributes);
logger.debug("CMAB decision is {}", cmabDecision);

cmabCache.save(cacheKey, new CmabCacheValue(attributesHash, cmabDecision.getVariationId(), cmabDecision.getCmabUUID()));
CmabDecision cmabDecision = fetchDecision(ruleId, userId, filteredAttributes);
logger.debug("CMAB decision is {}", cmabDecision);

return cmabDecision;
cmabCache.save(cacheKey, new CmabCacheValue(attributesHash, cmabDecision.getVariationId(), cmabDecision.getCmabUUID()));

return cmabDecision;
} finally {
lock.unlock();
}
}

private CmabDecision fetchDecision(String ruleId, String userId, Map<String, Object> attributes) {
Expand Down Expand Up @@ -192,6 +207,13 @@ private String hashAttributes(Map<String, Object> attributes) {
return Integer.toHexString(hash);
}

private int getLockIndex(String userId, String ruleId) {
// Create a hash of userId + ruleId for consistent lock selection
String combined = userId + ruleId;
int hash = MurmurHash3.murmurhash3_x86_32(combined, 0, combined.length(), 0);
return Math.abs(hash) % NUM_LOCK_STRIPES;
}

public static Builder builder() {
return new Builder();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,14 @@
*/
package com.optimizely.ab.cmab;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import java.lang.reflect.Method;
import java.util.*;

import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;

import static org.junit.Assert.*;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyString;
import static org.mockito.Matchers.eq;
Expand Down Expand Up @@ -375,4 +371,61 @@ public void testAttributeOrderDoesNotMatterForCaching() {
assertNotNull(decision.getCmabUUID());
verify(mockCmabCache).save(eq(cacheKey), any(CmabCacheValue.class));
}
}
@Test
public void testLockStripingDistribution() {
// Test different combinations to ensure they get different lock indices
String[][] testCases = {
{"user1", "rule1"},
{"user2", "rule1"},
{"user1", "rule2"},
{"user3", "rule3"},
{"user4", "rule4"}
};

Set<Integer> lockIndices = new HashSet<>();
for (String[] testCase : testCases) {
String userId = testCase[0];
String ruleId = testCase[1];

// Use reflection to access the private getLockIndex method
try {
Method getLockIndexMethod = DefaultCmabService.class.getDeclaredMethod("getLockIndex", String.class, String.class);
getLockIndexMethod.setAccessible(true);

int index = (Integer) getLockIndexMethod.invoke(cmabService, userId, ruleId);

// Verify index is within expected range
assertTrue("Lock index should be non-negative", index >= 0);
assertTrue("Lock index should be less than NUM_LOCK_STRIPES", index < 1000);

lockIndices.add(index);
} catch (Exception e) {
fail("Failed to invoke getLockIndex method: " + e.getMessage());
}
}

assertTrue("Different user/rule combinations should generally use different locks", lockIndices.size() > 1);
}

@Test
public void testSameUserRuleCombinationUsesConsistentLock() {
String userId = "test_user";
String ruleId = "test_rule";

try {
Method getLockIndexMethod = DefaultCmabService.class.getDeclaredMethod("getLockIndex", String.class, String.class);
getLockIndexMethod.setAccessible(true);

// Get lock index multiple times
int index1 = (Integer) getLockIndexMethod.invoke(cmabService, userId, ruleId);
int index2 = (Integer) getLockIndexMethod.invoke(cmabService, userId, ruleId);
int index3 = (Integer) getLockIndexMethod.invoke(cmabService, userId, ruleId);

// All should be the same
assertEquals("Same user/rule should always use same lock", index1, index2);
assertEquals("Same user/rule should always use same lock", index2, index3);
} catch (Exception e) {
fail("Failed to invoke getLockIndex method: " + e.getMessage());
}
}
}
Loading