diff --git a/src/snowflake/snowpark/_internal/type_utils.py b/src/snowflake/snowpark/_internal/type_utils.py index dd816a6ba6..c016e0f0f0 100644 --- a/src/snowflake/snowpark/_internal/type_utils.py +++ b/src/snowflake/snowpark/_internal/type_utils.py @@ -1324,22 +1324,36 @@ def parse_struct_field_list(fields_str: str) -> Optional[StructType]: def split_top_level_comma_fields(s: str) -> List[str]: """ - Splits 's' by commas not enclosed in matching brackets. + Splits 's' by commas not enclosed in matching brackets or inside quoted + identifiers. + Example: "int, array, decimal(10,2)" => ["int", "array", "decimal(10,2)"]. + + Quoted-identifier-aware: ``,`` ``(`` ``)`` ``<`` ``>`` characters that + appear inside a ``"..."`` quoted span (e.g. an OBJECT field name like + ``"a, b"`` or ``"xz"``) are skipped over and do not affect the + bracket counter or split positions. Snowflake uses ``""`` as the only + in-band escape inside a quoted identifier. """ parts = [] bracket_depth = 0 start_idx = 0 - for i, c in enumerate(s): - if c in ["<", "("]: + i = 0 + while i < len(s): + c = s[i] + if c == '"': + i = _scan_quoted_identifier(s, i) + continue + if c in ("<", "("): bracket_depth += 1 - elif c in [">", ")"]: + elif c in (">", ")"): bracket_depth -= 1 if bracket_depth < 0: raise ValueError(f"Mismatched bracket in '{s}'.") elif c == "," and bracket_depth == 0: parts.append(s[start_idx:i].strip()) start_idx = i + 1 + i += 1 parts.append(s[start_idx:].strip()) return parts @@ -1406,6 +1420,60 @@ def _lookup_simple_type(name: str, original: str) -> DataType: raise ValueError(f"'{original}' is not a supported type") +def _scan_quoted_identifier(s: str, start: int) -> int: + """Return the index just past a quoted identifier that begins at ``s[start] == '"'``. + + Snowflake's identifier grammar (``SFSqlLexer.g`` ``QuotedString`` rule) allows + any character inside ``"..."`` and uses ``""`` as the only in-band escape for a + literal ``"``. The canonical inverse is + ``SqlIdentifierUtils.java::quote()`` which doubles every embedded ``"`` + and nothing else. + + Raises ``ValueError`` if the closing quote is missing. + + Precondition: ``s[start] == '"'``. All current callers guard on this; we + do not re-check here because asserts are stripped under ``python -O`` and + promoting to ``raise`` would be overkill for a private helper. + """ + i = start + 1 + while i < len(s): + if s[i] == '"': + if i + 1 < len(s) and s[i + 1] == '"': + i += 2 # escaped "" inside the name; keep scanning + continue + return i + 1 # index just past the closing quote + i += 1 + raise ValueError(f"Unterminated quoted identifier in: {s!r}") + + +def _split_object_field(field_def: str) -> Tuple[str, str]: + """Split a single OBJECT field definition into ``(name_token, remainder)``. + + Quoted-identifier-aware: + ``foo NUMBER`` -> (``foo``, ``NUMBER``) + ``"col with space" NUMBER`` -> (``"col with space"``, ``NUMBER``) + ``"a, b" NUMBER`` -> (``"a, b"``, ``NUMBER``) + + The returned ``name_token`` still carries any surrounding quotes so the + caller can decide whether to unquote (via ``_strip_quoted_identifier``) + while preserving the raw form for diagnostics. + """ + field_def = field_def.lstrip() + if not field_def: + raise ValueError("Empty OBJECT field definition") + if field_def[0] == '"': + end = _scan_quoted_identifier(field_def, 0) + name_token = field_def[:end] + remainder = field_def[end:].lstrip() + if not remainder: + raise ValueError(f"Cannot parse OBJECT field definition: {field_def!r}") + return name_token, remainder + parts = field_def.split(None, 1) + if len(parts) != 2: + raise ValueError(f"Cannot parse OBJECT field definition: {field_def!r}") + return parts[0], parts[1] + + def _extract_paren_content(type_str: str) -> Optional[Tuple[str, str]]: """Extract the base keyword and content inside matching parentheses. @@ -1418,6 +1486,10 @@ def _extract_paren_content(type_str: str) -> Optional[Tuple[str, str]]: backend (``INFER_SCHEMA``), so we fail loudly rather than silently degrade to ``VariantType``. + Quoted-identifier-aware: ``(`` and ``)`` characters appearing inside a + ``"..."`` quoted name (``OBJECT("a(b)c" TEXT)``) are skipped over and do + not affect the depth counter. + E.g. "OBJECT(city VARCHAR, zip NUMBER(38,0))" -> ("OBJECT", "city VARCHAR, zip NUMBER(38,0)") """ paren_idx = type_str.find("(") @@ -1425,13 +1497,19 @@ def _extract_paren_content(type_str: str) -> Optional[Tuple[str, str]]: return None base = type_str[:paren_idx].strip() depth = 0 - for i in range(paren_idx, len(type_str)): - if type_str[i] == "(": + i = paren_idx + while i < len(type_str): + c = type_str[i] + if c == '"': + i = _scan_quoted_identifier(type_str, i) + continue + if c == "(": depth += 1 - elif type_str[i] == ")": + elif c == ")": depth -= 1 if depth == 0: return base, type_str[paren_idx + 1 : i] + i += 1 raise ValueError(f"Unbalanced parentheses in type string: '{type_str}'") @@ -1499,13 +1577,17 @@ def _sf_type_to_type_object(type_str: str) -> DataType: # SQL grammar for OBJECT types; raise so backend bugs or # malformed input surface loudly. raise ValueError(f"Empty field in OBJECT type: '{type_str}'") - parts = field_def.split(None, 1) - if len(parts) != 2: - raise ValueError(f"Cannot parse OBJECT field definition: '{field_def}'") - field_name = parts[0] - type_part, nullable = extract_nullable_keyword(parts[1]) + # Quoted-identifier-aware split so OBJECT field names containing + # spaces, commas, parens, or other non-bare characters survive + # round-trip through INFER_SCHEMA. The name token is passed to + # ``StructField`` with its surrounding quotes intact so + # ``ColumnIdentifier``'s ``ALREADY_QUOTED`` branch preserves + # mixed-case names verbatim (without quotes the bare-identifier + # rule would case-fold ``"Foo"`` to ``"FOO"``). + name_token, type_remainder = _split_object_field(field_def) + type_part, nullable = extract_nullable_keyword(type_remainder) field_type = _sf_type_to_type_object(type_part) - struct_fields.append(StructField(field_name, field_type, nullable=nullable)) + struct_fields.append(StructField(name_token, field_type, nullable=nullable)) return StructType(struct_fields, structured=True) precision_scale = get_number_precision_scale(type_str) diff --git a/tests/unit/test_dataframe_reader_type_parsing.py b/tests/unit/test_dataframe_reader_type_parsing.py index 6522d43ee2..9e2882972a 100644 --- a/tests/unit/test_dataframe_reader_type_parsing.py +++ b/tests/unit/test_dataframe_reader_type_parsing.py @@ -10,7 +10,10 @@ from snowflake.snowpark._internal.type_utils import ( _extract_paren_content, _parse_structured_type_str, + _scan_quoted_identifier, _sf_type_to_type_object, + _split_object_field, + split_top_level_comma_fields, ) from snowflake.snowpark.types import ( ArrayType, @@ -83,6 +86,284 @@ def test_trailing_text_after_close(self): assert result == ("ARRAY", "VARCHAR") +# --------------------------------------------------------------------------- +# Quoted-identifier helpers (SNOW-3440288 follow-up) +# +# These pin the quote-aware behavior added to support OBJECT field names that +# require quoting in Snowflake's identifier rules (mixed case, reserved words, +# names containing characters outside ``[A-Za-z0-9_$]``, etc.). +# +# The grammar pinned here matches the Snowflake server lexer +# (``GlobalServices/modules/sql/sql-antlr/.../SFSqlLexer.g`` ``QuotedString`` +# rule) and the canonical ``SqlIdentifierUtils.java::quote()`` helper: +# ``""`` is the only in-band escape for a literal ``"``. +# --------------------------------------------------------------------------- + + +class TestScanQuotedIdentifier: + def test_simple_quoted_name(self): + s = '"foo" rest' + assert _scan_quoted_identifier(s, 0) == 5 # index just past closing " + + def test_escaped_quote_inside(self): + # "a""b" is a 6-char span (positions 0-5); index just past it is 6 + s = '"a""b" rest' + assert _scan_quoted_identifier(s, 0) == 6 + + def test_unterminated_raises(self): + with pytest.raises(ValueError, match="Unterminated quoted identifier"): + _scan_quoted_identifier('"foo', 0) + + def test_only_escaped_then_unterminated_raises(self): + # "a"" — every " is the start of an escape pair; no closing quote. + with pytest.raises(ValueError, match="Unterminated quoted identifier"): + _scan_quoted_identifier('"a""', 0) + + +class TestSplitObjectField: + def test_bare_name_split_on_whitespace(self): + assert _split_object_field("foo NUMBER") == ("foo", "NUMBER") + + def test_quoted_name_with_internal_space(self): + assert _split_object_field('"col with space" NUMBER') == ( + '"col with space"', + "NUMBER", + ) + + def test_quoted_name_with_internal_comma(self): + assert _split_object_field('"a, b" NUMBER') == ('"a, b"', "NUMBER") + + def test_quoted_name_with_escaped_quote(self): + assert _split_object_field('"a""b" NUMBER') == ('"a""b"', "NUMBER") + + def test_quoted_name_with_no_type_raises(self): + with pytest.raises(ValueError, match="Cannot parse OBJECT field definition"): + _split_object_field('"foo"') + + def test_bare_name_with_no_type_raises(self): + with pytest.raises(ValueError, match="Cannot parse OBJECT field definition"): + _split_object_field("foo") + + def test_empty_string_raises(self): + with pytest.raises(ValueError, match="Empty OBJECT field definition"): + _split_object_field("") + + def test_whitespace_only_raises(self): + with pytest.raises(ValueError, match="Empty OBJECT field definition"): + _split_object_field(" \t ") + + +class TestSplitTopLevelCommaFieldsQuoteAware: + def test_comma_inside_quoted_name_not_a_split(self): + # Single field whose name contains a comma — must not split. + assert split_top_level_comma_fields('"a, b" NUMBER') == ['"a, b" NUMBER'] + + def test_paren_inside_quoted_name_not_a_bracket(self): + # The bracket counter must not be confused by ( ) inside a quoted name. + assert split_top_level_comma_fields('"p(q)r" NUMBER, c TEXT') == [ + '"p(q)r" NUMBER', + "c TEXT", + ] + + def test_angle_bracket_inside_quoted_name(self): + assert split_top_level_comma_fields('"xz" NUMBER, c TEXT') == [ + '"xz" NUMBER', + "c TEXT", + ] + + def test_escaped_quote_inside_quoted_name(self): + assert split_top_level_comma_fields('"a""b" NUMBER, c TEXT') == [ + '"a""b" NUMBER', + "c TEXT", + ] + + def test_mixed_quoted_and_bare_fields(self): + assert split_top_level_comma_fields( + '"first-name" TEXT, age NUMBER(10,0), "col with space" NUMBER' + ) == [ + '"first-name" TEXT', + "age NUMBER(10,0)", + '"col with space" NUMBER', + ] + + +class TestExtractParenContentQuoteAware: + def test_paren_inside_quoted_name(self): + # The outer ( ) must be matched correctly; inner ( ) inside a quoted + # name must not affect depth. + assert _extract_paren_content('OBJECT("a(b)c" TEXT)') == ( + "OBJECT", + '"a(b)c" TEXT', + ) + + def test_close_paren_inside_quoted_name(self): + # A lone ) inside a quoted name must not close the outer paren early. + assert _extract_paren_content('OBJECT(")foo" TEXT)') == ( + "OBJECT", + '")foo" TEXT', + ) + + +class TestSfTypeToTypeObjectQuotedNames: + """End-to-end pins for the parser quote-awareness fix. + + Field-name expectations match what ``StructField`` returns after running + the (possibly stripped) name through ``ColumnIdentifier`` / + ``quote_name``: pure-uppercase ASCII names are emitted bare; everything + else is emitted with quotes (and ``""``-escaped where the unquoted form + contained a ``"``). This is the same convention as the existing pre-quote + tests above (``"city"`` → ``"CITY"``). + """ + + # --- happy path / case preservation --- + + def test_bare_lowercase(self): + result = _sf_type_to_type_object("OBJECT(foo NUMBER)") + assert isinstance(result, StructType) + assert result.fields[0].name == "FOO" + + def test_quoted_mixed_case_preserved(self): + result = _sf_type_to_type_object('OBJECT("Foo" NUMBER)') + assert isinstance(result, StructType) + assert result.fields[0].name == '"Foo"' + + def test_pure_uppercase_quoted_equals_bare(self): + # `OBJECT("FOO" NUMBER)` and `OBJECT(FOO NUMBER)` must produce + # identical StructType results — pin so a future "optimization" + # that case-folds the unquoted form differently is caught. + quoted = _sf_type_to_type_object('OBJECT("FOO" NUMBER)') + bare = _sf_type_to_type_object("OBJECT(FOO NUMBER)") + assert quoted == bare + assert quoted.fields[0].name == "FOO" + + def test_two_quoted_siblings(self): + result = _sf_type_to_type_object('OBJECT("Foo" NUMBER, "Bar baz" TEXT)') + assert [f.name for f in result.fields] == ['"Foo"', '"Bar baz"'] + + # --- field names that need quoting --- + + def test_reserved_word_as_quoted_name(self): + result = _sf_type_to_type_object('OBJECT("select" NUMBER)') + assert result.fields[0].name == '"select"' + + def test_hyphenated_name(self): + result = _sf_type_to_type_object('OBJECT("first-name" TEXT)') + assert result.fields[0].name == '"first-name"' + + def test_name_with_space(self): + # Pre-patch this would fail because field_def.split(None, 1) would + # split the quoted name on its internal whitespace. + result = _sf_type_to_type_object('OBJECT("col with space" NUMBER)') + assert result.fields[0].name == '"col with space"' + + def test_name_with_dollar(self): + result = _sf_type_to_type_object('OBJECT("$weird" NUMBER)') + assert result.fields[0].name == '"$weird"' + + def test_name_with_backtick(self): + result = _sf_type_to_type_object('OBJECT("back`tick`" NUMBER)') + assert result.fields[0].name == '"back`tick`"' + + # --- escape and embedded delimiters in name --- + + def test_escaped_quote_in_name(self): + # "" inside a quoted identifier represents one literal " character. + result = _sf_type_to_type_object('OBJECT("a""b" NUMBER)') + # Round-tripped through quote_name → re-escaped on display. + assert result.fields[0].name == '"a""b"' + + def test_comma_in_name_does_not_split_fields(self): + result = _sf_type_to_type_object('OBJECT("a, b" NUMBER, c TEXT)') + assert len(result.fields) == 2 + assert result.fields[0].name == '"a, b"' + assert result.fields[1].name == "C" + + def test_parens_in_name_do_not_unbalance(self): + result = _sf_type_to_type_object('OBJECT("p(q)r" NUMBER)') + assert result.fields[0].name == '"p(q)r"' + + def test_angle_brackets_in_name(self): + result = _sf_type_to_type_object('OBJECT("xz" NUMBER)') + assert result.fields[0].name == '"xz"' + + # --- nullability boundary (footnote 2) --- + # + # Documents the boundary so a future refactor that "simplifies" by + # calling extract_nullable_keyword on the entire field_def (including + # the quoted name) is caught — the trailing NOT NULL on the *type* must + # be stripped, the literal text "NOT NULL" inside the *name* must not. + + def test_quoted_name_ending_in_not_null_text(self): + result = _sf_type_to_type_object('OBJECT("x NOT NULL" TEXT)') + assert result.fields[0].name == '"x NOT NULL"' + assert result.fields[0].nullable is True + assert result.fields[0].datatype == StringType() + + def test_quoted_name_with_not_null_text_and_non_nullable_type(self): + result = _sf_type_to_type_object('OBJECT("y NOT NULL" TEXT NOT NULL)') + assert result.fields[0].name == '"y NOT NULL"' + assert result.fields[0].nullable is False + assert result.fields[0].datatype == StringType() + + # --- nesting + container recursion --- + + def test_object_in_object_with_quoted_names(self): + result = _sf_type_to_type_object('OBJECT("a-b" OBJECT("c d" NUMBER))') + assert isinstance(result, StructType) + assert result.fields[0].name == '"a-b"' + inner = result.fields[0].datatype + assert isinstance(inner, StructType) + assert inner.fields[0].name == '"c d"' + assert inner.fields[0].datatype == DecimalType(38, 0) + + def test_array_of_object_with_quoted_field(self): + result = _sf_type_to_type_object('ARRAY(OBJECT("first-name" TEXT))') + assert isinstance(result, ArrayType) + inner = result.element_type + assert isinstance(inner, StructType) + assert inner.fields[0].name == '"first-name"' + + def test_map_value_object_with_quoted_field(self): + result = _sf_type_to_type_object('MAP(TEXT, OBJECT("first-name" TEXT))') + assert isinstance(result, MapType) + inner = result.value_type + assert isinstance(inner, StructType) + assert inner.fields[0].name == '"first-name"' + + def test_quoted_name_with_array_not_null(self): + result = _sf_type_to_type_object( + 'OBJECT("a b" ARRAY(NUMBER NOT NULL) NOT NULL)' + ) + assert isinstance(result, StructType) + assert result.fields[0].name == '"a b"' + assert result.fields[0].nullable is False + arr = result.fields[0].datatype + assert isinstance(arr, ArrayType) + assert arr.contains_null is False + + # --- malformed inputs surface as ValueError, not silent corruption --- + # + # Pin the error-surfacing contract for the two adversarial shapes that can + # actually reach `_split_object_field` after `split_top_level_comma_fields` + # (which is greedy on `"..."` spans). The parser does *not* validate + # OBJECT inputs upstream; it relies on INFER_SCHEMA emitting + # grammar-compliant strings. These tests pin that any deviation raises a + # clear `ValueError` from the appropriate parse step. + + def test_quoted_name_with_garbage_type_raises_unsupported_type(self): + # `OBJECT("a NUM"BER)` — `_scan_quoted_identifier` greedily matches + # `"a NUM"`, leaves `BER` as the type token, and `_sf_type_to_type_object` + # rejects the unknown type rather than silently producing a struct. + with pytest.raises(ValueError, match="not a supported type"): + _sf_type_to_type_object('OBJECT("a NUM"BER)') + + def test_unterminated_quoted_name_raises(self): + # `OBJECT("a NUMBER)` — no closing `"`. `_scan_quoted_identifier` + # raises rather than silently consuming the trailing `)`. + with pytest.raises(ValueError, match="Unterminated quoted identifier"): + _sf_type_to_type_object('OBJECT("a NUMBER)') + + # --------------------------------------------------------------------------- # _sf_type_to_type_object # ---------------------------------------------------------------------------