Skip to content

Commit 624a5e5

Browse files
version that works
1 parent 5a79088 commit 624a5e5

File tree

4 files changed

+385
-18
lines changed

4 files changed

+385
-18
lines changed

bxa/xspec/solver.py

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -170,12 +170,24 @@ def log_likelihood(self, params):
170170
def run(
171171
self, sampler_kwargs={'resume': 'overwrite'}, run_kwargs={'Lepsilon': 0.1},
172172
speed="safe", resume=None, n_live_points=None,
173-
frac_remain=None, Lepsilon=0.1, evidence_tolerance=None
173+
frac_remain=None, Lepsilon=0.1, evidence_tolerance=None,
174+
stepsampler_kwargs=None,
174175
):
175176
"""Run nested sampling with ultranest.
176177
177-
:sampler_kwargs: arguments passed to ReactiveNestedSampler (see ultranest documentation)
178-
:run_kwargs: arguments passed to ReactiveNestedSampler.run() (see ultranest documentation)
178+
:param sampler_kwargs: arguments passed to ReactiveNestedSampler (see ultranest documentation)
179+
:param run_kwargs: arguments passed to ReactiveNestedSampler.run() (see ultranest documentation)
180+
:param stepsampler_kwargs: dictionary, which contains the following keys:
181+
`'initial_max_ncalls'` (int), for example 40000. This sets the initial sampling
182+
with MLFriends before switching to a step sampler. Setting this to zero may help resuming.
183+
all other arguments are passed directly to `ultranest.stepsampler.SliceSampler`,
184+
including arguments such as max_nsteps, nsteps, and
185+
`'generate_direction'` (function), which can be for example `ultranest.stepsampler.generate_mixture_random_direction`:
186+
the name of the proposal function in `ultranest.stepsampler` passed to `ultranest.stepsampler.SliceSampler`.
187+
For partial parameter updates to consider slow and fast variables, use `ultranest.stepsampler.SpeedVariableGenerator`.
188+
A recommended configuration is:
189+
`stepsampler_kwargs=dict(generate_direction=ultranest.stepsampler.generate_mixture_random_direction,
190+
initial_max_ncalls=40000, nsteps=100, max_nsteps=1000, region_filter=False)`.
179191
180192
The following arguments are also available directly for backward compatibility:
181193
@@ -184,6 +196,11 @@ def run(
184196
:param evidence_tolerance: sets run_kwargs['dlogz']
185197
:param Lepsilon: sets run_kwargs['Lepsilon']
186198
:param frac_remain: sets run_kwargs['frac_remain']
199+
:param speed: 'safe' (default), uses MLFriends algorithm of UltraNest, 'auto':
200+
corresponds to `stepsampler_kwargs=dict(generate_direction=ultranest.stepsampler.generate_mixture_random_direction,
201+
initial_max_ncalls=40000, nsteps=1000, max_nsteps=1000, adaptive_nsteps='move-distance')`,
202+
a integer setting nsteps, corresponds to
203+
`stepsampler_kwargs=dict(generate_direction=ultranest.stepsampler.generate_mixture_random_direction, nsteps=speed)`.
187204
"""
188205

189206
# run nested sampling
@@ -210,23 +227,22 @@ def run(
210227
log_dir=self.outputfiles_basename,
211228
vectorized=self.vectorized, **sampler_kwargs)
212229

