Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 95 additions & 13 deletions src/snowflake/snowpark/_internal/type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1324,22 +1324,36 @@ def parse_struct_field_list(fields_str: str) -> Optional[StructType]:

def split_top_level_comma_fields(s: str) -> List[str]:
"""
Splits 's' by commas not enclosed in matching brackets.
Splits 's' by commas not enclosed in matching brackets or inside quoted
identifiers.

Example: "int, array<long>, decimal(10,2)" => ["int", "array<long>", "decimal(10,2)"].

Quoted-identifier-aware: ``,`` ``(`` ``)`` ``<`` ``>`` characters that
appear inside a ``"..."`` quoted span (e.g. an OBJECT field name like
``"a, b"`` or ``"x<y>z"``) are skipped over and do not affect the
bracket counter or split positions. Snowflake uses ``""`` as the only
in-band escape inside a quoted identifier.
"""
parts = []
bracket_depth = 0
start_idx = 0
for i, c in enumerate(s):
if c in ["<", "("]:
i = 0
while i < len(s):
c = s[i]
if c == '"':
i = _scan_quoted_identifier(s, i)
continue
if c in ("<", "("):
bracket_depth += 1
elif c in [">", ")"]:
elif c in (">", ")"):
bracket_depth -= 1
if bracket_depth < 0:
raise ValueError(f"Mismatched bracket in '{s}'.")
elif c == "," and bracket_depth == 0:
parts.append(s[start_idx:i].strip())
start_idx = i + 1
i += 1
parts.append(s[start_idx:].strip())
return parts

Expand Down Expand Up @@ -1406,6 +1420,60 @@ def _lookup_simple_type(name: str, original: str) -> DataType:
raise ValueError(f"'{original}' is not a supported type")


def _scan_quoted_identifier(s: str, start: int) -> int:
"""Return the index just past a quoted identifier that begins at ``s[start] == '"'``.

Snowflake's identifier grammar (``SFSqlLexer.g`` ``QuotedString`` rule) allows
any character inside ``"..."`` and uses ``""`` as the only in-band escape for a
literal ``"``. The canonical inverse is
``SqlIdentifierUtils.java::quote()`` which doubles every embedded ``"``
and nothing else.

Raises ``ValueError`` if the closing quote is missing.

Precondition: ``s[start] == '"'``. All current callers guard on this; we
do not re-check here because asserts are stripped under ``python -O`` and
promoting to ``raise`` would be overkill for a private helper.
"""
i = start + 1
while i < len(s):
if s[i] == '"':
if i + 1 < len(s) and s[i + 1] == '"':
i += 2 # escaped "" inside the name; keep scanning
continue
return i + 1 # index just past the closing quote
Comment on lines +1440 to +1444
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: prefer writing like this so it's slightly easier to tell what this check is doing

Suggested change
if s[i] == '"':
if i + 1 < len(s) and s[i + 1] == '"':
i += 2 # escaped "" inside the name; keep scanning
continue
return i + 1 # index just past the closing quote
if s[i:i + 1] == '""': # check for a "" escape sequence in the name
i += 2
continue
elif s[i] == '"':
# found closing quote, return the index just past it
return i + 1 # index just past the closing quote

i += 1
raise ValueError(f"Unterminated quoted identifier in: {s!r}")


def _split_object_field(field_def: str) -> Tuple[str, str]:
"""Split a single OBJECT field definition into ``(name_token, remainder)``.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible that a malformed field_def like "a NUM"BER reach here? Or is this already handled in the upstream?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's possible. Added a test that we raise exception with clear error message.


Quoted-identifier-aware:
``foo NUMBER`` -> (``foo``, ``NUMBER``)
``"col with space" NUMBER`` -> (``"col with space"``, ``NUMBER``)
``"a, b" NUMBER`` -> (``"a, b"``, ``NUMBER``)

The returned ``name_token`` still carries any surrounding quotes so the
caller can decide whether to unquote (via ``_strip_quoted_identifier``)
while preserving the raw form for diagnostics.
"""
field_def = field_def.lstrip()
if not field_def:
raise ValueError("Empty OBJECT field definition")
if field_def[0] == '"':
end = _scan_quoted_identifier(field_def, 0)
name_token = field_def[:end]
remainder = field_def[end:].lstrip()
if not remainder:
raise ValueError(f"Cannot parse OBJECT field definition: {field_def!r}")
return name_token, remainder
parts = field_def.split(None, 1)
if len(parts) != 2:
raise ValueError(f"Cannot parse OBJECT field definition: {field_def!r}")
return parts[0], parts[1]


def _extract_paren_content(type_str: str) -> Optional[Tuple[str, str]]:
"""Extract the base keyword and content inside matching parentheses.

Expand All @@ -1418,20 +1486,30 @@ def _extract_paren_content(type_str: str) -> Optional[Tuple[str, str]]:
backend (``INFER_SCHEMA``), so we fail loudly rather than silently
degrade to ``VariantType``.

Quoted-identifier-aware: ``(`` and ``)`` characters appearing inside a
``"..."`` quoted name (``OBJECT("a(b)c" TEXT)``) are skipped over and do
not affect the depth counter.

E.g. "OBJECT(city VARCHAR, zip NUMBER(38,0))" -> ("OBJECT", "city VARCHAR, zip NUMBER(38,0)")
"""
paren_idx = type_str.find("(")
if paren_idx == -1:
return None
base = type_str[:paren_idx].strip()
depth = 0
for i in range(paren_idx, len(type_str)):
if type_str[i] == "(":
i = paren_idx
while i < len(type_str):
c = type_str[i]
if c == '"':
i = _scan_quoted_identifier(type_str, i)
continue
if c == "(":
depth += 1
elif type_str[i] == ")":
elif c == ")":
depth -= 1
if depth == 0:
return base, type_str[paren_idx + 1 : i]
i += 1
raise ValueError(f"Unbalanced parentheses in type string: '{type_str}'")


Expand Down Expand Up @@ -1499,13 +1577,17 @@ def _sf_type_to_type_object(type_str: str) -> DataType:
# SQL grammar for OBJECT types; raise so backend bugs or
# malformed input surface loudly.
raise ValueError(f"Empty field in OBJECT type: '{type_str}'")
parts = field_def.split(None, 1)
if len(parts) != 2:
raise ValueError(f"Cannot parse OBJECT field definition: '{field_def}'")
field_name = parts[0]
type_part, nullable = extract_nullable_keyword(parts[1])
# Quoted-identifier-aware split so OBJECT field names containing
# spaces, commas, parens, or other non-bare characters survive
# round-trip through INFER_SCHEMA. The name token is passed to
# ``StructField`` with its surrounding quotes intact so
# ``ColumnIdentifier``'s ``ALREADY_QUOTED`` branch preserves
# mixed-case names verbatim (without quotes the bare-identifier
# rule would case-fold ``"Foo"`` to ``"FOO"``).
name_token, type_remainder = _split_object_field(field_def)
type_part, nullable = extract_nullable_keyword(type_remainder)
field_type = _sf_type_to_type_object(type_part)
struct_fields.append(StructField(field_name, field_type, nullable=nullable))
struct_fields.append(StructField(name_token, field_type, nullable=nullable))
return StructType(struct_fields, structured=True)

precision_scale = get_number_precision_scale(type_str)
Expand Down
Loading
Loading