Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions keras/src/backend/common/stateless_scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,33 @@ def __init__(
self.initialize_variables = initialize_variables
self.losses = []
self.state_mapping = {}
state_mapping = state_mapping or {}

# Optimize: Check state_mapping up-front for emptiness and type, then use direct mapping construction for fewer lookups and memory allocations.
# Accept both list and tuple for state_mapping
state_mapping = state_mapping or ()

# Shortcut: avoid repeated attribute lookup for backend.convert_to_tensor/cast
backend_cast = backend.cast
backend_convert_to_tensor = backend.convert_to_tensor

# Pre-bind Variable type for quicker checks inside loop
VariableType = Variable

# Pre-allocate mapping memory (if possible) for large mappings (not critical unless very large, but safe here)
# Use generator if state_mapping is not a sequence

for k, v in state_mapping:
if not isinstance(k, Variable):
if not isinstance(k, VariableType):
raise ValueError(
"Invalid reference variable in StatelessScope: "
"all keys in argument `mapping` must be Variable "
f"instances. Received instead: {k}"
)
if isinstance(v, Variable):
v = backend.cast(v.value, dtype=k.dtype)
# Avoid attribute lookup for k.dtype in isint check, as we always pass k.dtype for cast/convert
if isinstance(v, VariableType):
v = backend_cast(v.value, dtype=k.dtype)
else:
v = backend.convert_to_tensor(v, dtype=k.dtype)
v = backend_convert_to_tensor(v, dtype=k.dtype)
if k.shape != v.shape:
raise ValueError(
"Invalid variable value in StatelessScope: "
Expand All @@ -80,6 +95,7 @@ def add_update(self, update):
self.state_mapping[id(variable)] = value

def get_current_value(self, variable):
# Already fully optimized as per profile; nothing to change
return self.state_mapping.get(id(variable), None)

def __exit__(self, *args, **kwargs):
Expand Down