213-
if speed == "safe":
214-
pass
215-
elif speed == "auto":
216-
region_filter = run_kwargs.pop('region_filter', True)
217-
self.sampler.run(max_ncalls=40000, **run_kwargs)
218-
219-
self.sampler.stepsampler = ultranest.stepsampler.SliceSampler(
220-
nsteps=1000,
230+
if speed == 'auto' and stepsampler_kwargs is None:
231+
stepsampler_kwargs = dict(
221232
generate_direction=ultranest.stepsampler.generate_mixture_random_direction,
222-
adaptive_nsteps='move-distance', region_filter=region_filter
223-
)
224-
else:
225-
self.sampler.stepsampler = ultranest.stepsampler.SliceSampler(
233+
initial_max_ncalls=40000, nsteps=1000, max_nsteps=1000, adaptive_nsteps='move-distance')
234+
elif speed not in ('auto', 'safe'):
235+
assert stepsampler_kwargs is None, 'do not set both speed and stepsampler_kwargs'
236+
stepsampler_kwargs = dict(
226237
generate_direction=ultranest.stepsampler.generate_mixture_random_direction,
227-
nsteps=speed,
228-
adaptive_nsteps=False,
229-
region_filter=False)
238+
initial_max_ncalls=0, nsteps=int(speed))
239+
240+
if stepsampler_kwargs is not None:
241+
initial_max_ncalls = stepsampler_kwargs.pop('initial_max_ncalls', 0)
242+
if initial_max_ncalls > 0:
243+
self.sampler.run(max_ncalls=initial_max_ncalls, **run_kwargs)
244+
245+
self.sampler.stepsampler = ultranest.stepsampler.SliceSampler(**stepsampler_kwargs)
230246

