Skip to content

Commit 8e98dcf

Browse files
Jammy2211Jammy2211
authored andcommitted
simplify numba import
1 parent 28bbc29 commit 8e98dcf

File tree

3 files changed

+15
-43
lines changed

3 files changed

+15
-43
lines changed

autoarray/inversion/inversion/imaging/w_tilde.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def __init__(
5353
import numba
5454
except ModuleNotFoundError:
5555
raise exc.InversionException(
56-
"Inversion functionality (linear light profiles, pixelized reconstructions) is "
56+
"Inversion w-tilde functionality (pixelized reconstructions) is "
5757
"disabled if numba is not installed.\n\n"
5858
"This is because the run-times without numba are too slow.\n\n"
5959
"Please install numba, which is described at the following web page:\n\n"

autoarray/inversion/inversion/interferometer/w_tilde.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,17 @@ def __init__(
5353
the simultaneous linear equations are combined and solved simultaneously.
5454
"""
5555

56+
try:
57+
import numba
58+
except ModuleNotFoundError:
59+
raise exc.InversionException(
60+
"Inversion w-tilde functionality (pixelized reconstructions) is "
61+
"disabled if numba is not installed.\n\n"
62+
"This is because the run-times without numba are too slow.\n\n"
63+
"Please install numba, which is described at the following web page:\n\n"
64+
"https://pyautolens.readthedocs.io/en/latest/installation/overview.html"
65+
)
66+
5667
self.w_tilde = w_tilde
5768
self.w_tilde.check_noise_map(noise_map=dataset.noise_map)
5869

autoarray/numba_util.py

Lines changed: 3 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,10 @@
1-
import os
2-
from functools import wraps
31
import logging
4-
import time
5-
from typing import Callable
62

73
from autoconf import conf
84

9-
from autoarray import exc
105

116
logger = logging.getLogger(__name__)
127

13-
"""
14-
Depending on if we're using a super computer, we want two different numba decorators:
15-
16-
If on laptop:
17-
18-
@numba.jit(nopython=True, cache=True, parallel=False)
19-
20-
If on super computer:
21-
22-
@numba.jit(nopython=True, cache=False, parallel=True)
23-
"""
24-
258
try:
269
nopython = conf.instance["general"]["numba"]["nopython"]
2710
cache = conf.instance["general"]["numba"]["cache"]
@@ -31,41 +14,19 @@
3114
cache = True
3215
parallel = False
3316

34-
try:
35-
if os.environ.get("USE_JAX") == "1":
36-
1
37-
else:
38-
import numba
39-
40-
except ModuleNotFoundError:
41-
logger.warning(
42-
f"\n******************************************************************************\n"
43-
f"Numba is not being used, either because it is disabled in `config/general.yaml` "
44-
f"or because it is not installed.\n\n. "
45-
f"This will lead to slow performance.\n\n. "
46-
f"Install numba as described at the following webpage for improved performance. \n"
47-
f"https://pyautolens.readthedocs.io/en/latest/installation/numba.html \n"
48-
f"********************************************************************************"
49-
)
50-
5117

5218
def jit(nopython=nopython, cache=cache, parallel=parallel):
53-
def wrapper(func):
54-
try:
55-
use_numba = conf.instance["general"]["numba"]["use_numba"]
56-
57-
if not use_numba:
58-
return func
5919

60-
except KeyError:
61-
pass
20+
def wrapper(func):
6221

6322
try:
23+
6424
import numba
6525

6626
return numba.jit(func, nopython=nopython, cache=cache, parallel=parallel)
6727

6828
except ModuleNotFoundError:
29+
6930
return func
7031

7132
return wrapper

0 commit comments

Comments
 (0)