diff --git a/src/snowflake/snowpark/catalog.py b/src/snowflake/snowpark/catalog.py index 4572809ea2..e92e546a99 100644 --- a/src/snowflake/snowpark/catalog.py +++ b/src/snowflake/snowpark/catalog.py @@ -2,15 +2,28 @@ # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. # +from abc import ABC, abstractmethod +from ctypes import ArgumentError import re -from typing import List, Optional, Union +from typing import ( + List, + Optional, + Union, + TYPE_CHECKING, +) + +from snowflake.snowpark import context +from snowflake.snowpark._internal.analyzer.analyzer_utils import unquote_if_quoted +from snowflake.snowpark.exceptions import SnowparkSQLException, NotFoundError try: from snowflake.core import Root # type: ignore from snowflake.core.database import Database # type: ignore - from snowflake.core.exceptions import NotFoundError + from snowflake.core.database._generated.models import Database as ModelDatabase # type: ignore + from snowflake.core.exceptions import NotFoundError as CoreNotFoundError # type: ignore from snowflake.core.procedure import Procedure from snowflake.core.schema import Schema # type: ignore + from snowflake.core.schema._generated.models import Schema as ModelSchema # type: ignore from snowflake.core.table import Table, TableColumn from snowflake.core.user_defined_function import UserDefinedFunction from snowflake.core.view import View @@ -19,11 +32,827 @@ "Missing optional dependency: 'snowflake.core'." ) from e # pragma: no cover +from snowflake.snowpark._internal.type_utils import ( + convert_sp_to_sf_type, + type_string_to_type_object, +) +from snowflake.snowpark.functions import lit, parse_json +from snowflake.snowpark.types import DataType + +if TYPE_CHECKING: + from snowflake.snowpark.session import Session + +# Cap for SHOW AS RESOURCE DATABASES / SCHEMAS in the SQL backend (SCOS; avoids +# oversized result sets when accounts have very many databases or schemas). +_SHOW_AS_RESOURCE_LIMIT = 10000 + + +class _CatalogBackend(ABC): + """Internal catalog implementation selected by compatibility mode and SQL base flag.""" + + def __init__(self, catalog: "Catalog") -> None: + self._catalog = catalog + + @abstractmethod + def list_databases( + self, + *, + pattern: Optional[str] = None, + like: Optional[str] = None, + ) -> List[Database]: + raise NotImplementedError( + "_CatalogBackend.list_databases must be implemented by a concrete subclass." + ) + + @abstractmethod + def list_schemas( + self, + *, + database: Optional[Union[str, Database]] = None, + pattern: Optional[str] = None, + like: Optional[str] = None, + ) -> List[Schema]: + raise NotImplementedError( + "_CatalogBackend.list_schemas must be implemented by a concrete subclass." + ) + + @abstractmethod + def get_database(self, database: str) -> Database: + raise NotImplementedError( + "_CatalogBackend.get_database must be implemented by a concrete subclass." + ) + + @abstractmethod + def get_schema( + self, schema: str, *, database: Optional[Union[str, Database]] = None + ) -> Schema: + raise NotImplementedError( + "_CatalogBackend.get_schema must be implemented by a concrete subclass." + ) + + @abstractmethod + def get_table( + self, + table_name: str, + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> Union[Table, View]: + raise NotImplementedError( + "_CatalogBackend.get_table must be implemented by a concrete subclass." + ) + + @abstractmethod + def get_view( + self, + view_name: str, + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> View: + raise NotImplementedError( + "_CatalogBackend.get_view must be implemented by a concrete subclass." + ) + + @abstractmethod + def get_procedure( + self, + procedure_name: str, + arg_types: List[DataType], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> Procedure: + raise NotImplementedError( + "_CatalogBackend.get_procedure must be implemented by a concrete subclass." + ) + + @abstractmethod + def get_user_defined_function( + self, + udf_name: str, + arg_types: List[DataType], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> UserDefinedFunction: + raise NotImplementedError( + "_CatalogBackend.get_user_defined_function must be implemented by a concrete subclass." + ) + + @abstractmethod + def database_exists(self, database: Union[str, Database]) -> bool: + raise NotImplementedError( + "_CatalogBackend.database_exists must be implemented by a concrete subclass." + ) + + @abstractmethod + def schema_exists( + self, + schema: Union[str, Schema], + *, + database: Optional[Union[str, Database]] = None, + ) -> bool: + raise NotImplementedError( + "_CatalogBackend.schema_exists must be implemented by a concrete subclass." + ) + + @abstractmethod + def table_exists( + self, + table: Union[str, Table], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> bool: + raise NotImplementedError( + "_CatalogBackend.table_exists must be implemented by a concrete subclass." + ) + + @abstractmethod + def view_exists( + self, + view: Union[str, View], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> bool: + raise NotImplementedError( + "_CatalogBackend.view_exists must be implemented by a concrete subclass." + ) + + @abstractmethod + def procedure_exists( + self, + procedure: Union[str, Procedure], + arg_types: Optional[List[DataType]] = None, + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> bool: + raise NotImplementedError( + "_CatalogBackend.procedure_exists must be implemented by a concrete subclass." + ) + + @abstractmethod + def user_defined_function_exists( + self, + udf: Union[str, UserDefinedFunction], + arg_types: Optional[List[DataType]] = None, + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> bool: + raise NotImplementedError( + "_CatalogBackend.user_defined_function_exists must be implemented by a " + "concrete subclass." + ) + + @abstractmethod + def drop_database(self, database: Union[str, Database]) -> None: + raise NotImplementedError( + "_CatalogBackend.drop_database must be implemented by a concrete subclass." + ) + + @abstractmethod + def drop_schema( + self, + schema: Union[str, Schema], + *, + database: Optional[Union[str, Database]] = None, + ) -> None: + raise NotImplementedError( + "_CatalogBackend.drop_schema must be implemented by a concrete subclass." + ) + + @abstractmethod + def drop_table( + self, + table: Union[str, Table], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> None: + raise NotImplementedError( + "_CatalogBackend.drop_table must be implemented by a concrete subclass." + ) + + @abstractmethod + def drop_view( + self, + view: Union[str, View], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> None: + raise NotImplementedError( + "_CatalogBackend.drop_view must be implemented by a concrete subclass." + ) + + +class _SqlCatalogBackend(_CatalogBackend): + def list_databases( + self, + *, + pattern: Optional[str] = None, + like: Optional[str] = None, + ) -> List[Database]: + c = self._catalog + like_str = f"LIKE '{like}'" if like else "" + df = c._session.sql( + f"SHOW AS RESOURCE DATABASES {like_str} LIMIT {_SHOW_AS_RESOURCE_LIMIT}" + ) + if pattern: + c._initialize_regex_udf() + assert c._python_regex_udf is not None # pyright + df = df.filter( + c._python_regex_udf(lit(pattern), parse_json('"As Resource"')["name"]) + ) + + return list( + map( + lambda row: Database._from_model(ModelDatabase.from_json(str(row[0]))), + df.collect(), + ) + ) + + def list_schemas( + self, + *, + database: Optional[Union[str, Database]] = None, + pattern: Optional[str] = None, + like: Optional[str] = None, + ) -> List[Schema]: + c = self._catalog + db_name = c._parse_database(database) + like_str = f"LIKE '{like}'" if like else "" + df = c._session.sql( + f"SHOW AS RESOURCE SCHEMAS {like_str} IN {db_name} LIMIT {_SHOW_AS_RESOURCE_LIMIT}" + ) + if pattern: + c._initialize_regex_udf() + assert c._python_regex_udf is not None # pyright + df = df.filter( + c._python_regex_udf(lit(pattern), parse_json('"As Resource"')["name"]) + ) + + return list( + map( + lambda row: Schema._from_model(ModelSchema.from_json(str(row[0]))), + df.collect(), + ) + ) + + def get_database(self, database: str) -> Database: + try: + return self.list_databases(like=unquote_if_quoted(database))[0] + except IndexError: + raise NotFoundError(f"Database with name {database} could not be found") + + def get_schema( + self, schema: str, *, database: Optional[Union[str, Database]] = None + ) -> Schema: + c = self._catalog + db_name = c._parse_database(database) + try: + return self.list_schemas(database=db_name, like=unquote_if_quoted(schema))[ + 0 + ] + except ( + IndexError, + SnowparkSQLException, + ): + raise NotFoundError( + f"Schema with name {schema} could not be found in database '{db_name}'" + ) + + def get_table( + self, + table_name: str, + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> Union[Table, View]: + c = self._catalog + db_name = c._parse_database(database) + schema_name = c._parse_schema(schema) + like_arg = unquote_if_quoted(table_name) + tables = c.list_tables(database=db_name, schema=schema_name, like=like_arg) + views: List[View] = [] + if not tables: + views = c.list_views(database=db_name, schema=schema_name, like=like_arg) + if tables: + return tables[0] + if views: + return views[0] + raise NotFoundError( + f"Table with name {table_name} could not be found in schema '{db_name}.{schema_name}'" + ) + + def get_view( + self, + view_name: str, + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> View: + c = self._catalog + db_name = c._parse_database(database) + schema_name = c._parse_schema(schema) + try: + return c.list_views( + database=db_name, + schema=schema_name, + like=unquote_if_quoted(view_name), + )[0] + except IndexError: + raise NotFoundError( + f"View with name {view_name} could not be found in schema '{db_name}.{schema_name}'" + ) + + def get_procedure( + self, + procedure_name: str, + arg_types: List[DataType], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> Procedure: + c = self._catalog + db_name = c._parse_database(database) + schema_name = c._parse_schema(schema) + procedure_id = c._parse_function_or_procedure(procedure_name, arg_types) + + try: + procedures = c._session.sql( + f"DESCRIBE AS RESOURCE PROCEDURE {db_name}.{schema_name}.{procedure_id}" + ).collect() + return Procedure.from_json(str(procedures[0][0])) + except ( + IndexError, + SnowparkSQLException, + ): + raise NotFoundError( + f"Procedure with name {procedure_name} and arguments {arg_types} could not be found in schema '{db_name}.{schema_name}'" + ) + + def get_user_defined_function( + self, + udf_name: str, + arg_types: List[DataType], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> UserDefinedFunction: + c = self._catalog + db_name = c._parse_database(database) + schema_name = c._parse_schema(schema) + function_id = c._parse_function_or_procedure(udf_name, arg_types) + + try: + rows = c._session.sql( + f"DESCRIBE AS RESOURCE FUNCTION {db_name}.{schema_name}.{function_id}" + ).collect() + return UserDefinedFunction.from_json(str(rows[0][0])) + except ( + IndexError, + SnowparkSQLException, + ): + raise NotFoundError( + f"Function with name {udf_name} and arguments {arg_types} could not be found in schema '{db_name}.{schema_name}'" + ) + + def database_exists(self, database: Union[str, Database]) -> bool: + c = self._catalog + db_name = c._parse_database(database) + try: + self.get_database(db_name) + return True + except NotFoundError: + return False + + def schema_exists( + self, + schema: Union[str, Schema], + *, + database: Optional[Union[str, Database]] = None, + ) -> bool: + c = self._catalog + db_name = c._parse_database(database, schema) + schema_name = c._parse_schema(schema) + try: + self.get_schema(schema=schema_name, database=db_name) + return True + except NotFoundError: + return False + + def table_exists( + self, + table: Union[str, Table], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> bool: + c = self._catalog + db_name = c._parse_database(database, table) + schema_name = c._parse_schema(schema, table) + table_name = table if isinstance(table, str) else table.name + try: + self.get_table(table_name=table_name, database=db_name, schema=schema_name) + return True + except NotFoundError: + return False + + def view_exists( + self, + view: Union[str, View], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> bool: + c = self._catalog + db_name = c._parse_database(database, view) + schema_name = c._parse_schema(schema, view) + view_name = view if isinstance(view, str) else view.name + try: + self.get_view(view_name=view_name, database=db_name, schema=schema_name) + return True + except NotFoundError: + return False + + def procedure_exists( + self, + procedure: Union[str, Procedure], + arg_types: Optional[List[DataType]] = None, + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> bool: + try: + if isinstance(procedure, Procedure): + if arg_types is not None or database is not None or schema is not None: + raise ArgumentError( + "When provided procedure is a Procedure class no other arguments can be provided" + ) + database = procedure.database_name + schema = procedure.schema_name + arg_types = [ + type_string_to_type_object(a.datatype) for a in procedure.arguments + ] + procedure = procedure.name + self.get_procedure( + procedure_name=procedure, + arg_types=arg_types, + database=database, + schema=schema, + ) + return True + except NotFoundError: + return False + + def user_defined_function_exists( + self, + udf: Union[str, UserDefinedFunction], + arg_types: Optional[List[DataType]] = None, + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> bool: + try: + if isinstance(udf, UserDefinedFunction): + if arg_types is not None or database is not None or schema is not None: + raise ArgumentError( + "When provided udf is a UserDefinedFunction class no other arguments can be provided" + ) + database = udf.database_name + schema = udf.schema_name + arg_types = [ + type_string_to_type_object(a.datatype) for a in udf.arguments + ] + udf = udf.name + self.get_user_defined_function( + udf_name=udf, + arg_types=arg_types, + database=database, + schema=schema, + ) + return True + except NotFoundError: + return False + + def drop_database(self, database: Union[str, Database]) -> None: + c = self._catalog + db_name = c._parse_database(database) + c._session.sql(f"DROP DATABASE {db_name}").collect() + + def drop_schema( + self, + schema: Union[str, Schema], + *, + database: Optional[Union[str, Database]] = None, + ) -> None: + c = self._catalog + db_name = c._parse_database(database, schema) + schema_name = c._parse_schema(schema) + c._session.sql(f"DROP SCHEMA {db_name}.{schema_name}").collect() + + def drop_table( + self, + table: Union[str, Table], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> None: + c = self._catalog + db_name = c._parse_database(database, table) + schema_name = c._parse_schema(schema, table) + table_name = table if isinstance(table, str) else table.name + c._session.sql(f"DROP TABLE {db_name}.{schema_name}.{table_name}").collect() + + def drop_view( + self, + view: Union[str, View], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> None: + c = self._catalog + db_name = c._parse_database(database, view) + schema_name = c._parse_schema(schema, view) + view_name = view if isinstance(view, str) else view.name + c._session.sql(f"DROP VIEW {db_name}.{schema_name}.{view_name}").collect() + + +class _RestCatalogBackend(_CatalogBackend): + def __init__(self, catalog: "Catalog") -> None: + super().__init__(catalog) + self._root = Root(catalog._session) + + def list_databases( + self, + *, + pattern: Optional[str] = None, + like: Optional[str] = None, + ) -> List[Database]: + it = self._root.databases.iter(like=like) + if pattern: + it = filter(lambda x: re.match(pattern, x.name), it) + return list(it) + + def list_schemas( + self, + *, + database: Optional[Union[str, Database]] = None, + pattern: Optional[str] = None, + like: Optional[str] = None, + ) -> List[Schema]: + db_name = self._catalog._parse_database(database) + it = self._root.databases[db_name].schemas.iter(like=like) + if pattern: + it = filter(lambda x: re.match(pattern, x.name), it) + return list(it) + + def get_database(self, database: str) -> Database: + return self._root.databases[database].fetch() + + def get_schema( + self, schema: str, *, database: Optional[Union[str, Database]] = None + ) -> Schema: + db_name = self._catalog._parse_database(database) + return self._root.databases[db_name].schemas[schema].fetch() -import snowflake.snowpark -from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type -from snowflake.snowpark.functions import lit, parse_json -from snowflake.snowpark.types import DataType + def get_table( + self, + table_name: str, + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> Union[Table, View]: + c = self._catalog + db_name = c._parse_database(database) + schema_name = c._parse_schema(schema) + return ( + self._root.databases[db_name] + .schemas[schema_name] + .tables[table_name] + .fetch() + ) + + def get_view( + self, + view_name: str, + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> View: + c = self._catalog + db_name = c._parse_database(database) + schema_name = c._parse_schema(schema) + return ( + self._root.databases[db_name].schemas[schema_name].views[view_name].fetch() + ) + + def get_procedure( + self, + procedure_name: str, + arg_types: List[DataType], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> Procedure: + c = self._catalog + db_name = c._parse_database(database) + schema_name = c._parse_schema(schema) + procedure_id = c._parse_function_or_procedure(procedure_name, arg_types) + return ( + self._root.databases[db_name] + .schemas[schema_name] + .procedures[procedure_id] + .fetch() + ) + + def get_user_defined_function( + self, + udf_name: str, + arg_types: List[DataType], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> UserDefinedFunction: + c = self._catalog + db_name = c._parse_database(database) + schema_name = c._parse_schema(schema) + function_id = c._parse_function_or_procedure(udf_name, arg_types) + return ( + self._root.databases[db_name] + .schemas[schema_name] + .user_defined_functions[function_id] + .fetch() + ) + + def database_exists(self, database: Union[str, Database]) -> bool: + c = self._catalog + db_name = c._parse_database(database) + try: + self._root.databases[db_name].fetch() + return True + except CoreNotFoundError: + return False + + def schema_exists( + self, + schema: Union[str, Schema], + *, + database: Optional[Union[str, Database]] = None, + ) -> bool: + c = self._catalog + db_name = c._parse_database(database, schema) + schema_name = c._parse_schema(schema) + try: + self._root.databases[db_name].schemas[schema_name].fetch() + return True + except CoreNotFoundError: + return False + + def table_exists( + self, + table: Union[str, Table], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> bool: + c = self._catalog + db_name = c._parse_database(database, table) + schema_name = c._parse_schema(schema, table) + table_name = table if isinstance(table, str) else table.name + try: + self._root.databases[db_name].schemas[schema_name].tables[ + table_name + ].fetch() + return True + except CoreNotFoundError: + return False + + def view_exists( + self, + view: Union[str, View], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> bool: + c = self._catalog + db_name = c._parse_database(database, view) + schema_name = c._parse_schema(schema, view) + view_name = view if isinstance(view, str) else view.name + try: + self._root.databases[db_name].schemas[schema_name].views[view_name].fetch() + return True + except CoreNotFoundError: + return False + + def procedure_exists( + self, + procedure: Union[str, Procedure], + arg_types: Optional[List[DataType]] = None, + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> bool: + c = self._catalog + try: + if isinstance(procedure, Procedure): + if arg_types is not None or database is not None or schema is not None: + raise ArgumentError( + "When provided procedure is a Procedure class no other arguments can be provided" + ) + database = procedure.database_name + schema = procedure.schema_name + arg_types = [ + type_string_to_type_object(a.datatype) for a in procedure.arguments + ] + procedure = procedure.name + db_name = c._parse_database(database, procedure) + schema_name = c._parse_schema(schema, procedure) + procedure_id = c._parse_function_or_procedure(procedure, arg_types) + self._root.databases[db_name].schemas[schema_name].procedures[ + procedure_id + ].fetch() + return True + except CoreNotFoundError: + return False + + def user_defined_function_exists( + self, + udf: Union[str, UserDefinedFunction], + arg_types: Optional[List[DataType]] = None, + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> bool: + c = self._catalog + try: + if isinstance(udf, UserDefinedFunction): + if arg_types is not None or database is not None or schema is not None: + raise ArgumentError( + "When provided udf is a UserDefinedFunction class no other arguments can be provided" + ) + database = udf.database_name + schema = udf.schema_name + arg_types = [ + type_string_to_type_object(a.datatype) for a in udf.arguments + ] + udf = udf.name + db_name = c._parse_database(database, udf) + schema_name = c._parse_schema(schema, udf) + function_id = c._parse_function_or_procedure(udf, arg_types) + self._root.databases[db_name].schemas[schema_name].user_defined_functions[ + function_id + ].fetch() + return True + except CoreNotFoundError: + return False + + def drop_database(self, database: Union[str, Database]) -> None: + c = self._catalog + db_name = c._parse_database(database) + self._root.databases[db_name].drop() + + def drop_schema( + self, + schema: Union[str, Schema], + *, + database: Optional[Union[str, Database]] = None, + ) -> None: + c = self._catalog + db_name = c._parse_database(database, schema) + schema_name = c._parse_schema(schema) + self._root.databases[db_name].schemas[schema_name].drop() + + def drop_table( + self, + table: Union[str, Table], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> None: + c = self._catalog + db_name = c._parse_database(database, table) + schema_name = c._parse_schema(schema, table) + table_name = table if isinstance(table, str) else table.name + self._root.databases[db_name].schemas[schema_name].tables[table_name].drop() + + def drop_view( + self, + view: Union[str, View], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> None: + c = self._catalog + db_name = c._parse_database(database, view) + schema_name = c._parse_schema(schema, view) + view_name = view if isinstance(view, str) else view.name + self._root.databases[db_name].schemas[schema_name].views[view_name].drop() class Catalog: @@ -32,14 +861,17 @@ class Catalog: views, functions, etc. """ - def __init__(self, session: "snowflake.snowpark.session.Session") -> None: # type: ignore + def __init__(self, session: "Session", *, _use_sql_base: bool = True) -> None: self._session = session - self._root = Root(session) self._python_regex_udf = None + if context._is_snowpark_connect_compatible_mode and _use_sql_base: + self._backend: _CatalogBackend = _SqlCatalogBackend(self) + else: + self._backend = _RestCatalogBackend(self) def _parse_database( self, - database: Optional[Union[str, Database]], + database: object, model_obj: Optional[ Union[str, Schema, Table, View, Procedure, UserDefinedFunction] ] = None, @@ -66,7 +898,7 @@ def _parse_database( def _parse_schema( self, - schema: Optional[Union[str, Schema]], + schema: object, model_obj: Optional[ Union[str, Table, View, Procedure, UserDefinedFunction] ] = None, @@ -138,13 +970,9 @@ def _list_objects( f"SHOW AS RESOURCE {object_name} {like_str} IN {db_name}.{schema_name} -- catalog api" ) if pattern: - # initialize udf self._initialize_regex_udf() assert self._python_regex_udf is not None # pyright - # The result of SHOW AS RESOURCE query is a json string which contains - # key 'name' to store the name of the object. We parse json for the returned - # result and apply the filter on name. df = df.filter( self._python_regex_udf( lit(pattern), parse_json('"As Resource"')["name"] @@ -153,7 +981,6 @@ def _list_objects( return list(map(lambda row: object_class.from_json(row[0]), df.collect())) - # List methods def list_databases( self, *, @@ -166,11 +993,7 @@ def list_databases( pattern: the python regex pattern of name to match. Defaults to None. like: the sql style pattern for name to match. Default to None. """ - iter = self._root.databases.iter(like=like) - if pattern: - iter = filter(lambda x: re.match(pattern, x.name), iter) - - return list(iter) + return self._backend.list_databases(pattern=pattern, like=like) def list_schemas( self, @@ -187,11 +1010,7 @@ def list_schemas( pattern: the python regex pattern of name to match. Defaults to None. like: the sql style pattern for name to match. Default to None. """ - db_name = self._parse_database(database) - iter = self._root.databases[db_name].schemas.iter(like=like) - if pattern: - iter = filter(lambda x: re.match(pattern, x.name), iter) - return list(iter) + return self._backend.list_schemas(database=database, pattern=pattern, like=like) def list_tables( self, @@ -318,7 +1137,6 @@ def list_user_defined_functions( like=like, ) - # get methods def get_current_database(self) -> Optional[str]: """Get the current database.""" return self._session.get_current_database() @@ -329,14 +1147,13 @@ def get_current_schema(self) -> Optional[str]: def get_database(self, database: str) -> Database: """Name of the database to get""" - return self._root.databases[database].fetch() + return self._backend.get_database(database) def get_schema( self, schema: str, *, database: Optional[Union[str, Database]] = None ) -> Schema: """Name of the schema to get.""" - db_name = self._parse_database(database) - return self._root.databases[db_name].schemas[schema].fetch() + return self._backend.get_schema(schema, database=database) def get_table( self, @@ -344,23 +1161,22 @@ def get_table( *, database: Optional[Union[str, Database]] = None, schema: Optional[Union[str, Schema]] = None, - ) -> Table: - """Get the table by name in given database and schema. If database or schema are not - provided, get the table in the current database and schema. + ) -> Union[Table, View]: + """Get the table or permanent view by name in the given database and schema. + + If database or schema are not provided, resolve the name in the current database + and schema. Matches :meth:`pyspark.sql.Catalog.getTable`, which returns metadata + for base tables and for views. + + When ``context._is_snowpark_connect_compatible_mode`` is False (legacy REST path), + only base tables are returned; use :meth:`get_view` for views. Args: - table_name: name of the table. + table_name: name of the table or view. database: database name or ``Database`` object. Defaults to None. schema: schema name or ``Schema`` object. Defaults to None. """ - db_name = self._parse_database(database) - schema_name = self._parse_schema(schema) - return ( - self._root.databases[db_name] - .schemas[schema_name] - .tables[table_name] - .fetch() - ) + return self._backend.get_table(table_name, database=database, schema=schema) def get_view( self, @@ -377,11 +1193,7 @@ def get_view( database: database name or ``Database`` object. Defaults to None. schema: schema name or ``Schema`` object. Defaults to None. """ - db_name = self._parse_database(database) - schema_name = self._parse_schema(schema) - return ( - self._root.databases[db_name].schemas[schema_name].views[view_name].fetch() - ) + return self._backend.get_view(view_name, database=database, schema=schema) def get_procedure( self, @@ -400,14 +1212,8 @@ def get_procedure( database: database name or ``Database`` object. Defaults to None. schema: schema name or ``Schema`` object. Defaults to None. """ - db_name = self._parse_database(database) - schema_name = self._parse_schema(schema) - procedure_id = self._parse_function_or_procedure(procedure_name, arg_types) - return ( - self._root.databases[db_name] - .schemas[schema_name] - .procedures[procedure_id] - .fetch() + return self._backend.get_procedure( + procedure_name, arg_types, database=database, schema=schema ) def get_user_defined_function( @@ -428,17 +1234,10 @@ def get_user_defined_function( database: database name or ``Database`` object. Defaults to None. schema: schema name or ``Schema`` object. Defaults to None. """ - db_name = self._parse_database(database) - schema_name = self._parse_schema(schema) - function_id = self._parse_function_or_procedure(udf_name, arg_types) - return ( - self._root.databases[db_name] - .schemas[schema_name] - .user_defined_functions[function_id] - .fetch() + return self._backend.get_user_defined_function( + udf_name, arg_types, database=database, schema=schema ) - # set methods def set_current_database(self, database: Union[str, Database]) -> None: """Set the current default database for the session. @@ -457,19 +1256,13 @@ def set_current_schema(self, schema: Union[str, Schema]) -> None: schema_name = self._parse_schema(schema) self._session.use_schema(schema_name) - # exists methods def database_exists(self, database: Union[str, Database]) -> bool: """Check if the given database exists. Args: database: database name or ``Database`` object. """ - db_name = self._parse_database(database) - try: - self._root.databases[db_name].fetch() - return True - except NotFoundError: - return False + return self._backend.database_exists(database) def schema_exists( self, @@ -484,13 +1277,7 @@ def schema_exists( schema: schema name or ``Schema`` object. database: database name or ``Database`` object. Defaults to None. """ - db_name = self._parse_database(database, schema) - schema_name = self._parse_schema(schema) - try: - self._root.databases[db_name].schemas[schema_name].fetch() - return True - except NotFoundError: - return False + return self._backend.schema_exists(schema, database=database) def table_exists( self, @@ -507,16 +1294,7 @@ def table_exists( database: database name or ``Database`` object. Defaults to None. schema: schema name or ``Schema`` object. Defaults to None. """ - db_name = self._parse_database(database, table) - schema_name = self._parse_schema(schema, table) - table_name = table if isinstance(table, str) else table.name - try: - self._root.databases[db_name].schemas[schema_name].tables[ - table_name - ].fetch() - return True - except NotFoundError: - return False + return self._backend.table_exists(table, database=database, schema=schema) def view_exists( self, @@ -533,14 +1311,7 @@ def view_exists( database: database name or ``Database`` object. Defaults to None. schema: schema name or ``Schema`` object. Defaults to None. """ - db_name = self._parse_database(database, view) - schema_name = self._parse_schema(schema, view) - view_name = view if isinstance(view, str) else view.name - try: - self._root.databases[db_name].schemas[schema_name].views[view_name].fetch() - return True - except NotFoundError: - return False + return self._backend.view_exists(view, database=database, schema=schema) def procedure_exists( self, @@ -559,17 +1330,9 @@ def procedure_exists( database: database name or ``Database`` object. Defaults to None. schema: schema name or ``Schema`` object. Defaults to None. """ - db_name = self._parse_database(database, procedure) - schema_name = self._parse_schema(schema, procedure) - procedure_id = self._parse_function_or_procedure(procedure, arg_types) - - try: - self._root.databases[db_name].schemas[schema_name].procedures[ - procedure_id - ].fetch() - return True - except NotFoundError: - return False + return self._backend.procedure_exists( + procedure, arg_types, database=database, schema=schema + ) def user_defined_function_exists( self, @@ -590,27 +1353,17 @@ def user_defined_function_exists( database: database name or ``Database`` object. Defaults to None. schema: schema name or ``Schema`` object. Defaults to None. """ - db_name = self._parse_database(database, udf) - schema_name = self._parse_schema(schema, udf) - function_id = self._parse_function_or_procedure(udf, arg_types) - - try: - self._root.databases[db_name].schemas[schema_name].user_defined_functions[ - function_id - ].fetch() - return True - except NotFoundError: - return False + return self._backend.user_defined_function_exists( + udf, arg_types, database=database, schema=schema + ) - # drop methods def drop_database(self, database: Union[str, Database]) -> None: """Drop the given database. Args: database: database name or ``Database`` object. """ - db_name = self._parse_database(database) - self._root.databases[db_name].drop() + return self._backend.drop_database(database) def drop_schema( self, @@ -625,9 +1378,7 @@ def drop_schema( schema: schema name or ``Schema`` object. database: database name or ``Database`` object. Defaults to None. """ - db_name = self._parse_database(database, schema) - schema_name = self._parse_schema(schema) - self._root.databases[db_name].schemas[schema_name].drop() + return self._backend.drop_schema(schema, database=database) def drop_table( self, @@ -644,11 +1395,7 @@ def drop_table( database: database name or ``Database`` object. Defaults to None. schema: schema name or ``Schema`` object. Defaults to None. """ - db_name = self._parse_database(database, table) - schema_name = self._parse_schema(schema, table) - table_name = table if isinstance(table, str) else table.name - - self._root.databases[db_name].schemas[schema_name].tables[table_name].drop() + return self._backend.drop_table(table, database=database, schema=schema) def drop_view( self, @@ -665,11 +1412,7 @@ def drop_view( database: database name or ``Database`` object. Defaults to None. schema: schema name or ``Schema`` object. Defaults to None. """ - db_name = self._parse_database(database, view) - schema_name = self._parse_schema(schema, view) - view_name = view if isinstance(view, str) else view.name - - self._root.databases[db_name].schemas[schema_name].views[view_name].drop() + return self._backend.drop_view(view, database=database, schema=schema) # aliases listDatabases = list_databases diff --git a/src/snowflake/snowpark/exceptions.py b/src/snowflake/snowpark/exceptions.py index 1142e9545e..d31fe178a6 100644 --- a/src/snowflake/snowpark/exceptions.py +++ b/src/snowflake/snowpark/exceptions.py @@ -283,3 +283,9 @@ class SnowparkInvalidObjectNameException(SnowparkGeneralException): """ pass + + +class NotFoundError(SnowparkClientException): + """Raised when we encounter an object is not found.""" + + pass diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 5803cc8329..ca2b15ee9f 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -254,6 +254,7 @@ _session_management_lock = RLock() _active_sessions: Set["Session"] = set() +_USE_SQL_BASE_OPTION_KEY = "_use_sql_base" _PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS_STRING = ( "PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS" ) @@ -463,6 +464,13 @@ def __init__(self) -> None: self._app_name = None self._format_json = None + @staticmethod + def _connection_options( + options: Dict[str, Union[int, str]] + ) -> Dict[str, Union[int, str]]: + # Internal Session-only options must not be forwarded to connector connect(**kwargs). + return {k: v for k, v in options.items() if k != _USE_SQL_BASE_OPTION_KEY} + def _remove_config(self, key: str) -> "Session.SessionBuilder": """Only used in test.""" self._options.pop(key, None) @@ -569,8 +577,11 @@ def _create_internal( # Set paramstyle to qmark by default to be consistent with previous behavior if "paramstyle" not in self._options: self._options["paramstyle"] = "qmark" + connection_options = self._connection_options(self._options) new_session = Session( - ServerConnection({}, conn) if conn else ServerConnection(self._options), + ServerConnection({}, conn) + if conn + else ServerConnection(connection_options), self._options, ) @@ -628,6 +639,8 @@ def __init__( """ self.version = get_version() self._session_stage = None + options = options or {} + self._use_sql_base = options.pop(_USE_SQL_BASE_OPTION_KEY, True) if isinstance(conn, MockServerConnection): self._udf_registration = MockUDFRegistration(self) @@ -848,7 +861,7 @@ def __init__( _PYTHON_SNOWPARK_COLLECT_TELEMETRY_AT_CRITICAL_PATH_VERSION ) ) - self._conf = self.RuntimeConfig(self, options or {}) + self._conf = self.RuntimeConfig(self, options) self._runtime_version_from_requirement: str = None self._temp_table_auto_cleaner: TempTableAutoCleaner = TempTableAutoCleaner(self) self._sp_profiler = StoredProcedureProfiler(session=self) @@ -961,7 +974,7 @@ def catalog(self): external_feature_name="Session.catalog", raise_error=NotImplementedError, ) - self._catalog = Catalog(self) + self._catalog = Catalog(self, _use_sql_base=self._use_sql_base) return self._catalog def close(self) -> None: diff --git a/tests/conftest.py b/tests/conftest.py index bfd9ab8f78..bca98bd3b3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,6 +14,8 @@ from snowflake.snowpark._internal.utils import warning_dict from .ast.conftest import default_unparser_path +pytest_plugins = ("tests.integ.test_catalog",) + logging.getLogger("snowflake.connector").setLevel(logging.ERROR) excluded_frontend_files = [ diff --git a/tests/integ/test_catalog.py b/tests/integ/test_catalog.py index 2628b0f8f9..0827e132dd 100644 --- a/tests/integ/test_catalog.py +++ b/tests/integ/test_catalog.py @@ -1,32 +1,26 @@ # # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. # +"""Catalog integration tests and shared fixtures. + +Mode-agnostic tests (same behavior for SQL and REST catalog backends) live in +this module. Backend-specific tests are in ``test_catalog_sql_mode.py`` and +``test_catalog_rest_mode.py``, which reuse the fixtures defined here via +``pytest_plugins`` in ``conftest.py``. +""" -from unittest.mock import patch import uuid +from unittest.mock import patch + import pytest -from snowflake.snowpark._internal.analyzer.analyzer_utils import unquote_if_quoted from snowflake.snowpark.catalog import Catalog +from snowflake.snowpark.context import _DEFAULT_ARTIFACT_REPOSITORY from snowflake.snowpark.session import Session from snowflake.snowpark.types import IntegerType -from snowflake.core.exceptions import APIError -from snowflake.snowpark.context import _DEFAULT_ARTIFACT_REPOSITORY - - -pytestmark = [ - pytest.mark.xfail( - "config.getoption('local_testing_mode', default=False)", - reason="deepcopy is not supported and required by local testing", - run=False, - ), - pytest.mark.xfail( - raises=APIError, - reason="Failure due to warehouse overload", - ), -] CATALOG_TEMP_OBJECT_PREFIX = "SP_CATALOG_TEMP" +DOES_NOT_EXIST_PATTERN = "does_not_exist_.*" def get_temp_name(type: str) -> str: @@ -186,34 +180,13 @@ def temp_udf2(session, temp_db1, temp_schema1): ) -DOES_NOT_EXIST_PATTERN = "does_not_exist_.*" - - -def test_list_db(session, temp_db1, temp_db2): - catalog: Catalog = session.catalog - db_list = catalog.list_databases(pattern=f"{CATALOG_TEMP_OBJECT_PREFIX}_DB_*") - assert {db.name for db in db_list} >= {temp_db1, temp_db2} - - db_list = catalog.list_databases(like=f"{CATALOG_TEMP_OBJECT_PREFIX}_DB_%") - assert {db.name for db in db_list} >= {temp_db1, temp_db2} - - -def test_list_schema(session, temp_db1, temp_schema1, temp_schema2): - catalog: Catalog = session.catalog - assert ( - len(catalog.list_databases(pattern=f"{CATALOG_TEMP_OBJECT_PREFIX}_SCHEMA_.*")) - == 0 - ) - - schema_list = catalog.list_schemas( - pattern=f"{CATALOG_TEMP_OBJECT_PREFIX}_SCHEMA_.*", database=temp_db1 - ) - assert {schema.name for schema in schema_list} >= {temp_schema1, temp_schema2} - - schema_list = catalog.list_schemas( - like=f"{CATALOG_TEMP_OBJECT_PREFIX}_SCHEMA_%", database=temp_db1 - ) - assert {schema.name for schema in schema_list} >= {temp_schema1, temp_schema2} +pytestmark = [ + pytest.mark.xfail( + "config.getoption('local_testing_mode', default=False)", + reason="deepcopy is not supported and required by local testing", + run=False, + ), +] def test_list_tables(session, temp_db1, temp_schema1, temp_table1, temp_table2): @@ -344,48 +317,6 @@ def test_list_udfs(session, temp_db1, temp_schema1, temp_udf1, temp_udf2): assert {udf.name for udf in udf_list} >= {temp_udf1, temp_udf2} -def test_get_db_schema(session): - catalog: Catalog = session.catalog - current_db = session.get_current_database() - current_schema = session.get_current_schema() - assert catalog.get_database(current_db).name == unquote_if_quoted(current_db) - assert catalog.get_schema(current_schema).name == unquote_if_quoted(current_schema) - - -def test_get_table_view(session, temp_db1, temp_schema1, temp_table1, temp_view1): - catalog: Catalog = session.catalog - table = catalog.get_table(temp_table1, database=temp_db1, schema=temp_schema1) - assert table.name == temp_table1 - assert table.database_name == temp_db1 - assert table.schema_name == temp_schema1 - - view = catalog.get_view(temp_view1, database=temp_db1, schema=temp_schema1) - assert view.name == temp_view1 - assert view.database_name == temp_db1 - assert view.schema_name == temp_schema1 - - -@pytest.mark.udf -def test_get_function_procedure_udf( - session, temp_db1, temp_schema1, temp_procedure1, temp_udf1 -): - catalog: Catalog = session.catalog - - procedure = catalog.get_procedure( - temp_procedure1, [IntegerType()], database=temp_db1, schema=temp_schema1 - ) - assert procedure.name == temp_procedure1 - assert procedure.database_name == temp_db1 - assert procedure.schema_name == temp_schema1 - - udf = catalog.get_user_defined_function( - temp_udf1, [IntegerType()], database=temp_db1, schema=temp_schema1 - ) - assert udf.name == temp_udf1 - assert udf.database_name == temp_db1 - assert udf.schema_name == temp_schema1 - - def test_set_db_schema(session, temp_db1, temp_db2, temp_schema1, temp_schema2): catalog = session.catalog @@ -407,112 +338,6 @@ def test_set_db_schema(session, temp_db1, temp_db2, temp_schema1, temp_schema2): session.use_schema(original_schema) -def test_exists_db_schema(session, temp_db1, temp_schema1): - catalog = session.catalog - assert catalog.database_exists(temp_db1) - assert not catalog.database_exists("does_not_exist") - - assert catalog.schema_exists(temp_schema1, database=temp_db1) - assert not catalog.schema_exists(temp_schema1, database="does_not_exist") - - -def test_exists_table_view(session, temp_db1, temp_schema1, temp_table1, temp_view1): - catalog = session.catalog - db1_obj = catalog._root.databases[temp_db1].fetch() - schema1_obj = catalog._root.databases[temp_db1].schemas[temp_schema1].fetch() - - assert catalog.table_exists(temp_table1, database=temp_db1, schema=temp_schema1) - assert catalog.table_exists(temp_table1, database=db1_obj, schema=schema1_obj) - table = catalog.get_table(temp_table1, database=temp_db1, schema=temp_schema1) - assert catalog.table_exists(table) - assert not catalog.table_exists( - "does_not_exist", database=temp_db1, schema=temp_schema1 - ) - - assert catalog.view_exists(temp_view1, database=temp_db1, schema=temp_schema1) - assert catalog.view_exists(temp_view1, database=db1_obj, schema=schema1_obj) - view = catalog.get_view(temp_view1, database=temp_db1, schema=temp_schema1) - assert catalog.view_exists(view) - assert not catalog.view_exists( - "does_not_exist", database=temp_db1, schema=temp_schema1 - ) - - -@pytest.mark.udf -def test_exists_function_procedure_udf( - session, temp_db1, temp_schema1, temp_procedure1, temp_udf1 -): - catalog = session.catalog - db1_obj = catalog._root.databases[temp_db1].fetch() - schema1_obj = catalog._root.databases[temp_db1].schemas[temp_schema1].fetch() - - assert catalog.procedure_exists( - temp_procedure1, [IntegerType()], database=temp_db1, schema=temp_schema1 - ) - assert catalog.procedure_exists( - temp_procedure1, [IntegerType()], database=db1_obj, schema=schema1_obj - ) - proc = catalog.get_procedure( - temp_procedure1, [IntegerType()], database=temp_db1, schema=temp_schema1 - ) - assert catalog.procedure_exists(proc) - assert not catalog.procedure_exists( - "does_not_exist", [], database=temp_db1, schema=temp_schema1 - ) - - assert catalog.user_defined_function_exists( - temp_udf1, [IntegerType()], database=temp_db1, schema=temp_schema1 - ) - assert catalog.user_defined_function_exists( - temp_udf1, [IntegerType()], database=db1_obj, schema=schema1_obj - ) - udf = catalog.get_user_defined_function( - temp_udf1, [IntegerType()], database=temp_db1, schema=temp_schema1 - ) - assert catalog.user_defined_function_exists(udf) - assert not catalog.user_defined_function_exists( - "does_not_exist", [], database=temp_db1, schema=temp_schema1 - ) - - -@pytest.mark.parametrize("use_object", [True, False]) -def test_drop(session, use_object): - catalog = session.catalog - - original_db = session.get_current_database() - original_schema = session.get_current_schema() - try: - temp_db = create_temp_db(session) - temp_schema = create_temp_schema(session, temp_db) - temp_table = create_temp_table(session, temp_db, temp_schema) - temp_view = create_temp_view(session, temp_db, temp_schema) - if use_object: - temp_schema = catalog._root.databases[temp_db].schemas[temp_schema].fetch() - temp_db = catalog._root.databases[temp_db].fetch() - - assert catalog.database_exists(temp_db) - assert catalog.schema_exists(temp_schema, database=temp_db) - assert catalog.table_exists(temp_table, database=temp_db, schema=temp_schema) - assert catalog.view_exists(temp_view, database=temp_db, schema=temp_schema) - - catalog.drop_table(temp_table, database=temp_db, schema=temp_schema) - catalog.drop_view(temp_view, database=temp_db, schema=temp_schema) - - assert not catalog.table_exists( - temp_table, database=temp_db, schema=temp_schema - ) - assert not catalog.view_exists(temp_view, database=temp_db, schema=temp_schema) - - catalog.drop_schema(temp_schema, database=temp_db) - assert not catalog.schema_exists(temp_schema, database=temp_db) - - catalog.drop_database(temp_db) - assert not catalog.database_exists(temp_db) - finally: - session.use_database(original_db) - session.use_schema(original_schema) - - def test_parse_names_negative(session): catalog = session.catalog with pytest.raises( diff --git a/tests/integ/test_catalog_rest_mode.py b/tests/integ/test_catalog_rest_mode.py new file mode 100644 index 0000000000..11b0cc39ed --- /dev/null +++ b/tests/integ/test_catalog_rest_mode.py @@ -0,0 +1,250 @@ +# +# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. +# +"""Catalog integration tests with ``context._is_snowpark_connect_compatible_mode`` False (REST / Root backend). + +Keep this file separate from ``test_catalog_sql_mode.py`` so removing one backend path +deletes only the matching test module. +""" + +import pytest + +from snowflake.core.exceptions import APIError +from snowflake.core.exceptions import NotFoundError as CoreNotFoundError +from snowflake.snowpark import context +from snowflake.snowpark._internal.analyzer.analyzer_utils import unquote_if_quoted +from snowflake.snowpark.catalog import Catalog +from snowflake.snowpark.types import IntegerType +from tests.integ.test_catalog import ( + CATALOG_TEMP_OBJECT_PREFIX, + create_temp_db, + create_temp_schema, + create_temp_table, + create_temp_view, +) + +pytestmark = [ + pytest.mark.xfail( + "config.getoption('local_testing_mode', default=False)", + reason="deepcopy is not supported and required by local testing", + run=False, + ), + pytest.mark.xfail( + raises=APIError, + reason="Failure due to warehouse overload", + ), +] + + +@pytest.fixture(autouse=True, scope="module") +def _catalog_rest_backend_mode(session): + mp = pytest.MonkeyPatch() + mp.setattr(context, "_is_snowpark_connect_compatible_mode", False) + mp.setattr(session, "_catalog", None) + try: + yield + finally: + mp.undo() + session._catalog = None + + +def test_list_db_rest_mode(session, temp_db1, temp_db2): + catalog: Catalog = session.catalog + db_list = catalog.list_databases(pattern=f"{CATALOG_TEMP_OBJECT_PREFIX}_DB_*") + assert {db.name for db in db_list} >= {temp_db1, temp_db2} + + db_list = catalog.list_databases(like=f"{CATALOG_TEMP_OBJECT_PREFIX}_DB_%") + assert {db.name for db in db_list} >= {temp_db1, temp_db2} + + +def test_list_schema_rest_mode(session, temp_db1, temp_schema1, temp_schema2): + catalog: Catalog = session.catalog + assert ( + len(catalog.list_databases(pattern=f"{CATALOG_TEMP_OBJECT_PREFIX}_SCHEMA_.*")) + == 0 + ) + + schema_list = catalog.list_schemas( + pattern=f"{CATALOG_TEMP_OBJECT_PREFIX}_SCHEMA_.*", database=temp_db1 + ) + assert {schema.name for schema in schema_list} >= {temp_schema1, temp_schema2} + + schema_list = catalog.list_schemas( + like=f"{CATALOG_TEMP_OBJECT_PREFIX}_SCHEMA_%", database=temp_db1 + ) + assert {schema.name for schema in schema_list} >= {temp_schema1, temp_schema2} + + +def test_get_db_schema_rest_mode(session): + catalog: Catalog = session.catalog + current_db = session.get_current_database() + current_schema = session.get_current_schema() + assert catalog.get_database(current_db).name == unquote_if_quoted(current_db) + assert catalog.get_schema(current_schema).name == unquote_if_quoted(current_schema) + + +def test_get_database_missing_raises_core_not_found_rest_mode(session): + catalog: Catalog = session.catalog + with pytest.raises(CoreNotFoundError): + catalog.get_database("NONEXISTENT_DB_XYZ_12345") + + +def test_get_view_rest_mode(session, temp_db1, temp_schema1, temp_view1): + catalog: Catalog = session.catalog + view = catalog.get_view(temp_view1, database=temp_db1, schema=temp_schema1) + assert view.name == temp_view1 + assert view.database_name == temp_db1 + assert view.schema_name == temp_schema1 + + +@pytest.mark.udf +def test_get_procedure_rest_mode(session, temp_db1, temp_schema1, temp_procedure1): + catalog: Catalog = session.catalog + procedure = catalog.get_procedure( + temp_procedure1, [IntegerType()], database=temp_db1, schema=temp_schema1 + ) + assert procedure.name == temp_procedure1 + assert procedure.database_name == temp_db1 + assert procedure.schema_name == temp_schema1 + + +@pytest.mark.udf +def test_get_user_defined_function_rest_mode( + session, temp_db1, temp_schema1, temp_udf1 +): + catalog: Catalog = session.catalog + udf = catalog.get_user_defined_function( + temp_udf1, [IntegerType()], database=temp_db1, schema=temp_schema1 + ) + assert udf.name == temp_udf1 + assert udf.database_name == temp_db1 + assert udf.schema_name == temp_schema1 + + +def test_database_exists_rest_mode(session, temp_db1): + catalog: Catalog = session.catalog + assert catalog.database_exists(temp_db1) + assert not catalog.database_exists("does_not_exist") + + +def test_get_table_view_rest_mode( + session, temp_db1, temp_schema1, temp_table1, temp_view1 +): + catalog: Catalog = session.catalog + table = catalog.get_table(temp_table1, database=temp_db1, schema=temp_schema1) + assert table.name == temp_table1 + assert table.database_name == temp_db1 + assert table.schema_name == temp_schema1 + + view = catalog.get_view(temp_view1, database=temp_db1, schema=temp_schema1) + assert view.name == temp_view1 + assert view.database_name == temp_db1 + assert view.schema_name == temp_schema1 + + +def test_exists_db_schema_rest_mode(session, temp_db1, temp_schema1): + catalog = session.catalog + assert catalog.database_exists(temp_db1) + assert not catalog.database_exists("does_not_exist") + + assert catalog.schema_exists(temp_schema1, database=temp_db1) + assert not catalog.schema_exists(temp_schema1, database="does_not_exist") + + +def test_exists_table_view_rest_mode( + session, temp_db1, temp_schema1, temp_table1, temp_view1 +): + catalog = session.catalog + db1_obj = catalog.get_database(temp_db1) + schema1_obj = catalog.get_schema(database=temp_db1, schema=temp_schema1) + + assert catalog.table_exists(temp_table1, database=temp_db1, schema=temp_schema1) + assert catalog.table_exists(temp_table1, database=db1_obj, schema=schema1_obj) + table = catalog.get_table(temp_table1, database=temp_db1, schema=temp_schema1) + assert catalog.table_exists(table) + assert not catalog.table_exists( + "does_not_exist", database=temp_db1, schema=temp_schema1 + ) + + assert catalog.view_exists(temp_view1, database=temp_db1, schema=temp_schema1) + assert catalog.view_exists(temp_view1, database=db1_obj, schema=schema1_obj) + view = catalog.get_view(temp_view1, database=temp_db1, schema=temp_schema1) + assert catalog.view_exists(view) + assert not catalog.view_exists( + "does_not_exist", database=temp_db1, schema=temp_schema1 + ) + + +@pytest.mark.udf +def test_exists_function_procedure_udf_rest_mode( + session, temp_db1, temp_schema1, temp_procedure1, temp_udf1 +): + catalog = session.catalog + db1_obj = catalog.get_database(temp_db1) + schema1_obj = catalog.get_schema(temp_schema1, database=temp_db1) + + assert catalog.procedure_exists( + temp_procedure1, [IntegerType()], database=temp_db1, schema=temp_schema1 + ) + assert catalog.procedure_exists( + temp_procedure1, [IntegerType()], database=db1_obj, schema=schema1_obj + ) + proc = catalog.get_procedure( + temp_procedure1, [IntegerType()], database=temp_db1, schema=temp_schema1 + ) + assert catalog.procedure_exists(proc) + assert not catalog.procedure_exists( + "does_not_exist", [], database=temp_db1, schema=temp_schema1 + ) + + assert catalog.user_defined_function_exists( + temp_udf1, [IntegerType()], database=temp_db1, schema=temp_schema1 + ) + assert catalog.user_defined_function_exists( + temp_udf1, [IntegerType()], database=db1_obj, schema=schema1_obj + ) + udf = catalog.get_user_defined_function( + temp_udf1, [IntegerType()], database=temp_db1, schema=temp_schema1 + ) + assert catalog.user_defined_function_exists(udf) + assert not catalog.user_defined_function_exists( + "does_not_exist", [], database=temp_db1, schema=temp_schema1 + ) + + +@pytest.mark.parametrize("use_object", [True, False]) +def test_drop_rest_mode(session, use_object): + catalog = session.catalog + + original_db = session.get_current_database() + original_schema = session.get_current_schema() + try: + temp_db = create_temp_db(session) + temp_schema = create_temp_schema(session, temp_db) + temp_table = create_temp_table(session, temp_db, temp_schema) + temp_view = create_temp_view(session, temp_db, temp_schema) + if use_object: + temp_schema = catalog.get_schema(temp_schema, database=temp_db) + temp_db = catalog.get_database(temp_db) + + assert catalog.database_exists(temp_db) + assert catalog.schema_exists(temp_schema, database=temp_db) + assert catalog.table_exists(temp_table, database=temp_db, schema=temp_schema) + assert catalog.view_exists(temp_view, database=temp_db, schema=temp_schema) + + catalog.drop_table(temp_table, database=temp_db, schema=temp_schema) + catalog.drop_view(temp_view, database=temp_db, schema=temp_schema) + + assert not catalog.table_exists( + temp_table, database=temp_db, schema=temp_schema + ) + assert not catalog.view_exists(temp_view, database=temp_db, schema=temp_schema) + + catalog.drop_schema(temp_schema, database=temp_db) + assert not catalog.schema_exists(temp_schema, database=temp_db) + + catalog.drop_database(temp_db) + assert not catalog.database_exists(temp_db) + finally: + session.use_database(original_db) + session.use_schema(original_schema) diff --git a/tests/integ/test_catalog_sql_mode.py b/tests/integ/test_catalog_sql_mode.py new file mode 100644 index 0000000000..9a6ed7e10d --- /dev/null +++ b/tests/integ/test_catalog_sql_mode.py @@ -0,0 +1,277 @@ +# +# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. +# +"""Catalog integration tests with ``context._is_snowpark_connect_compatible_mode`` True (SQL backend). + +Keep this file separate from ``test_catalog_rest_mode.py`` so removing one backend path +deletes only the matching test module. +""" + +import pytest + +from snowflake.core.exceptions import NotFoundError as CoreNotFoundError +from snowflake.core.view import View +from snowflake.snowpark import context +from snowflake.snowpark._internal.analyzer.analyzer_utils import unquote_if_quoted +from snowflake.snowpark.catalog import Catalog +from snowflake.snowpark.exceptions import NotFoundError +from snowflake.snowpark.types import IntegerType +from tests.integ.test_catalog import ( + CATALOG_TEMP_OBJECT_PREFIX, + create_temp_db, + create_temp_schema, + create_temp_table, + create_temp_view, +) + +pytestmark = [ + pytest.mark.xfail( + "config.getoption('local_testing_mode', default=False)", + reason="deepcopy is not supported and required by local testing", + run=False, + ), +] + + +@pytest.fixture(autouse=True, scope="module") +def _catalog_sql_backend_mode(session): + mp = pytest.MonkeyPatch() + mp.setattr(context, "_is_snowpark_connect_compatible_mode", True) + mp.setattr(session, "_catalog", None) + try: + yield + finally: + mp.undo() + session._catalog = None + + +def test_list_db_sql_mode(session, temp_db1, temp_db2): + catalog: Catalog = session.catalog + db_list = catalog.list_databases(pattern=f"{CATALOG_TEMP_OBJECT_PREFIX}_DB_*") + assert {db.name for db in db_list} >= {temp_db1, temp_db2} + + db_list = catalog.list_databases(like=f"{CATALOG_TEMP_OBJECT_PREFIX}_DB_%") + assert {db.name for db in db_list} >= {temp_db1, temp_db2} + + +def test_list_schema_sql_mode(session, temp_db1, temp_schema1, temp_schema2): + catalog: Catalog = session.catalog + assert ( + len(catalog.list_databases(pattern=f"{CATALOG_TEMP_OBJECT_PREFIX}_SCHEMA_.*")) + == 0 + ) + + schema_list = catalog.list_schemas( + pattern=f"{CATALOG_TEMP_OBJECT_PREFIX}_SCHEMA_.*", database=temp_db1 + ) + assert {schema.name for schema in schema_list} >= {temp_schema1, temp_schema2} + + schema_list = catalog.list_schemas( + like=f"{CATALOG_TEMP_OBJECT_PREFIX}_SCHEMA_%", database=temp_db1 + ) + assert {schema.name for schema in schema_list} >= {temp_schema1, temp_schema2} + + +def test_get_db_schema_sql_mode(session): + catalog: Catalog = session.catalog + current_db = session.get_current_database() + current_schema = session.get_current_schema() + assert catalog.get_database(current_db).name == unquote_if_quoted(current_db) + assert catalog.get_schema(current_schema).name == unquote_if_quoted(current_schema) + + +def test_get_database_missing_raises_snowpark_not_found_sql_mode(session): + catalog: Catalog = session.catalog + with pytest.raises(NotFoundError, match="could not be found"): + catalog.get_database("NONEXISTENT_DB_XYZ_12345") + + +def test_compat_mode_with_sql_base_disabled_uses_rest_backend(session): + original_use_sql_base = session._use_sql_base + try: + session._use_sql_base = False + session._catalog = None + catalog: Catalog = session.catalog + assert type(catalog._backend).__name__ == "_RestCatalogBackend" + with pytest.raises(CoreNotFoundError): + catalog.get_database("NONEXISTENT_DB_XYZ_12345") + finally: + session._use_sql_base = original_use_sql_base + session._catalog = None + + +def test_compat_mode_with_sql_base_enabled_uses_sql_backend(session): + original_use_sql_base = session._use_sql_base + try: + session._use_sql_base = True + session._catalog = None + catalog: Catalog = session.catalog + assert type(catalog._backend).__name__ == "_SqlCatalogBackend" + finally: + session._use_sql_base = original_use_sql_base + session._catalog = None + + +def test_get_table_resolves_view_sql_mode(session, temp_db1, temp_schema1, temp_view1): + catalog: Catalog = session.catalog + obj = catalog.get_table(temp_view1, database=temp_db1, schema=temp_schema1) + assert isinstance(obj, View) + assert obj.name == temp_view1 + + +def test_table_exists_true_for_view_name_sql_mode( + session, temp_db1, temp_schema1, temp_view1 +): + catalog: Catalog = session.catalog + assert catalog.table_exists(temp_view1, database=temp_db1, schema=temp_schema1) + + +@pytest.mark.udf +def test_get_procedure_sql_mode(session, temp_db1, temp_schema1, temp_procedure1): + catalog: Catalog = session.catalog + procedure = catalog.get_procedure( + temp_procedure1, [IntegerType()], database=temp_db1, schema=temp_schema1 + ) + assert procedure.name == temp_procedure1 + assert procedure.database_name == temp_db1 + assert procedure.schema_name == temp_schema1 + + +@pytest.mark.udf +def test_get_user_defined_function_sql_mode(session, temp_db1, temp_schema1, temp_udf1): + catalog: Catalog = session.catalog + udf = catalog.get_user_defined_function( + temp_udf1, [IntegerType()], database=temp_db1, schema=temp_schema1 + ) + assert udf.name == temp_udf1 + assert udf.database_name == temp_db1 + assert udf.schema_name == temp_schema1 + + +def test_database_exists_sql_mode(session, temp_db1): + catalog: Catalog = session.catalog + assert catalog.database_exists(temp_db1) + assert not catalog.database_exists("does_not_exist") + + +def test_get_table_view_sql_mode( + session, temp_db1, temp_schema1, temp_table1, temp_view1 +): + catalog: Catalog = session.catalog + table = catalog.get_table(temp_table1, database=temp_db1, schema=temp_schema1) + assert table.name == temp_table1 + assert table.database_name == temp_db1 + assert table.schema_name == temp_schema1 + + view = catalog.get_view(temp_view1, database=temp_db1, schema=temp_schema1) + assert view.name == temp_view1 + assert view.database_name == temp_db1 + assert view.schema_name == temp_schema1 + + +def test_exists_db_schema_sql_mode(session, temp_db1, temp_schema1): + catalog = session.catalog + assert catalog.database_exists(temp_db1) + assert not catalog.database_exists("does_not_exist") + + assert catalog.schema_exists(temp_schema1, database=temp_db1) + assert not catalog.schema_exists(temp_schema1, database="does_not_exist") + + +def test_exists_table_view_sql_mode( + session, temp_db1, temp_schema1, temp_table1, temp_view1 +): + catalog = session.catalog + db1_obj = catalog.get_database(temp_db1) + schema1_obj = catalog.get_schema(database=temp_db1, schema=temp_schema1) + + assert catalog.table_exists(temp_table1, database=temp_db1, schema=temp_schema1) + assert catalog.table_exists(temp_table1, database=db1_obj, schema=schema1_obj) + table = catalog.get_table(temp_table1, database=temp_db1, schema=temp_schema1) + assert catalog.table_exists(table) + assert not catalog.table_exists( + "does_not_exist", database=temp_db1, schema=temp_schema1 + ) + + assert catalog.view_exists(temp_view1, database=temp_db1, schema=temp_schema1) + assert catalog.view_exists(temp_view1, database=db1_obj, schema=schema1_obj) + view = catalog.get_view(temp_view1, database=temp_db1, schema=temp_schema1) + assert catalog.view_exists(view) + assert not catalog.view_exists( + "does_not_exist", database=temp_db1, schema=temp_schema1 + ) + + +@pytest.mark.udf +def test_exists_function_procedure_udf_sql_mode( + session, temp_db1, temp_schema1, temp_procedure1, temp_udf1 +): + catalog = session.catalog + db1_obj = catalog.get_database(temp_db1) + schema1_obj = catalog.get_schema(temp_schema1, database=temp_db1) + + assert catalog.procedure_exists( + temp_procedure1, [IntegerType()], database=temp_db1, schema=temp_schema1 + ) + assert catalog.procedure_exists( + temp_procedure1, [IntegerType()], database=db1_obj, schema=schema1_obj + ) + proc = catalog.get_procedure( + temp_procedure1, [IntegerType()], database=temp_db1, schema=temp_schema1 + ) + assert catalog.procedure_exists(proc) + assert not catalog.procedure_exists( + "does_not_exist", [], database=temp_db1, schema=temp_schema1 + ) + + assert catalog.user_defined_function_exists( + temp_udf1, [IntegerType()], database=temp_db1, schema=temp_schema1 + ) + assert catalog.user_defined_function_exists( + temp_udf1, [IntegerType()], database=db1_obj, schema=schema1_obj + ) + udf = catalog.get_user_defined_function( + temp_udf1, [IntegerType()], database=temp_db1, schema=temp_schema1 + ) + assert catalog.user_defined_function_exists(udf) + assert not catalog.user_defined_function_exists( + "does_not_exist", [], database=temp_db1, schema=temp_schema1 + ) + + +@pytest.mark.parametrize("use_object", [True, False]) +def test_drop_sql_mode(session, use_object): + catalog = session.catalog + + original_db = session.get_current_database() + original_schema = session.get_current_schema() + try: + temp_db = create_temp_db(session) + temp_schema = create_temp_schema(session, temp_db) + temp_table = create_temp_table(session, temp_db, temp_schema) + temp_view = create_temp_view(session, temp_db, temp_schema) + if use_object: + temp_schema = catalog.get_schema(temp_schema, database=temp_db) + temp_db = catalog.get_database(temp_db) + + assert catalog.database_exists(temp_db) + assert catalog.schema_exists(temp_schema, database=temp_db) + assert catalog.table_exists(temp_table, database=temp_db, schema=temp_schema) + assert catalog.view_exists(temp_view, database=temp_db, schema=temp_schema) + + catalog.drop_table(temp_table, database=temp_db, schema=temp_schema) + catalog.drop_view(temp_view, database=temp_db, schema=temp_schema) + + assert not catalog.table_exists( + temp_table, database=temp_db, schema=temp_schema + ) + assert not catalog.view_exists(temp_view, database=temp_db, schema=temp_schema) + + catalog.drop_schema(temp_schema, database=temp_db) + assert not catalog.schema_exists(temp_schema, database=temp_db) + + catalog.drop_database(temp_db) + assert not catalog.database_exists(temp_db) + finally: + session.use_database(original_db) + session.use_schema(original_schema) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 0349618659..9377d1ad92 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -117,6 +117,56 @@ def test_used_scoped_temp_object(): assert Session(fake_connection)._use_scoped_temp_objects is False +@pytest.mark.parametrize( + "option_value, expected", + [(None, True), (True, True), (False, False)], +) +def test_session_use_sql_base_from_options(option_value, expected): + fake_connection = mock.create_autospec(ServerConnection) + fake_connection._conn = mock.Mock() + fake_connection._thread_safe_session_enabled = True + fake_connection._get_client_side_session_parameter = ( + lambda x, y: ServerConnection._get_client_side_session_parameter( + fake_connection, x, y + ) + ) + fake_connection._conn._session_parameters = {} + + options = {} if option_value is None else {"_use_sql_base": option_value} + session = Session(fake_connection, options) + assert session._use_sql_base is expected + assert session.conf.get("_use_sql_base") is None + + +@pytest.mark.parametrize( + "option_value, expected_backend_name", + [(True, "_SqlCatalogBackend"), (False, "_RestCatalogBackend")], +) +def test_catalog_backend_selection_from_use_sql_base_option( + option_value, expected_backend_name +): + import snowflake.snowpark.context as ctx + + fake_connection = mock.create_autospec(ServerConnection) + fake_connection._conn = mock.Mock() + fake_connection._thread_safe_session_enabled = True + fake_connection._get_client_side_session_parameter = ( + lambda x, y: ServerConnection._get_client_side_session_parameter( + fake_connection, x, y + ) + ) + fake_connection._conn._session_parameters = {} + fake_connection.get_session_id.return_value = "fake_session_id" + + original_compat = ctx._is_snowpark_connect_compatible_mode + try: + ctx._is_snowpark_connect_compatible_mode = True + session = Session(fake_connection, {"_use_sql_base": option_value}) + assert type(session.catalog._backend).__name__ == expected_backend_name + finally: + ctx._is_snowpark_connect_compatible_mode = original_compat + + def test_close_exception(): fake_connection = mock.create_autospec(ServerConnection) fake_connection._conn = mock.Mock()