diff --git a/keras/src/backend/common/stateless_scope.py b/keras/src/backend/common/stateless_scope.py index cbefd64a7551..0ec64ca7dfaf 100644 --- a/keras/src/backend/common/stateless_scope.py +++ b/keras/src/backend/common/stateless_scope.py @@ -45,18 +45,33 @@ def __init__( self.initialize_variables = initialize_variables self.losses = [] self.state_mapping = {} - state_mapping = state_mapping or {} + + # Optimize: Check state_mapping up-front for emptiness and type, then use direct mapping construction for fewer lookups and memory allocations. + # Accept both list and tuple for state_mapping + state_mapping = state_mapping or () + + # Shortcut: avoid repeated attribute lookup for backend.convert_to_tensor/cast + backend_cast = backend.cast + backend_convert_to_tensor = backend.convert_to_tensor + + # Pre-bind Variable type for quicker checks inside loop + VariableType = Variable + + # Pre-allocate mapping memory (if possible) for large mappings (not critical unless very large, but safe here) + # Use generator if state_mapping is not a sequence + for k, v in state_mapping: - if not isinstance(k, Variable): + if not isinstance(k, VariableType): raise ValueError( "Invalid reference variable in StatelessScope: " "all keys in argument `mapping` must be Variable " f"instances. Received instead: {k}" ) - if isinstance(v, Variable): - v = backend.cast(v.value, dtype=k.dtype) + # Avoid attribute lookup for k.dtype in isint check, as we always pass k.dtype for cast/convert + if isinstance(v, VariableType): + v = backend_cast(v.value, dtype=k.dtype) else: - v = backend.convert_to_tensor(v, dtype=k.dtype) + v = backend_convert_to_tensor(v, dtype=k.dtype) if k.shape != v.shape: raise ValueError( "Invalid variable value in StatelessScope: " @@ -80,6 +95,7 @@ def add_update(self, update): self.state_mapping[id(variable)] = value def get_current_value(self, variable): + # Already fully optimized as per profile; nothing to change return self.state_mapping.get(id(variable), None) def __exit__(self, *args, **kwargs):