Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,009 changes: 876 additions & 133 deletions src/snowflake/snowpark/catalog.py

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions src/snowflake/snowpark/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,3 +283,9 @@ class SnowparkInvalidObjectNameException(SnowparkGeneralException):
"""

pass


class NotFoundError(SnowparkClientException):
"""Raised when we encounter an object is not found."""

pass
19 changes: 16 additions & 3 deletions src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@

_session_management_lock = RLock()
_active_sessions: Set["Session"] = set()
_USE_SQL_BASE_OPTION_KEY = "_use_sql_base"
_PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS_STRING = (
"PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS"
)
Expand Down Expand Up @@ -463,6 +464,13 @@ def __init__(self) -> None:
self._app_name = None
self._format_json = None

@staticmethod
def _connection_options(
options: Dict[str, Union[int, str]]
) -> Dict[str, Union[int, str]]:
# Internal Session-only options must not be forwarded to connector connect(**kwargs).
return {k: v for k, v in options.items() if k != _USE_SQL_BASE_OPTION_KEY}

def _remove_config(self, key: str) -> "Session.SessionBuilder":
"""Only used in test."""
self._options.pop(key, None)
Expand Down Expand Up @@ -569,8 +577,11 @@ def _create_internal(
# Set paramstyle to qmark by default to be consistent with previous behavior
if "paramstyle" not in self._options:
self._options["paramstyle"] = "qmark"
connection_options = self._connection_options(self._options)
new_session = Session(
ServerConnection({}, conn) if conn else ServerConnection(self._options),
ServerConnection({}, conn)
if conn
else ServerConnection(connection_options),
self._options,
)

Expand Down Expand Up @@ -628,6 +639,8 @@ def __init__(
"""
self.version = get_version()
self._session_stage = None
options = options or {}
self._use_sql_base = options.pop(_USE_SQL_BASE_OPTION_KEY, True)

if isinstance(conn, MockServerConnection):
self._udf_registration = MockUDFRegistration(self)
Expand Down Expand Up @@ -848,7 +861,7 @@ def __init__(
_PYTHON_SNOWPARK_COLLECT_TELEMETRY_AT_CRITICAL_PATH_VERSION
)
)
self._conf = self.RuntimeConfig(self, options or {})
self._conf = self.RuntimeConfig(self, options)
self._runtime_version_from_requirement: str = None
self._temp_table_auto_cleaner: TempTableAutoCleaner = TempTableAutoCleaner(self)
self._sp_profiler = StoredProcedureProfiler(session=self)
Expand Down Expand Up @@ -961,7 +974,7 @@ def catalog(self):
external_feature_name="Session.catalog",
raise_error=NotImplementedError,
)
self._catalog = Catalog(self)
self._catalog = Catalog(self, _use_sql_base=self._use_sql_base)
return self._catalog

def close(self) -> None:
Expand Down
2 changes: 2 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from snowflake.snowpark._internal.utils import warning_dict
from .ast.conftest import default_unparser_path

pytest_plugins = ("tests.integ.test_catalog",)

logging.getLogger("snowflake.connector").setLevel(logging.ERROR)

excluded_frontend_files = [
Expand Down
211 changes: 18 additions & 193 deletions tests/integ/test_catalog.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,26 @@
#
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
#
"""Catalog integration tests and shared fixtures.

Mode-agnostic tests (same behavior for SQL and REST catalog backends) live in
this module. Backend-specific tests are in ``test_catalog_sql_mode.py`` and
``test_catalog_rest_mode.py``, which reuse the fixtures defined here via
``pytest_plugins`` in ``conftest.py``.
"""

from unittest.mock import patch
import uuid
from unittest.mock import patch

import pytest

from snowflake.snowpark._internal.analyzer.analyzer_utils import unquote_if_quoted
from snowflake.snowpark.catalog import Catalog
from snowflake.snowpark.context import _DEFAULT_ARTIFACT_REPOSITORY
from snowflake.snowpark.session import Session
from snowflake.snowpark.types import IntegerType
from snowflake.core.exceptions import APIError
from snowflake.snowpark.context import _DEFAULT_ARTIFACT_REPOSITORY


pytestmark = [
pytest.mark.xfail(
"config.getoption('local_testing_mode', default=False)",
reason="deepcopy is not supported and required by local testing",
run=False,
),
pytest.mark.xfail(
raises=APIError,
reason="Failure due to warehouse overload",
),
]

CATALOG_TEMP_OBJECT_PREFIX = "SP_CATALOG_TEMP"
DOES_NOT_EXIST_PATTERN = "does_not_exist_.*"


def get_temp_name(type: str) -> str:
Expand Down Expand Up @@ -186,34 +180,13 @@ def temp_udf2(session, temp_db1, temp_schema1):
)


