diff --git a/custom_components/mbtalive/config_flow.py b/custom_components/mbtalive/config_flow.py index 6c464f0..53fa00a 100644 --- a/custom_components/mbtalive/config_flow.py +++ b/custom_components/mbtalive/config_flow.py @@ -1,4 +1,5 @@ import logging +import re import voluptuous as vol from typing import Any, Dict, Optional @@ -14,6 +15,16 @@ _LOGGER = logging.getLogger(__name__) +# --- Custom validator for train number --- +def validate_train_number(value: str | None) -> str | None: + """Validate that train number is 2–4 alphanumeric characters or empty.""" + if value in ("", None): + return None # normalize empty + value = str(value).strip().upper() + if not re.fullmatch(r"[A-Z0-9]{2,4}", value): + raise vol.Invalid("Train number must be 2–4 letters/numbers or empty") + return value + def get_user_schema(default_api_key: str = "") -> vol.Schema: return vol.Schema({ @@ -21,10 +32,7 @@ def get_user_schema(default_api_key: str = "") -> vol.Schema: vol.Required("arrive_at", default=""): str, vol.Required("api_key", default=default_api_key): vol.All(str, vol.Length(min=32, max=32)), vol.Optional("max_trips", default=2): int, - vol.Optional("train", default=""): vol.Any( - vol.All(str, vol.Length(min=3, max=3)), - "" - ) + vol.Optional("train", default=""): validate_train_number, }) @@ -45,6 +53,10 @@ async def async_step_user( arrive_at = user_input.get("arrive_at") api_key = user_input.get("api_key", "").strip() train = user_input.get("train", "") + # Normalize empty train to None + if train in ("", None): + train = None + user_input["train"] = None if not api_key: errors["api_key"] = "required"