231247
self.sampler.run(**run_kwargs)
232248
self.sampler.print_results()
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
"""
2+
Example of doing BXA in X-spec
3+
"""
4+
import bxa.xspec as bxa
5+
from xspec import Fit, Plot, Model, AllData, AllModels
6+
import scipy.stats
7+
from numpy import log10
8+
9+
def concatenate_transformations(list_of_transformations, list_of_prior_functions):
10+
"""
11+
Concatenates transformation functions and prior functions.
12+
"""
13+
def combined_prior(cube):
14+
params = cube.copy()
15+
i = 0
16+
for transformations, prior in zip(list_of_transformations, list_of_prior_functions):
17+
# transform the first block, which has this many parameters:
18+
nparams = len(transformations)
19+
# apply the responsible prior to these:
20+
params[i:i+nparams] = prior(cube[i:i+nparams])
21+
i += nparams
22+
return params
23+
24+
# return the combined function and full list of transform functions
25+
combined_transformations = []
26+
for transformations in list_of_transformations:
27+
combined_transformations += transformations
28+
return combined_prior, combined_transformations
29+
30+
def create_gaussian_hierarchical_prior_function(mean_min, mean_max, sigma_max, participating_parameters, sigma_min=0, name='hierarchical'):
31+
"""Hierarchical Gaussian prior.
32+
33+
Flat priors on the hyper-parameters: mean and sigma.
34+
This avoids a strong funnel (compared to a log-uniform sigma prior)
35+
and is recommended by Gelman et al.
36+
"""
37+
38+
upper_transformations = [
39+
dict(skip_setting=True, name='mean-%s' % name),
40+
dict(skip_setting=True, name='sigma-%s' % name),
41+
]
42+
lower_transformations = [bxa.create_uniform_prior_for(m, par) for m, par in participating_parameters]
43+
44+
def gaussian_hierarchical_prior(cube):
45+
params = cube.copy()
46+
mean = params[0] = (mean_max - mean_min) * cube[0] + mean_min
47+
sigma = params[1] = (sigma_max - sigma_min) * cube[1] + sigma_min
48+
49+
rv = scipy.stats.norm(mean, sigma)
50+
for i, (_, par) in enumerate(participating_parameters, start=2):
51+
pval, pdelta, pmin, pbottom, ptop, pmax = par.values
52+
params[i] = max(pmin, min(pmax, rv.ppf(cube[i])))
53+
return params
54+
55+
return gaussian_hierarchical_prior, upper_transformations + lower_transformations
56+
57+
def create_loggaussian_hierarchical_prior_function(log_mean_min, log_mean_max, sigma_max, participating_parameters, sigma_min=0, name='hierarchical'):
58+
"""Hierarchical Gaussian prior on the log of parameters.
59+
60+
Flat priors on the hyper-parameters: mean and sigma.
61+
This avoids a strong funnel (compared to a log-uniform sigma prior)
62+
and is recommended by Gelman et al.
63+
"""
64+
upper_transformations = [
65+
dict(skip_setting=True, name='log_mean-%s' % name),
66+
dict(skip_setting=True, name='sigma-%s' % name),
67+
]
68+
lower_transformations = [bxa.create_loguniform_prior_for(m, par) for m, par in participating_parameters]
69+
70+
def loggaussian_hierarchical_prior(cube):
71+
params = cube.copy()
72+
log_mean = params[0] = (log_mean_max - log_mean_min) * cube[0] + log_mean_min
73+
sigma = params[1] = (sigma_max - sigma_min) * cube[1] + sigma_min
74+
75+
rv = scipy.stats.norm(log_mean, sigma)
76+
for i, (_, par) in enumerate(participating_parameters, start=2):
77+
pval, pdelta, pmin, pbottom, ptop, pmax = par.values
78+
params[i] = max(log10(pmin), min(log10(pmax), rv.ppf(cube[i])))
79+
return params
80+
81+
return loggaussian_hierarchical_prior, upper_transformations + lower_transformations
82+
83+
Fit.statMethod = 'cstat'
84+
Plot.xAxis = 'keV'
85+
86+
print("setting up main model")
87+
Model("wabs*pow")
88+
nH_transformations = []
89+
nH_parameters = []
90+
norm_transformations = []
91+
norm_parameters = []
92+
93+
for i, filename in enumerate(['sim1.fak', 'sim2.fak', 'sim3.fak'], start=1):
94+
print("loading..", filename)
95+
AllData("%d:%d %s" % (i, i, filename))
96+
s1 = AllData(i)
97+
s1.ignore("**-0.2, 8.0-**")
98+
print("setting up model")
99+
m1 = AllModels(i)
100+
m1.wabs.nH.values = ",,0.01,0.01,1000,1000"
101+
m1.powerlaw.norm.values = ",,1e-10,1e-10,1e1,1e1"
102+
#m1.powerlaw.norm.values = ",,1e-10,1e-10,1e1,1e1"
103+
#m1.powerlaw.PhoIndex.values = ",,1,1,3,3"
104+
if i != 1:
105+
m1.powerlaw.PhoIndex.link = '=%d' % AllModels(1).powerlaw.PhoIndex.index
106+
nH_transformations.append(bxa.create_uniform_prior_for( m1, m1.wabs.nH))
107+
nH_parameters.append((m1, m1.wabs.nH))
108+
norm_transformations.append(bxa.create_uniform_prior_for(m1, m1.powerlaw.norm))
109+
norm_parameters.append((m1, m1.powerlaw.norm))
110+
111+
print("setting up multi-level priors")
112+
# define multi-level prior for communicating parameters
113+
hierarchical_nH_prior, hierarchical_nH_transformations = create_loggaussian_hierarchical_prior_function(-2, 2, 2, nH_parameters, name='NH')
114+
hierarchical_norm_prior, hierarchical_norm_transformations = create_loggaussian_hierarchical_prior_function(-4, 0, 2, norm_parameters, name='norm')
115+
116+
# set parameters on non-communicating parameters
117+
print("setting up non-communicating parameters")
118+
mref = AllModels(1)
119+
mref.powerlaw.PhoIndex.values = ",,1,1,3,3"
120+
transformations = [
121+
bxa.create_gaussian_prior_for(mref, mref.powerlaw.PhoIndex, 1.95, 0.15),
122+
]
123+
AllModels.show()
124+
AllData.show()
125+
126+
plain_prior = bxa.create_prior_function(transformations)
127+
128+
combined_prior, combined_transformations = concatenate_transformations(
129+
[transformations, hierarchical_nH_transformations, hierarchical_norm_transformations],
130+
[plain_prior, hierarchical_nH_prior, hierarchical_norm_prior]
131+
)
132+
133+
print('running analysis ...')
134+
# where to store intermediate and final results? this is the prefix used
135+
solver = bxa.BXASolver(
136+
transformations=combined_transformations,
137+
prior_function=combined_prior,
138+
outputfiles_basename='dblhierarchical/',
139+
)
140+
import ultranest.stepsampler
141+
results = solver.run(resume=True,
142+
run_kwargs=dict(frac_remain=0.5),
143+
stepsampler_kwargs=dict(
144+
generate_direction=ultranest.stepsampler.generate_mixture_random_direction,
145+
initial_max_ncalls=100000, nsteps=100))
146+
147+
#results = solver.run(resume=True, speed='auto',
148+
# stepsampler_kwargs=dict(
149+
# generate_direction=ultranest.stepsampler.generate_mixture_random_direction,
150+
# initial_max_ncalls=40000, nsteps=1000, max_nsteps=1000, adaptive_nsteps='move-distance'))
151+
152+
print('running analysis ... done!')
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
"""
2+
Example of doing BXA in X-spec
3+
"""
4+
import bxa.xspec as bxa
5+
from xspec import Fit, Plot, Model, AllData, AllModels
6+
import scipy.stats
7+
from numpy import log10
8+
9+
Fit.statMethod = 'cstat'
10+
Plot.xAxis = 'keV'
11+
12+
print("setting up main model")
13+
Model("wabs*pow")
14+
transformations = []
15+
parameters = []
16+
17+
for i, filename in enumerate(['sim1.fak', 'sim2.fak', 'sim3.fak'], start=1):
18+
print("loading..", filename)
19+
AllData("%d:%d %s" % (i, i, filename))
20+
s1 = AllData(i)
21+
s1.ignore("**-0.2, 8.0-**")
22+
print("setting up model")
23+
m1 = AllModels(i)
24+
m1.wabs.nH.values = ",,0.01,0.01,1000,1000"
25+
m1.powerlaw.norm.values = ",,1e-10,1e-10,1e1,1e1"
26+
#m1.powerlaw.norm.values = ",,1e-10,1e-10,1e1,1e1"
27+
if i != 1:
28+
m1.powerlaw.PhoIndex.link = '=%d' % AllModels(1).powerlaw.PhoIndex.index
29+
transformations += [
30+
bxa.create_loguniform_prior_for(m1, m1.wabs.nH),
31+
bxa.create_loguniform_prior_for(m1, m1.powerlaw.norm)]
32+
parameters += [
33+
(m1, m1.wabs.nH),
34+
(m1, m1.powerlaw.norm)]
35+
36+
# set parameters on non-communicating parameters
37+
print("setting up non-communicating parameters")
38+
mref = AllModels(1)
39+
mref.powerlaw.PhoIndex.values = ",,1,1,3,3"
40+
transformations += [
41+
bxa.create_gaussian_prior_for(mref, mref.powerlaw.PhoIndex, 1.95, 0.15),
42+
]
43+
AllModels.show()
44+
AllData.show()
45+
46+
prior = bxa.create_prior_function(transformations)
47+
48+
print('running analysis ...')
49+
# where to store intermediate and final results? this is the prefix used
50+
solver = bxa.BXASolver(
51+
transformations=transformations,
52+
prior_function=prior,
53+
outputfiles_basename='independent/',
54+
)
55+
import ultranest.stepsampler
56+
results = solver.run(resume=True,
57+
run_kwargs=dict(frac_remain=0.5),
58+
stepsampler_kwargs=dict(
59+
generate_direction=ultranest.stepsampler.generate_mixture_random_direction,
60+
initial_max_ncalls=200000, nsteps=100))
61+
62+
#results = solver.run(resume=True, speed='auto',
63+
# stepsampler_kwargs=dict(
64+
# generate_direction=ultranest.stepsampler.generate_mixture_random_direction,
65+
# initial_max_ncalls=40000, nsteps=1000, max_nsteps=1000, adaptive_nsteps='move-distance'))
66+
67+
print('running analysis ... done!')

0 commit comments

Comments
 (0)