99import  itertools 
1010import  os 
1111import  platform 
12+ import  shutil 
1213import  sys 
1314import  warnings 
1415from  concurrent .futures  import  FIRST_COMPLETED , ThreadPoolExecutor , wait 
1819from  importlib .metadata  import  PackageNotFoundError , version 
1920from  json  import  JSONDecodeError 
2021from  math  import  ceil 
21- from  typing  import  (
22-     TYPE_CHECKING ,
23-     ForwardRef ,
24-     Optional ,
25-     get_args ,
26- )
22+ from  typing  import  TYPE_CHECKING , ForwardRef , Optional , get_args 
2723from  urllib .parse  import  quote , urljoin 
2824
25+ import  pyarrow  as  pa 
26+ import  pyarrow .dataset  as  ds 
2927import  requests 
28+ from  deltalake  import  DeltaTable , QueryBuilder , convert_to_deltalake 
3029from  emmet .core .utils  import  jsanitize 
3130from  pydantic  import  BaseModel , create_model 
3231from  requests .adapters  import  HTTPAdapter 
3635from  urllib3 .util .retry  import  Retry 
3736
3837from  mp_api .client .core .settings  import  MAPIClientSettings 
39- from  mp_api .client .core .utils  import  load_json , validate_ids 
38+ from  mp_api .client .core .utils  import  MPDataset ,  load_json , validate_ids 
4039
4140try :
4241    import  boto3 
@@ -71,6 +70,7 @@ class BaseRester:
7170    document_model : type [BaseModel ] |  None  =  None 
7271    supports_versions : bool  =  False 
7372    primary_key : str  =  "material_id" 
73+     delta_backed : bool  =  False 
7474
7575    def  __init__ (
7676        self ,
@@ -85,6 +85,8 @@ def __init__(
8585        timeout : int  =  20 ,
8686        headers : dict  |  None  =  None ,
8787        mute_progress_bars : bool  =  SETTINGS .MUTE_PROGRESS_BARS ,
88+         local_dataset_cache : str  |  os .PathLike  =  SETTINGS .LOCAL_DATASET_CACHE ,
89+         force_renew : bool  =  False ,
8890    ):
8991        """Initialize the REST API helper class. 
9092
@@ -116,6 +118,9 @@ def __init__(
116118            timeout: Time in seconds to wait until a request timeout error is thrown 
117119            headers: Custom headers for localhost connections. 
118120            mute_progress_bars: Whether to disable progress bars. 
121+             local_dataset_cache: Target directory for downloading full datasets. Defaults 
122+                 to 'materialsproject_datasets' in the user's home directory 
123+             force_renew: Option to overwrite existing local dataset 
119124        """ 
120125        # TODO: think about how to migrate from PMG_MAPI_KEY 
121126        self .api_key  =  api_key  or  os .getenv ("MP_API_KEY" )
@@ -129,6 +134,8 @@ def __init__(
129134        self .timeout  =  timeout 
130135        self .headers  =  headers  or  {}
131136        self .mute_progress_bars  =  mute_progress_bars 
137+         self .local_dataset_cache  =  local_dataset_cache 
138+         self .force_renew  =  force_renew 
132139        self .db_version  =  BaseRester ._get_database_version (self .endpoint )
133140
134141        if  self .suffix :
@@ -212,7 +219,7 @@ def _get_database_version(endpoint):
212219        remains unchanged and available for querying via its task_id. 
213220
214221        The database version is set as a date in the format YYYY_MM_DD, 
215-         where  "_DD" may be optional. An additional numerical or `postN` suffix 
222+         predicate  "_DD" may be optional. An additional numerical or `postN` suffix 
216223        might be added if multiple releases happen on the same day. 
217224
218225        Returns: database version as a string 
@@ -356,10 +363,7 @@ def _patch_resource(
356363            raise  MPRestError (str (ex ))
357364
358365    def  _query_open_data (
359-         self ,
360-         bucket : str ,
361-         key : str ,
362-         decoder : Callable  |  None  =  None ,
366+         self , bucket : str , key : str , decoder : Callable  |  None  =  None 
363367    ) ->  tuple [list [dict ] |  list [bytes ], int ]:
364368        """Query and deserialize Materials Project AWS open data s3 buckets. 
365369
@@ -463,6 +467,12 @@ def _query_resource(
463467                    url  +=  "/" 
464468
465469            if  query_s3 :
470+                 pbar_message  =  (  # type: ignore 
471+                     f"Retrieving { self .document_model .__name__ }   documents"   # type: ignore 
472+                     if  self .document_model  is  not   None 
473+                     else  "Retrieving documents" 
474+                 )
475+ 
466476                db_version  =  self .db_version .replace ("." , "-" )
467477                if  "/"  not  in   self .suffix :
468478                    suffix  =  self .suffix 
@@ -474,9 +484,14 @@ def _query_resource(
474484                    suffix  =  suffix .replace ("_" , "-" )
475485
476486                # Check if user has access to GNoMe 
487+                 # temp suppress tqdm 
488+                 re_enable  =  not  self .mute_progress_bars 
489+                 self .mute_progress_bars  =  True 
477490                has_gnome_access  =  bool (
478491                    self ._submit_requests (
479-                         url = urljoin (self .endpoint , "materials/summary/" ),
492+                         url = urljoin (
493+                             "https://api.materialsproject.org/" , "materials/summary/" 
494+                         ),
480495                        criteria = {
481496                            "batch_id" : "gnome_r2scan_statics" ,
482497                            "_fields" : "material_id" ,
@@ -489,21 +504,147 @@ def _query_resource(
489504                    .get ("meta" , {})
490505                    .get ("total_doc" , 0 )
491506                )
507+                 self .mute_progress_bars  =  not  re_enable 
492508
493-                 # Paginate over all entries in the bucket. 
494-                 # TODO: change when a subset of entries needed from DB 
495509                if  "tasks"  in  suffix :
496-                     bucket_suffix , prefix  =  "parsed" , "tasks_atomate2"  
510+                     bucket_suffix , prefix  =  ( "parsed" , "core/tasks/"  ) 
497511                else :
498512                    bucket_suffix  =  "build" 
499513                    prefix  =  f"collections/{ db_version }  /{ suffix }  " 
500514
501-                 # only include prefixes accessible to user 
502-                 # i.e. append `batch_id=others/core` to `prefix` 
503-                 if  not  has_gnome_access :
504-                     prefix  +=  "/batch_id=others" 
505- 
506515                bucket  =  f"materialsproject-{ bucket_suffix }  " 
516+ 
517+                 if  self .delta_backed :
518+                     target_path  =  (
519+                         self .local_dataset_cache  +  f"/{ bucket_suffix }  /{ prefix }  " 
520+                     )
521+                     os .makedirs (target_path , exist_ok = True )
522+ 
523+                     if  DeltaTable .is_deltatable (target_path ):
524+                         if  self .force_renew :
525+                             shutil .rmtree (target_path )
526+                             warnings .warn (
527+                                 f"Regenerating { suffix }   dataset at { target_path }  ..." ,
528+                                 MPLocalDatasetWarning ,
529+                             )
530+                             os .makedirs (target_path , exist_ok = True )
531+                         else :
532+                             warnings .warn (
533+                                 f"Dataset for { suffix }   already exists at { target_path }  , delete or move existing dataset " 
534+                                 "or re-run search query with MPRester(force_renew=True)" ,
535+                                 MPLocalDatasetWarning ,
536+                             )
537+ 
538+                             return  {
539+                                 "data" : MPDataset (
540+                                     path = target_path ,
541+                                     document_model = self .document_model ,
542+                                     use_document_model = self .use_document_model ,
543+                                 )
544+                             }
545+ 
546+                     tbl  =  DeltaTable (
547+                         f"s3a://{ bucket }  /{ prefix }  " ,
548+                         storage_options = {
549+                             "AWS_SKIP_SIGNATURE" : "true" ,
550+                             "AWS_REGION" : "us-east-1" ,
551+                         },
552+                     )
553+ 
554+                     controlled_batch_str  =  "," .join (
555+                         [f"'{ tag }  '"  for  tag  in  SETTINGS .ACCESS_CONTROLLED_BATCH_IDS ]
556+                     )
557+ 
558+                     predicate  =  (
559+                         " WHERE batch_id NOT IN ("   # don't delete leading space 
560+                         +  controlled_batch_str 
561+                         +  ")" 
562+                         if  not  has_gnome_access 
563+                         else  "" 
564+                     )
565+ 
566+                     builder  =  QueryBuilder ().register ("tbl" , tbl )
567+ 
568+                     # Setup progress bar 
569+                     num_docs_needed  =  pa .table (
570+                         builder .execute ("SELECT COUNT(*) FROM tbl" ).read_all ()
571+                     )[0 ][0 ].as_py ()
572+ 
573+                     # TODO: Update tasks (+ others?) resource to have emmet-api BatchIdQuery operator 
574+                     #   -> need to modify BatchIdQuery operator to handle root level 
575+                     #      batch_id, not only builder_meta.batch_id 
576+                     # if not has_gnome_access: 
577+                     #     num_docs_needed = self.count( 
578+                     #         {"batch_id_neq_any": SETTINGS.ACCESS_CONTROLLED_BATCH_IDS} 
579+                     #     ) 
580+ 
581+                     pbar  =  (
582+                         tqdm (
583+                             desc = pbar_message ,
584+                             total = num_docs_needed ,
585+                         )
586+                         if  not  self .mute_progress_bars 
587+                         else  None 
588+                     )
589+ 
590+                     iterator  =  builder .execute ("SELECT * FROM tbl"  +  predicate )
591+ 
592+                     file_options  =  ds .ParquetFileFormat ().make_write_options (
593+                         compression = "zstd" 
594+                     )
595+ 
596+                     def  _flush (accumulator , group ):
597+                         ds .write_dataset (
598+                             accumulator ,
599+                             base_dir = target_path ,
600+                             format = "parquet" ,
601+                             basename_template = f"group-{ group }  -" 
602+                             +  "part-{i}.zstd.parquet" ,
603+                             existing_data_behavior = "overwrite_or_ignore" ,
604+                             max_rows_per_group = 1024 ,
605+                             file_options = file_options ,
606+                         )
607+ 
608+                     group  =  1 
609+                     size  =  0 
610+                     accumulator  =  []
611+                     for  page  in  iterator :
612+                         # arro3 rb to pyarrow rb for compat w/ pyarrow ds writer 
613+                         accumulator .append (pa .record_batch (page ))
614+                         page_size  =  page .num_rows 
615+                         size  +=  page_size 
616+ 
617+                         if  pbar  is  not   None :
618+                             pbar .update (page_size )
619+ 
620+                         if  size  >=  SETTINGS .DATASET_FLUSH_THRESHOLD :
621+                             _flush (accumulator , group )
622+                             group  +=  1 
623+                             size  =  0 
624+                             accumulator  =  []
625+ 
626+                     if  accumulator :
627+                         _flush (accumulator , group  +  1 )
628+ 
629+                     convert_to_deltalake (target_path )
630+ 
631+                     warnings .warn (
632+                         f"Dataset for { suffix }   written to { target_path }  . It is recommended to optimize " 
633+                         "the table according to your usage patterns prior to running intensive workloads, " 
634+                         "see: https://delta-io.github.io/delta-rs/delta-lake-best-practices/#optimizing-table-layout" ,
635+                         MPLocalDatasetWarning ,
636+                     )
637+ 
638+                     return  {
639+                         "data" : MPDataset (
640+                             path = target_path ,
641+                             document_model = self .document_model ,
642+                             use_document_model = self .use_document_model ,
643+                         )
644+                     }
645+ 
646+                 # Paginate over all entries in the bucket. 
647+                 # TODO: change when a subset of entries needed from DB 
507648                paginator  =  self .s3_client .get_paginator ("list_objects_v2" )
508649                pages  =  paginator .paginate (Bucket = bucket , Prefix = prefix )
509650
@@ -540,11 +681,6 @@ def _query_resource(
540681                }
541682
542683                # Setup progress bar 
543-                 pbar_message  =  (  # type: ignore 
544-                     f"Retrieving { self .document_model .__name__ }   documents"   # type: ignore 
545-                     if  self .document_model  is  not   None 
546-                     else  "Retrieving documents" 
547-                 )
548684                num_docs_needed  =  int (self .count ())
549685                pbar  =  (
550686                    tqdm (
@@ -1372,3 +1508,7 @@ class MPRestError(Exception):
13721508
13731509class  MPRestWarning (Warning ):
13741510    """Raised when a query is malformed but interpretable.""" 
1511+ 
1512+ 
1513+ class  MPLocalDatasetWarning (Warning ):
1514+     """Raised when unrecoverable actions are performed on a local dataset.""" 
0 commit comments