diff --git a/examples/cProfile.ipynb b/examples/cProfile.ipynb new file mode 100644 index 0000000..93031fc --- /dev/null +++ b/examples/cProfile.ipynb @@ -0,0 +1,283 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "638c82eb-3757-4fe9-b1e8-0da9b736fa2d", + "metadata": {}, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "6ac6cc3c-e07c-47da-a97f-a087621390ef", + "metadata": {}, + "outputs": [], + "source": [ + "import LFPy\n", + "import lfpykit\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "15b63f94-ebcd-42bf-8e80-9761def92965", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "total number of segments: 3078\n" + ] + } + ], + "source": [ + "# LFPy.Cell parameters\n", + "cellParameters = {\n", + " 'morphology': 'L5_Mainen96_LFPy.hoc', # morphology file\n", + " 'v_init': -65, # initial voltage\n", + " 'cm': 1.0, # membrane capacitance\n", + " 'Ra': 150, # axial resistivity\n", + " 'passive': True, # insert passive channels\n", + " 'passive_parameters': {\"g_pas\": 1. / 3E4,\n", + " \"e_pas\": -65}, # passive params\n", + " 'dt': 2**-4, # simulation time res\n", + " 'nsegs_method': 'lambda_f', # discretization rule\n", + " 'lambda_f': 1000 # frequency (Hz)\n", + "}\n", + "\n", + "# create LFPy.Cell instance\n", + "cell = LFPy.Cell(**cellParameters)\n", + "cell.set_rotation(x=4.98919, y=-4.33261, z=0.)\n", + "\n", + "print(f'total number of segments: {cell.totnsegs}')\n", + "\n", + "# parameters for line source potential\n", + "el_params = dict(\n", + " x = np.linspace(0, 1000, 1001),\n", + " y = np.zeros(1001),\n", + " z = np.zeros(1001),\n", + " sigma = 0.3\n", + ")\n", + "\n", + "cell.simulate(rec_imem=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "3b8e2ec8-c8a4-48a6-81ef-b7dcee36c836", + "metadata": {}, + "outputs": [], + "source": [ + "# create line-source potential predictor\n", + "lsp = lfpykit.LineSourcePotential(cell, **el_params)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "79185fd5-9485-425c-8727-146c793126d9", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "OMP: Info #273: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " \n", + "*** Profile printout saved to text file 'prun0'.\n" + ] + } + ], + "source": [ + "%%prun -s cumulative -q -l 50 -T prun0\n", + "for i in range(100):\n", + " lsp.get_transformation_matrix()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "9fcd45b6-5349-4718-be60-ab8befe8dab0", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " 2670927 function calls (2651191 primitive calls) in 4.223 seconds\n", + "\n", + " Ordered by: cumulative time\n", + " List reduced from 1402 to 50 due to restriction <50>\n", + "\n", + " ncalls tottime percall cumtime percall filename:lineno(function)\n", + " 70/1 0.000 0.000 4.250 4.250 {built-in method builtins.exec}\n", + " 100 0.004 0.000 3.933 0.039 models.py:441(get_transformation_matrix)\n", + " 100 2.328 0.023 2.332 0.023 models.py:463(_get_transform)\n", + " 100 0.001 0.000 1.304 0.013 dispatcher.py:388(_compile_for_args)\n", + " 100 0.002 0.000 1.283 0.013 dispatcher.py:915(compile)\n", + " 100 0.000 0.000 1.269 0.013 caching.py:639(load_overload)\n", + "61409/61309 0.865 0.000 1.258 0.000 ffi.py:149(__call__)\n", + " 100 0.000 0.000 1.167 0.012 caching.py:650(_load_overload)\n", + " 100 0.000 0.000 1.116 0.011 caching.py:404(rebuild)\n", + " 100 0.001 0.000 1.116 0.011 compiler.py:210(_rebuild)\n", + " 100 0.000 0.000 1.088 0.011 codegen.py:1158(unserialize_library)\n", + " 100 0.000 0.000 1.087 0.011 codegen.py:926(_unserialize)\n", + " 100 0.000 0.000 0.509 0.005 module.py:29(parse_bitcode)\n", + " 200 0.002 0.000 0.509 0.003 codegen.py:1088(_load_defined_symbols)\n", + " 400 0.014 0.000 0.501 0.001 codegen.py:1092()\n", + " 21100 0.008 0.000 0.315 0.000 ffi.py:356(__del__)\n", + " 21100 0.009 0.000 0.305 0.000 ffi.py:313(close)\n", + " 143 0.001 0.000 0.286 0.002 decorators.py:189(wrapper)\n", + " 100 0.001 0.000 0.282 0.003 module.py:76(_dispose)\n", + " 100 0.000 0.000 0.269 0.003 dispatcher.py:862(enable_caching)\n", + " 100 0.001 0.000 0.269 0.003 caching.py:610(__init__)\n", + " 100 0.001 0.000 0.267 0.003 caching.py:336(__init__)\n", + " 100 0.000 0.000 0.265 0.003 caching.py:186(from_function)\n", + " 100 0.000 0.000 0.260 0.003 caching.py:116(ensure_cache_path)\n", + " 100 0.000 0.000 0.256 0.003 tempfile.py:575(TemporaryFile)\n", + " 100 0.000 0.000 0.248 0.002 tempfile.py:244(_mkstemp_inner)\n", + " 100 0.245 0.002 0.245 0.002 {built-in method posix.open}\n", + " 61409 0.022 0.000 0.209 0.000 ffi.py:73(__exit__)\n", + " 61409 0.015 0.000 0.182 0.000 base.py:1260(exit_fn)\n", + " 61409 0.021 0.000 0.181 0.000 ffi.py:67(__enter__)\n", + " 20900 0.011 0.000 0.181 0.000 module.py:213(__next__)\n", + " 123018 0.042 0.000 0.171 0.000 event.py:193(broadcast)\n", + " 61509 0.033 0.000 0.168 0.000 event.py:388(end_event)\n", + " 20100 0.016 0.000 0.161 0.000 value.py:206(is_declaration)\n", + " 61409 0.014 0.000 0.155 0.000 base.py:1257(enter_fn)\n", + " 61509 0.032 0.000 0.141 0.000 event.py:374(start_event)\n", + " 15600 0.011 0.000 0.125 0.000 value.py:143(name)\n", + " 122818 0.045 0.000 0.121 0.000 event.py:227(notify)\n", + " 15000 0.009 0.000 0.114 0.000 module.py:233(_next)\n", + " 270/100 0.000 0.000 0.102 0.001 base.py:269(refresh)\n", + "1473/1436 0.003 0.000 0.082 0.000 :1022(_find_and_load)\n", + " 123018 0.037 0.000 0.073 0.000 event.py:84(__init__)\n", + "4640/4459 0.003 0.000 0.070 0.000 :1053(_handle_fromlist)\n", + " 73/36 0.000 0.000 0.066 0.002 :987(_find_and_load_unlocked)\n", + " 127/36 0.000 0.000 0.064 0.002 :233(_call_with_frames_removed)\n", + " 72/36 0.000 0.000 0.064 0.002 :664(_load_unlocked)\n", + " 69/35 0.000 0.000 0.063 0.002 :877(exec_module)\n", + " 270 0.003 0.000 0.062 0.000 cpu.py:69(load_additional_registries)\n", + " 100 0.000 0.000 0.058 0.001 codegen.py:782(_finalize_final_module)\n", + " 52/33 0.000 0.000 0.057 0.002 {built-in method builtins.__import__}\n" + ] + } + ], + "source": [ + "print(open('prun0', 'r').read())" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "045d95d5-f58a-4ee3-a591-ea2bf241269e", + "metadata": {}, + "outputs": [], + "source": [ + "# create point-source potential predictor\n", + "psp = lfpykit.PointSourcePotential(cell, **el_params)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "a6d382dd-d4c6-4bab-9147-3cd5b189a700", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " \n", + "*** Profile printout saved to text file 'prun1'.\n" + ] + } + ], + "source": [ + "%%prun -s cumulative -q -l 20 -T prun1\n", + "for i in range(100):\n", + " psp.get_transformation_matrix()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "2531c6fd-935f-4907-89a9-aef29e73f277", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " 3203403 function calls in 8.508 seconds\n", + "\n", + " Ordered by: cumulative time\n", + "\n", + " ncalls tottime percall cumtime percall filename:lineno(function)\n", + " 1 0.000 0.000 8.508 8.508 {built-in method builtins.exec}\n", + " 1 0.010 0.010 8.508 8.508 :1()\n", + " 100 0.232 0.002 8.498 0.085 models.py:279(get_transformation_matrix)\n", + " 100100 1.279 0.000 8.266 0.000 lfpcalc.py:604(calc_lfp_pointsource)\n", + " 300300 0.065 0.000 6.647 0.000 {method 'mean' of 'numpy.ndarray' objects}\n", + " 300300 0.806 0.000 6.582 0.000 _methods.py:162(_mean)\n", + " 300300 5.484 0.000 5.484 0.000 {method 'reduce' of 'numpy.ufunc' objects}\n", + " 100100 0.340 0.000 0.340 0.000 lfpcalc.py:686(_check_rlimit_point)\n", + " 300300 0.210 0.000 0.245 0.000 _methods.py:66(_count_reduce_items)\n", + " 600600 0.026 0.000 0.026 0.000 {built-in method builtins.issubclass}\n", + " 600600 0.026 0.000 0.026 0.000 {built-in method builtins.isinstance}\n", + " 300300 0.019 0.000 0.019 0.000 {built-in method numpy.core._multiarray_umath.normalize_axis_index}\n", + " 300300 0.011 0.000 0.011 0.000 {built-in method numpy.asanyarray}\n", + " 100 0.000 0.000 0.000 0.000 {built-in method numpy.empty}\n", + " 1 0.000 0.000 0.000 0.000 {method 'disable' of '_lsprof.Profiler' objects}\n" + ] + } + ], + "source": [ + "print(open('prun1', 'r').read())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "281b24fe-12ed-40eb-a56e-2983c2a8806a", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.2" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/lfpykit/lfpcalc.py b/lfpykit/lfpcalc.py index 2fc66aa..09eec78 100644 --- a/lfpykit/lfpcalc.py +++ b/lfpykit/lfpcalc.py @@ -14,6 +14,8 @@ """ +from numba import njit +import numba import numpy as np @@ -352,8 +354,15 @@ def _anisotropic_line_source_case_iiii(a, b, c): np.arcsinh(b / np.sqrt(4 * a * c - b * b))) -def calc_lfp_linesource(cell_x, cell_y, cell_z, - x, y, z, sigma, r_limit): +@njit(nogil=True, cache=True, fastmath=False) +def calc_lfp_linesource(cell_x: numba.double[:, :], + cell_y: numba.double[:, :], + cell_z: numba.double[:, :], + x: numba.double, + y: numba.double, + z: numba.double, + sigma: numba.double, + r_limit: numba.double[:]): """Calculate electric field potential using the line-source method, all segments treated as line sources. @@ -385,36 +394,27 @@ def calc_lfp_linesource(cell_x, cell_y, cell_z, zstart = cell_z[:, 0] zend = cell_z[:, 1] - return _calc_lfp_linesource( - xstart, - xend, - ystart, - yend, - zstart, - zend, - x, - y, - z, - sigma, - r_limit) - - -def _calc_lfp_linesource(xstart, - xend, - ystart, - yend, - zstart, - zend, - x, - y, - z, - sigma, - r_limit): + return _calc_lfp_linesource(xstart, xend, ystart, yend, zstart, zend, + x, y, z, sigma, r_limit) + + +@njit(nogil=True, cache=True, fastmath=False) +def _calc_lfp_linesource(xstart: numba.double[:], + xend: numba.double[:], + ystart: numba.double[:], + yend: numba.double[:], + zstart: numba.double[:], + zend: numba.double[:], + x: numba.double, + y: numba.double, + z: numba.double, + sigma: numba.double, + r_limit: numba.double[:]): deltaS = _deltaS_calc(xstart, xend, ystart, yend, zstart, zend) h = _h_calc(xstart, xend, ystart, yend, zstart, zend, deltaS, x, y, z) r2 = _r2_calc(xend, yend, zend, x, y, z, h) - too_close_idxs = np.where(r2 < r_limit * r_limit)[0] + too_close_idxs = r2 < (r_limit * r_limit) r2[too_close_idxs] = r_limit[too_close_idxs]**2 l_ = h + deltaS @@ -426,11 +426,11 @@ def _calc_lfp_linesource(xstart, mapping = np.zeros(xstart.size) # case i, h < 0, l < 0, see Eq. C.13 in Gary Holt's thesis, 1998. - [i] = np.where(hnegi & lnegi) + i = hnegi & lnegi # case ii, h < 0, l >= 0 - [ii] = np.where(hnegi & lposi) + ii = hnegi & lposi # case iii, h >= 0, l >= 0 - [iii] = np.where(hposi & lposi) + iii = hposi & lposi mapping[i] = _linesource_calc_case1(l_[i], r2[i], h[i]) mapping[ii] = _linesource_calc_case2(l_[ii], r2[ii], h[ii]) @@ -517,39 +517,43 @@ def calc_lfp_root_as_point(cell_x, cell_y, cell_z, x, y, z, sigma, r_limit, return 1 / (4 * np.pi * sigma * deltaS) * mapping -def _linesource_calc_case1(l_i, - r2_i, - h_i): +@njit(nogil=True, cache=True, fastmath=True) +def _linesource_calc_case1(l_i: numba.double[:], + r2_i: numba.double[:], + h_i: numba.double[:]): """Calculates linesource contribution for case i""" bb = np.sqrt(h_i * h_i + r2_i) - h_i cc = np.sqrt(l_i * l_i + r2_i) - l_i return np.log(bb / cc) -def _linesource_calc_case2(l_ii, - r2_ii, - h_ii): +@njit(nogil=True, cache=True, fastmath=True) +def _linesource_calc_case2(l_ii: numba.double[:], + r2_ii: numba.double[:], + h_ii: numba.double[:]): """Calculates linesource contribution for case ii""" bb = np.sqrt(h_ii * h_ii + r2_ii) - h_ii cc = (l_ii + np.sqrt(l_ii * l_ii + r2_ii)) / r2_ii return np.log(bb * cc) -def _linesource_calc_case3(l_iii, - r2_iii, - h_iii): +@njit(nogil=True, cache=True, fastmath=True) +def _linesource_calc_case3(l_iii: numba.double[:], + r2_iii: numba.double[:], + h_iii: numba.double[:]): """Calculates linesource contribution for case iii""" bb = np.sqrt(l_iii * l_iii + r2_iii) + l_iii cc = np.sqrt(h_iii * h_iii + r2_iii) + h_iii return np.log(bb / cc) -def _deltaS_calc(xstart, - xend, - ystart, - yend, - zstart, - zend): +@njit(nogil=True, cache=True, fastmath=True) +def _deltaS_calc(xstart: numba.double[:], + xend: numba.double[:], + ystart: numba.double[:], + yend: numba.double[:], + zstart: numba.double[:], + zend: numba.double[:]): """Returns length of each segment""" deltaS = np.sqrt((xstart - xend)**2 + (ystart - yend)**2 + @@ -557,16 +561,17 @@ def _deltaS_calc(xstart, return deltaS -def _h_calc(xstart, - xend, - ystart, - yend, - zstart, - zend, - deltaS, - x, - y, - z): +@njit(nogil=True, cache=True, fastmath=True) +def _h_calc(xstart: numba.double[:], + xend: numba.double[:], + ystart: numba.double[:], + yend: numba.double[:], + zstart: numba.double[:], + zend: numba.double[:], + deltaS: numba.double[:], + x: numba.double, + y: numba.double, + z: numba.double): """Subroutine used by calc_lfp_*()""" ccX = (x - xend) * (xend - xstart) ccY = (y - yend) * (yend - ystart) @@ -576,13 +581,14 @@ def _h_calc(xstart, return cc / deltaS -def _r2_calc(xend, - yend, - zend, - x, - y, - z, - h): +@njit(nogil=True, cache=True, fastmath=True) +def _r2_calc(xend: numba.double[:], + yend: numba.double[:], + zend: numba.double[:], + x: numba.double, + y: numba.double, + z: numba.double, + h: numba.double[:]): """Subroutine used by calc_lfp_*()""" r2 = (xend - x)**2 + (yend - y)**2 + (zend - z)**2 - h**2 return np.abs(r2) @@ -676,7 +682,6 @@ def calc_lfp_pointsource_anisotropic(cell_x, cell_y, cell_z, mapping = 1 / (4 * np.pi * sigma_r) return mapping - def _check_rlimit_point(r2, r_limit): """Correct r2 so that r2 >= r_limit**2 for all values""" inds = r2 < r_limit * r_limit diff --git a/lfpykit/models.py b/lfpykit/models.py index 2877d98..7c94bd4 100644 --- a/lfpykit/models.py +++ b/lfpykit/models.py @@ -13,6 +13,7 @@ GNU General Public License for more details. """ +import numba import sys from copy import deepcopy import numpy as np @@ -459,10 +460,17 @@ def get_transformation_matrix(self): else: r_limit = self.cell.d / 2 - def _get_transform(cell_x, cell_y, cell_z, - x, y, z, sigma, r_limit): + @numba.njit(nogil=True, cache=True, fastmath=False, parallel=True) + def _get_transform(cell_x: numba.double[:, :], + cell_y: numba.double[:, :], + cell_z: numba.double[:, :], + x: numba.double, + y: numba.double, + z: numba.double, + sigma: numba.double, + r_limit: numba.double[:]): M = np.empty((x.size, cell_x.shape[0])) - for j in range(x.size): + for j in numba.prange(x.size): M[j, :] = lfpcalc.calc_lfp_linesource(cell_x=cell_x, cell_y=cell_y, cell_z=cell_z, diff --git a/lfpykit/tests/test_lfpcalc.py b/lfpykit/tests/test_lfpcalc.py index bf0d8da..0b3a6f3 100644 --- a/lfpykit/tests/test_lfpcalc.py +++ b/lfpykit/tests/test_lfpcalc.py @@ -51,10 +51,10 @@ def test_lfpcalc_calc_lfp_pointsource_00(self): np.testing.assert_equal(1. / (4 * np.pi * sigma), lfpcalc.calc_lfp_pointsource( cell.x, cell.y, cell.z, - x=0.5, y=0, z=1, - sigma=sigma, - r_limit=cell.d / 2 - )) + x=0.5, y=0, z=1, + sigma=sigma, + r_limit=cell.d / 2 + )) def test_lfpcalc_calc_lfp_pointsource_moi_00(self): """ @@ -68,19 +68,17 @@ def test_lfpcalc_calc_lfp_pointsource_moi_00(self): steps = 20 cell = DummyCell(np.array([[h / 2, h / 2]])) - in_vivo = lfpcalc.calc_lfp_pointsource( - cell.x, cell.y, cell.z, - x=0.5, y=0, z=1, sigma=sigma_T, - r_limit=cell.d / 2) - in_vitro = lfpcalc.calc_lfp_pointsource_moi( - cell.x, cell.y, cell.z, - x=0.5, y=0, z=1, - sigma_T=sigma_T, - sigma_G=sigma_G, - sigma_S=sigma_S, - r_limit=cell.d / 2, - h=h, - steps=steps) + in_vivo = lfpcalc.calc_lfp_pointsource(cell.x, cell.y, cell.z, + x=0.5, y=0, z=1, sigma=sigma_T, + r_limit=cell.d / 2) + in_vitro = lfpcalc.calc_lfp_pointsource_moi(cell.x, cell.y, cell.z, + x=0.5, y=0, z=1, + sigma_T=sigma_T, + sigma_G=sigma_G, + sigma_S=sigma_S, + r_limit=cell.d / 2, + h=h, + steps=steps) np.testing.assert_equal(in_vivo, in_vitro) @@ -316,6 +314,7 @@ def test_lfpcalc_calc_lfp_pointsource_moi_saline_effect(self): h=h, steps=steps) + without_saline = lfpcalc.calc_lfp_pointsource_moi( cell.x, cell.y, cell.z, x=0, y=0, z=0, diff --git a/requirements.txt b/requirements.txt index 33eb3ef..ade43ad 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ # pip requirements file numpy>=1.15.2 +numba scipy sympy MEAutility diff --git a/setup.py b/setup.py index 5616e24..e4bc64f 100644 --- a/setup.py +++ b/setup.py @@ -44,6 +44,7 @@ python_requires='>=3.6', install_requires=[ 'numpy>=1.15.2', + 'numba', 'scipy', 'meautility'], package_data={'lfpykit': [os.path.join('tests', '*.npz'),