From 3467b7b4b2a6a9f274e8ca413859c0c0c3a7916e Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Tue, 21 Apr 2026 10:57:47 -0700 Subject: [PATCH 01/17] reapply change --- CHANGELOG.md | 8 + src/snowflake/snowpark/catalog.py | 229 +++++++++++++++++++-------- src/snowflake/snowpark/exceptions.py | 6 + tests/integ/test_catalog.py | 17 +- 4 files changed, 185 insertions(+), 75 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 77092b7da1..e3e7f005d5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,13 @@ # Release History +## 1.51.0 (TBD) + +### Snowpark Python API Updates + +#### Improvements + +- Catalog API now uses SQL commands instead of SnowAPI calls to improve stability. + ## 1.50.0 (TBD) ### Snowpark Python API Updates diff --git a/src/snowflake/snowpark/catalog.py b/src/snowflake/snowpark/catalog.py index 4572809ea2..730ae184a7 100644 --- a/src/snowflake/snowpark/catalog.py +++ b/src/snowflake/snowpark/catalog.py @@ -2,15 +2,24 @@ # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. # +from ctypes import ArgumentError import re -from typing import List, Optional, Union +from typing import ( + List, + Optional, + Union, + TYPE_CHECKING, +) + +from snowflake.snowpark._internal.analyzer.analyzer_utils import unquote_if_quoted +from snowflake.snowpark.exceptions import SnowparkSQLException, NotFoundError try: - from snowflake.core import Root # type: ignore from snowflake.core.database import Database # type: ignore - from snowflake.core.exceptions import NotFoundError + from snowflake.core.database._generated.models import Database as ModelDatabase # type: ignore from snowflake.core.procedure import Procedure from snowflake.core.schema import Schema # type: ignore + from snowflake.core.schema._generated.models import Schema as ModelSchema # type: ignore from snowflake.core.table import Table, TableColumn from snowflake.core.user_defined_function import UserDefinedFunction from snowflake.core.view import View @@ -19,27 +28,28 @@ "Missing optional dependency: 'snowflake.core'." ) from e # pragma: no cover - -import snowflake.snowpark -from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type +from snowflake.snowpark._internal.type_utils import ( + convert_sp_to_sf_type, + type_string_to_type_object, +) from snowflake.snowpark.functions import lit, parse_json from snowflake.snowpark.types import DataType +if TYPE_CHECKING: + from snowflake.snowpark.session import Session -class Catalog: """The Catalog class provides methods to interact with and manage the Snowflake objects. It allows users to list, get, and drop various database objects such as databases, schemas, tables, views, functions, etc. """ - def __init__(self, session: "snowflake.snowpark.session.Session") -> None: # type: ignore + def __init__(self, session: "Session") -> None: self._session = session - self._root = Root(session) self._python_regex_udf = None def _parse_database( self, - database: Optional[Union[str, Database]], + database: object, model_obj: Optional[ Union[str, Schema, Table, View, Procedure, UserDefinedFunction] ] = None, @@ -66,7 +76,7 @@ def _parse_database( def _parse_schema( self, - schema: Optional[Union[str, Schema]], + schema: object, model_obj: Optional[ Union[str, Table, View, Procedure, UserDefinedFunction] ] = None, @@ -166,11 +176,28 @@ def list_databases( pattern: the python regex pattern of name to match. Defaults to None. like: the sql style pattern for name to match. Default to None. """ - iter = self._root.databases.iter(like=like) + like_str = f"LIKE '{like}'" if like else "" + df = self._session.sql(f"SHOW AS RESOURCE DATABASES {like_str}") if pattern: - iter = filter(lambda x: re.match(pattern, x.name), iter) + # initialize udf + self._initialize_regex_udf() + assert self._python_regex_udf is not None # pyright - return list(iter) + # The result of SHOW AS RESOURCE query is a json string which contains + # key 'name' to store the name of the object. We parse json for the returned + # result and apply the filter on name. + df = df.filter( + self._python_regex_udf( + lit(pattern), parse_json('"As Resource"')["name"] + ) + ) + + return list( + map( + lambda row: Database._from_model(ModelDatabase.from_json(str(row[0]))), + df.collect(), + ) + ) def list_schemas( self, @@ -188,10 +215,28 @@ def list_schemas( like: the sql style pattern for name to match. Default to None. """ db_name = self._parse_database(database) - iter = self._root.databases[db_name].schemas.iter(like=like) + like_str = f"LIKE '{like}'" if like else "" + df = self._session.sql(f"SHOW AS RESOURCE SCHEMAS {like_str} IN {db_name}") if pattern: - iter = filter(lambda x: re.match(pattern, x.name), iter) - return list(iter) + # initialize udf + self._initialize_regex_udf() + assert self._python_regex_udf is not None # pyright + + # The result of SHOW AS RESOURCE query is a json string which contains + # key 'name' to store the name of the object. We parse json for the returned + # result and apply the filter on name. + df = df.filter( + self._python_regex_udf( + lit(pattern), parse_json('"As Resource"')["name"] + ) + ) + + return list( + map( + lambda row: Schema._from_model(ModelSchema.from_json(str(row[0]))), + df.collect(), + ) + ) def list_tables( self, @@ -329,14 +374,27 @@ def get_current_schema(self) -> Optional[str]: def get_database(self, database: str) -> Database: """Name of the database to get""" - return self._root.databases[database].fetch() + try: + return self.list_databases(like=unquote_if_quoted(database))[0] + except IndexError: + raise NotFoundError(f"Database with name {database} could not be found") def get_schema( self, schema: str, *, database: Optional[Union[str, Database]] = None ) -> Schema: """Name of the schema to get.""" db_name = self._parse_database(database) - return self._root.databases[db_name].schemas[schema].fetch() + try: + return self.list_schemas(database=db_name, like=unquote_if_quoted(schema))[ + 0 + ] + except ( + IndexError, # schema with this name doesn't exist + SnowparkSQLException, # database in which we are looking doesn't exist + ): + raise NotFoundError( + f"Schema with name {schema} could not be found in database '{db_name}'" + ) def get_table( self, @@ -355,12 +413,16 @@ def get_table( """ db_name = self._parse_database(database) schema_name = self._parse_schema(schema) - return ( - self._root.databases[db_name] - .schemas[schema_name] - .tables[table_name] - .fetch() - ) + try: + return self.listTables( + database=db_name, + schema=schema_name, + like=unquote_if_quoted(table_name), + )[0] + except IndexError: + raise NotFoundError( + f"Table with name {table_name} could not be found in schema '{db_name}.{schema_name}'" + ) def get_view( self, @@ -379,9 +441,16 @@ def get_view( """ db_name = self._parse_database(database) schema_name = self._parse_schema(schema) - return ( - self._root.databases[db_name].schemas[schema_name].views[view_name].fetch() - ) + try: + return self.list_views( + database=db_name, + schema=schema_name, + like=unquote_if_quoted(view_name), + )[0] + except IndexError: + raise NotFoundError( + f"View with name {view_name} could not be found in schema '{db_name}.{schema_name}'" + ) def get_procedure( self, @@ -403,12 +472,19 @@ def get_procedure( db_name = self._parse_database(database) schema_name = self._parse_schema(schema) procedure_id = self._parse_function_or_procedure(procedure_name, arg_types) - return ( - self._root.databases[db_name] - .schemas[schema_name] - .procedures[procedure_id] - .fetch() - ) + + try: + procedures = self._session.sql( + f"DESCRIBE AS RESOURCE PROCEDURE {db_name}.{schema_name}.{procedure_id}" + ).collect() + return Procedure.from_json(str(procedures[0][0])) + except ( + IndexError, # when sql returned no results + SnowparkSQLException, # when database, or schema doesn't exist + ): + raise NotFoundError( + f"Procedure with name {procedure_name} and arguments {arg_types} could not be found in schema '{db_name}.{schema_name}'" + ) def get_user_defined_function( self, @@ -431,12 +507,19 @@ def get_user_defined_function( db_name = self._parse_database(database) schema_name = self._parse_schema(schema) function_id = self._parse_function_or_procedure(udf_name, arg_types) - return ( - self._root.databases[db_name] - .schemas[schema_name] - .user_defined_functions[function_id] - .fetch() - ) + + try: + procedures = self._session.sql( + f"DESCRIBE AS RESOURCE FUNCTION {db_name}.{schema_name}.{function_id}" + ).collect() + return UserDefinedFunction.from_json(str(procedures[0][0])) + except ( + IndexError, # when sql returned no results + SnowparkSQLException, # when database, or schema doesn't exist + ): + raise NotFoundError( + f"Function with name {udf_name} and arguments {arg_types} could not be found in schema '{db_name}.{schema_name}'" + ) # set methods def set_current_database(self, database: Union[str, Database]) -> None: @@ -466,7 +549,7 @@ def database_exists(self, database: Union[str, Database]) -> bool: """ db_name = self._parse_database(database) try: - self._root.databases[db_name].fetch() + self.get_database(db_name) return True except NotFoundError: return False @@ -487,7 +570,7 @@ def schema_exists( db_name = self._parse_database(database, schema) schema_name = self._parse_schema(schema) try: - self._root.databases[db_name].schemas[schema_name].fetch() + self.get_schema(schema=schema_name, database=db_name) return True except NotFoundError: return False @@ -511,9 +594,7 @@ def table_exists( schema_name = self._parse_schema(schema, table) table_name = table if isinstance(table, str) else table.name try: - self._root.databases[db_name].schemas[schema_name].tables[ - table_name - ].fetch() + self.get_table(table_name=table_name, database=db_name, schema=schema_name) return True except NotFoundError: return False @@ -537,7 +618,7 @@ def view_exists( schema_name = self._parse_schema(schema, view) view_name = view if isinstance(view, str) else view.name try: - self._root.databases[db_name].schemas[schema_name].views[view_name].fetch() + self.get_view(view_name=view_name, database=db_name, schema=schema_name) return True except NotFoundError: return False @@ -559,14 +640,24 @@ def procedure_exists( database: database name or ``Database`` object. Defaults to None. schema: schema name or ``Schema`` object. Defaults to None. """ - db_name = self._parse_database(database, procedure) - schema_name = self._parse_schema(schema, procedure) - procedure_id = self._parse_function_or_procedure(procedure, arg_types) - try: - self._root.databases[db_name].schemas[schema_name].procedures[ - procedure_id - ].fetch() + if isinstance(procedure, Procedure): + if arg_types is not None or database is not None or schema is not None: + raise ArgumentError( + "When provided procedure is a Procedure class no other arguments can be provided" + ) + database = procedure.database_name + schema = procedure.schema_name + arg_types = [ + type_string_to_type_object(a.datatype) for a in procedure.arguments + ] + procedure = procedure.name + self.get_procedure( + procedure_name=procedure, + arg_types=arg_types, + database=database, + schema=schema, + ) return True except NotFoundError: return False @@ -590,14 +681,24 @@ def user_defined_function_exists( database: database name or ``Database`` object. Defaults to None. schema: schema name or ``Schema`` object. Defaults to None. """ - db_name = self._parse_database(database, udf) - schema_name = self._parse_schema(schema, udf) - function_id = self._parse_function_or_procedure(udf, arg_types) - try: - self._root.databases[db_name].schemas[schema_name].user_defined_functions[ - function_id - ].fetch() + if isinstance(udf, UserDefinedFunction): + if arg_types is not None or database is not None or schema is not None: + raise ArgumentError( + "When provided udf is a UserDefinedFunction class no other arguments can be provided" + ) + database = udf.database_name + schema = udf.schema_name + arg_types = [ + type_string_to_type_object(a.datatype) for a in udf.arguments + ] + udf = udf.name + self.get_user_defined_function( + udf_name=udf, + arg_types=arg_types, + database=database, + schema=schema, + ) return True except NotFoundError: return False @@ -610,7 +711,7 @@ def drop_database(self, database: Union[str, Database]) -> None: database: database name or ``Database`` object. """ db_name = self._parse_database(database) - self._root.databases[db_name].drop() + self._session.sql(f"DROP DATABASE {db_name}").collect() def drop_schema( self, @@ -627,7 +728,7 @@ def drop_schema( """ db_name = self._parse_database(database, schema) schema_name = self._parse_schema(schema) - self._root.databases[db_name].schemas[schema_name].drop() + self._session.sql(f"DROP SCHEMA {db_name}.{schema_name}").collect() def drop_table( self, @@ -648,7 +749,7 @@ def drop_table( schema_name = self._parse_schema(schema, table) table_name = table if isinstance(table, str) else table.name - self._root.databases[db_name].schemas[schema_name].tables[table_name].drop() + self._session.sql(f"DROP TABLE {db_name}.{schema_name}.{table_name}").collect() def drop_view( self, @@ -669,7 +770,7 @@ def drop_view( schema_name = self._parse_schema(schema, view) view_name = view if isinstance(view, str) else view.name - self._root.databases[db_name].schemas[schema_name].views[view_name].drop() + self._session.sql(f"DROP VIEW {db_name}.{schema_name}.{view_name}").collect() # aliases listDatabases = list_databases diff --git a/src/snowflake/snowpark/exceptions.py b/src/snowflake/snowpark/exceptions.py index 1142e9545e..d31fe178a6 100644 --- a/src/snowflake/snowpark/exceptions.py +++ b/src/snowflake/snowpark/exceptions.py @@ -283,3 +283,9 @@ class SnowparkInvalidObjectNameException(SnowparkGeneralException): """ pass + + +class NotFoundError(SnowparkClientException): + """Raised when we encounter an object is not found.""" + + pass diff --git a/tests/integ/test_catalog.py b/tests/integ/test_catalog.py index 11643e4005..e8bd173e21 100644 --- a/tests/integ/test_catalog.py +++ b/tests/integ/test_catalog.py @@ -10,7 +10,6 @@ from snowflake.snowpark.catalog import Catalog from snowflake.snowpark.session import Session from snowflake.snowpark.types import IntegerType -from snowflake.core.exceptions import APIError pytestmark = [ @@ -19,10 +18,6 @@ reason="deepcopy is not supported and required by local testing", run=False, ), - pytest.mark.xfail( - raises=APIError, - reason="Failure due to warehouse overload", - ), ] CATALOG_TEMP_OBJECT_PREFIX = "SP_CATALOG_TEMP" @@ -412,8 +407,8 @@ def test_exists_db_schema(session, temp_db1, temp_schema1): def test_exists_table_view(session, temp_db1, temp_schema1, temp_table1, temp_view1): catalog = session.catalog - db1_obj = catalog._root.databases[temp_db1].fetch() - schema1_obj = catalog._root.databases[temp_db1].schemas[temp_schema1].fetch() + db1_obj = catalog.get_database(temp_db1) + schema1_obj = catalog.get_schema(database=temp_db1, schema=temp_schema1) assert catalog.table_exists(temp_table1, database=temp_db1, schema=temp_schema1) assert catalog.table_exists(temp_table1, database=db1_obj, schema=schema1_obj) @@ -437,8 +432,8 @@ def test_exists_function_procedure_udf( session, temp_db1, temp_schema1, temp_procedure1, temp_udf1 ): catalog = session.catalog - db1_obj = catalog._root.databases[temp_db1].fetch() - schema1_obj = catalog._root.databases[temp_db1].schemas[temp_schema1].fetch() + db1_obj = catalog.get_database(temp_db1) + schema1_obj = catalog.get_schema(temp_schema1, database=temp_db1) assert catalog.procedure_exists( temp_procedure1, [IntegerType()], database=temp_db1, schema=temp_schema1 @@ -481,8 +476,8 @@ def test_drop(session, use_object): temp_table = create_temp_table(session, temp_db, temp_schema) temp_view = create_temp_view(session, temp_db, temp_schema) if use_object: - temp_schema = catalog._root.databases[temp_db].schemas[temp_schema].fetch() - temp_db = catalog._root.databases[temp_db].fetch() + temp_schema = catalog.get_schema(temp_schema, database=temp_db) + temp_db = catalog.get_database(temp_db) assert catalog.database_exists(temp_db) assert catalog.schema_exists(temp_schema, database=temp_db) From 143f30a08cbab3086281fef69bb7d35b239be353 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Tue, 21 Apr 2026 11:09:41 -0700 Subject: [PATCH 02/17] fix test --- src/snowflake/snowpark/catalog.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/snowflake/snowpark/catalog.py b/src/snowflake/snowpark/catalog.py index 730ae184a7..515091731b 100644 --- a/src/snowflake/snowpark/catalog.py +++ b/src/snowflake/snowpark/catalog.py @@ -38,6 +38,8 @@ if TYPE_CHECKING: from snowflake.snowpark.session import Session + +class Catalog: """The Catalog class provides methods to interact with and manage the Snowflake objects. It allows users to list, get, and drop various database objects such as databases, schemas, tables, views, functions, etc. From e45239a8a1448bc2b347aa7ab8545f4250b75699 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Wed, 22 Apr 2026 15:39:39 -0700 Subject: [PATCH 03/17] avoid regress --- src/snowflake/snowpark/catalog.py | 33 ++++++++++++++++++------------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/src/snowflake/snowpark/catalog.py b/src/snowflake/snowpark/catalog.py index 515091731b..8333d85ae4 100644 --- a/src/snowflake/snowpark/catalog.py +++ b/src/snowflake/snowpark/catalog.py @@ -404,27 +404,32 @@ def get_table( *, database: Optional[Union[str, Database]] = None, schema: Optional[Union[str, Schema]] = None, - ) -> Table: - """Get the table by name in given database and schema. If database or schema are not - provided, get the table in the current database and schema. + ) -> Union[Table, View]: + """Get the table or permanent view by name in the given database and schema. + + If database or schema are not provided, resolve the name in the current database + and schema. Matches :meth:`pyspark.sql.Catalog.getTable`, which returns metadata + for base tables and for views. Args: - table_name: name of the table. + table_name: name of the table or view. database: database name or ``Database`` object. Defaults to None. schema: schema name or ``Schema`` object. Defaults to None. """ db_name = self._parse_database(database) schema_name = self._parse_schema(schema) - try: - return self.listTables( - database=db_name, - schema=schema_name, - like=unquote_if_quoted(table_name), - )[0] - except IndexError: - raise NotFoundError( - f"Table with name {table_name} could not be found in schema '{db_name}.{schema_name}'" - ) + like_arg = unquote_if_quoted(table_name) + tables = self.listTables(database=db_name, schema=schema_name, like=like_arg) + views: List[View] = [] + if not tables: + views = self.list_views(database=db_name, schema=schema_name, like=like_arg) + if tables: + return tables[0] + if views: + return views[0] + raise NotFoundError( + f"Table with name {table_name} could not be found in schema '{db_name}.{schema_name}'" + ) def get_view( self, From 0692a09991cfab6b2902ee8dbc5a5848424b3dd1 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Mon, 27 Apr 2026 12:00:48 -0700 Subject: [PATCH 04/17] dual mode of catalog --- src/snowflake/snowpark/catalog.py | 894 ++++++++++++++++++++------ tests/integ/conftest.py | 2 + tests/integ/test_catalog.py | 349 +--------- tests/integ/test_catalog_rest_mode.py | 253 ++++++++ tests/integ/test_catalog_sql_mode.py | 243 +++++++ 5 files changed, 1203 insertions(+), 538 deletions(-) create mode 100644 tests/integ/test_catalog_rest_mode.py create mode 100644 tests/integ/test_catalog_sql_mode.py diff --git a/src/snowflake/snowpark/catalog.py b/src/snowflake/snowpark/catalog.py index 8333d85ae4..0a2af60e58 100644 --- a/src/snowflake/snowpark/catalog.py +++ b/src/snowflake/snowpark/catalog.py @@ -2,6 +2,7 @@ # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. # +from abc import ABC, abstractmethod from ctypes import ArgumentError import re from typing import ( @@ -11,12 +12,15 @@ TYPE_CHECKING, ) +from snowflake.snowpark import context from snowflake.snowpark._internal.analyzer.analyzer_utils import unquote_if_quoted from snowflake.snowpark.exceptions import SnowparkSQLException, NotFoundError try: + from snowflake.core import Root # type: ignore from snowflake.core.database import Database # type: ignore from snowflake.core.database._generated.models import Database as ModelDatabase # type: ignore + from snowflake.core.exceptions import NotFoundError as CoreNotFoundError # type: ignore from snowflake.core.procedure import Procedure from snowflake.core.schema import Schema # type: ignore from snowflake.core.schema._generated.models import Schema as ModelSchema # type: ignore @@ -39,6 +43,662 @@ from snowflake.snowpark.session import Session +class _CatalogBackend(ABC): + """Internal catalog implementation selected by ``context._is_snowpark_connect_compatible_mode``.""" + + def __init__(self, catalog: "Catalog") -> None: + self._catalog = catalog + + @abstractmethod + def list_databases( + self, + *, + pattern: Optional[str] = None, + like: Optional[str] = None, + ) -> List[Database]: + pass + + @abstractmethod + def list_schemas( + self, + *, + database: Optional[Union[str, Database]] = None, + pattern: Optional[str] = None, + like: Optional[str] = None, + ) -> List[Schema]: + pass + + @abstractmethod + def get_database(self, database: str) -> Database: + pass + + @abstractmethod + def get_schema( + self, schema: str, *, database: Optional[Union[str, Database]] = None + ) -> Schema: + pass + + @abstractmethod + def get_table( + self, + table_name: str, + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> Union[Table, View]: + pass + + @abstractmethod + def get_view( + self, + view_name: str, + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> View: + pass + + @abstractmethod + def get_procedure( + self, + procedure_name: str, + arg_types: List[DataType], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> Procedure: + pass + + @abstractmethod + def get_user_defined_function( + self, + udf_name: str, + arg_types: List[DataType], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> UserDefinedFunction: + pass + + @abstractmethod + def database_exists(self, database: Union[str, Database]) -> bool: + pass + + @abstractmethod + def schema_exists( + self, + schema: Union[str, Schema], + *, + database: Optional[Union[str, Database]] = None, + ) -> bool: + pass + + @abstractmethod + def table_exists( + self, + table: Union[str, Table], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> bool: + pass + + @abstractmethod + def view_exists( + self, + view: Union[str, View], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> bool: + pass + + @abstractmethod + def procedure_exists( + self, + procedure: Union[str, Procedure], + arg_types: Optional[List[DataType]] = None, + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> bool: + pass + + @abstractmethod + def user_defined_function_exists( + self, + udf: Union[str, UserDefinedFunction], + arg_types: Optional[List[DataType]] = None, + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> bool: + pass + + +class _SqlCatalogBackend(_CatalogBackend): + def list_databases( + self, + *, + pattern: Optional[str] = None, + like: Optional[str] = None, + ) -> List[Database]: + c = self._catalog + like_str = f"LIKE '{like}'" if like else "" + df = c._session.sql(f"SHOW AS RESOURCE DATABASES {like_str}") + if pattern: + c._initialize_regex_udf() + assert c._python_regex_udf is not None # pyright + df = df.filter( + c._python_regex_udf(lit(pattern), parse_json('"As Resource"')["name"]) + ) + + return list( + map( + lambda row: Database._from_model(ModelDatabase.from_json(str(row[0]))), + df.collect(), + ) + ) + + def list_schemas( + self, + *, + database: Optional[Union[str, Database]] = None, + pattern: Optional[str] = None, + like: Optional[str] = None, + ) -> List[Schema]: + c = self._catalog + db_name = c._parse_database(database) + like_str = f"LIKE '{like}'" if like else "" + df = c._session.sql(f"SHOW AS RESOURCE SCHEMAS {like_str} IN {db_name}") + if pattern: + c._initialize_regex_udf() + assert c._python_regex_udf is not None # pyright + df = df.filter( + c._python_regex_udf(lit(pattern), parse_json('"As Resource"')["name"]) + ) + + return list( + map( + lambda row: Schema._from_model(ModelSchema.from_json(str(row[0]))), + df.collect(), + ) + ) + + def get_database(self, database: str) -> Database: + try: + return self.list_databases(like=unquote_if_quoted(database))[0] + except IndexError: + raise NotFoundError(f"Database with name {database} could not be found") + + def get_schema( + self, schema: str, *, database: Optional[Union[str, Database]] = None + ) -> Schema: + c = self._catalog + db_name = c._parse_database(database) + try: + return self.list_schemas(database=db_name, like=unquote_if_quoted(schema))[ + 0 + ] + except ( + IndexError, + SnowparkSQLException, + ): + raise NotFoundError( + f"Schema with name {schema} could not be found in database '{db_name}'" + ) + + def get_table( + self, + table_name: str, + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> Union[Table, View]: + c = self._catalog + db_name = c._parse_database(database) + schema_name = c._parse_schema(schema) + like_arg = unquote_if_quoted(table_name) + tables = c.list_tables(database=db_name, schema=schema_name, like=like_arg) + views: List[View] = [] + if not tables: + views = c.list_views(database=db_name, schema=schema_name, like=like_arg) + if tables: + return tables[0] + if views: + return views[0] + raise NotFoundError( + f"Table with name {table_name} could not be found in schema '{db_name}.{schema_name}'" + ) + + def get_view( + self, + view_name: str, + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> View: + c = self._catalog + db_name = c._parse_database(database) + schema_name = c._parse_schema(schema) + try: + return c.list_views( + database=db_name, + schema=schema_name, + like=unquote_if_quoted(view_name), + )[0] + except IndexError: + raise NotFoundError( + f"View with name {view_name} could not be found in schema '{db_name}.{schema_name}'" + ) + + def get_procedure( + self, + procedure_name: str, + arg_types: List[DataType], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> Procedure: + c = self._catalog + db_name = c._parse_database(database) + schema_name = c._parse_schema(schema) + procedure_id = c._parse_function_or_procedure(procedure_name, arg_types) + + try: + procedures = c._session.sql( + f"DESCRIBE AS RESOURCE PROCEDURE {db_name}.{schema_name}.{procedure_id}" + ).collect() + return Procedure.from_json(str(procedures[0][0])) + except ( + IndexError, + SnowparkSQLException, + ): + raise NotFoundError( + f"Procedure with name {procedure_name} and arguments {arg_types} could not be found in schema '{db_name}.{schema_name}'" + ) + + def get_user_defined_function( + self, + udf_name: str, + arg_types: List[DataType], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> UserDefinedFunction: + c = self._catalog + db_name = c._parse_database(database) + schema_name = c._parse_schema(schema) + function_id = c._parse_function_or_procedure(udf_name, arg_types) + + try: + rows = c._session.sql( + f"DESCRIBE AS RESOURCE FUNCTION {db_name}.{schema_name}.{function_id}" + ).collect() + return UserDefinedFunction.from_json(str(rows[0][0])) + except ( + IndexError, + SnowparkSQLException, + ): + raise NotFoundError( + f"Function with name {udf_name} and arguments {arg_types} could not be found in schema '{db_name}.{schema_name}'" + ) + + def database_exists(self, database: Union[str, Database]) -> bool: + c = self._catalog + db_name = c._parse_database(database) + try: + self.get_database(db_name) + return True + except NotFoundError: + return False + + def schema_exists( + self, + schema: Union[str, Schema], + *, + database: Optional[Union[str, Database]] = None, + ) -> bool: + c = self._catalog + db_name = c._parse_database(database, schema) + schema_name = c._parse_schema(schema) + try: + self.get_schema(schema=schema_name, database=db_name) + return True + except NotFoundError: + return False + + def table_exists( + self, + table: Union[str, Table], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> bool: + c = self._catalog + db_name = c._parse_database(database, table) + schema_name = c._parse_schema(schema, table) + table_name = table if isinstance(table, str) else table.name + try: + self.get_table(table_name=table_name, database=db_name, schema=schema_name) + return True + except NotFoundError: + return False + + def view_exists( + self, + view: Union[str, View], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> bool: + c = self._catalog + db_name = c._parse_database(database, view) + schema_name = c._parse_schema(schema, view) + view_name = view if isinstance(view, str) else view.name + try: + self.get_view(view_name=view_name, database=db_name, schema=schema_name) + return True + except NotFoundError: + return False + + def procedure_exists( + self, + procedure: Union[str, Procedure], + arg_types: Optional[List[DataType]] = None, + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> bool: + try: + if isinstance(procedure, Procedure): + if arg_types is not None or database is not None or schema is not None: + raise ArgumentError( + "When provided procedure is a Procedure class no other arguments can be provided" + ) + database = procedure.database_name + schema = procedure.schema_name + arg_types = [ + type_string_to_type_object(a.datatype) for a in procedure.arguments + ] + procedure = procedure.name + self.get_procedure( + procedure_name=procedure, + arg_types=arg_types, + database=database, + schema=schema, + ) + return True + except NotFoundError: + return False + + def user_defined_function_exists( + self, + udf: Union[str, UserDefinedFunction], + arg_types: Optional[List[DataType]] = None, + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> bool: + try: + if isinstance(udf, UserDefinedFunction): + if arg_types is not None or database is not None or schema is not None: + raise ArgumentError( + "When provided udf is a UserDefinedFunction class no other arguments can be provided" + ) + database = udf.database_name + schema = udf.schema_name + arg_types = [ + type_string_to_type_object(a.datatype) for a in udf.arguments + ] + udf = udf.name + self.get_user_defined_function( + udf_name=udf, + arg_types=arg_types, + database=database, + schema=schema, + ) + return True + except NotFoundError: + return False + + +class _RestCatalogBackend(_CatalogBackend): + def __init__(self, catalog: "Catalog") -> None: + super().__init__(catalog) + self._root_obj: Optional[Root] = None + + @property + def _root(self) -> Root: + if self._root_obj is None: + self._root_obj = Root(self._catalog._session) + return self._root_obj + + def list_databases( + self, + *, + pattern: Optional[str] = None, + like: Optional[str] = None, + ) -> List[Database]: + it = self._root.databases.iter(like=like) + if pattern: + it = filter(lambda x: re.match(pattern, x.name), it) + return list(it) + + def list_schemas( + self, + *, + database: Optional[Union[str, Database]] = None, + pattern: Optional[str] = None, + like: Optional[str] = None, + ) -> List[Schema]: + db_name = self._catalog._parse_database(database) + it = self._root.databases[db_name].schemas.iter(like=like) + if pattern: + it = filter(lambda x: re.match(pattern, x.name), it) + return list(it) + + def get_database(self, database: str) -> Database: + return self._root.databases[database].fetch() + + def get_schema( + self, schema: str, *, database: Optional[Union[str, Database]] = None + ) -> Schema: + db_name = self._catalog._parse_database(database) + return self._root.databases[db_name].schemas[schema].fetch() + + def get_table( + self, + table_name: str, + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> Union[Table, View]: + c = self._catalog + db_name = c._parse_database(database) + schema_name = c._parse_schema(schema) + return ( + self._root.databases[db_name] + .schemas[schema_name] + .tables[table_name] + .fetch() + ) + + def get_view( + self, + view_name: str, + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> View: + c = self._catalog + db_name = c._parse_database(database) + schema_name = c._parse_schema(schema) + return ( + self._root.databases[db_name].schemas[schema_name].views[view_name].fetch() + ) + + def get_procedure( + self, + procedure_name: str, + arg_types: List[DataType], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> Procedure: + c = self._catalog + db_name = c._parse_database(database) + schema_name = c._parse_schema(schema) + procedure_id = c._parse_function_or_procedure(procedure_name, arg_types) + return ( + self._root.databases[db_name] + .schemas[schema_name] + .procedures[procedure_id] + .fetch() + ) + + def get_user_defined_function( + self, + udf_name: str, + arg_types: List[DataType], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> UserDefinedFunction: + c = self._catalog + db_name = c._parse_database(database) + schema_name = c._parse_schema(schema) + function_id = c._parse_function_or_procedure(udf_name, arg_types) + return ( + self._root.databases[db_name] + .schemas[schema_name] + .user_defined_functions[function_id] + .fetch() + ) + + def database_exists(self, database: Union[str, Database]) -> bool: + c = self._catalog + db_name = c._parse_database(database) + try: + self._root.databases[db_name].fetch() + return True + except CoreNotFoundError: + return False + + def schema_exists( + self, + schema: Union[str, Schema], + *, + database: Optional[Union[str, Database]] = None, + ) -> bool: + c = self._catalog + db_name = c._parse_database(database, schema) + schema_name = c._parse_schema(schema) + try: + self._root.databases[db_name].schemas[schema_name].fetch() + return True + except CoreNotFoundError: + return False + + def table_exists( + self, + table: Union[str, Table], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> bool: + c = self._catalog + db_name = c._parse_database(database, table) + schema_name = c._parse_schema(schema, table) + table_name = table if isinstance(table, str) else table.name + try: + self._root.databases[db_name].schemas[schema_name].tables[ + table_name + ].fetch() + return True + except CoreNotFoundError: + return False + + def view_exists( + self, + view: Union[str, View], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> bool: + c = self._catalog + db_name = c._parse_database(database, view) + schema_name = c._parse_schema(schema, view) + view_name = view if isinstance(view, str) else view.name + try: + self._root.databases[db_name].schemas[schema_name].views[view_name].fetch() + return True + except CoreNotFoundError: + return False + + def procedure_exists( + self, + procedure: Union[str, Procedure], + arg_types: Optional[List[DataType]] = None, + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> bool: + c = self._catalog + try: + if isinstance(procedure, Procedure): + if arg_types is not None or database is not None or schema is not None: + raise ArgumentError( + "When provided procedure is a Procedure class no other arguments can be provided" + ) + database = procedure.database_name + schema = procedure.schema_name + arg_types = [ + type_string_to_type_object(a.datatype) for a in procedure.arguments + ] + procedure = procedure.name + db_name = c._parse_database(database, procedure) + schema_name = c._parse_schema(schema, procedure) + procedure_id = c._parse_function_or_procedure(procedure, arg_types) + self._root.databases[db_name].schemas[schema_name].procedures[ + procedure_id + ].fetch() + return True + except CoreNotFoundError: + return False + + def user_defined_function_exists( + self, + udf: Union[str, UserDefinedFunction], + arg_types: Optional[List[DataType]] = None, + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> bool: + c = self._catalog + try: + if isinstance(udf, UserDefinedFunction): + if arg_types is not None or database is not None or schema is not None: + raise ArgumentError( + "When provided udf is a UserDefinedFunction class no other arguments can be provided" + ) + database = udf.database_name + schema = udf.schema_name + arg_types = [ + type_string_to_type_object(a.datatype) for a in udf.arguments + ] + udf = udf.name + db_name = c._parse_database(database, udf) + schema_name = c._parse_schema(schema, udf) + function_id = c._parse_function_or_procedure(udf, arg_types) + self._root.databases[db_name].schemas[schema_name].user_defined_functions[ + function_id + ].fetch() + return True + except CoreNotFoundError: + return False + + class Catalog: """The Catalog class provides methods to interact with and manage the Snowflake objects. It allows users to list, get, and drop various database objects such as databases, schemas, tables, @@ -48,6 +708,15 @@ class Catalog: def __init__(self, session: "Session") -> None: self._session = session self._python_regex_udf = None + self._sql_backend = _SqlCatalogBackend(self) + self._rest_backend: Optional[_RestCatalogBackend] = None + + def _backend(self) -> _CatalogBackend: + if context._is_snowpark_connect_compatible_mode: + return self._sql_backend + if self._rest_backend is None: + self._rest_backend = _RestCatalogBackend(self) + return self._rest_backend def _parse_database( self, @@ -150,13 +819,9 @@ def _list_objects( f"SHOW AS RESOURCE {object_name} {like_str} IN {db_name}.{schema_name} -- catalog api" ) if pattern: - # initialize udf self._initialize_regex_udf() assert self._python_regex_udf is not None # pyright - # The result of SHOW AS RESOURCE query is a json string which contains - # key 'name' to store the name of the object. We parse json for the returned - # result and apply the filter on name. df = df.filter( self._python_regex_udf( lit(pattern), parse_json('"As Resource"')["name"] @@ -165,7 +830,6 @@ def _list_objects( return list(map(lambda row: object_class.from_json(row[0]), df.collect())) - # List methods def list_databases( self, *, @@ -178,28 +842,7 @@ def list_databases( pattern: the python regex pattern of name to match. Defaults to None. like: the sql style pattern for name to match. Default to None. """ - like_str = f"LIKE '{like}'" if like else "" - df = self._session.sql(f"SHOW AS RESOURCE DATABASES {like_str}") - if pattern: - # initialize udf - self._initialize_regex_udf() - assert self._python_regex_udf is not None # pyright - - # The result of SHOW AS RESOURCE query is a json string which contains - # key 'name' to store the name of the object. We parse json for the returned - # result and apply the filter on name. - df = df.filter( - self._python_regex_udf( - lit(pattern), parse_json('"As Resource"')["name"] - ) - ) - - return list( - map( - lambda row: Database._from_model(ModelDatabase.from_json(str(row[0]))), - df.collect(), - ) - ) + return self._backend().list_databases(pattern=pattern, like=like) def list_schemas( self, @@ -216,28 +859,8 @@ def list_schemas( pattern: the python regex pattern of name to match. Defaults to None. like: the sql style pattern for name to match. Default to None. """ - db_name = self._parse_database(database) - like_str = f"LIKE '{like}'" if like else "" - df = self._session.sql(f"SHOW AS RESOURCE SCHEMAS {like_str} IN {db_name}") - if pattern: - # initialize udf - self._initialize_regex_udf() - assert self._python_regex_udf is not None # pyright - - # The result of SHOW AS RESOURCE query is a json string which contains - # key 'name' to store the name of the object. We parse json for the returned - # result and apply the filter on name. - df = df.filter( - self._python_regex_udf( - lit(pattern), parse_json('"As Resource"')["name"] - ) - ) - - return list( - map( - lambda row: Schema._from_model(ModelSchema.from_json(str(row[0]))), - df.collect(), - ) + return self._backend().list_schemas( + database=database, pattern=pattern, like=like ) def list_tables( @@ -365,7 +988,6 @@ def list_user_defined_functions( like=like, ) - # get methods def get_current_database(self) -> Optional[str]: """Get the current database.""" return self._session.get_current_database() @@ -376,27 +998,13 @@ def get_current_schema(self) -> Optional[str]: def get_database(self, database: str) -> Database: """Name of the database to get""" - try: - return self.list_databases(like=unquote_if_quoted(database))[0] - except IndexError: - raise NotFoundError(f"Database with name {database} could not be found") + return self._backend().get_database(database) def get_schema( self, schema: str, *, database: Optional[Union[str, Database]] = None ) -> Schema: """Name of the schema to get.""" - db_name = self._parse_database(database) - try: - return self.list_schemas(database=db_name, like=unquote_if_quoted(schema))[ - 0 - ] - except ( - IndexError, # schema with this name doesn't exist - SnowparkSQLException, # database in which we are looking doesn't exist - ): - raise NotFoundError( - f"Schema with name {schema} could not be found in database '{db_name}'" - ) + return self._backend().get_schema(schema, database=database) def get_table( self, @@ -411,25 +1019,15 @@ def get_table( and schema. Matches :meth:`pyspark.sql.Catalog.getTable`, which returns metadata for base tables and for views. + When ``context._is_snowpark_connect_compatible_mode`` is False (legacy REST path), + only base tables are returned; use :meth:`get_view` for views. + Args: table_name: name of the table or view. database: database name or ``Database`` object. Defaults to None. schema: schema name or ``Schema`` object. Defaults to None. """ - db_name = self._parse_database(database) - schema_name = self._parse_schema(schema) - like_arg = unquote_if_quoted(table_name) - tables = self.listTables(database=db_name, schema=schema_name, like=like_arg) - views: List[View] = [] - if not tables: - views = self.list_views(database=db_name, schema=schema_name, like=like_arg) - if tables: - return tables[0] - if views: - return views[0] - raise NotFoundError( - f"Table with name {table_name} could not be found in schema '{db_name}.{schema_name}'" - ) + return self._backend().get_table(table_name, database=database, schema=schema) def get_view( self, @@ -446,18 +1044,7 @@ def get_view( database: database name or ``Database`` object. Defaults to None. schema: schema name or ``Schema`` object. Defaults to None. """ - db_name = self._parse_database(database) - schema_name = self._parse_schema(schema) - try: - return self.list_views( - database=db_name, - schema=schema_name, - like=unquote_if_quoted(view_name), - )[0] - except IndexError: - raise NotFoundError( - f"View with name {view_name} could not be found in schema '{db_name}.{schema_name}'" - ) + return self._backend().get_view(view_name, database=database, schema=schema) def get_procedure( self, @@ -476,22 +1063,9 @@ def get_procedure( database: database name or ``Database`` object. Defaults to None. schema: schema name or ``Schema`` object. Defaults to None. """ - db_name = self._parse_database(database) - schema_name = self._parse_schema(schema) - procedure_id = self._parse_function_or_procedure(procedure_name, arg_types) - - try: - procedures = self._session.sql( - f"DESCRIBE AS RESOURCE PROCEDURE {db_name}.{schema_name}.{procedure_id}" - ).collect() - return Procedure.from_json(str(procedures[0][0])) - except ( - IndexError, # when sql returned no results - SnowparkSQLException, # when database, or schema doesn't exist - ): - raise NotFoundError( - f"Procedure with name {procedure_name} and arguments {arg_types} could not be found in schema '{db_name}.{schema_name}'" - ) + return self._backend().get_procedure( + procedure_name, arg_types, database=database, schema=schema + ) def get_user_defined_function( self, @@ -511,24 +1085,10 @@ def get_user_defined_function( database: database name or ``Database`` object. Defaults to None. schema: schema name or ``Schema`` object. Defaults to None. """ - db_name = self._parse_database(database) - schema_name = self._parse_schema(schema) - function_id = self._parse_function_or_procedure(udf_name, arg_types) - - try: - procedures = self._session.sql( - f"DESCRIBE AS RESOURCE FUNCTION {db_name}.{schema_name}.{function_id}" - ).collect() - return UserDefinedFunction.from_json(str(procedures[0][0])) - except ( - IndexError, # when sql returned no results - SnowparkSQLException, # when database, or schema doesn't exist - ): - raise NotFoundError( - f"Function with name {udf_name} and arguments {arg_types} could not be found in schema '{db_name}.{schema_name}'" - ) + return self._backend().get_user_defined_function( + udf_name, arg_types, database=database, schema=schema + ) - # set methods def set_current_database(self, database: Union[str, Database]) -> None: """Set the current default database for the session. @@ -547,19 +1107,13 @@ def set_current_schema(self, schema: Union[str, Schema]) -> None: schema_name = self._parse_schema(schema) self._session.use_schema(schema_name) - # exists methods def database_exists(self, database: Union[str, Database]) -> bool: """Check if the given database exists. Args: database: database name or ``Database`` object. """ - db_name = self._parse_database(database) - try: - self.get_database(db_name) - return True - except NotFoundError: - return False + return self._backend().database_exists(database) def schema_exists( self, @@ -574,13 +1128,7 @@ def schema_exists( schema: schema name or ``Schema`` object. database: database name or ``Database`` object. Defaults to None. """ - db_name = self._parse_database(database, schema) - schema_name = self._parse_schema(schema) - try: - self.get_schema(schema=schema_name, database=db_name) - return True - except NotFoundError: - return False + return self._backend().schema_exists(schema, database=database) def table_exists( self, @@ -597,14 +1145,7 @@ def table_exists( database: database name or ``Database`` object. Defaults to None. schema: schema name or ``Schema`` object. Defaults to None. """ - db_name = self._parse_database(database, table) - schema_name = self._parse_schema(schema, table) - table_name = table if isinstance(table, str) else table.name - try: - self.get_table(table_name=table_name, database=db_name, schema=schema_name) - return True - except NotFoundError: - return False + return self._backend().table_exists(table, database=database, schema=schema) def view_exists( self, @@ -621,14 +1162,7 @@ def view_exists( database: database name or ``Database`` object. Defaults to None. schema: schema name or ``Schema`` object. Defaults to None. """ - db_name = self._parse_database(database, view) - schema_name = self._parse_schema(schema, view) - view_name = view if isinstance(view, str) else view.name - try: - self.get_view(view_name=view_name, database=db_name, schema=schema_name) - return True - except NotFoundError: - return False + return self._backend().view_exists(view, database=database, schema=schema) def procedure_exists( self, @@ -647,27 +1181,9 @@ def procedure_exists( database: database name or ``Database`` object. Defaults to None. schema: schema name or ``Schema`` object. Defaults to None. """ - try: - if isinstance(procedure, Procedure): - if arg_types is not None or database is not None or schema is not None: - raise ArgumentError( - "When provided procedure is a Procedure class no other arguments can be provided" - ) - database = procedure.database_name - schema = procedure.schema_name - arg_types = [ - type_string_to_type_object(a.datatype) for a in procedure.arguments - ] - procedure = procedure.name - self.get_procedure( - procedure_name=procedure, - arg_types=arg_types, - database=database, - schema=schema, - ) - return True - except NotFoundError: - return False + return self._backend().procedure_exists( + procedure, arg_types, database=database, schema=schema + ) def user_defined_function_exists( self, @@ -688,29 +1204,10 @@ def user_defined_function_exists( database: database name or ``Database`` object. Defaults to None. schema: schema name or ``Schema`` object. Defaults to None. """ - try: - if isinstance(udf, UserDefinedFunction): - if arg_types is not None or database is not None or schema is not None: - raise ArgumentError( - "When provided udf is a UserDefinedFunction class no other arguments can be provided" - ) - database = udf.database_name - schema = udf.schema_name - arg_types = [ - type_string_to_type_object(a.datatype) for a in udf.arguments - ] - udf = udf.name - self.get_user_defined_function( - udf_name=udf, - arg_types=arg_types, - database=database, - schema=schema, - ) - return True - except NotFoundError: - return False + return self._backend().user_defined_function_exists( + udf, arg_types, database=database, schema=schema + ) - # drop methods def drop_database(self, database: Union[str, Database]) -> None: """Drop the given database. @@ -779,7 +1276,6 @@ def drop_view( self._session.sql(f"DROP VIEW {db_name}.{schema_name}.{view_name}").collect() - # aliases listDatabases = list_databases listSchemas = list_schemas listTables = list_tables diff --git a/tests/integ/conftest.py b/tests/integ/conftest.py index fc1835e923..cbed543fbe 100644 --- a/tests/integ/conftest.py +++ b/tests/integ/conftest.py @@ -30,6 +30,8 @@ RUNNING_ON_GH = os.getenv("GITHUB_ACTIONS") == "true" RUNNING_ON_JENKINS = "JENKINS_HOME" in os.environ +pytest_plugins = ("tests.integ.catalog_integ_common",) + test_dir = os.path.dirname(__file__) test_data_dir = os.path.join(test_dir, "cassettes") diff --git a/tests/integ/test_catalog.py b/tests/integ/test_catalog.py index e8bd173e21..fa8940bd38 100644 --- a/tests/integ/test_catalog.py +++ b/tests/integ/test_catalog.py @@ -1,16 +1,21 @@ # # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. # +"""Mode-agnostic catalog integration tests. + +Only tests whose call paths are identical between the SQL-based and REST-based +catalog backends live here. Backend-specific behavior is covered in +``test_catalog_sql_mode.py`` and ``test_catalog_rest_mode.py``. +""" from unittest.mock import patch -import uuid import pytest -from snowflake.snowpark._internal.analyzer.analyzer_utils import unquote_if_quoted from snowflake.snowpark.catalog import Catalog -from snowflake.snowpark.session import Session -from snowflake.snowpark.types import IntegerType - +from tests.integ.catalog_integ_common import ( + CATALOG_TEMP_OBJECT_PREFIX, + DOES_NOT_EXIST_PATTERN, +) pytestmark = [ pytest.mark.xfail( @@ -20,192 +25,6 @@ ), ] -CATALOG_TEMP_OBJECT_PREFIX = "SP_CATALOG_TEMP" - - -def get_temp_name(type: str) -> str: - return f"{CATALOG_TEMP_OBJECT_PREFIX}_{type}_{uuid.uuid4().hex[:6]}".upper() - - -def create_temp_db(session) -> str: - original_db = session.get_current_database() - temp_db = get_temp_name("DB") - session._run_query(f"create or replace database {temp_db}") - session.use_database(original_db) - return temp_db - - -@pytest.fixture(scope="module") -def temp_db1(session): - temp_db = create_temp_db(session) - yield temp_db - session._run_query(f"drop database if exists {temp_db}") - - -@pytest.fixture(scope="module") -def temp_db2(session): - temp_db = create_temp_db(session) - yield temp_db - session._run_query(f"drop database if exists {temp_db}") - - -def create_temp_schema(session, db: str) -> str: - original_db = session.get_current_database() - original_schema = session.get_current_schema() - temp_schema = get_temp_name("SCHEMA") - session._run_query(f"create or replace schema {db}.{temp_schema}") - - session.use_database(original_db) - session.use_schema(original_schema) - return temp_schema - - -@pytest.fixture(scope="module") -def temp_schema1(session, temp_db1): - temp_schema = create_temp_schema(session, temp_db1) - yield temp_schema - session._run_query(f"drop schema if exists {temp_db1}.{temp_schema}") - - -@pytest.fixture(scope="module") -def temp_schema2(session, temp_db1): - temp_schema = create_temp_schema(session, temp_db1) - yield temp_schema - session._run_query(f"drop schema if exists {temp_db1}.{temp_schema}") - - -def create_temp_table(session, db: str, schema: str) -> str: - temp_table = get_temp_name("TABLE") - session._run_query( - f"create or replace temp table {db}.{schema}.{temp_table} (a int, b string)" - ) - return temp_table - - -@pytest.fixture(scope="module") -def temp_table1(session, temp_db1, temp_schema1): - temp_table = create_temp_table(session, temp_db1, temp_schema1) - yield temp_table - session._run_query(f"drop table if exists {temp_db1}.{temp_schema1}.{temp_table}") - - -@pytest.fixture(scope="module") -def temp_table2(session, temp_db1, temp_schema1): - temp_table = create_temp_table(session, temp_db1, temp_schema1) - yield temp_table - session._run_query(f"drop table if exists {temp_db1}.{temp_schema1}.{temp_table}") - - -def create_temp_view(session, db: str, schema: str) -> str: - temp_schema = get_temp_name("VIEW") - session._run_query( - f"create or replace temp view {db}.{schema}.{temp_schema} as select 1 as a, '2' as b" - ) - return temp_schema - - -@pytest.fixture(scope="module") -def temp_view1(session, temp_db1, temp_schema1): - temp_view = create_temp_view(session, temp_db1, temp_schema1) - yield temp_view - session._run_query(f"drop view if exists {temp_db1}.{temp_schema1}.{temp_view}") - - -@pytest.fixture(scope="module") -def temp_view2(session, temp_db1, temp_schema1): - temp_view = create_temp_view(session, temp_db1, temp_schema1) - yield temp_view - session._run_query(f"drop view if exists {temp_db1}.{temp_schema1}.{temp_view}") - - -def create_temp_procedure(session: Session, db, schema) -> str: - temp_procedure = get_temp_name("PROCEDURE") - session.sproc.register( - lambda _, x: x + 1, - return_type=IntegerType(), - input_types=[IntegerType()], - name=f"{db}.{schema}.{temp_procedure}", - packages=["snowflake-snowpark-python"], - ) - return temp_procedure - - -@pytest.fixture(scope="module") -def temp_procedure1(session, temp_db1, temp_schema1): - temp_procedure = create_temp_procedure(session, temp_db1, temp_schema1) - yield temp_procedure - session._run_query( - f"drop procedure if exists {temp_db1}.{temp_schema1}.{temp_procedure}(int)" - ) - - -@pytest.fixture(scope="module") -def temp_procedure2(session, temp_db1, temp_schema1): - temp_procedure = create_temp_procedure(session, temp_db1, temp_schema1) - yield temp_procedure - session._run_query( - f"drop procedure if exists {temp_db1}.{temp_schema1}.{temp_procedure}(int)" - ) - - -def create_temp_udf(session: Session, db, schema) -> str: - temp_udf = get_temp_name("UDF") - session.udf.register( - lambda x: x + 1, - return_type=IntegerType(), - input_types=[IntegerType()], - name=f"{db}.{schema}.{temp_udf}", - ) - return temp_udf - - -@pytest.fixture(scope="module") -def temp_udf1(session, temp_db1, temp_schema1): - temp_udf = create_temp_udf(session, temp_db1, temp_schema1) - yield temp_udf - session._run_query( - f"drop function if exists {temp_db1}.{temp_schema1}.{temp_udf}(int)" - ) - - -@pytest.fixture(scope="module") -def temp_udf2(session, temp_db1, temp_schema1): - temp_udf = create_temp_udf(session, temp_db1, temp_schema1) - yield temp_udf - session._run_query( - f"drop function if exists {temp_db1}.{temp_schema1}.{temp_udf}(int)" - ) - - -DOES_NOT_EXIST_PATTERN = "does_not_exist_.*" - - -def test_list_db(session, temp_db1, temp_db2): - catalog: Catalog = session.catalog - db_list = catalog.list_databases(pattern=f"{CATALOG_TEMP_OBJECT_PREFIX}_DB_*") - assert {db.name for db in db_list} >= {temp_db1, temp_db2} - - db_list = catalog.list_databases(like=f"{CATALOG_TEMP_OBJECT_PREFIX}_DB_%") - assert {db.name for db in db_list} >= {temp_db1, temp_db2} - - -def test_list_schema(session, temp_db1, temp_schema1, temp_schema2): - catalog: Catalog = session.catalog - assert ( - len(catalog.list_databases(pattern=f"{CATALOG_TEMP_OBJECT_PREFIX}_SCHEMA_.*")) - == 0 - ) - - schema_list = catalog.list_schemas( - pattern=f"{CATALOG_TEMP_OBJECT_PREFIX}_SCHEMA_.*", database=temp_db1 - ) - assert {schema.name for schema in schema_list} >= {temp_schema1, temp_schema2} - - schema_list = catalog.list_schemas( - like=f"{CATALOG_TEMP_OBJECT_PREFIX}_SCHEMA_%", database=temp_db1 - ) - assert {schema.name for schema in schema_list} >= {temp_schema1, temp_schema2} - def test_list_tables(session, temp_db1, temp_schema1, temp_table1, temp_table2): catalog: Catalog = session.catalog @@ -333,48 +152,6 @@ def test_list_udfs(session, temp_db1, temp_schema1, temp_udf1, temp_udf2): assert {udf.name for udf in udf_list} >= {temp_udf1, temp_udf2} -def test_get_db_schema(session): - catalog: Catalog = session.catalog - current_db = session.get_current_database() - current_schema = session.get_current_schema() - assert catalog.get_database(current_db).name == unquote_if_quoted(current_db) - assert catalog.get_schema(current_schema).name == unquote_if_quoted(current_schema) - - -def test_get_table_view(session, temp_db1, temp_schema1, temp_table1, temp_view1): - catalog: Catalog = session.catalog - table = catalog.get_table(temp_table1, database=temp_db1, schema=temp_schema1) - assert table.name == temp_table1 - assert table.database_name == temp_db1 - assert table.schema_name == temp_schema1 - - view = catalog.get_view(temp_view1, database=temp_db1, schema=temp_schema1) - assert view.name == temp_view1 - assert view.database_name == temp_db1 - assert view.schema_name == temp_schema1 - - -@pytest.mark.udf -def test_get_function_procedure_udf( - session, temp_db1, temp_schema1, temp_procedure1, temp_udf1 -): - catalog: Catalog = session.catalog - - procedure = catalog.get_procedure( - temp_procedure1, [IntegerType()], database=temp_db1, schema=temp_schema1 - ) - assert procedure.name == temp_procedure1 - assert procedure.database_name == temp_db1 - assert procedure.schema_name == temp_schema1 - - udf = catalog.get_user_defined_function( - temp_udf1, [IntegerType()], database=temp_db1, schema=temp_schema1 - ) - assert udf.name == temp_udf1 - assert udf.database_name == temp_db1 - assert udf.schema_name == temp_schema1 - - def test_set_db_schema(session, temp_db1, temp_db2, temp_schema1, temp_schema2): catalog = session.catalog @@ -396,112 +173,6 @@ def test_set_db_schema(session, temp_db1, temp_db2, temp_schema1, temp_schema2): session.use_schema(original_schema) -def test_exists_db_schema(session, temp_db1, temp_schema1): - catalog = session.catalog - assert catalog.database_exists(temp_db1) - assert not catalog.database_exists("does_not_exist") - - assert catalog.schema_exists(temp_schema1, database=temp_db1) - assert not catalog.schema_exists(temp_schema1, database="does_not_exist") - - -def test_exists_table_view(session, temp_db1, temp_schema1, temp_table1, temp_view1): - catalog = session.catalog - db1_obj = catalog.get_database(temp_db1) - schema1_obj = catalog.get_schema(database=temp_db1, schema=temp_schema1) - - assert catalog.table_exists(temp_table1, database=temp_db1, schema=temp_schema1) - assert catalog.table_exists(temp_table1, database=db1_obj, schema=schema1_obj) - table = catalog.get_table(temp_table1, database=temp_db1, schema=temp_schema1) - assert catalog.table_exists(table) - assert not catalog.table_exists( - "does_not_exist", database=temp_db1, schema=temp_schema1 - ) - - assert catalog.view_exists(temp_view1, database=temp_db1, schema=temp_schema1) - assert catalog.view_exists(temp_view1, database=db1_obj, schema=schema1_obj) - view = catalog.get_view(temp_view1, database=temp_db1, schema=temp_schema1) - assert catalog.view_exists(view) - assert not catalog.view_exists( - "does_not_exist", database=temp_db1, schema=temp_schema1 - ) - - -@pytest.mark.udf -def test_exists_function_procedure_udf( - session, temp_db1, temp_schema1, temp_procedure1, temp_udf1 -): - catalog = session.catalog - db1_obj = catalog.get_database(temp_db1) - schema1_obj = catalog.get_schema(temp_schema1, database=temp_db1) - - assert catalog.procedure_exists( - temp_procedure1, [IntegerType()], database=temp_db1, schema=temp_schema1 - ) - assert catalog.procedure_exists( - temp_procedure1, [IntegerType()], database=db1_obj, schema=schema1_obj - ) - proc = catalog.get_procedure( - temp_procedure1, [IntegerType()], database=temp_db1, schema=temp_schema1 - ) - assert catalog.procedure_exists(proc) - assert not catalog.procedure_exists( - "does_not_exist", [], database=temp_db1, schema=temp_schema1 - ) - - assert catalog.user_defined_function_exists( - temp_udf1, [IntegerType()], database=temp_db1, schema=temp_schema1 - ) - assert catalog.user_defined_function_exists( - temp_udf1, [IntegerType()], database=db1_obj, schema=schema1_obj - ) - udf = catalog.get_user_defined_function( - temp_udf1, [IntegerType()], database=temp_db1, schema=temp_schema1 - ) - assert catalog.user_defined_function_exists(udf) - assert not catalog.user_defined_function_exists( - "does_not_exist", [], database=temp_db1, schema=temp_schema1 - ) - - -@pytest.mark.parametrize("use_object", [True, False]) -def test_drop(session, use_object): - catalog = session.catalog - - original_db = session.get_current_database() - original_schema = session.get_current_schema() - try: - temp_db = create_temp_db(session) - temp_schema = create_temp_schema(session, temp_db) - temp_table = create_temp_table(session, temp_db, temp_schema) - temp_view = create_temp_view(session, temp_db, temp_schema) - if use_object: - temp_schema = catalog.get_schema(temp_schema, database=temp_db) - temp_db = catalog.get_database(temp_db) - - assert catalog.database_exists(temp_db) - assert catalog.schema_exists(temp_schema, database=temp_db) - assert catalog.table_exists(temp_table, database=temp_db, schema=temp_schema) - assert catalog.view_exists(temp_view, database=temp_db, schema=temp_schema) - - catalog.drop_table(temp_table, database=temp_db, schema=temp_schema) - catalog.drop_view(temp_view, database=temp_db, schema=temp_schema) - - assert not catalog.table_exists( - temp_table, database=temp_db, schema=temp_schema - ) - assert not catalog.view_exists(temp_view, database=temp_db, schema=temp_schema) - - catalog.drop_schema(temp_schema, database=temp_db) - assert not catalog.schema_exists(temp_schema, database=temp_db) - - catalog.drop_database(temp_db) - assert not catalog.database_exists(temp_db) - finally: - session.use_database(original_db) - session.use_schema(original_schema) - - def test_parse_names_negative(session): catalog = session.catalog with pytest.raises( diff --git a/tests/integ/test_catalog_rest_mode.py b/tests/integ/test_catalog_rest_mode.py new file mode 100644 index 0000000000..89de818eab --- /dev/null +++ b/tests/integ/test_catalog_rest_mode.py @@ -0,0 +1,253 @@ +# +# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. +# +"""Catalog integration tests with ``context._is_snowpark_connect_compatible_mode`` False (REST / Root backend). + +Keep this file separate from ``test_catalog_sql_mode.py`` so removing one backend path +deletes only the matching test module. +""" + +import pytest + +from snowflake.core.exceptions import NotFoundError as CoreNotFoundError +from snowflake.snowpark import context +from snowflake.snowpark._internal.analyzer.analyzer_utils import unquote_if_quoted +from snowflake.snowpark.catalog import Catalog +from snowflake.snowpark.types import IntegerType +from tests.integ.catalog_integ_common import ( + CATALOG_TEMP_OBJECT_PREFIX, + create_temp_db, + create_temp_schema, + create_temp_table, + create_temp_view, +) + +pytestmark = [ + pytest.mark.xfail( + "config.getoption('local_testing_mode', default=False)", + reason="deepcopy is not supported and required by local testing", + run=False, + ), +] + + +@pytest.fixture(autouse=True) +def _catalog_rest_backend_mode(monkeypatch): + monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", False) + + +def test_list_db_rest_mode(session, temp_db1, temp_db2): + catalog: Catalog = session.catalog + db_list = catalog.list_databases(pattern=f"{CATALOG_TEMP_OBJECT_PREFIX}_DB_*") + assert {db.name for db in db_list} >= {temp_db1, temp_db2} + + db_list = catalog.list_databases(like=f"{CATALOG_TEMP_OBJECT_PREFIX}_DB_%") + assert {db.name for db in db_list} >= {temp_db1, temp_db2} + + +def test_list_schema_rest_mode(session, temp_db1, temp_schema1, temp_schema2): + catalog: Catalog = session.catalog + assert ( + len(catalog.list_databases(pattern=f"{CATALOG_TEMP_OBJECT_PREFIX}_SCHEMA_.*")) + == 0 + ) + + schema_list = catalog.list_schemas( + pattern=f"{CATALOG_TEMP_OBJECT_PREFIX}_SCHEMA_.*", database=temp_db1 + ) + assert {schema.name for schema in schema_list} >= {temp_schema1, temp_schema2} + + schema_list = catalog.list_schemas( + like=f"{CATALOG_TEMP_OBJECT_PREFIX}_SCHEMA_%", database=temp_db1 + ) + assert {schema.name for schema in schema_list} >= {temp_schema1, temp_schema2} + + +def test_get_db_schema_rest_mode(session): + catalog: Catalog = session.catalog + current_db = session.get_current_database() + current_schema = session.get_current_schema() + assert catalog.get_database(current_db).name == unquote_if_quoted(current_db) + assert catalog.get_schema(current_schema).name == unquote_if_quoted(current_schema) + + +def test_get_database_missing_raises_core_not_found_rest_mode(session): + catalog: Catalog = session.catalog + with pytest.raises(CoreNotFoundError): + catalog.get_database("NONEXISTENT_DB_XYZ_12345") + + +def test_get_table_does_not_resolve_view_rest_mode( + session, temp_db1, temp_schema1, temp_view1 +): + catalog: Catalog = session.catalog + with pytest.raises(CoreNotFoundError): + catalog.get_table(temp_view1, database=temp_db1, schema=temp_schema1) + + +def test_get_view_rest_mode(session, temp_db1, temp_schema1, temp_view1): + catalog: Catalog = session.catalog + view = catalog.get_view(temp_view1, database=temp_db1, schema=temp_schema1) + assert view.name == temp_view1 + assert view.database_name == temp_db1 + assert view.schema_name == temp_schema1 + + +def test_table_exists_false_for_view_name_rest_mode( + session, temp_db1, temp_schema1, temp_view1 +): + catalog: Catalog = session.catalog + assert not catalog.table_exists(temp_view1, database=temp_db1, schema=temp_schema1) + + +@pytest.mark.udf +def test_get_procedure_rest_mode(session, temp_db1, temp_schema1, temp_procedure1): + catalog: Catalog = session.catalog + procedure = catalog.get_procedure( + temp_procedure1, [IntegerType()], database=temp_db1, schema=temp_schema1 + ) + assert procedure.name == temp_procedure1 + assert procedure.database_name == temp_db1 + assert procedure.schema_name == temp_schema1 + + +@pytest.mark.udf +def test_get_user_defined_function_rest_mode( + session, temp_db1, temp_schema1, temp_udf1 +): + catalog: Catalog = session.catalog + udf = catalog.get_user_defined_function( + temp_udf1, [IntegerType()], database=temp_db1, schema=temp_schema1 + ) + assert udf.name == temp_udf1 + assert udf.database_name == temp_db1 + assert udf.schema_name == temp_schema1 + + +def test_database_exists_rest_mode(session, temp_db1): + catalog: Catalog = session.catalog + assert catalog.database_exists(temp_db1) + assert not catalog.database_exists("does_not_exist") + + +def test_get_table_view_rest_mode( + session, temp_db1, temp_schema1, temp_table1, temp_view1 +): + catalog: Catalog = session.catalog + table = catalog.get_table(temp_table1, database=temp_db1, schema=temp_schema1) + assert table.name == temp_table1 + assert table.database_name == temp_db1 + assert table.schema_name == temp_schema1 + + view = catalog.get_view(temp_view1, database=temp_db1, schema=temp_schema1) + assert view.name == temp_view1 + assert view.database_name == temp_db1 + assert view.schema_name == temp_schema1 + + +def test_exists_db_schema_rest_mode(session, temp_db1, temp_schema1): + catalog = session.catalog + assert catalog.database_exists(temp_db1) + assert not catalog.database_exists("does_not_exist") + + assert catalog.schema_exists(temp_schema1, database=temp_db1) + assert not catalog.schema_exists(temp_schema1, database="does_not_exist") + + +def test_exists_table_view_rest_mode( + session, temp_db1, temp_schema1, temp_table1, temp_view1 +): + catalog = session.catalog + db1_obj = catalog.get_database(temp_db1) + schema1_obj = catalog.get_schema(database=temp_db1, schema=temp_schema1) + + assert catalog.table_exists(temp_table1, database=temp_db1, schema=temp_schema1) + assert catalog.table_exists(temp_table1, database=db1_obj, schema=schema1_obj) + table = catalog.get_table(temp_table1, database=temp_db1, schema=temp_schema1) + assert catalog.table_exists(table) + assert not catalog.table_exists( + "does_not_exist", database=temp_db1, schema=temp_schema1 + ) + + assert catalog.view_exists(temp_view1, database=temp_db1, schema=temp_schema1) + assert catalog.view_exists(temp_view1, database=db1_obj, schema=schema1_obj) + view = catalog.get_view(temp_view1, database=temp_db1, schema=temp_schema1) + assert catalog.view_exists(view) + assert not catalog.view_exists( + "does_not_exist", database=temp_db1, schema=temp_schema1 + ) + + +@pytest.mark.udf +def test_exists_function_procedure_udf_rest_mode( + session, temp_db1, temp_schema1, temp_procedure1, temp_udf1 +): + catalog = session.catalog + db1_obj = catalog.get_database(temp_db1) + schema1_obj = catalog.get_schema(temp_schema1, database=temp_db1) + + assert catalog.procedure_exists( + temp_procedure1, [IntegerType()], database=temp_db1, schema=temp_schema1 + ) + assert catalog.procedure_exists( + temp_procedure1, [IntegerType()], database=db1_obj, schema=schema1_obj + ) + proc = catalog.get_procedure( + temp_procedure1, [IntegerType()], database=temp_db1, schema=temp_schema1 + ) + assert catalog.procedure_exists(proc) + assert not catalog.procedure_exists( + "does_not_exist", [], database=temp_db1, schema=temp_schema1 + ) + + assert catalog.user_defined_function_exists( + temp_udf1, [IntegerType()], database=temp_db1, schema=temp_schema1 + ) + assert catalog.user_defined_function_exists( + temp_udf1, [IntegerType()], database=db1_obj, schema=schema1_obj + ) + udf = catalog.get_user_defined_function( + temp_udf1, [IntegerType()], database=temp_db1, schema=temp_schema1 + ) + assert catalog.user_defined_function_exists(udf) + assert not catalog.user_defined_function_exists( + "does_not_exist", [], database=temp_db1, schema=temp_schema1 + ) + + +@pytest.mark.parametrize("use_object", [True, False]) +def test_drop_rest_mode(session, use_object): + catalog = session.catalog + + original_db = session.get_current_database() + original_schema = session.get_current_schema() + try: + temp_db = create_temp_db(session) + temp_schema = create_temp_schema(session, temp_db) + temp_table = create_temp_table(session, temp_db, temp_schema) + temp_view = create_temp_view(session, temp_db, temp_schema) + if use_object: + temp_schema = catalog.get_schema(temp_schema, database=temp_db) + temp_db = catalog.get_database(temp_db) + + assert catalog.database_exists(temp_db) + assert catalog.schema_exists(temp_schema, database=temp_db) + assert catalog.table_exists(temp_table, database=temp_db, schema=temp_schema) + assert catalog.view_exists(temp_view, database=temp_db, schema=temp_schema) + + catalog.drop_table(temp_table, database=temp_db, schema=temp_schema) + catalog.drop_view(temp_view, database=temp_db, schema=temp_schema) + + assert not catalog.table_exists( + temp_table, database=temp_db, schema=temp_schema + ) + assert not catalog.view_exists(temp_view, database=temp_db, schema=temp_schema) + + catalog.drop_schema(temp_schema, database=temp_db) + assert not catalog.schema_exists(temp_schema, database=temp_db) + + catalog.drop_database(temp_db) + assert not catalog.database_exists(temp_db) + finally: + session.use_database(original_db) + session.use_schema(original_schema) diff --git a/tests/integ/test_catalog_sql_mode.py b/tests/integ/test_catalog_sql_mode.py new file mode 100644 index 0000000000..b7d166b211 --- /dev/null +++ b/tests/integ/test_catalog_sql_mode.py @@ -0,0 +1,243 @@ +# +# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. +# +"""Catalog integration tests with ``context._is_snowpark_connect_compatible_mode`` True (SQL backend). + +Keep this file separate from ``test_catalog_rest_mode.py`` so removing one backend path +deletes only the matching test module. +""" + +import pytest + +from snowflake.core.view import View +from snowflake.snowpark import context +from snowflake.snowpark._internal.analyzer.analyzer_utils import unquote_if_quoted +from snowflake.snowpark.catalog import Catalog +from snowflake.snowpark.exceptions import NotFoundError +from snowflake.snowpark.types import IntegerType +from tests.integ.catalog_integ_common import ( + CATALOG_TEMP_OBJECT_PREFIX, + create_temp_db, + create_temp_schema, + create_temp_table, + create_temp_view, +) + +pytestmark = [ + pytest.mark.xfail( + "config.getoption('local_testing_mode', default=False)", + reason="deepcopy is not supported and required by local testing", + run=False, + ), +] + + +@pytest.fixture(autouse=True) +def _catalog_sql_backend_mode(monkeypatch): + monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True) + + +def test_list_db_sql_mode(session, temp_db1, temp_db2): + catalog: Catalog = session.catalog + db_list = catalog.list_databases(pattern=f"{CATALOG_TEMP_OBJECT_PREFIX}_DB_*") + assert {db.name for db in db_list} >= {temp_db1, temp_db2} + + db_list = catalog.list_databases(like=f"{CATALOG_TEMP_OBJECT_PREFIX}_DB_%") + assert {db.name for db in db_list} >= {temp_db1, temp_db2} + + +def test_list_schema_sql_mode(session, temp_db1, temp_schema1, temp_schema2): + catalog: Catalog = session.catalog + assert ( + len(catalog.list_databases(pattern=f"{CATALOG_TEMP_OBJECT_PREFIX}_SCHEMA_.*")) + == 0 + ) + + schema_list = catalog.list_schemas( + pattern=f"{CATALOG_TEMP_OBJECT_PREFIX}_SCHEMA_.*", database=temp_db1 + ) + assert {schema.name for schema in schema_list} >= {temp_schema1, temp_schema2} + + schema_list = catalog.list_schemas( + like=f"{CATALOG_TEMP_OBJECT_PREFIX}_SCHEMA_%", database=temp_db1 + ) + assert {schema.name for schema in schema_list} >= {temp_schema1, temp_schema2} + + +def test_get_db_schema_sql_mode(session): + catalog: Catalog = session.catalog + current_db = session.get_current_database() + current_schema = session.get_current_schema() + assert catalog.get_database(current_db).name == unquote_if_quoted(current_db) + assert catalog.get_schema(current_schema).name == unquote_if_quoted(current_schema) + + +def test_get_database_missing_raises_snowpark_not_found_sql_mode(session): + catalog: Catalog = session.catalog + with pytest.raises(NotFoundError, match="could not be found"): + catalog.get_database("NONEXISTENT_DB_XYZ_12345") + + +def test_get_table_resolves_view_sql_mode(session, temp_db1, temp_schema1, temp_view1): + catalog: Catalog = session.catalog + obj = catalog.get_table(temp_view1, database=temp_db1, schema=temp_schema1) + assert isinstance(obj, View) + assert obj.name == temp_view1 + + +def test_table_exists_true_for_view_name_sql_mode( + session, temp_db1, temp_schema1, temp_view1 +): + catalog: Catalog = session.catalog + assert catalog.table_exists(temp_view1, database=temp_db1, schema=temp_schema1) + + +@pytest.mark.udf +def test_get_procedure_sql_mode(session, temp_db1, temp_schema1, temp_procedure1): + catalog: Catalog = session.catalog + procedure = catalog.get_procedure( + temp_procedure1, [IntegerType()], database=temp_db1, schema=temp_schema1 + ) + assert procedure.name == temp_procedure1 + assert procedure.database_name == temp_db1 + assert procedure.schema_name == temp_schema1 + + +@pytest.mark.udf +def test_get_user_defined_function_sql_mode(session, temp_db1, temp_schema1, temp_udf1): + catalog: Catalog = session.catalog + udf = catalog.get_user_defined_function( + temp_udf1, [IntegerType()], database=temp_db1, schema=temp_schema1 + ) + assert udf.name == temp_udf1 + assert udf.database_name == temp_db1 + assert udf.schema_name == temp_schema1 + + +def test_database_exists_sql_mode(session, temp_db1): + catalog: Catalog = session.catalog + assert catalog.database_exists(temp_db1) + assert not catalog.database_exists("does_not_exist") + + +def test_get_table_view_sql_mode( + session, temp_db1, temp_schema1, temp_table1, temp_view1 +): + catalog: Catalog = session.catalog + table = catalog.get_table(temp_table1, database=temp_db1, schema=temp_schema1) + assert table.name == temp_table1 + assert table.database_name == temp_db1 + assert table.schema_name == temp_schema1 + + view = catalog.get_view(temp_view1, database=temp_db1, schema=temp_schema1) + assert view.name == temp_view1 + assert view.database_name == temp_db1 + assert view.schema_name == temp_schema1 + + +def test_exists_db_schema_sql_mode(session, temp_db1, temp_schema1): + catalog = session.catalog + assert catalog.database_exists(temp_db1) + assert not catalog.database_exists("does_not_exist") + + assert catalog.schema_exists(temp_schema1, database=temp_db1) + assert not catalog.schema_exists(temp_schema1, database="does_not_exist") + + +def test_exists_table_view_sql_mode( + session, temp_db1, temp_schema1, temp_table1, temp_view1 +): + catalog = session.catalog + db1_obj = catalog.get_database(temp_db1) + schema1_obj = catalog.get_schema(database=temp_db1, schema=temp_schema1) + + assert catalog.table_exists(temp_table1, database=temp_db1, schema=temp_schema1) + assert catalog.table_exists(temp_table1, database=db1_obj, schema=schema1_obj) + table = catalog.get_table(temp_table1, database=temp_db1, schema=temp_schema1) + assert catalog.table_exists(table) + assert not catalog.table_exists( + "does_not_exist", database=temp_db1, schema=temp_schema1 + ) + + assert catalog.view_exists(temp_view1, database=temp_db1, schema=temp_schema1) + assert catalog.view_exists(temp_view1, database=db1_obj, schema=schema1_obj) + view = catalog.get_view(temp_view1, database=temp_db1, schema=temp_schema1) + assert catalog.view_exists(view) + assert not catalog.view_exists( + "does_not_exist", database=temp_db1, schema=temp_schema1 + ) + + +@pytest.mark.udf +def test_exists_function_procedure_udf_sql_mode( + session, temp_db1, temp_schema1, temp_procedure1, temp_udf1 +): + catalog = session.catalog + db1_obj = catalog.get_database(temp_db1) + schema1_obj = catalog.get_schema(temp_schema1, database=temp_db1) + + assert catalog.procedure_exists( + temp_procedure1, [IntegerType()], database=temp_db1, schema=temp_schema1 + ) + assert catalog.procedure_exists( + temp_procedure1, [IntegerType()], database=db1_obj, schema=schema1_obj + ) + proc = catalog.get_procedure( + temp_procedure1, [IntegerType()], database=temp_db1, schema=temp_schema1 + ) + assert catalog.procedure_exists(proc) + assert not catalog.procedure_exists( + "does_not_exist", [], database=temp_db1, schema=temp_schema1 + ) + + assert catalog.user_defined_function_exists( + temp_udf1, [IntegerType()], database=temp_db1, schema=temp_schema1 + ) + assert catalog.user_defined_function_exists( + temp_udf1, [IntegerType()], database=db1_obj, schema=schema1_obj + ) + udf = catalog.get_user_defined_function( + temp_udf1, [IntegerType()], database=temp_db1, schema=temp_schema1 + ) + assert catalog.user_defined_function_exists(udf) + assert not catalog.user_defined_function_exists( + "does_not_exist", [], database=temp_db1, schema=temp_schema1 + ) + + +@pytest.mark.parametrize("use_object", [True, False]) +def test_drop_sql_mode(session, use_object): + catalog = session.catalog + + original_db = session.get_current_database() + original_schema = session.get_current_schema() + try: + temp_db = create_temp_db(session) + temp_schema = create_temp_schema(session, temp_db) + temp_table = create_temp_table(session, temp_db, temp_schema) + temp_view = create_temp_view(session, temp_db, temp_schema) + if use_object: + temp_schema = catalog.get_schema(temp_schema, database=temp_db) + temp_db = catalog.get_database(temp_db) + + assert catalog.database_exists(temp_db) + assert catalog.schema_exists(temp_schema, database=temp_db) + assert catalog.table_exists(temp_table, database=temp_db, schema=temp_schema) + assert catalog.view_exists(temp_view, database=temp_db, schema=temp_schema) + + catalog.drop_table(temp_table, database=temp_db, schema=temp_schema) + catalog.drop_view(temp_view, database=temp_db, schema=temp_schema) + + assert not catalog.table_exists( + temp_table, database=temp_db, schema=temp_schema + ) + assert not catalog.view_exists(temp_view, database=temp_db, schema=temp_schema) + + catalog.drop_schema(temp_schema, database=temp_db) + assert not catalog.schema_exists(temp_schema, database=temp_db) + + catalog.drop_database(temp_db) + assert not catalog.database_exists(temp_db) + finally: + session.use_database(original_db) + session.use_schema(original_schema) From f6108350b0fffd05fd5aa7c3bac8ee335cab437b Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Mon, 27 Apr 2026 15:21:27 -0700 Subject: [PATCH 05/17] init back end at catalog init --- src/snowflake/snowpark/catalog.py | 49 ++++++++++----------------- tests/integ/test_catalog_rest_mode.py | 13 +++++-- tests/integ/test_catalog_sql_mode.py | 13 +++++-- 3 files changed, 38 insertions(+), 37 deletions(-) diff --git a/src/snowflake/snowpark/catalog.py b/src/snowflake/snowpark/catalog.py index 0a2af60e58..fedf1672ac 100644 --- a/src/snowflake/snowpark/catalog.py +++ b/src/snowflake/snowpark/catalog.py @@ -466,13 +466,7 @@ def user_defined_function_exists( class _RestCatalogBackend(_CatalogBackend): def __init__(self, catalog: "Catalog") -> None: super().__init__(catalog) - self._root_obj: Optional[Root] = None - - @property - def _root(self) -> Root: - if self._root_obj is None: - self._root_obj = Root(self._catalog._session) - return self._root_obj + self._root = Root(catalog._session) def list_databases( self, @@ -708,15 +702,10 @@ class Catalog: def __init__(self, session: "Session") -> None: self._session = session self._python_regex_udf = None - self._sql_backend = _SqlCatalogBackend(self) - self._rest_backend: Optional[_RestCatalogBackend] = None - - def _backend(self) -> _CatalogBackend: if context._is_snowpark_connect_compatible_mode: - return self._sql_backend - if self._rest_backend is None: - self._rest_backend = _RestCatalogBackend(self) - return self._rest_backend + self._backend: _CatalogBackend = _SqlCatalogBackend(self) + else: + self._backend = _RestCatalogBackend(self) def _parse_database( self, @@ -842,7 +831,7 @@ def list_databases( pattern: the python regex pattern of name to match. Defaults to None. like: the sql style pattern for name to match. Default to None. """ - return self._backend().list_databases(pattern=pattern, like=like) + return self._backend.list_databases(pattern=pattern, like=like) def list_schemas( self, @@ -859,9 +848,7 @@ def list_schemas( pattern: the python regex pattern of name to match. Defaults to None. like: the sql style pattern for name to match. Default to None. """ - return self._backend().list_schemas( - database=database, pattern=pattern, like=like - ) + return self._backend.list_schemas(database=database, pattern=pattern, like=like) def list_tables( self, @@ -998,13 +985,13 @@ def get_current_schema(self) -> Optional[str]: def get_database(self, database: str) -> Database: """Name of the database to get""" - return self._backend().get_database(database) + return self._backend.get_database(database) def get_schema( self, schema: str, *, database: Optional[Union[str, Database]] = None ) -> Schema: """Name of the schema to get.""" - return self._backend().get_schema(schema, database=database) + return self._backend.get_schema(schema, database=database) def get_table( self, @@ -1027,7 +1014,7 @@ def get_table( database: database name or ``Database`` object. Defaults to None. schema: schema name or ``Schema`` object. Defaults to None. """ - return self._backend().get_table(table_name, database=database, schema=schema) + return self._backend.get_table(table_name, database=database, schema=schema) def get_view( self, @@ -1044,7 +1031,7 @@ def get_view( database: database name or ``Database`` object. Defaults to None. schema: schema name or ``Schema`` object. Defaults to None. """ - return self._backend().get_view(view_name, database=database, schema=schema) + return self._backend.get_view(view_name, database=database, schema=schema) def get_procedure( self, @@ -1063,7 +1050,7 @@ def get_procedure( database: database name or ``Database`` object. Defaults to None. schema: schema name or ``Schema`` object. Defaults to None. """ - return self._backend().get_procedure( + return self._backend.get_procedure( procedure_name, arg_types, database=database, schema=schema ) @@ -1085,7 +1072,7 @@ def get_user_defined_function( database: database name or ``Database`` object. Defaults to None. schema: schema name or ``Schema`` object. Defaults to None. """ - return self._backend().get_user_defined_function( + return self._backend.get_user_defined_function( udf_name, arg_types, database=database, schema=schema ) @@ -1113,7 +1100,7 @@ def database_exists(self, database: Union[str, Database]) -> bool: Args: database: database name or ``Database`` object. """ - return self._backend().database_exists(database) + return self._backend.database_exists(database) def schema_exists( self, @@ -1128,7 +1115,7 @@ def schema_exists( schema: schema name or ``Schema`` object. database: database name or ``Database`` object. Defaults to None. """ - return self._backend().schema_exists(schema, database=database) + return self._backend.schema_exists(schema, database=database) def table_exists( self, @@ -1145,7 +1132,7 @@ def table_exists( database: database name or ``Database`` object. Defaults to None. schema: schema name or ``Schema`` object. Defaults to None. """ - return self._backend().table_exists(table, database=database, schema=schema) + return self._backend.table_exists(table, database=database, schema=schema) def view_exists( self, @@ -1162,7 +1149,7 @@ def view_exists( database: database name or ``Database`` object. Defaults to None. schema: schema name or ``Schema`` object. Defaults to None. """ - return self._backend().view_exists(view, database=database, schema=schema) + return self._backend.view_exists(view, database=database, schema=schema) def procedure_exists( self, @@ -1181,7 +1168,7 @@ def procedure_exists( database: database name or ``Database`` object. Defaults to None. schema: schema name or ``Schema`` object. Defaults to None. """ - return self._backend().procedure_exists( + return self._backend.procedure_exists( procedure, arg_types, database=database, schema=schema ) @@ -1204,7 +1191,7 @@ def user_defined_function_exists( database: database name or ``Database`` object. Defaults to None. schema: schema name or ``Schema`` object. Defaults to None. """ - return self._backend().user_defined_function_exists( + return self._backend.user_defined_function_exists( udf, arg_types, database=database, schema=schema ) diff --git a/tests/integ/test_catalog_rest_mode.py b/tests/integ/test_catalog_rest_mode.py index 89de818eab..be4bba4a39 100644 --- a/tests/integ/test_catalog_rest_mode.py +++ b/tests/integ/test_catalog_rest_mode.py @@ -31,9 +31,16 @@ ] -@pytest.fixture(autouse=True) -def _catalog_rest_backend_mode(monkeypatch): - monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", False) +@pytest.fixture(autouse=True, scope="module") +def _catalog_rest_backend_mode(session): + mp = pytest.MonkeyPatch() + mp.setattr(context, "_is_snowpark_connect_compatible_mode", False) + mp.setattr(session, "_catalog", None) + try: + yield + finally: + mp.undo() + session._catalog = None def test_list_db_rest_mode(session, temp_db1, temp_db2): diff --git a/tests/integ/test_catalog_sql_mode.py b/tests/integ/test_catalog_sql_mode.py index b7d166b211..ac97b435ce 100644 --- a/tests/integ/test_catalog_sql_mode.py +++ b/tests/integ/test_catalog_sql_mode.py @@ -32,9 +32,16 @@ ] -@pytest.fixture(autouse=True) -def _catalog_sql_backend_mode(monkeypatch): - monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True) +@pytest.fixture(autouse=True, scope="module") +def _catalog_sql_backend_mode(session): + mp = pytest.MonkeyPatch() + mp.setattr(context, "_is_snowpark_connect_compatible_mode", True) + mp.setattr(session, "_catalog", None) + try: + yield + finally: + mp.undo() + session._catalog = None def test_list_db_sql_mode(session, temp_db1, temp_db2): From 674d83171cada68968dd7aaaa8266030a53813d3 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Tue, 28 Apr 2026 10:06:31 -0700 Subject: [PATCH 06/17] push missed test fixture --- tests/integ/catalog_integ_common.py | 170 ++++++++++++++++++++++++++++ 1 file changed, 170 insertions(+) create mode 100644 tests/integ/catalog_integ_common.py diff --git a/tests/integ/catalog_integ_common.py b/tests/integ/catalog_integ_common.py new file mode 100644 index 0000000000..bb459ac8a6 --- /dev/null +++ b/tests/integ/catalog_integ_common.py @@ -0,0 +1,170 @@ +# +# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. +# +"""Shared pytest fixtures for catalog integration tests (see ``test_catalog*.py``).""" + +import uuid + +import pytest + +from snowflake.snowpark.session import Session +from snowflake.snowpark.types import IntegerType + +CATALOG_TEMP_OBJECT_PREFIX = "SP_CATALOG_TEMP" + + +def get_temp_name(type: str) -> str: + return f"{CATALOG_TEMP_OBJECT_PREFIX}_{type}_{uuid.uuid4().hex[:6]}".upper() + + +def create_temp_db(session) -> str: + original_db = session.get_current_database() + temp_db = get_temp_name("DB") + session._run_query(f"create or replace database {temp_db}") + session.use_database(original_db) + return temp_db + + +@pytest.fixture(scope="module") +def temp_db1(session): + temp_db = create_temp_db(session) + yield temp_db + session._run_query(f"drop database if exists {temp_db}") + + +@pytest.fixture(scope="module") +def temp_db2(session): + temp_db = create_temp_db(session) + yield temp_db + session._run_query(f"drop database if exists {temp_db}") + + +def create_temp_schema(session, db: str) -> str: + original_db = session.get_current_database() + original_schema = session.get_current_schema() + temp_schema = get_temp_name("SCHEMA") + session._run_query(f"create or replace schema {db}.{temp_schema}") + + session.use_database(original_db) + session.use_schema(original_schema) + return temp_schema + + +@pytest.fixture(scope="module") +def temp_schema1(session, temp_db1): + temp_schema = create_temp_schema(session, temp_db1) + yield temp_schema + session._run_query(f"drop schema if exists {temp_db1}.{temp_schema}") + + +@pytest.fixture(scope="module") +def temp_schema2(session, temp_db1): + temp_schema = create_temp_schema(session, temp_db1) + yield temp_schema + session._run_query(f"drop schema if exists {temp_db1}.{temp_schema}") + + +def create_temp_table(session, db: str, schema: str) -> str: + temp_table = get_temp_name("TABLE") + session._run_query( + f"create or replace temp table {db}.{schema}.{temp_table} (a int, b string)" + ) + return temp_table + + +@pytest.fixture(scope="module") +def temp_table1(session, temp_db1, temp_schema1): + temp_table = create_temp_table(session, temp_db1, temp_schema1) + yield temp_table + session._run_query(f"drop table if exists {temp_db1}.{temp_schema1}.{temp_table}") + + +@pytest.fixture(scope="module") +def temp_table2(session, temp_db1, temp_schema1): + temp_table = create_temp_table(session, temp_db1, temp_schema1) + yield temp_table + session._run_query(f"drop table if exists {temp_db1}.{temp_schema1}.{temp_table}") + + +def create_temp_view(session, db: str, schema: str) -> str: + temp_schema = get_temp_name("VIEW") + session._run_query( + f"create or replace temp view {db}.{schema}.{temp_schema} as select 1 as a, '2' as b" + ) + return temp_schema + + +@pytest.fixture(scope="module") +def temp_view1(session, temp_db1, temp_schema1): + temp_view = create_temp_view(session, temp_db1, temp_schema1) + yield temp_view + session._run_query(f"drop view if exists {temp_db1}.{temp_schema1}.{temp_view}") + + +@pytest.fixture(scope="module") +def temp_view2(session, temp_db1, temp_schema1): + temp_view = create_temp_view(session, temp_db1, temp_schema1) + yield temp_view + session._run_query(f"drop view if exists {temp_db1}.{temp_schema1}.{temp_view}") + + +def create_temp_procedure(session: Session, db, schema) -> str: + temp_procedure = get_temp_name("PROCEDURE") + session.sproc.register( + lambda _, x: x + 1, + return_type=IntegerType(), + input_types=[IntegerType()], + name=f"{db}.{schema}.{temp_procedure}", + packages=["snowflake-snowpark-python"], + ) + return temp_procedure + + +@pytest.fixture(scope="module") +def temp_procedure1(session, temp_db1, temp_schema1): + temp_procedure = create_temp_procedure(session, temp_db1, temp_schema1) + yield temp_procedure + session._run_query( + f"drop procedure if exists {temp_db1}.{temp_schema1}.{temp_procedure}(int)" + ) + + +@pytest.fixture(scope="module") +def temp_procedure2(session, temp_db1, temp_schema1): + temp_procedure = create_temp_procedure(session, temp_db1, temp_schema1) + yield temp_procedure + session._run_query( + f"drop procedure if exists {temp_db1}.{temp_schema1}.{temp_procedure}(int)" + ) + + +def create_temp_udf(session: Session, db, schema) -> str: + temp_udf = get_temp_name("UDF") + session.udf.register( + lambda x: x + 1, + return_type=IntegerType(), + input_types=[IntegerType()], + name=f"{db}.{schema}.{temp_udf}", + ) + return temp_udf + + +@pytest.fixture(scope="module") +def temp_udf1(session, temp_db1, temp_schema1): + temp_udf = create_temp_udf(session, temp_db1, temp_schema1) + yield temp_udf + session._run_query( + f"drop function if exists {temp_db1}.{temp_schema1}.{temp_udf}(int)" + ) + + +@pytest.fixture(scope="module") +def temp_udf2(session, temp_db1, temp_schema1): + temp_udf = create_temp_udf(session, temp_db1, temp_schema1) + yield temp_udf + session._run_query( + f"drop function if exists {temp_db1}.{temp_schema1}.{temp_udf}(int)" + ) + + +DOES_NOT_EXIST_PATTERN = "does_not_exist_.*" From 5933508fc926c9b7ad49f7f750007069cc3e7192 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Tue, 28 Apr 2026 10:29:54 -0700 Subject: [PATCH 07/17] fix lint --- tests/integ/test_catalog.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/integ/test_catalog.py b/tests/integ/test_catalog.py index 801ecf5c36..115aabf36a 100644 --- a/tests/integ/test_catalog.py +++ b/tests/integ/test_catalog.py @@ -16,9 +16,6 @@ CATALOG_TEMP_OBJECT_PREFIX, DOES_NOT_EXIST_PATTERN, ) -from snowflake.snowpark.session import Session -from snowflake.snowpark.types import IntegerType -from snowflake.core.exceptions import APIError from snowflake.snowpark.context import _DEFAULT_ARTIFACT_REPOSITORY From 02ec3b8ab1c005b67920b519fd665201fb4d7526 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Tue, 28 Apr 2026 11:40:12 -0700 Subject: [PATCH 08/17] remove changelog and use notimplementederror --- CHANGELOG.md | 8 ----- src/snowflake/snowpark/catalog.py | 57 +++++++++++++++++++++++-------- 2 files changed, 43 insertions(+), 22 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a0e2e8b621..959915450f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,13 +1,5 @@ # Release History -## 1.51.0 (TBD) - -### Snowpark Python API Updates - -#### Improvements - -- Catalog API now uses SQL commands instead of SnowAPI calls to improve stability. - ## 1.50.0 (2026-04-23) ### Snowpark Python API Updates diff --git a/src/snowflake/snowpark/catalog.py b/src/snowflake/snowpark/catalog.py index fedf1672ac..4078ebc340 100644 --- a/src/snowflake/snowpark/catalog.py +++ b/src/snowflake/snowpark/catalog.py @@ -56,7 +56,9 @@ def list_databases( pattern: Optional[str] = None, like: Optional[str] = None, ) -> List[Database]: - pass + raise NotImplementedError( + "_CatalogBackend.list_databases must be implemented by a concrete subclass." + ) @abstractmethod def list_schemas( @@ -66,17 +68,23 @@ def list_schemas( pattern: Optional[str] = None, like: Optional[str] = None, ) -> List[Schema]: - pass + raise NotImplementedError( + "_CatalogBackend.list_schemas must be implemented by a concrete subclass." + ) @abstractmethod def get_database(self, database: str) -> Database: - pass + raise NotImplementedError( + "_CatalogBackend.get_database must be implemented by a concrete subclass." + ) @abstractmethod def get_schema( self, schema: str, *, database: Optional[Union[str, Database]] = None ) -> Schema: - pass + raise NotImplementedError( + "_CatalogBackend.get_schema must be implemented by a concrete subclass." + ) @abstractmethod def get_table( @@ -86,7 +94,9 @@ def get_table( database: Optional[Union[str, Database]] = None, schema: Optional[Union[str, Schema]] = None, ) -> Union[Table, View]: - pass + raise NotImplementedError( + "_CatalogBackend.get_table must be implemented by a concrete subclass." + ) @abstractmethod def get_view( @@ -96,7 +106,9 @@ def get_view( database: Optional[Union[str, Database]] = None, schema: Optional[Union[str, Schema]] = None, ) -> View: - pass + raise NotImplementedError( + "_CatalogBackend.get_view must be implemented by a concrete subclass." + ) @abstractmethod def get_procedure( @@ -107,7 +119,9 @@ def get_procedure( database: Optional[Union[str, Database]] = None, schema: Optional[Union[str, Schema]] = None, ) -> Procedure: - pass + raise NotImplementedError( + "_CatalogBackend.get_procedure must be implemented by a concrete subclass." + ) @abstractmethod def get_user_defined_function( @@ -118,11 +132,15 @@ def get_user_defined_function( database: Optional[Union[str, Database]] = None, schema: Optional[Union[str, Schema]] = None, ) -> UserDefinedFunction: - pass + raise NotImplementedError( + "_CatalogBackend.get_user_defined_function must be implemented by a concrete subclass." + ) @abstractmethod def database_exists(self, database: Union[str, Database]) -> bool: - pass + raise NotImplementedError( + "_CatalogBackend.database_exists must be implemented by a concrete subclass." + ) @abstractmethod def schema_exists( @@ -131,7 +149,9 @@ def schema_exists( *, database: Optional[Union[str, Database]] = None, ) -> bool: - pass + raise NotImplementedError( + "_CatalogBackend.schema_exists must be implemented by a concrete subclass." + ) @abstractmethod def table_exists( @@ -141,7 +161,9 @@ def table_exists( database: Optional[Union[str, Database]] = None, schema: Optional[Union[str, Schema]] = None, ) -> bool: - pass + raise NotImplementedError( + "_CatalogBackend.table_exists must be implemented by a concrete subclass." + ) @abstractmethod def view_exists( @@ -151,7 +173,9 @@ def view_exists( database: Optional[Union[str, Database]] = None, schema: Optional[Union[str, Schema]] = None, ) -> bool: - pass + raise NotImplementedError( + "_CatalogBackend.view_exists must be implemented by a concrete subclass." + ) @abstractmethod def procedure_exists( @@ -162,7 +186,9 @@ def procedure_exists( database: Optional[Union[str, Database]] = None, schema: Optional[Union[str, Schema]] = None, ) -> bool: - pass + raise NotImplementedError( + "_CatalogBackend.procedure_exists must be implemented by a concrete subclass." + ) @abstractmethod def user_defined_function_exists( @@ -173,7 +199,10 @@ def user_defined_function_exists( database: Optional[Union[str, Database]] = None, schema: Optional[Union[str, Schema]] = None, ) -> bool: - pass + raise NotImplementedError( + "_CatalogBackend.user_defined_function_exists must be implemented by a " + "concrete subclass." + ) class _SqlCatalogBackend(_CatalogBackend): From c78333c8f1f348c86149859eaf64ccd3acb56e53 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Tue, 28 Apr 2026 14:31:08 -0700 Subject: [PATCH 09/17] move fixture back to test_catalog --- tests/integ/catalog_integ_common.py | 170 ------------------------- tests/integ/conftest.py | 2 +- tests/integ/test_catalog.py | 174 ++++++++++++++++++++++++-- tests/integ/test_catalog_rest_mode.py | 2 +- tests/integ/test_catalog_sql_mode.py | 2 +- 5 files changed, 169 insertions(+), 181 deletions(-) delete mode 100644 tests/integ/catalog_integ_common.py diff --git a/tests/integ/catalog_integ_common.py b/tests/integ/catalog_integ_common.py deleted file mode 100644 index bb459ac8a6..0000000000 --- a/tests/integ/catalog_integ_common.py +++ /dev/null @@ -1,170 +0,0 @@ -# -# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. -# -"""Shared pytest fixtures for catalog integration tests (see ``test_catalog*.py``).""" - -import uuid - -import pytest - -from snowflake.snowpark.session import Session -from snowflake.snowpark.types import IntegerType - -CATALOG_TEMP_OBJECT_PREFIX = "SP_CATALOG_TEMP" - - -def get_temp_name(type: str) -> str: - return f"{CATALOG_TEMP_OBJECT_PREFIX}_{type}_{uuid.uuid4().hex[:6]}".upper() - - -def create_temp_db(session) -> str: - original_db = session.get_current_database() - temp_db = get_temp_name("DB") - session._run_query(f"create or replace database {temp_db}") - session.use_database(original_db) - return temp_db - - -@pytest.fixture(scope="module") -def temp_db1(session): - temp_db = create_temp_db(session) - yield temp_db - session._run_query(f"drop database if exists {temp_db}") - - -@pytest.fixture(scope="module") -def temp_db2(session): - temp_db = create_temp_db(session) - yield temp_db - session._run_query(f"drop database if exists {temp_db}") - - -def create_temp_schema(session, db: str) -> str: - original_db = session.get_current_database() - original_schema = session.get_current_schema() - temp_schema = get_temp_name("SCHEMA") - session._run_query(f"create or replace schema {db}.{temp_schema}") - - session.use_database(original_db) - session.use_schema(original_schema) - return temp_schema - - -@pytest.fixture(scope="module") -def temp_schema1(session, temp_db1): - temp_schema = create_temp_schema(session, temp_db1) - yield temp_schema - session._run_query(f"drop schema if exists {temp_db1}.{temp_schema}") - - -@pytest.fixture(scope="module") -def temp_schema2(session, temp_db1): - temp_schema = create_temp_schema(session, temp_db1) - yield temp_schema - session._run_query(f"drop schema if exists {temp_db1}.{temp_schema}") - - -def create_temp_table(session, db: str, schema: str) -> str: - temp_table = get_temp_name("TABLE") - session._run_query( - f"create or replace temp table {db}.{schema}.{temp_table} (a int, b string)" - ) - return temp_table - - -@pytest.fixture(scope="module") -def temp_table1(session, temp_db1, temp_schema1): - temp_table = create_temp_table(session, temp_db1, temp_schema1) - yield temp_table - session._run_query(f"drop table if exists {temp_db1}.{temp_schema1}.{temp_table}") - - -@pytest.fixture(scope="module") -def temp_table2(session, temp_db1, temp_schema1): - temp_table = create_temp_table(session, temp_db1, temp_schema1) - yield temp_table - session._run_query(f"drop table if exists {temp_db1}.{temp_schema1}.{temp_table}") - - -def create_temp_view(session, db: str, schema: str) -> str: - temp_schema = get_temp_name("VIEW") - session._run_query( - f"create or replace temp view {db}.{schema}.{temp_schema} as select 1 as a, '2' as b" - ) - return temp_schema - - -@pytest.fixture(scope="module") -def temp_view1(session, temp_db1, temp_schema1): - temp_view = create_temp_view(session, temp_db1, temp_schema1) - yield temp_view - session._run_query(f"drop view if exists {temp_db1}.{temp_schema1}.{temp_view}") - - -@pytest.fixture(scope="module") -def temp_view2(session, temp_db1, temp_schema1): - temp_view = create_temp_view(session, temp_db1, temp_schema1) - yield temp_view - session._run_query(f"drop view if exists {temp_db1}.{temp_schema1}.{temp_view}") - - -def create_temp_procedure(session: Session, db, schema) -> str: - temp_procedure = get_temp_name("PROCEDURE") - session.sproc.register( - lambda _, x: x + 1, - return_type=IntegerType(), - input_types=[IntegerType()], - name=f"{db}.{schema}.{temp_procedure}", - packages=["snowflake-snowpark-python"], - ) - return temp_procedure - - -@pytest.fixture(scope="module") -def temp_procedure1(session, temp_db1, temp_schema1): - temp_procedure = create_temp_procedure(session, temp_db1, temp_schema1) - yield temp_procedure - session._run_query( - f"drop procedure if exists {temp_db1}.{temp_schema1}.{temp_procedure}(int)" - ) - - -@pytest.fixture(scope="module") -def temp_procedure2(session, temp_db1, temp_schema1): - temp_procedure = create_temp_procedure(session, temp_db1, temp_schema1) - yield temp_procedure - session._run_query( - f"drop procedure if exists {temp_db1}.{temp_schema1}.{temp_procedure}(int)" - ) - - -def create_temp_udf(session: Session, db, schema) -> str: - temp_udf = get_temp_name("UDF") - session.udf.register( - lambda x: x + 1, - return_type=IntegerType(), - input_types=[IntegerType()], - name=f"{db}.{schema}.{temp_udf}", - ) - return temp_udf - - -@pytest.fixture(scope="module") -def temp_udf1(session, temp_db1, temp_schema1): - temp_udf = create_temp_udf(session, temp_db1, temp_schema1) - yield temp_udf - session._run_query( - f"drop function if exists {temp_db1}.{temp_schema1}.{temp_udf}(int)" - ) - - -@pytest.fixture(scope="module") -def temp_udf2(session, temp_db1, temp_schema1): - temp_udf = create_temp_udf(session, temp_db1, temp_schema1) - yield temp_udf - session._run_query( - f"drop function if exists {temp_db1}.{temp_schema1}.{temp_udf}(int)" - ) - - -DOES_NOT_EXIST_PATTERN = "does_not_exist_.*" diff --git a/tests/integ/conftest.py b/tests/integ/conftest.py index 15feb15fed..01edcf170b 100644 --- a/tests/integ/conftest.py +++ b/tests/integ/conftest.py @@ -31,7 +31,7 @@ RUNNING_ON_GH = os.getenv("GITHUB_ACTIONS") == "true" RUNNING_ON_JENKINS = "JENKINS_HOME" in os.environ -pytest_plugins = ("tests.integ.catalog_integ_common",) +pytest_plugins = ("tests.integ.test_catalog",) test_dir = os.path.dirname(__file__) test_data_dir = os.path.join(test_dir, "cassettes") diff --git a/tests/integ/test_catalog.py b/tests/integ/test_catalog.py index 115aabf36a..b0fef8e6c1 100644 --- a/tests/integ/test_catalog.py +++ b/tests/integ/test_catalog.py @@ -1,22 +1,180 @@ # # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. # -"""Mode-agnostic catalog integration tests. +"""Catalog integration tests and shared fixtures. -Only tests whose call paths are identical between the SQL-based and REST-based -catalog backends live here. Backend-specific behavior is covered in -``test_catalog_sql_mode.py`` and ``test_catalog_rest_mode.py``. +Mode-agnostic tests (same behavior for SQL and REST catalog backends) live in +this module. Backend-specific tests are in ``test_catalog_sql_mode.py`` and +``test_catalog_rest_mode.py``, which reuse the fixtures defined here via +``pytest_plugins`` in ``conftest.py``. """ +import uuid from unittest.mock import patch + import pytest from snowflake.snowpark.catalog import Catalog -from tests.integ.catalog_integ_common import ( - CATALOG_TEMP_OBJECT_PREFIX, - DOES_NOT_EXIST_PATTERN, -) from snowflake.snowpark.context import _DEFAULT_ARTIFACT_REPOSITORY +from snowflake.snowpark.session import Session +from snowflake.snowpark.types import IntegerType + +CATALOG_TEMP_OBJECT_PREFIX = "SP_CATALOG_TEMP" +DOES_NOT_EXIST_PATTERN = "does_not_exist_.*" + + +def get_temp_name(type: str) -> str: + return f"{CATALOG_TEMP_OBJECT_PREFIX}_{type}_{uuid.uuid4().hex[:6]}".upper() + + +def create_temp_db(session) -> str: + original_db = session.get_current_database() + temp_db = get_temp_name("DB") + session._run_query(f"create or replace database {temp_db}") + session.use_database(original_db) + return temp_db + + +@pytest.fixture(scope="module") +def temp_db1(session): + temp_db = create_temp_db(session) + yield temp_db + session._run_query(f"drop database if exists {temp_db}") + + +@pytest.fixture(scope="module") +def temp_db2(session): + temp_db = create_temp_db(session) + yield temp_db + session._run_query(f"drop database if exists {temp_db}") + + +def create_temp_schema(session, db: str) -> str: + original_db = session.get_current_database() + original_schema = session.get_current_schema() + temp_schema = get_temp_name("SCHEMA") + session._run_query(f"create or replace schema {db}.{temp_schema}") + + session.use_database(original_db) + session.use_schema(original_schema) + return temp_schema + + +@pytest.fixture(scope="module") +def temp_schema1(session, temp_db1): + temp_schema = create_temp_schema(session, temp_db1) + yield temp_schema + session._run_query(f"drop schema if exists {temp_db1}.{temp_schema}") + + +@pytest.fixture(scope="module") +def temp_schema2(session, temp_db1): + temp_schema = create_temp_schema(session, temp_db1) + yield temp_schema + session._run_query(f"drop schema if exists {temp_db1}.{temp_schema}") + + +def create_temp_table(session, db: str, schema: str) -> str: + temp_table = get_temp_name("TABLE") + session._run_query( + f"create or replace temp table {db}.{schema}.{temp_table} (a int, b string)" + ) + return temp_table + + +@pytest.fixture(scope="module") +def temp_table1(session, temp_db1, temp_schema1): + temp_table = create_temp_table(session, temp_db1, temp_schema1) + yield temp_table + session._run_query(f"drop table if exists {temp_db1}.{temp_schema1}.{temp_table}") + + +@pytest.fixture(scope="module") +def temp_table2(session, temp_db1, temp_schema1): + temp_table = create_temp_table(session, temp_db1, temp_schema1) + yield temp_table + session._run_query(f"drop table if exists {temp_db1}.{temp_schema1}.{temp_table}") + + +def create_temp_view(session, db: str, schema: str) -> str: + temp_schema = get_temp_name("VIEW") + session._run_query( + f"create or replace temp view {db}.{schema}.{temp_schema} as select 1 as a, '2' as b" + ) + return temp_schema + + +@pytest.fixture(scope="module") +def temp_view1(session, temp_db1, temp_schema1): + temp_view = create_temp_view(session, temp_db1, temp_schema1) + yield temp_view + session._run_query(f"drop view if exists {temp_db1}.{temp_schema1}.{temp_view}") + + +@pytest.fixture(scope="module") +def temp_view2(session, temp_db1, temp_schema1): + temp_view = create_temp_view(session, temp_db1, temp_schema1) + yield temp_view + session._run_query(f"drop view if exists {temp_db1}.{temp_schema1}.{temp_view}") + + +def create_temp_procedure(session: Session, db, schema) -> str: + temp_procedure = get_temp_name("PROCEDURE") + session.sproc.register( + lambda _, x: x + 1, + return_type=IntegerType(), + input_types=[IntegerType()], + name=f"{db}.{schema}.{temp_procedure}", + packages=["snowflake-snowpark-python"], + ) + return temp_procedure + + +@pytest.fixture(scope="module") +def temp_procedure1(session, temp_db1, temp_schema1): + temp_procedure = create_temp_procedure(session, temp_db1, temp_schema1) + yield temp_procedure + session._run_query( + f"drop procedure if exists {temp_db1}.{temp_schema1}.{temp_procedure}(int)" + ) + + +@pytest.fixture(scope="module") +def temp_procedure2(session, temp_db1, temp_schema1): + temp_procedure = create_temp_procedure(session, temp_db1, temp_schema1) + yield temp_procedure + session._run_query( + f"drop procedure if exists {temp_db1}.{temp_schema1}.{temp_procedure}(int)" + ) + + +def create_temp_udf(session: Session, db, schema) -> str: + temp_udf = get_temp_name("UDF") + session.udf.register( + lambda x: x + 1, + return_type=IntegerType(), + input_types=[IntegerType()], + name=f"{db}.{schema}.{temp_udf}", + ) + return temp_udf + + +@pytest.fixture(scope="module") +def temp_udf1(session, temp_db1, temp_schema1): + temp_udf = create_temp_udf(session, temp_db1, temp_schema1) + yield temp_udf + session._run_query( + f"drop function if exists {temp_db1}.{temp_schema1}.{temp_udf}(int)" + ) + + +@pytest.fixture(scope="module") +def temp_udf2(session, temp_db1, temp_schema1): + temp_udf = create_temp_udf(session, temp_db1, temp_schema1) + yield temp_udf + session._run_query( + f"drop function if exists {temp_db1}.{temp_schema1}.{temp_udf}(int)" + ) pytestmark = [ diff --git a/tests/integ/test_catalog_rest_mode.py b/tests/integ/test_catalog_rest_mode.py index be4bba4a39..e7d33932ca 100644 --- a/tests/integ/test_catalog_rest_mode.py +++ b/tests/integ/test_catalog_rest_mode.py @@ -14,7 +14,7 @@ from snowflake.snowpark._internal.analyzer.analyzer_utils import unquote_if_quoted from snowflake.snowpark.catalog import Catalog from snowflake.snowpark.types import IntegerType -from tests.integ.catalog_integ_common import ( +from tests.integ.test_catalog import ( CATALOG_TEMP_OBJECT_PREFIX, create_temp_db, create_temp_schema, diff --git a/tests/integ/test_catalog_sql_mode.py b/tests/integ/test_catalog_sql_mode.py index ac97b435ce..eff8bd8d91 100644 --- a/tests/integ/test_catalog_sql_mode.py +++ b/tests/integ/test_catalog_sql_mode.py @@ -15,7 +15,7 @@ from snowflake.snowpark.catalog import Catalog from snowflake.snowpark.exceptions import NotFoundError from snowflake.snowpark.types import IntegerType -from tests.integ.catalog_integ_common import ( +from tests.integ.test_catalog import ( CATALOG_TEMP_OBJECT_PREFIX, create_temp_db, create_temp_schema, From 2afd73802e36984b0564540e6c5372f6286788e0 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Tue, 28 Apr 2026 19:29:58 -0700 Subject: [PATCH 10/17] fix test --- tests/conftest.py | 2 ++ tests/integ/conftest.py | 2 -- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index bfd9ab8f78..bca98bd3b3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,6 +14,8 @@ from snowflake.snowpark._internal.utils import warning_dict from .ast.conftest import default_unparser_path +pytest_plugins = ("tests.integ.test_catalog",) + logging.getLogger("snowflake.connector").setLevel(logging.ERROR) excluded_frontend_files = [ diff --git a/tests/integ/conftest.py b/tests/integ/conftest.py index 01edcf170b..0ac231a396 100644 --- a/tests/integ/conftest.py +++ b/tests/integ/conftest.py @@ -31,8 +31,6 @@ RUNNING_ON_GH = os.getenv("GITHUB_ACTIONS") == "true" RUNNING_ON_JENKINS = "JENKINS_HOME" in os.environ -pytest_plugins = ("tests.integ.test_catalog",) - test_dir = os.path.dirname(__file__) test_data_dir = os.path.join(test_dir, "cassettes") From 50d0c32320f2d760b6067c51d07cf33f03321a41 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Wed, 29 Apr 2026 10:10:40 -0700 Subject: [PATCH 11/17] add limit 10000 in sql base(scos only) --- src/snowflake/snowpark/catalog.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/snowflake/snowpark/catalog.py b/src/snowflake/snowpark/catalog.py index 4078ebc340..dadcfe31ac 100644 --- a/src/snowflake/snowpark/catalog.py +++ b/src/snowflake/snowpark/catalog.py @@ -42,6 +42,10 @@ if TYPE_CHECKING: from snowflake.snowpark.session import Session +# Cap for SHOW AS RESOURCE DATABASES / SCHEMAS in the SQL backend (SCOS; avoids +# oversized result sets when accounts have very many databases or schemas). +_SHOW_AS_RESOURCE_LIMIT = 10000 + class _CatalogBackend(ABC): """Internal catalog implementation selected by ``context._is_snowpark_connect_compatible_mode``.""" @@ -214,7 +218,9 @@ def list_databases( ) -> List[Database]: c = self._catalog like_str = f"LIKE '{like}'" if like else "" - df = c._session.sql(f"SHOW AS RESOURCE DATABASES {like_str}") + df = c._session.sql( + f"SHOW AS RESOURCE DATABASES {like_str} LIMIT {_SHOW_AS_RESOURCE_LIMIT}" + ) if pattern: c._initialize_regex_udf() assert c._python_regex_udf is not None # pyright @@ -239,7 +245,9 @@ def list_schemas( c = self._catalog db_name = c._parse_database(database) like_str = f"LIKE '{like}'" if like else "" - df = c._session.sql(f"SHOW AS RESOURCE SCHEMAS {like_str} IN {db_name}") + df = c._session.sql( + f"SHOW AS RESOURCE SCHEMAS {like_str} IN {db_name} LIMIT {_SHOW_AS_RESOURCE_LIMIT}" + ) if pattern: c._initialize_regex_udf() assert c._python_regex_udf is not None # pyright From 9e89ca4525bcd253ff59ff5b17e6838349596312 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Wed, 29 Apr 2026 11:36:19 -0700 Subject: [PATCH 12/17] remove wrong test --- tests/integ/test_catalog_rest_mode.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/tests/integ/test_catalog_rest_mode.py b/tests/integ/test_catalog_rest_mode.py index e7d33932ca..7082f808b3 100644 --- a/tests/integ/test_catalog_rest_mode.py +++ b/tests/integ/test_catalog_rest_mode.py @@ -84,14 +84,6 @@ def test_get_database_missing_raises_core_not_found_rest_mode(session): catalog.get_database("NONEXISTENT_DB_XYZ_12345") -def test_get_table_does_not_resolve_view_rest_mode( - session, temp_db1, temp_schema1, temp_view1 -): - catalog: Catalog = session.catalog - with pytest.raises(CoreNotFoundError): - catalog.get_table(temp_view1, database=temp_db1, schema=temp_schema1) - - def test_get_view_rest_mode(session, temp_db1, temp_schema1, temp_view1): catalog: Catalog = session.catalog view = catalog.get_view(temp_view1, database=temp_db1, schema=temp_schema1) @@ -100,13 +92,6 @@ def test_get_view_rest_mode(session, temp_db1, temp_schema1, temp_view1): assert view.schema_name == temp_schema1 -def test_table_exists_false_for_view_name_rest_mode( - session, temp_db1, temp_schema1, temp_view1 -): - catalog: Catalog = session.catalog - assert not catalog.table_exists(temp_view1, database=temp_db1, schema=temp_schema1) - - @pytest.mark.udf def test_get_procedure_rest_mode(session, temp_db1, temp_schema1, temp_procedure1): catalog: Catalog = session.catalog From ef16ee85c71164167a82639e08f3137b01c9deca Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Thu, 30 Apr 2026 16:45:56 -0700 Subject: [PATCH 13/17] address comments --- src/snowflake/snowpark/catalog.py | 144 +++++++++++++++++++++++--- tests/integ/test_catalog.py | 3 + tests/integ/test_catalog_rest_mode.py | 5 + 3 files changed, 137 insertions(+), 15 deletions(-) diff --git a/src/snowflake/snowpark/catalog.py b/src/snowflake/snowpark/catalog.py index dadcfe31ac..af23658c92 100644 --- a/src/snowflake/snowpark/catalog.py +++ b/src/snowflake/snowpark/catalog.py @@ -208,6 +208,47 @@ def user_defined_function_exists( "concrete subclass." ) + @abstractmethod + def drop_database(self, database: Union[str, Database]) -> None: + raise NotImplementedError( + "_CatalogBackend.drop_database must be implemented by a concrete subclass." + ) + + @abstractmethod + def drop_schema( + self, + schema: Union[str, Schema], + *, + database: Optional[Union[str, Database]] = None, + ) -> None: + raise NotImplementedError( + "_CatalogBackend.drop_schema must be implemented by a concrete subclass." + ) + + @abstractmethod + def drop_table( + self, + table: Union[str, Table], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> None: + raise NotImplementedError( + "_CatalogBackend.drop_table must be implemented by a concrete subclass." + ) + + @abstractmethod + def drop_view( + self, + view: Union[str, View], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> None: + raise NotImplementedError( + "_CatalogBackend.drop_view must be implemented by a concrete subclass." + ) + class _SqlCatalogBackend(_CatalogBackend): def list_databases( @@ -499,6 +540,48 @@ def user_defined_function_exists( except NotFoundError: return False + def drop_database(self, database: Union[str, Database]) -> None: + c = self._catalog + db_name = c._parse_database(database) + c._session.sql(f"DROP DATABASE {db_name}").collect() + + def drop_schema( + self, + schema: Union[str, Schema], + *, + database: Optional[Union[str, Database]] = None, + ) -> None: + c = self._catalog + db_name = c._parse_database(database, schema) + schema_name = c._parse_schema(schema) + c._session.sql(f"DROP SCHEMA {db_name}.{schema_name}").collect() + + def drop_table( + self, + table: Union[str, Table], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> None: + c = self._catalog + db_name = c._parse_database(database, table) + schema_name = c._parse_schema(schema, table) + table_name = table if isinstance(table, str) else table.name + c._session.sql(f"DROP TABLE {db_name}.{schema_name}.{table_name}").collect() + + def drop_view( + self, + view: Union[str, View], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> None: + c = self._catalog + db_name = c._parse_database(database, view) + schema_name = c._parse_schema(schema, view) + view_name = view if isinstance(view, str) else view.name + c._session.sql(f"DROP VIEW {db_name}.{schema_name}.{view_name}").collect() + class _RestCatalogBackend(_CatalogBackend): def __init__(self, catalog: "Catalog") -> None: @@ -729,6 +812,48 @@ def user_defined_function_exists( except CoreNotFoundError: return False + def drop_database(self, database: Union[str, Database]) -> None: + c = self._catalog + db_name = c._parse_database(database) + self._root.databases[db_name].drop() + + def drop_schema( + self, + schema: Union[str, Schema], + *, + database: Optional[Union[str, Database]] = None, + ) -> None: + c = self._catalog + db_name = c._parse_database(database, schema) + schema_name = c._parse_schema(schema) + self._root.databases[db_name].schemas[schema_name].drop() + + def drop_table( + self, + table: Union[str, Table], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> None: + c = self._catalog + db_name = c._parse_database(database, table) + schema_name = c._parse_schema(schema, table) + table_name = table if isinstance(table, str) else table.name + self._root.databases[db_name].schemas[schema_name].tables[table_name].drop() + + def drop_view( + self, + view: Union[str, View], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> None: + c = self._catalog + db_name = c._parse_database(database, view) + schema_name = c._parse_schema(schema, view) + view_name = view if isinstance(view, str) else view.name + self._root.databases[db_name].schemas[schema_name].views[view_name].drop() + class Catalog: """The Catalog class provides methods to interact with and manage the Snowflake objects. @@ -1238,8 +1363,7 @@ def drop_database(self, database: Union[str, Database]) -> None: Args: database: database name or ``Database`` object. """ - db_name = self._parse_database(database) - self._session.sql(f"DROP DATABASE {db_name}").collect() + return self._backend.drop_database(database) def drop_schema( self, @@ -1254,9 +1378,7 @@ def drop_schema( schema: schema name or ``Schema`` object. database: database name or ``Database`` object. Defaults to None. """ - db_name = self._parse_database(database, schema) - schema_name = self._parse_schema(schema) - self._session.sql(f"DROP SCHEMA {db_name}.{schema_name}").collect() + return self._backend.drop_schema(schema, database=database) def drop_table( self, @@ -1273,11 +1395,7 @@ def drop_table( database: database name or ``Database`` object. Defaults to None. schema: schema name or ``Schema`` object. Defaults to None. """ - db_name = self._parse_database(database, table) - schema_name = self._parse_schema(schema, table) - table_name = table if isinstance(table, str) else table.name - - self._session.sql(f"DROP TABLE {db_name}.{schema_name}.{table_name}").collect() + return self._backend.drop_table(table, database=database, schema=schema) def drop_view( self, @@ -1294,11 +1412,7 @@ def drop_view( database: database name or ``Database`` object. Defaults to None. schema: schema name or ``Schema`` object. Defaults to None. """ - db_name = self._parse_database(database, view) - schema_name = self._parse_schema(schema, view) - view_name = view if isinstance(view, str) else view.name - - self._session.sql(f"DROP VIEW {db_name}.{schema_name}.{view_name}").collect() + return self._backend.drop_view(view, database=database, schema=schema) listDatabases = list_databases listSchemas = list_schemas diff --git a/tests/integ/test_catalog.py b/tests/integ/test_catalog.py index b0fef8e6c1..0827e132dd 100644 --- a/tests/integ/test_catalog.py +++ b/tests/integ/test_catalog.py @@ -54,6 +54,9 @@ def create_temp_schema(session, db: str) -> str: original_schema = session.get_current_schema() temp_schema = get_temp_name("SCHEMA") session._run_query(f"create or replace schema {db}.{temp_schema}") + session.sql( + f"ALTER SCHEMA SET DEFAULT_PYTHON_ARTIFACT_REPOSITORY = {_DEFAULT_ARTIFACT_REPOSITORY}" + ).collect() session.use_database(original_db) session.use_schema(original_schema) diff --git a/tests/integ/test_catalog_rest_mode.py b/tests/integ/test_catalog_rest_mode.py index 7082f808b3..11b0cc39ed 100644 --- a/tests/integ/test_catalog_rest_mode.py +++ b/tests/integ/test_catalog_rest_mode.py @@ -9,6 +9,7 @@ import pytest +from snowflake.core.exceptions import APIError from snowflake.core.exceptions import NotFoundError as CoreNotFoundError from snowflake.snowpark import context from snowflake.snowpark._internal.analyzer.analyzer_utils import unquote_if_quoted @@ -28,6 +29,10 @@ reason="deepcopy is not supported and required by local testing", run=False, ), + pytest.mark.xfail( + raises=APIError, + reason="Failure due to warehouse overload", + ), ] From d9dad80fc00d6bef0ec210769b2ebe25028508f4 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Thu, 30 Apr 2026 16:47:48 -0700 Subject: [PATCH 14/17] restore comment --- src/snowflake/snowpark/catalog.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/snowflake/snowpark/catalog.py b/src/snowflake/snowpark/catalog.py index af23658c92..35025d8397 100644 --- a/src/snowflake/snowpark/catalog.py +++ b/src/snowflake/snowpark/catalog.py @@ -1414,6 +1414,7 @@ def drop_view( """ return self._backend.drop_view(view, database=database, schema=schema) + # aliases listDatabases = list_databases listSchemas = list_schemas listTables = list_tables From 68650b11faed2808f4ddbc65fa36deac32baf622 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Fri, 1 May 2026 12:58:08 -0700 Subject: [PATCH 15/17] parameter protection --- src/snowflake/snowpark/catalog.py | 6 +++--- src/snowflake/snowpark/session.py | 5 ++++- tests/integ/test_catalog_sql_mode.py | 14 ++++++++++++++ tests/unit/test_session.py | 19 +++++++++++++++++++ 4 files changed, 40 insertions(+), 4 deletions(-) diff --git a/src/snowflake/snowpark/catalog.py b/src/snowflake/snowpark/catalog.py index 35025d8397..e92e546a99 100644 --- a/src/snowflake/snowpark/catalog.py +++ b/src/snowflake/snowpark/catalog.py @@ -48,7 +48,7 @@ class _CatalogBackend(ABC): - """Internal catalog implementation selected by ``context._is_snowpark_connect_compatible_mode``.""" + """Internal catalog implementation selected by compatibility mode and SQL base flag.""" def __init__(self, catalog: "Catalog") -> None: self._catalog = catalog @@ -861,10 +861,10 @@ class Catalog: views, functions, etc. """ - def __init__(self, session: "Session") -> None: + def __init__(self, session: "Session", *, _use_sql_base: bool = True) -> None: self._session = session self._python_regex_udf = None - if context._is_snowpark_connect_compatible_mode: + if context._is_snowpark_connect_compatible_mode and _use_sql_base: self._backend: _CatalogBackend = _SqlCatalogBackend(self) else: self._backend = _RestCatalogBackend(self) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 5803cc8329..2ec57288e6 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -628,6 +628,9 @@ def __init__( """ self.version = get_version() self._session_stage = None + self._use_sql_base = self._conn._get_client_side_session_parameter( + "SNOWPARK_CONNECT_CATALOG_USE_SQL_BASE", True + ) if isinstance(conn, MockServerConnection): self._udf_registration = MockUDFRegistration(self) @@ -961,7 +964,7 @@ def catalog(self): external_feature_name="Session.catalog", raise_error=NotImplementedError, ) - self._catalog = Catalog(self) + self._catalog = Catalog(self, _use_sql_base=self._use_sql_base) return self._catalog def close(self) -> None: diff --git a/tests/integ/test_catalog_sql_mode.py b/tests/integ/test_catalog_sql_mode.py index eff8bd8d91..e253b7396a 100644 --- a/tests/integ/test_catalog_sql_mode.py +++ b/tests/integ/test_catalog_sql_mode.py @@ -9,6 +9,7 @@ import pytest +from snowflake.core.exceptions import NotFoundError as CoreNotFoundError from snowflake.core.view import View from snowflake.snowpark import context from snowflake.snowpark._internal.analyzer.analyzer_utils import unquote_if_quoted @@ -85,6 +86,19 @@ def test_get_database_missing_raises_snowpark_not_found_sql_mode(session): catalog.get_database("NONEXISTENT_DB_XYZ_12345") +def test_compat_mode_with_sql_base_disabled_uses_rest_backend(session): + original_use_sql_base = session._use_sql_base + try: + session._use_sql_base = False + session._catalog = None + catalog: Catalog = session.catalog + with pytest.raises(CoreNotFoundError): + catalog.get_database("NONEXISTENT_DB_XYZ_12345") + finally: + session._use_sql_base = original_use_sql_base + session._catalog = None + + def test_get_table_resolves_view_sql_mode(session, temp_db1, temp_schema1, temp_view1): catalog: Catalog = session.catalog obj = catalog.get_table(temp_view1, database=temp_db1, schema=temp_schema1) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 0349618659..fdcba3abac 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -117,6 +117,25 @@ def test_used_scoped_temp_object(): assert Session(fake_connection)._use_scoped_temp_objects is False +@pytest.mark.parametrize( + "parameter_value", + [True, False], +) +def test_session_use_sql_base_from_session_parameter(parameter_value): + fake_connection = mock.create_autospec(ServerConnection) + fake_connection._conn = mock.Mock() + fake_connection._thread_safe_session_enabled = True + fake_connection._get_client_side_session_parameter = mock.Mock( + side_effect=lambda name, default: parameter_value + if name == "SNOWPARK_CONNECT_CATALOG_USE_SQL_BASE" + else default + ) + fake_connection._conn._session_parameters = {} + + session = Session(fake_connection) + assert session._use_sql_base is parameter_value + + def test_close_exception(): fake_connection = mock.create_autospec(ServerConnection) fake_connection._conn = mock.Mock() From bb2ef8e459a1947545cc5c454373806e329cc71d Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Fri, 1 May 2026 15:49:04 -0700 Subject: [PATCH 16/17] parameter protection --- src/snowflake/snowpark/session.py | 20 +++++++++++++++----- tests/unit/test_session.py | 19 ++++++++++--------- 2 files changed, 25 insertions(+), 14 deletions(-) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 2ec57288e6..ca2b15ee9f 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -254,6 +254,7 @@ _session_management_lock = RLock() _active_sessions: Set["Session"] = set() +_USE_SQL_BASE_OPTION_KEY = "_use_sql_base" _PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS_STRING = ( "PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS" ) @@ -463,6 +464,13 @@ def __init__(self) -> None: self._app_name = None self._format_json = None + @staticmethod + def _connection_options( + options: Dict[str, Union[int, str]] + ) -> Dict[str, Union[int, str]]: + # Internal Session-only options must not be forwarded to connector connect(**kwargs). + return {k: v for k, v in options.items() if k != _USE_SQL_BASE_OPTION_KEY} + def _remove_config(self, key: str) -> "Session.SessionBuilder": """Only used in test.""" self._options.pop(key, None) @@ -569,8 +577,11 @@ def _create_internal( # Set paramstyle to qmark by default to be consistent with previous behavior if "paramstyle" not in self._options: self._options["paramstyle"] = "qmark" + connection_options = self._connection_options(self._options) new_session = Session( - ServerConnection({}, conn) if conn else ServerConnection(self._options), + ServerConnection({}, conn) + if conn + else ServerConnection(connection_options), self._options, ) @@ -628,9 +639,8 @@ def __init__( """ self.version = get_version() self._session_stage = None - self._use_sql_base = self._conn._get_client_side_session_parameter( - "SNOWPARK_CONNECT_CATALOG_USE_SQL_BASE", True - ) + options = options or {} + self._use_sql_base = options.pop(_USE_SQL_BASE_OPTION_KEY, True) if isinstance(conn, MockServerConnection): self._udf_registration = MockUDFRegistration(self) @@ -851,7 +861,7 @@ def __init__( _PYTHON_SNOWPARK_COLLECT_TELEMETRY_AT_CRITICAL_PATH_VERSION ) ) - self._conf = self.RuntimeConfig(self, options or {}) + self._conf = self.RuntimeConfig(self, options) self._runtime_version_from_requirement: str = None self._temp_table_auto_cleaner: TempTableAutoCleaner = TempTableAutoCleaner(self) self._sp_profiler = StoredProcedureProfiler(session=self) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index fdcba3abac..4bf067979e 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -118,22 +118,23 @@ def test_used_scoped_temp_object(): @pytest.mark.parametrize( - "parameter_value", - [True, False], + "option_value, expected", + [(True, True), (False, False)], ) -def test_session_use_sql_base_from_session_parameter(parameter_value): +def test_session_use_sql_base_from_options(option_value, expected): fake_connection = mock.create_autospec(ServerConnection) fake_connection._conn = mock.Mock() fake_connection._thread_safe_session_enabled = True - fake_connection._get_client_side_session_parameter = mock.Mock( - side_effect=lambda name, default: parameter_value - if name == "SNOWPARK_CONNECT_CATALOG_USE_SQL_BASE" - else default + fake_connection._get_client_side_session_parameter = ( + lambda x, y: ServerConnection._get_client_side_session_parameter( + fake_connection, x, y + ) ) fake_connection._conn._session_parameters = {} - session = Session(fake_connection) - assert session._use_sql_base is parameter_value + session = Session(fake_connection, {"_use_sql_base": option_value}) + assert session._use_sql_base is expected + assert session.conf.get("_use_sql_base") is None def test_close_exception(): From e0097c9030075d0bb07cde109ccde65fae56b8c5 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Fri, 1 May 2026 15:59:33 -0700 Subject: [PATCH 17/17] add test --- tests/integ/test_catalog_sql_mode.py | 13 +++++++++++ tests/unit/test_session.py | 34 ++++++++++++++++++++++++++-- 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/tests/integ/test_catalog_sql_mode.py b/tests/integ/test_catalog_sql_mode.py index e253b7396a..9a6ed7e10d 100644 --- a/tests/integ/test_catalog_sql_mode.py +++ b/tests/integ/test_catalog_sql_mode.py @@ -92,6 +92,7 @@ def test_compat_mode_with_sql_base_disabled_uses_rest_backend(session): session._use_sql_base = False session._catalog = None catalog: Catalog = session.catalog + assert type(catalog._backend).__name__ == "_RestCatalogBackend" with pytest.raises(CoreNotFoundError): catalog.get_database("NONEXISTENT_DB_XYZ_12345") finally: @@ -99,6 +100,18 @@ def test_compat_mode_with_sql_base_disabled_uses_rest_backend(session): session._catalog = None +def test_compat_mode_with_sql_base_enabled_uses_sql_backend(session): + original_use_sql_base = session._use_sql_base + try: + session._use_sql_base = True + session._catalog = None + catalog: Catalog = session.catalog + assert type(catalog._backend).__name__ == "_SqlCatalogBackend" + finally: + session._use_sql_base = original_use_sql_base + session._catalog = None + + def test_get_table_resolves_view_sql_mode(session, temp_db1, temp_schema1, temp_view1): catalog: Catalog = session.catalog obj = catalog.get_table(temp_view1, database=temp_db1, schema=temp_schema1) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 4bf067979e..9377d1ad92 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -119,7 +119,7 @@ def test_used_scoped_temp_object(): @pytest.mark.parametrize( "option_value, expected", - [(True, True), (False, False)], + [(None, True), (True, True), (False, False)], ) def test_session_use_sql_base_from_options(option_value, expected): fake_connection = mock.create_autospec(ServerConnection) @@ -132,11 +132,41 @@ def test_session_use_sql_base_from_options(option_value, expected): ) fake_connection._conn._session_parameters = {} - session = Session(fake_connection, {"_use_sql_base": option_value}) + options = {} if option_value is None else {"_use_sql_base": option_value} + session = Session(fake_connection, options) assert session._use_sql_base is expected assert session.conf.get("_use_sql_base") is None +@pytest.mark.parametrize( + "option_value, expected_backend_name", + [(True, "_SqlCatalogBackend"), (False, "_RestCatalogBackend")], +) +def test_catalog_backend_selection_from_use_sql_base_option( + option_value, expected_backend_name +): + import snowflake.snowpark.context as ctx + + fake_connection = mock.create_autospec(ServerConnection) + fake_connection._conn = mock.Mock() + fake_connection._thread_safe_session_enabled = True + fake_connection._get_client_side_session_parameter = ( + lambda x, y: ServerConnection._get_client_side_session_parameter( + fake_connection, x, y + ) + ) + fake_connection._conn._session_parameters = {} + fake_connection.get_session_id.return_value = "fake_session_id" + + original_compat = ctx._is_snowpark_connect_compatible_mode + try: + ctx._is_snowpark_connect_compatible_mode = True + session = Session(fake_connection, {"_use_sql_base": option_value}) + assert type(session.catalog._backend).__name__ == expected_backend_name + finally: + ctx._is_snowpark_connect_compatible_mode = original_compat + + def test_close_exception(): fake_connection = mock.create_autospec(ServerConnection) fake_connection._conn = mock.Mock()