DOES_NOT_EXIST_PATTERN = "does_not_exist_.*"


def test_list_db(session, temp_db1, temp_db2):
catalog: Catalog = session.catalog
db_list = catalog.list_databases(pattern=f"{CATALOG_TEMP_OBJECT_PREFIX}_DB_*")
assert {db.name for db in db_list} >= {temp_db1, temp_db2}

db_list = catalog.list_databases(like=f"{CATALOG_TEMP_OBJECT_PREFIX}_DB_%")
assert {db.name for db in db_list} >= {temp_db1, temp_db2}


def test_list_schema(session, temp_db1, temp_schema1, temp_schema2):
catalog: Catalog = session.catalog
assert (
len(catalog.list_databases(pattern=f"{CATALOG_TEMP_OBJECT_PREFIX}_SCHEMA_.*"))
== 0
)

schema_list = catalog.list_schemas(
pattern=f"{CATALOG_TEMP_OBJECT_PREFIX}_SCHEMA_.*", database=temp_db1
)
assert {schema.name for schema in schema_list} >= {temp_schema1, temp_schema2}

schema_list = catalog.list_schemas(
like=f"{CATALOG_TEMP_OBJECT_PREFIX}_SCHEMA_%", database=temp_db1
)
assert {schema.name for schema in schema_list} >= {temp_schema1, temp_schema2}
pytestmark = [
pytest.mark.xfail(
"config.getoption('local_testing_mode', default=False)",
reason="deepcopy is not supported and required by local testing",
run=False,
),
]


def test_list_tables(session, temp_db1, temp_schema1, temp_table1, temp_table2):
Expand Down Expand Up @@ -344,48 +317,6 @@ def test_list_udfs(session, temp_db1, temp_schema1, temp_udf1, temp_udf2):
assert {udf.name for udf in udf_list} >= {temp_udf1, temp_udf2}


def test_get_db_schema(session):
catalog: Catalog = session.catalog
current_db = session.get_current_database()
current_schema = session.get_current_schema()
assert catalog.get_database(current_db).name == unquote_if_quoted(current_db)
assert catalog.get_schema(current_schema).name == unquote_if_quoted(current_schema)


def test_get_table_view(session, temp_db1, temp_schema1, temp_table1, temp_view1):
catalog: Catalog = session.catalog
table = catalog.get_table(temp_table1, database=temp_db1, schema=temp_schema1)
assert table.name == temp_table1
assert table.database_name == temp_db1
assert table.schema_name == temp_schema1

view = catalog.get_view(temp_view1, database=temp_db1, schema=temp_schema1)
assert view.name == temp_view1
assert view.database_name == temp_db1
assert view.schema_name == temp_schema1


@pytest.mark.udf
def test_get_function_procedure_udf(
session, temp_db1, temp_schema1, temp_procedure1, temp_udf1
):
catalog: Catalog = session.catalog

procedure = catalog.get_procedure(
temp_procedure1, [IntegerType()], database=temp_db1, schema=temp_schema1
)
assert procedure.name == temp_procedure1
assert procedure.database_name == temp_db1
assert procedure.schema_name == temp_schema1

udf = catalog.get_user_defined_function(
temp_udf1, [IntegerType()], database=temp_db1, schema=temp_schema1
)
assert udf.name == temp_udf1
assert udf.database_name == temp_db1
assert udf.schema_name == temp_schema1


def test_set_db_schema(session, temp_db1, temp_db2, temp_schema1, temp_schema2):
catalog = session.catalog

Expand All @@ -407,112 +338,6 @@ def test_set_db_schema(session, temp_db1, temp_db2, temp_schema1, temp_schema2):
session.use_schema(original_schema)


def test_exists_db_schema(session, temp_db1, temp_schema1):
catalog = session.catalog
assert catalog.database_exists(temp_db1)
assert not catalog.database_exists("does_not_exist")

