1- import sys
2-
1+ # Use this only when running the test locally.
2+ # import sys
33# sys.path.append(".") # Adds the module to path
44
55import unittest
66
7- from deeptrack import features
8- from deeptrack import units_registry as u
9- from deeptrack import optics
7+ import numpy as np
108
9+ from deeptrack import optics
1110from deeptrack .scatterers import PointParticle , Sphere
12- from deeptrack . image import Image
11+ from deeptrack import units_registry as u
1312
13+ from deeptrack .backend import TORCH_AVAILABLE , xp
14+ from deeptrack .tests import BackendTestBase
1415
15- import numpy as np
16+ if TORCH_AVAILABLE :
17+ import torch
18+
19+
20+ class TestOptics_NumPy (BackendTestBase ):
21+ BACKEND = "numpy"
1622
23+ @property
24+ def array_type (self ):
25+ if self .BACKEND == "numpy" :
26+ return np .ndarray
27+ elif self .BACKEND == "torch" :
28+ return torch .Tensor
29+ else :
30+ raise ValueError (f"Unsupported backend: { self .BACKEND } " )
1731
18- class TestOptics (unittest .TestCase ):
1932 def test_Microscope (self ):
2033 microscope_type = optics .Fluorescence ()
2134 scatterer = PointParticle ()
2235 microscope = optics .Microscope (
2336 sample = scatterer , objective = microscope_type ,
2437 )
2538 output_image = microscope .get (None )
26- self .assertIsInstance (output_image , np . ndarray )
39+ self .assertIsInstance (output_image , self . array_type )
2740 self .assertEqual (output_image .shape , (128 , 128 , 1 ))
2841
2942 def test_Optics (self ):
@@ -51,7 +64,7 @@ def test_Fluorescence(self):
5164 )
5265 imaged_scatterer = microscope (scatterer )
5366 output_image = imaged_scatterer .resolve ()
54- self .assertIsInstance (output_image , np . ndarray )
67+ self .assertIsInstance (output_image , self . array_type )
5568 self .assertEqual (microscope .NA (), 0.7 )
5669 self .assertEqual (output_image .shape , (64 , 64 , 1 ))
5770
@@ -74,7 +87,7 @@ def test_Brightfield(self):
7487 )
7588 imaged_scatterer = microscope (scatterer )
7689 output_image = imaged_scatterer .resolve ()
77- self .assertIsInstance (output_image , np . ndarray )
90+ self .assertIsInstance (output_image , self . array_type )
7891 self .assertEqual (output_image .shape , (64 , 64 , 1 ))
7992
8093 def test_Holography (self ):
@@ -96,7 +109,7 @@ def test_Holography(self):
96109 )
97110 imaged_scatterer = microscope (scatterer )
98111 output_image = imaged_scatterer .resolve ()
99- self .assertIsInstance (output_image , np . ndarray )
112+ self .assertIsInstance (output_image , self . array_type )
100113 self .assertEqual (output_image .shape , (64 , 64 , 1 ))
101114
102115 def test_ISCAT (self ):
@@ -119,7 +132,7 @@ def test_ISCAT(self):
119132 imaged_scatterer = microscope (scatterer )
120133 output_image = imaged_scatterer .resolve ()
121134 self .assertEqual (microscope .illumination_angle (), 3.141592653589793 )
122- self .assertIsInstance (output_image , np . ndarray )
135+ self .assertIsInstance (output_image , self . array_type )
123136 self .assertEqual (output_image .shape , (64 , 64 , 1 ))
124137
125138 def test_Darkfield (self ):
@@ -142,7 +155,7 @@ def test_Darkfield(self):
142155 imaged_scatterer = microscope (scatterer )
143156 output_image = imaged_scatterer .resolve ()
144157 self .assertEqual (microscope .illumination_angle (), 1.5707963267948966 )
145- self .assertIsInstance (output_image , np . ndarray )
158+ self .assertIsInstance (output_image , self . array_type )
146159 self .assertEqual (output_image .shape , (64 , 64 , 1 ))
147160
148161 def test_IlluminationGradient (self ):
@@ -166,7 +179,7 @@ def test_IlluminationGradient(self):
166179 )
167180 imaged_scatterer = microscope (scatterer )
168181 output_image = imaged_scatterer .resolve ()
169- self .assertIsInstance (output_image , np . ndarray )
182+ self .assertIsInstance (output_image , self . array_type )
170183 self .assertEqual (output_image .shape , (64 , 64 , 1 ))
171184
172185 def test_upscale_fluorescence (self ):
@@ -237,6 +250,11 @@ def test_upscale_brightfield(self):
237250 ).mean () # Mean absolute error
238251 self .assertLess (error , 0.01 )
239252
253+ # TODO: Extending the test and setting the backend to torch
254+ # @unittest.skipUnless(TORCH_AVAILABLE, "PyTorch is not installed.")
255+ # class TestOptics_PyTorch(TestOptics_NumPy):
256+ # BACKEND = "torch"
257+ # pass
240258
241259if __name__ == "__main__" :
242260 unittest .main ()
0 commit comments