@@ -523,6 +523,7 @@ def resize(
523
523
size = (height , width ),
524
524
)
525
525
image = self .pt_to_numpy (image )
526
+
526
527
return image
527
528
528
529
def binarize (self , image : PIL .Image .Image ) -> PIL .Image .Image :
@@ -838,6 +839,137 @@ def apply_overlay(
838
839
return image
839
840
840
841
842
+ class InpaintProcessor (ConfigMixin ):
843
+ """
844
+ Image processor for inpainting image and mask.
845
+ """
846
+
847
+ config_name = CONFIG_NAME
848
+
849
+ @register_to_config
850
+ def __init__ (
851
+ self ,
852
+ do_resize : bool = True ,
853
+ vae_scale_factor : int = 8 ,
854
+ vae_latent_channels : int = 4 ,
855
+ resample : str = "lanczos" ,
856
+ reducing_gap : int = None ,
857
+ do_normalize : bool = True ,
858
+ do_binarize : bool = False ,
859
+ do_convert_grayscale : bool = False ,
860
+ mask_do_normalize : bool = False ,
861
+ mask_do_binarize : bool = True ,
862
+ mask_do_convert_grayscale : bool = True ,
863
+ ):
864
+ super ().__init__ ()
865
+
866
+ self ._image_processor = VaeImageProcessor (
867
+ do_resize = do_resize ,
868
+ vae_scale_factor = vae_scale_factor ,
869
+ vae_latent_channels = vae_latent_channels ,
870
+ resample = resample ,
871
+ reducing_gap = reducing_gap ,
872
+ do_normalize = do_normalize ,
873
+ do_binarize = do_binarize ,
874
+ do_convert_grayscale = do_convert_grayscale ,
875
+ )
876
+ self ._mask_processor = VaeImageProcessor (
877
+ do_resize = do_resize ,
878
+ vae_scale_factor = vae_scale_factor ,
879
+ vae_latent_channels = vae_latent_channels ,
880
+ resample = resample ,
881
+ reducing_gap = reducing_gap ,
882
+ do_normalize = mask_do_normalize ,
883
+ do_binarize = mask_do_binarize ,
884
+ do_convert_grayscale = mask_do_convert_grayscale ,
885
+ )
886
+
887
+ def preprocess (
888
+ self ,
889
+ image : PIL .Image .Image ,
890
+ mask : PIL .Image .Image = None ,
891
+ height : int = None ,
892
+ width : int = None ,
893
+ padding_mask_crop : Optional [int ] = None ,
894
+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
895
+ """
896
+ Preprocess the image and mask.
897
+ """
898
+ if mask is None and padding_mask_crop is not None :
899
+ raise ValueError ("mask must be provided if padding_mask_crop is provided" )
900
+
901
+ # if mask is None, same behavior as regular image processor
902
+ if mask is None :
903
+ return self ._image_processor .preprocess (image , height = height , width = width )
904
+
905
+ if padding_mask_crop is not None :
906
+ crops_coords = self ._image_processor .get_crop_region (mask , width , height , pad = padding_mask_crop )
907
+ resize_mode = "fill"
908
+ else :
909
+ crops_coords = None
910
+ resize_mode = "default"
911
+
912
+ processed_image = self ._image_processor .preprocess (
913
+ image ,
914
+ height = height ,
915
+ width = width ,
916
+ crops_coords = crops_coords ,
917
+ resize_mode = resize_mode ,
918
+ )
919
+
920
+ processed_mask = self ._mask_processor .preprocess (
921
+ mask ,
922
+ height = height ,
923
+ width = width ,
924
+ resize_mode = resize_mode ,
925
+ crops_coords = crops_coords ,
926
+ )
927
+
928
+ if crops_coords is not None :
929
+ postprocessing_kwargs = {
930
+ "crops_coords" : crops_coords ,
931
+ "original_image" : image ,
932
+ "original_mask" : mask ,
933
+ }
934
+ else :
935
+ postprocessing_kwargs = {
936
+ "crops_coords" : None ,
937
+ "original_image" : None ,
938
+ "original_mask" : None ,
939
+ }
940
+
941
+ return processed_image , processed_mask , postprocessing_kwargs
942
+
943
+ def postprocess (
944
+ self ,
945
+ image : torch .Tensor ,
946
+ output_type : str = "pil" ,
947
+ original_image : Optional [PIL .Image .Image ] = None ,
948
+ original_mask : Optional [PIL .Image .Image ] = None ,
949
+ crops_coords : Optional [Tuple [int , int , int , int ]] = None ,
950
+ ) -> Tuple [PIL .Image .Image , PIL .Image .Image ]:
951
+ """
952
+ Postprocess the image, optionally apply mask overlay
953
+ """
954
+ image = self ._image_processor .postprocess (
955
+ image ,
956
+ output_type = output_type ,
957
+ )
958
+ # optionally apply the mask overlay
959
+ if crops_coords is not None and (original_image is None or original_mask is None ):
960
+ raise ValueError ("original_image and original_mask must be provided if crops_coords is provided" )
961
+
962
+ elif crops_coords is not None and output_type != "pil" :
963
+ raise ValueError ("output_type must be 'pil' if crops_coords is provided" )
964
+
965
+ elif crops_coords is not None :
966
+ image = [
967
+ self ._image_processor .apply_overlay (original_mask , original_image , i , crops_coords ) for i in image
968
+ ]
969
+
970
+ return image
971
+
972
+
841
973
class VaeImageProcessorLDM3D (VaeImageProcessor ):
842
974
"""
843
975
Image processor for VAE LDM3D.
0 commit comments