assert catalog.schema_exists(temp_schema1, database=temp_db1)
assert not catalog.schema_exists(temp_schema1, database="does_not_exist")


def test_exists_table_view(session, temp_db1, temp_schema1, temp_table1, temp_view1):
catalog = session.catalog
db1_obj = catalog._root.databases[temp_db1].fetch()
schema1_obj = catalog._root.databases[temp_db1].schemas[temp_schema1].fetch()

assert catalog.table_exists(temp_table1, database=temp_db1, schema=temp_schema1)
assert catalog.table_exists(temp_table1, database=db1_obj, schema=schema1_obj)
table = catalog.get_table(temp_table1, database=temp_db1, schema=temp_schema1)
assert catalog.table_exists(table)
assert not catalog.table_exists(
"does_not_exist", database=temp_db1, schema=temp_schema1
)

assert catalog.view_exists(temp_view1, database=temp_db1, schema=temp_schema1)
assert catalog.view_exists(temp_view1, database=db1_obj, schema=schema1_obj)
view = catalog.get_view(temp_view1, database=temp_db1, schema=temp_schema1)
assert catalog.view_exists(view)
assert not catalog.view_exists(
"does_not_exist", database=temp_db1, schema=temp_schema1
)


@pytest.mark.udf
def test_exists_function_procedure_udf(
session, temp_db1, temp_schema1, temp_procedure1, temp_udf1
):
catalog = session.catalog
db1_obj = catalog._root.databases[temp_db1].fetch()
schema1_obj = catalog._root.databases[temp_db1].schemas[temp_schema1].fetch()

assert catalog.procedure_exists(
temp_procedure1, [IntegerType()], database=temp_db1, schema=temp_schema1
)
assert catalog.procedure_exists(
temp_procedure1, [IntegerType()], database=db1_obj, schema=schema1_obj
)
proc = catalog.get_procedure(
temp_procedure1, [IntegerType()], database=temp_db1, schema=temp_schema1
)
assert catalog.procedure_exists(proc)
assert not catalog.procedure_exists(
"does_not_exist", [], database=temp_db1, schema=temp_schema1
)

assert catalog.user_defined_function_exists(
temp_udf1, [IntegerType()], database=temp_db1, schema=temp_schema1
)
assert catalog.user_defined_function_exists(
temp_udf1, [IntegerType()], database=db1_obj, schema=schema1_obj
)
udf = catalog.get_user_defined_function(
temp_udf1, [IntegerType()], database=temp_db1, schema=temp_schema1
)
assert catalog.user_defined_function_exists(udf)
assert not catalog.user_defined_function_exists(
"does_not_exist", [], database=temp_db1, schema=temp_schema1
)


@pytest.mark.parametrize("use_object", [True, False])
def test_drop(session, use_object):
catalog = session.catalog

original_db = session.get_current_database()
original_schema = session.get_current_schema()
try:
temp_db = create_temp_db(session)
temp_schema = create_temp_schema(session, temp_db)
temp_table = create_temp_table(session, temp_db, temp_schema)
temp_view = create_temp_view(session, temp_db, temp_schema)
if use_object:
temp_schema = catalog._root.databases[temp_db].schemas[temp_schema].fetch()
temp_db = catalog._root.databases[temp_db].fetch()

assert catalog.database_exists(temp_db)
assert catalog.schema_exists(temp_schema, database=temp_db)
assert catalog.table_exists(temp_table, database=temp_db, schema=temp_schema)
assert catalog.view_exists(temp_view, database=temp_db, schema=temp_schema)

catalog.drop_table(temp_table, database=temp_db, schema=temp_schema)
catalog.drop_view(temp_view, database=temp_db, schema=temp_schema)

assert not catalog.table_exists(
temp_table, database=temp_db, schema=temp_schema
)
assert not catalog.view_exists(temp_view, database=temp_db, schema=temp_schema)

catalog.drop_schema(temp_schema, database=temp_db)
assert not catalog.schema_exists(temp_schema, database=temp_db)

catalog.drop_database(temp_db)
assert not catalog.database_exists(temp_db)
finally:
session.use_database(original_db)
session.use_schema(original_schema)


def test_parse_names_negative(session):
catalog = session.catalog
with pytest.raises(
Expand Down
Loading
Loading