Skip to content
29 changes: 29 additions & 0 deletions tool/src/main/java/io/netbird/client/tool/EngineRestarter.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import android.util.Log;

import io.netbird.client.tool.networks.NetworkToggleListener;
import io.netbird.gomobile.android.ConnectionListener;

/**
* <p>EngineRestarter restarts the Go engine.</p>
Expand Down Expand Up @@ -52,6 +53,7 @@ private void restartEngine() {
if (isRestartInProgress) {
Log.e(LOGTAG, "engine restart timeout - forcing flag reset");
isRestartInProgress = false;
notifyDisconnected();
}
};

Expand All @@ -72,6 +74,7 @@ public void onStarted() {
@Override
public void onStopped() {
Log.d(LOGTAG, "engine is stopped, restarting...");
notifyConnecting();
engineRunner.runWithoutAuth();
}

Expand All @@ -81,6 +84,7 @@ public void onError(String msg) {
isRestartInProgress = false; // Resetting flag on error as well
handler.removeCallbacks(timeoutCallback); // Cancel timeout
engineRunner.removeServiceStateListener(this);
notifyDisconnected();
}
};
currentListener = serviceStateListener;
Expand All @@ -94,9 +98,34 @@ public void onError(String msg) {
}

Log.d(LOGTAG, "engine is running, stopping due to network change");
notifyConnecting();
engineRunner.stop();
}

private void notifyConnecting() {
ConnectionListener listener = engineRunner.getConnectionListener();
if (listener == null) {
return;
}
try {
listener.onConnecting();
} catch (Exception e) {
Log.w(LOGTAG, "onConnecting notification failed: " + e.getMessage());
}
}

private void notifyDisconnected() {
ConnectionListener listener = engineRunner.getConnectionListener();
if (listener == null) {
return;
}
try {
listener.onDisconnected();
} catch (Exception e) {
Log.w(LOGTAG, "onDisconnected notification failed: " + e.getMessage());
}
}

@Override
public void onNetworkTypeChanged() {
Log.d(LOGTAG, "network type changed, scheduling restart with "
Expand Down
7 changes: 7 additions & 0 deletions tool/src/main/java/io/netbird/client/tool/EngineRunner.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class EngineRunner {
private boolean engineIsRunning = false;
Set<ServiceStateListener> serviceStateListeners = ConcurrentHashMap.newKeySet();
private final Client goClient;
private ConnectionListener connectionListener;

public EngineRunner(Context context, NetworkChangeListener networkChangeListener, TunAdapter tunAdapter,
IFaceDiscover iFaceDiscover, String versionName, boolean isTraceLogEnabled, boolean isDebuggable,
Expand Down Expand Up @@ -124,13 +125,19 @@ public synchronized boolean isRunning() {
}

public synchronized void setConnectionListener(ConnectionListener listener) {
this.connectionListener = listener;
goClient.setConnectionListener(listener);
}

public synchronized void removeStatusListener() {
this.connectionListener = null;
goClient.removeConnectionListener();
}

synchronized ConnectionListener getConnectionListener() {
return connectionListener;
}

public synchronized void addServiceStateListener(ServiceStateListener serviceStateListener) {
if (engineIsRunning) {
serviceStateListener.onStarted();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ public void startForeground() {
NotificationChannel channel = new NotificationChannel(
channelId,
service.getResources().getString(R.string.fg_notification_channel_name),
NotificationManager.IMPORTANCE_DEFAULT);
NotificationManager.IMPORTANCE_LOW);
channel.setSound(null, null);
channel.enableVibration(false);
((NotificationManager) service.getSystemService(Context.NOTIFICATION_SERVICE)).createNotificationChannel(channel);

Intent notificationIntent = new Intent();
Expand Down
9 changes: 5 additions & 4 deletions tool/src/main/java/io/netbird/client/tool/VPNService.java
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,15 @@ public void onCreate() {
// Create foreground notification before initializing engine
fgNotification = new ForegroundNotification(this);

// Create network availability listener before initializing engine
networkAvailabilityListener = new ConcreteNetworkAvailabilityListener();


engineRunner = new EngineRunner(this, notifier, tunAdapter, iFaceDiscover, versionName,
preferences.isTraceLogEnabled(), Version.isDebuggable(this), profileManager);
engineRunner.addServiceStateListener(serviceStateListener);

// Create network availability listener after the engine runner so we
// can gate notifications on the engine actually being up; this avoids
// acting on Android's initial onAvailable burst during cold start.
networkAvailabilityListener = new ConcreteNetworkAvailabilityListener(engineRunner::isRunning);

engineRestarter = new EngineRestarter(engineRunner);
networkAvailabilityListener.subscribe(engineRestarter);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,24 @@

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.BooleanSupplier;

public class ConcreteNetworkAvailabilityListener implements NetworkAvailabilityListener {
private final Map<Integer, Boolean> availableNetworkTypes;
private final BooleanSupplier shouldNotify;
private NetworkToggleListener listener;

public ConcreteNetworkAvailabilityListener() {
this(() -> true);
}

// shouldNotify is consulted before each listener notification. Pass
// engineRunner::isRunning to swallow the initial onAvailable burst that
// fires right after registerNetworkCallback; until the engine is actually
// running there is nothing to restart.
public ConcreteNetworkAvailabilityListener(BooleanSupplier shouldNotify) {
this.availableNetworkTypes = new ConcurrentHashMap<>();
this.shouldNotify = shouldNotify;
}

@Override
Expand Down Expand Up @@ -38,9 +49,14 @@ public void onNetworkLost(@Constants.NetworkType int networkType) {
}

private void notifyListener() {
if (listener != null) {
listener.onNetworkTypeChanged();
NetworkToggleListener l = listener;
if (l == null) {
return;
}
if (!shouldNotify.getAsBoolean()) {
return;
}
l.onNetworkTypeChanged();
}

public void subscribe(NetworkToggleListener listener) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,22 @@
import androidx.annotation.NonNull;
import androidx.core.util.Consumer;

import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;

public class NetworkChangeDetector {
private static final String LOGTAG = NetworkChangeDetector.class.getSimpleName();
private final ConnectivityManager connectivityManager;
private ConnectivityManager.NetworkCallback networkCallback;
private ConnectivityManager.NetworkCallback defaultNetworkCallback;
private volatile NetworkAvailabilityListener listener;
private final AtomicBoolean defaultNetworkCallbackActive = new AtomicBoolean(false);
private final AtomicReference<Network> currentlyBoundDefaultNetwork = new AtomicReference<>(null);

public NetworkChangeDetector(ConnectivityManager connectivityManager) {
this.connectivityManager = connectivityManager;
initNetworkCallback();
initDefaultNetworkCallback();
}

private void checkNetworkCapabilities(Network network, Consumer<Integer> operation) {
Expand Down Expand Up @@ -58,18 +65,74 @@ public void onCapabilitiesChanged(@NonNull Network network, @NonNull NetworkCapa
};
}

private void initDefaultNetworkCallback() {
defaultNetworkCallback = new ConnectivityManager.NetworkCallback() {
@Override
public void onAvailable(@NonNull Network network) {
if (!defaultNetworkCallbackActive.get()) {
Log.d(LOGTAG, "ignoring onAvailable for " + network + "; default callback inactive");
return;
}
Log.d(LOGTAG, "default network became " + network + ", binding process to it");
try {
if (connectivityManager.bindProcessToNetwork(network)) {
currentlyBoundDefaultNetwork.set(network);
} else {
Log.w(LOGTAG, "bindProcessToNetwork returned false for " + network);
}
} catch (Exception e) {
Log.e(LOGTAG, "bindProcessToNetwork failed", e);
}
}

@Override
public void onLost(@NonNull Network network) {
if (!defaultNetworkCallbackActive.get()) {
Log.d(LOGTAG, "ignoring onLost for " + network + "; default callback inactive");
return;
}
if (!network.equals(currentlyBoundDefaultNetwork.get())) {
Log.d(LOGTAG, "ignoring onLost for " + network + "; not the currently bound default network");
return;
}
Log.d(LOGTAG, "default network " + network + " lost, clearing process binding");
try {
if (connectivityManager.bindProcessToNetwork(null)) {
currentlyBoundDefaultNetwork.compareAndSet(network, null);
}
} catch (Exception e) {
Log.e(LOGTAG, "bindProcessToNetwork(null) failed", e);
}
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.
};
}
Comment on lines +68 to +108
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Residual TOCTOU between active gate and bindProcessToNetwork call.

The AtomicBoolean gate prevents a stale callback from starting work after unregisterNetworkCallback(), but it does not make the gate-check + bind operation atomic with unregister. Interleaving:

  1. onAvailable (callback thread) reads defaultNetworkCallbackActive == true and passes the gate.
  2. unregisterNetworkCallback() sets the flag false, unregisters, then calls bindProcessToNetwork(null) and resets currentlyBoundDefaultNetwork to null.
  3. onAvailable resumes and calls bindProcessToNetwork(network) and sets currentlyBoundDefaultNetwork = network.

The process ends up bound to a network after full shutdown, with active == false so nothing will ever clear it. Symmetric issue exists for onLost racing with a new onAvailable.

A small synchronized block over the gate check together with the bind call (and over the corresponding section in unregisterNetworkCallback) would close this window at effectively no cost and was the shape of the fix suggested previously.

🔧 Proposed hardening
-    private final AtomicBoolean defaultNetworkCallbackActive = new AtomicBoolean(false);
-    private final AtomicReference<Network> currentlyBoundDefaultNetwork = new AtomicReference<>(null);
+    private final Object defaultNetworkBindingLock = new Object();
+    private boolean defaultNetworkCallbackActive = false;
+    private Network currentlyBoundDefaultNetwork = null;
@@
             public void onAvailable(`@NonNull` Network network) {
-                if (!defaultNetworkCallbackActive.get()) {
-                    Log.d(LOGTAG, "ignoring onAvailable for " + network + "; default callback inactive");
-                    return;
-                }
-                Log.d(LOGTAG, "default network became " + network + ", binding process to it");
-                try {
-                    if (connectivityManager.bindProcessToNetwork(network)) {
-                        currentlyBoundDefaultNetwork.set(network);
-                    } else {
-                        Log.w(LOGTAG, "bindProcessToNetwork returned false for " + network);
-                    }
-                } catch (Exception e) {
-                    Log.e(LOGTAG, "bindProcessToNetwork failed", e);
-                }
+                synchronized (defaultNetworkBindingLock) {
+                    if (!defaultNetworkCallbackActive) {
+                        Log.d(LOGTAG, "ignoring onAvailable for " + network + "; default callback inactive");
+                        return;
+                    }
+                    Log.d(LOGTAG, "default network became " + network + ", binding process to it");
+                    try {
+                        if (connectivityManager.bindProcessToNetwork(network)) {
+                            currentlyBoundDefaultNetwork = network;
+                        } else {
+                            Log.w(LOGTAG, "bindProcessToNetwork returned false for " + network);
+                        }
+                    } catch (Exception e) {
+                        Log.e(LOGTAG, "bindProcessToNetwork failed", e);
+                    }
+                }
             }
@@
             public void onLost(`@NonNull` Network network) {
-                if (!defaultNetworkCallbackActive.get()) { ... }
-                if (!network.equals(currentlyBoundDefaultNetwork.get())) { ... }
-                ...
+                synchronized (defaultNetworkBindingLock) {
+                    if (!defaultNetworkCallbackActive) { return; }
+                    if (!network.equals(currentlyBoundDefaultNetwork)) { return; }
+                    try {
+                        if (connectivityManager.bindProcessToNetwork(null)) {
+                            currentlyBoundDefaultNetwork = null;
+                        }
+                    } catch (Exception e) {
+                        Log.e(LOGTAG, "bindProcessToNetwork(null) failed", e);
+                    }
+                }
             }
@@
     public void unregisterNetworkCallback() {
-        defaultNetworkCallbackActive.set(false);
+        synchronized (defaultNetworkBindingLock) {
+            defaultNetworkCallbackActive = false;
+        }
@@
-        try {
-            connectivityManager.bindProcessToNetwork(null);
-            currentlyBoundDefaultNetwork.set(null);
-        } catch (Exception e) {
+        synchronized (defaultNetworkBindingLock) {
+            try {
+                connectivityManager.bindProcessToNetwork(null);
+                currentlyBoundDefaultNetwork = null;
+            } catch (Exception e) {
                 Log.e(LOGTAG, "bindProcessToNetwork(null) on unregister failed", e);
+            }
         }
     }
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@tool/src/main/java/io/netbird/client/tool/networks/NetworkChangeDetector.java`
around lines 68 - 108, The onAvailable/onLost callbacks in
initDefaultNetworkCallback have a TOCTOU race with unregisterNetworkCallback
because defaultNetworkCallbackActive is checked outside the bindProcessToNetwork
calls; make the gate check and the bind (and the currentlyBoundDefaultNetwork
updates) atomic by synchronizing them on a dedicated lock object (e.g., a
private final Object networkCallbackLock), i.e., wrap the checks of
defaultNetworkCallbackActive plus the subsequent
connectivityManager.bindProcessToNetwork(...) and
currentlyBoundDefaultNetwork.set/compareAndSet(...) in a
synchronized(networkCallbackLock) block, and also wrap the
unregisterNetworkCallback logic that sets defaultNetworkCallbackActive to false,
unregisters the callback, calls bindProcessToNetwork(null), and clears
currentlyBoundDefaultNetwork inside the same synchronized(networkCallbackLock)
to eliminate the race.


public void registerNetworkCallback() {
NetworkRequest.Builder builder = new NetworkRequest.Builder();
builder.addCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET);
connectivityManager.registerNetworkCallback(builder.build(), networkCallback);
defaultNetworkCallbackActive.set(true);
connectivityManager.registerDefaultNetworkCallback(defaultNetworkCallback);
}

public void unregisterNetworkCallback() {
defaultNetworkCallbackActive.set(false);
try {
connectivityManager.unregisterNetworkCallback(networkCallback);
} catch (Exception e) {
Log.e(LOGTAG, "failed to unregister network callback", e);
}
try {
connectivityManager.unregisterNetworkCallback(defaultNetworkCallback);
} catch (Exception e) {
Log.e(LOGTAG, "failed to unregister default network callback", e);
}
try {
connectivityManager.bindProcessToNetwork(null);
currentlyBoundDefaultNetwork.set(null);
} catch (Exception e) {
Log.e(LOGTAG, "bindProcessToNetwork(null) on unregister failed", e);
}
}

public void subscribe(NetworkAvailabilityListener listener) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public void deactivateMobile() {
this.listener.onNetworkLost(Constants.NetworkType.MOBILE);
}
}

private static class MockNetworkToggleListener implements NetworkToggleListener {
private int totalTimesNetworkTypeChanged = 0;

Expand All @@ -47,7 +47,7 @@ public void resetCounter() {
public void shouldNotifyListenerNetworkUpgraded() {
// Assemble:
var networkToggleListener = new MockNetworkToggleListener();
var networkAvailabilityListener = new ConcreteNetworkAvailabilityListener();
var networkAvailabilityListener = new ConcreteNetworkAvailabilityListener(() -> true);
networkAvailabilityListener.subscribe(networkToggleListener);

var networkChangeDetector = new MockNetworkChangeDetector(networkAvailabilityListener);
Expand All @@ -64,7 +64,7 @@ public void shouldNotifyListenerNetworkUpgraded() {
public void shouldNotifyListenerNetworkDowngraded() {
// Assemble:
var networkToggleListener = new MockNetworkToggleListener();
var networkAvailabilityListener = new ConcreteNetworkAvailabilityListener();
var networkAvailabilityListener = new ConcreteNetworkAvailabilityListener(() -> true);
networkAvailabilityListener.subscribe(networkToggleListener);

var networkChangeDetector = new MockNetworkChangeDetector(networkAvailabilityListener);
Expand All @@ -82,7 +82,7 @@ public void shouldNotifyListenerNetworkDowngraded() {
public void shouldNotNotifyListenerNetworkDidNotUpgrade() {
// Assemble:
var networkToggleListener = new MockNetworkToggleListener();
var networkAvailabilityListener = new ConcreteNetworkAvailabilityListener();
var networkAvailabilityListener = new ConcreteNetworkAvailabilityListener(() -> true);
networkAvailabilityListener.subscribe(networkToggleListener);

var networkChangeDetector = new MockNetworkChangeDetector(networkAvailabilityListener);
Expand All @@ -103,7 +103,7 @@ public void shouldNotNotifyListenerNetworkDidNotUpgrade() {
public void shouldNotNotifyListenerNoNetworksAvailable() {
// Assemble:
var networkToggleListener = new MockNetworkToggleListener();
var networkAvailabilityListener = new ConcreteNetworkAvailabilityListener();
var networkAvailabilityListener = new ConcreteNetworkAvailabilityListener(() -> true);
networkAvailabilityListener.subscribe(networkToggleListener);

var networkChangeDetector = new MockNetworkChangeDetector(networkAvailabilityListener);
Expand All @@ -118,4 +118,23 @@ public void shouldNotNotifyListenerNoNetworksAvailable() {
// Assert:
assertEquals(0, networkToggleListener.totalTimesNetworkTypeChanged);
}

@Test
public void shouldNotNotifyListenerWhenEngineNotRunning() {
// Assemble: engine never running, so initial onAvailable burst from
// Android must not trigger a restart.
var networkToggleListener = new MockNetworkToggleListener();
var networkAvailabilityListener = new ConcreteNetworkAvailabilityListener(() -> false);
networkAvailabilityListener.subscribe(networkToggleListener);

var networkChangeDetector = new MockNetworkChangeDetector(networkAvailabilityListener);

// Act:
networkChangeDetector.activateMobile();
networkChangeDetector.activateWifi();
networkChangeDetector.deactivateWifi();

// Assert:
assertEquals(0, networkToggleListener.totalTimesNetworkTypeChanged);
}
}
Loading