Skip to content

Commit 2b7cc8e

Browse files
update optics test (#427)
1 parent 410ce9c commit 2b7cc8e

File tree

1 file changed

+33
-15
lines changed

1 file changed

+33
-15
lines changed

deeptrack/tests/test_optics.py

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,42 @@
1-
import sys
2-
1+
# Use this only when running the test locally.
2+
# import sys
33
# sys.path.append(".") # Adds the module to path
44

55
import unittest
66

7-
from deeptrack import features
8-
from deeptrack import units_registry as u
9-
from deeptrack import optics
7+
import numpy as np
108

9+
from deeptrack import optics
1110
from deeptrack.scatterers import PointParticle, Sphere
12-
from deeptrack.image import Image
11+
from deeptrack import units_registry as u
1312

13+
from deeptrack.backend import TORCH_AVAILABLE, xp
14+
from deeptrack.tests import BackendTestBase
1415

15-
import numpy as np
16+
if TORCH_AVAILABLE:
17+
import torch
18+
19+
20+
class TestOptics_NumPy(BackendTestBase):
21+
BACKEND = "numpy"
1622

23+
@property
24+
def array_type(self):
25+
if self.BACKEND == "numpy":
26+
return np.ndarray
27+
elif self.BACKEND == "torch":
28+
return torch.Tensor
29+
else:
30+
raise ValueError(f"Unsupported backend: {self.BACKEND}")
1731

18-
class TestOptics(unittest.TestCase):
1932
def test_Microscope(self):
2033
microscope_type = optics.Fluorescence()
2134
scatterer = PointParticle()
2235
microscope = optics.Microscope(
2336
sample=scatterer, objective=microscope_type,
2437
)
2538
output_image = microscope.get(None)
26-
self.assertIsInstance(output_image, np.ndarray)
39+
self.assertIsInstance(output_image, self.array_type)
2740
self.assertEqual(output_image.shape, (128, 128, 1))
2841

2942
def test_Optics(self):
@@ -51,7 +64,7 @@ def test_Fluorescence(self):
5164
)
5265
imaged_scatterer = microscope(scatterer)
5366
output_image = imaged_scatterer.resolve()
54-
self.assertIsInstance(output_image, np.ndarray)
67+
self.assertIsInstance(output_image, self.array_type)
5568
self.assertEqual(microscope.NA(), 0.7)
5669
self.assertEqual(output_image.shape, (64, 64, 1))
5770

@@ -74,7 +87,7 @@ def test_Brightfield(self):
7487
)
7588
imaged_scatterer = microscope(scatterer)
7689
output_image = imaged_scatterer.resolve()
77-
self.assertIsInstance(output_image, np.ndarray)
90+
self.assertIsInstance(output_image, self.array_type)
7891
self.assertEqual(output_image.shape, (64, 64, 1))
7992

8093
def test_Holography(self):
@@ -96,7 +109,7 @@ def test_Holography(self):
96109
)
97110
imaged_scatterer = microscope(scatterer)
98111
output_image = imaged_scatterer.resolve()
99-
self.assertIsInstance(output_image, np.ndarray)
112+
self.assertIsInstance(output_image, self.array_type)
100113
self.assertEqual(output_image.shape, (64, 64, 1))
101114

102115
def test_ISCAT(self):
@@ -119,7 +132,7 @@ def test_ISCAT(self):
119132
imaged_scatterer = microscope(scatterer)
120133
output_image = imaged_scatterer.resolve()
121134
self.assertEqual(microscope.illumination_angle(), 3.141592653589793)
122-
self.assertIsInstance(output_image, np.ndarray)
135+
self.assertIsInstance(output_image, self.array_type)
123136
self.assertEqual(output_image.shape, (64, 64, 1))
124137

125138
def test_Darkfield(self):
@@ -142,7 +155,7 @@ def test_Darkfield(self):
142155
imaged_scatterer = microscope(scatterer)
143156
output_image = imaged_scatterer.resolve()
144157
self.assertEqual(microscope.illumination_angle(), 1.5707963267948966)
145-
self.assertIsInstance(output_image, np.ndarray)
158+
self.assertIsInstance(output_image, self.array_type)
146159
self.assertEqual(output_image.shape, (64, 64, 1))
147160

148161
def test_IlluminationGradient(self):
@@ -166,7 +179,7 @@ def test_IlluminationGradient(self):
166179
)
167180
imaged_scatterer = microscope(scatterer)
168181
output_image = imaged_scatterer.resolve()
169-
self.assertIsInstance(output_image, np.ndarray)
182+
self.assertIsInstance(output_image, self.array_type)
170183
self.assertEqual(output_image.shape, (64, 64, 1))
171184

172185
def test_upscale_fluorescence(self):
@@ -237,6 +250,11 @@ def test_upscale_brightfield(self):
237250
).mean() # Mean absolute error
238251
self.assertLess(error, 0.01)
239252

253+
# TODO: Extending the test and setting the backend to torch
254+
# @unittest.skipUnless(TORCH_AVAILABLE, "PyTorch is not installed.")
255+
# class TestOptics_PyTorch(TestOptics_NumPy):
256+
# BACKEND = "torch"
257+
# pass
240258

241259
if __name__ == "__main__":
242260
unittest.main()

0 commit comments

Comments
 (0)