2525from marigold .marigold import MarigoldPipeline
2626# pix2pix/merge net imports
2727from pix2pix .options .test_options import TestOptions
28+ # depthanyting v2
29+ try :
30+ from depth_anything_v2 import DepthAnythingV2
31+ except :
32+ print ('depth_anything_v2 import failed... somehow' )
2833
2934# Our code
3035from src .misc import *
@@ -80,6 +85,8 @@ def load_models(self, model_type, device: torch.device, boost: bool, tiling_mode
8085 model_dir = "./models/leres"
8186 if model_type == 11 :
8287 model_dir = "./models/depth_anything"
88+ if model_type in [12 , 13 , 14 ]:
89+ model_dir = "./models/depth_anything_v2"
8390
8491 # create paths to model if not present
8592 os .makedirs (model_dir , exist_ok = True )
@@ -227,6 +234,19 @@ def load_models(self, model_type, device: torch.device, boost: bool, tiling_mode
227234 "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth" )
228235
229236 model .load_state_dict (torch .load (model_path ))
237+ elif model_type in [12 , 13 , 14 ]: # depth_anything_v2 small, base, large
238+ letter = {12 : 's' , 13 : 'b' , 14 : 'l' }[model_type ]
239+ word = {12 : 'Small' , 13 : 'Base' , 14 : 'Large' }[model_type ]
240+ model_path = f"{ model_dir } /depth_anything_v2_vit{ letter } .pth"
241+ ensure_file_downloaded (model_path ,
242+ f"https://huggingface.co/depth-anything/Depth-Anything-V2-{ word } /resolve/main/depth_anything_v2_vit{ letter } .pth" )
243+ model_configs = {'vits' : {'encoder' : 'vits' , 'features' : 64 , 'out_channels' : [48 , 96 , 192 , 384 ]},
244+ 'vitb' : {'encoder' : 'vitb' , 'features' : 128 , 'out_channels' : [96 , 192 , 384 , 768 ]},
245+ 'vitl' : {'encoder' : 'vitl' , 'features' : 256 , 'out_channels' : [256 , 512 , 1024 , 1024 ]},
246+ 'vitg' : {'encoder' : 'vitg' , 'features' : 384 , 'out_channels' : [1536 , 1536 , 1536 , 1536 ]}}
247+ model = DepthAnythingV2 (** model_configs [f'vit{ letter } ' ])
248+ model .load_state_dict (torch .load (model_path , map_location = 'cpu' ))
249+ # 15 is reserved for Depth Anything V2 Giant
230250
231251 if tiling_mode :
232252 def flatten (el ):
@@ -250,6 +270,9 @@ def flatten(el):
250270 # TODO: Fix for zoedepth_n - it completely trips and generates black images
251271 if model_type in [1 , 2 , 3 , 4 , 5 , 6 , 8 , 9 , 11 ] and not boost :
252272 model = model .half ()
273+ if model_type in [12 , 13 , 14 ]:
274+ model .depth_head .half ()
275+ model .pretrained .half ()
253276 model .to (device ) # to correct device
254277
255278 self .depth_model = model
@@ -291,7 +314,10 @@ def get_default_net_size(model_type):
291314 8 : [384 , 768 ],
292315 9 : [384 , 512 ],
293316 10 : [768 , 768 ],
294- 11 : [518 , 518 ]
317+ 11 : [518 , 518 ],
318+ 12 : [518 , 518 ],
319+ 13 : [518 , 518 ],
320+ 14 : [518 , 518 ]
295321 }
296322 if model_type in sizes :
297323 return sizes [model_type ]
@@ -350,6 +376,8 @@ def get_raw_prediction(self, input, net_width, net_height):
350376 self .marigold_ensembles , self .marigold_steps )
351377 elif self .depth_model_type == 11 :
352378 raw_prediction = estimatedepthanything (img , self .depth_model , net_width , net_height )
379+ elif self .depth_model_type in [12 , 13 , 14 ]:
380+ raw_prediction = estimatedepthanything_v2 (img , self .depth_model , net_width , net_height )
353381 else :
354382 raw_prediction = estimateboost (img , self .depth_model , self .depth_model_type , self .pix2pix_model ,
355383 self .boost_rmax )
@@ -499,6 +527,20 @@ def estimatedepthanything(image, model, w, h):
499527 return depth .cpu ().numpy ()
500528
501529
530+ def estimatedepthanything_v2 (image , model , w , h ):
531+ # This is an awkward re-conversion, but I believe it should not impact quality
532+ img = cv2 .cvtColor ((image * 255.1 ).astype ('uint8' ), cv2 .COLOR_BGR2RGB )
533+ with torch .no_grad ():
534+ # Compare to: model.infer_image(img, w)
535+ image , (h , w ) = model .image2tensor (img , w )
536+ # Casting to correct type, it is the same as type of some model tensor (the one here is arbitrary)
537+ image_casted = image .type_as (model .pretrained .blocks [0 ].norm1 .weight .data )
538+ depth = model .forward (image_casted ).type_as (image )
539+ import torch .nn .functional as F
540+ depth = F .interpolate (depth [:, None ], (h , w ), mode = "bilinear" , align_corners = True )[0 , 0 ]
541+ return depth .cpu ().numpy ()
542+
543+
502544class ImageandPatchs :
503545 def __init__ (self , root_dir , name , patchsinfo , rgb_image , scale = 1 ):
504546 self .root_dir = root_dir
@@ -720,6 +762,8 @@ def estimateboost(img, model, model_type, pix2pixmodel, whole_size_threshold):
720762 net_receptive_field_size = 512
721763 elif model_type == 11 : # depth_anything
722764 net_receptive_field_size = 518
765+ elif model_type in [12 , 13 , 14 ]: # depth_anything_v2
766+ net_receptive_field_size = 518
723767 else : # other midas # TODO Marigold support
724768 net_receptive_field_size = 384
725769 patch_netsize = 2 * net_receptive_field_size
@@ -995,6 +1039,8 @@ def singleestimate(img, msize, model, net_type):
9951039 return estimatemarigold (img , model , msize , msize )
9961040 elif net_type == 11 :
9971041 return estimatedepthanything (img , model , msize , msize )
1042+ elif net_type in [12 , 13 , 14 ]:
1043+ return estimatedepthanything_v2 (img , model , msize , msize )
9981044 elif net_type >= 7 :
9991045 # np to PIL
10001046 return estimatezoedepth (Image .fromarray (np .uint8 (img * 255 )).convert ('RGB' ), model , msize , msize )
0 commit comments