Skip to content

Commit b21141a

Browse files
committed
Added wildcard signal handling while counting multiples in get_multiples_count() + in check_signals_allowlist().
Changed request_signals from list to set in order to use `issubset()` function
1 parent b625635 commit b21141a

File tree

1 file changed

+24
-9
lines changed

1 file changed

+24
-9
lines changed

src/server/_limiter.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,15 @@
88
from ._config import RATE_LIMIT, RATELIMIT_STORAGE_URL, REDIS_HOST, REDIS_PASSWORD
99
from ._exceptions import ValidationFailedException
1010
from ._params import extract_dates, extract_integers, extract_strings, parse_source_signal_sets
11-
from ._security import _is_public_route, current_user, require_api_key, show_no_api_key_warning, resolve_auth_token, ERROR_MSG_RATE_LIMIT, ERROR_MSG_MULTIPLES
11+
from ._security import (
12+
_is_public_route,
13+
current_user,
14+
require_api_key,
15+
show_no_api_key_warning,
16+
resolve_auth_token,
17+
ERROR_MSG_RATE_LIMIT,
18+
ERROR_MSG_MULTIPLES,
19+
)
1220

1321

1422
def deduct_on_success(response: Response) -> bool:
@@ -52,8 +60,9 @@ def get_multiples_count(request):
5260
if "window" in request.args.keys():
5361
multiple_selection_allowed -= 1
5462
for k, v in request.args.items():
55-
if v == "*":
63+
if "*" in v:
5664
multiple_selection_allowed -= 1
65+
continue
5766
try:
5867
vals = multiples.get(k)(k)
5968
if len(vals) >= 2:
@@ -70,17 +79,23 @@ def get_multiples_count(request):
7079

7180
def check_signals_allowlist(request):
7281
signals_allowlist = {":".join(ss_pair) for ss_pair in DashboardSignals().srcsig_list()}
73-
request_signals = []
74-
request_args = request.args.keys()
75-
if "signal" in request_args or "signals" in request_args:
82+
request_signals = set()
83+
try:
7684
source_signal_sets = parse_source_signal_sets()
7785
for source_signal in source_signal_sets:
78-
if isinstance(source_signal.signal, list):
79-
for signal in source_signal.signal:
80-
request_signals.append(f"{source_signal.source}:{signal}")
86+
# source_signal.signal is expected to be eiter list or bool:
87+
# in case of bool, we have wildcard signal -> return False as there are no chances that
88+
# all signals from given source will be whitelisted
89+
# in case of list, we have list of signals
90+
if isinstance(source_signal.signal, bool):
91+
return False
92+
for signal in source_signal.signal:
93+
request_signals.add(f"{source_signal.source}:{signal}")
94+
except ValidationFailedException:
95+
return False
8196
if len(request_signals) == 0:
8297
return False
83-
return all([signal in signals_allowlist for signal in request_signals])
98+
return set(request_signals).issubset(signals_allowlist)
8499

85100

86101
def _resolve_tracking_key() -> str:

0 commit comments

Comments
 (0)