Skip to content

Commit dd1576f

Browse files
authored
Merge pull request #411 from erusseil/extended_PR
2 parents 42ed8af + dacebdc commit dd1576f

File tree

3 files changed

+223
-107
lines changed

3 files changed

+223
-107
lines changed

light-curve/light_curve/light_curve_py/features/rainbow/bolometric.py

Lines changed: 155 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,19 @@
1+
import math
12
from abc import abstractmethod
23
from dataclasses import dataclass
34
from typing import Dict, List, Union
45

56
import numpy as np
7+
from scipy.special import lambertw
68

7-
__all__ = ["bolometric_terms", "BaseBolometricTerm", "SigmoidBolometricTerm", "BazinBolometricTerm"]
9+
__all__ = [
10+
"bolometric_terms",
11+
"BaseBolometricTerm",
12+
"SigmoidBolometricTerm",
13+
"BazinBolometricTerm",
14+
"LinexpBolometricTerm",
15+
"DoublexpBolometricTerm",
16+
]
817

918

1019
@dataclass()
@@ -186,7 +195,152 @@ def peak_time(t0, amplitude, rise_time, fall_time):
186195
return t0 + np.log(fall_time / rise_time) * rise_time * fall_time / (rise_time + fall_time)
187196

188197

198+
@dataclass()
199+
class LinexpBolometricTerm(BaseBolometricTerm):
200+
"""Linexp function, symmetric form. Generated using a prototype version of Multi-view
201+
Symbolic Regression (Russeil et al. 2024, https://arxiv.org/abs/2402.04298) on
202+
a SLSN ZTF light curve (https://ztf.snad.space/dr17/view/821207100004043)"""
203+
204+
@staticmethod
205+
def parameter_names():
206+
return ["reference_time", "amplitude", "rise_time"]
207+
208+
@staticmethod
209+
def parameter_scalings():
210+
return ["time", "flux", "timescale"]
211+
212+
@staticmethod
213+
def value(t, t0, amplitude, rise_time):
214+
dt = t0 - t
215+
protected_rise = math.copysign(max(1e-5, abs(rise_time)), rise_time)
216+
217+
# Coefficient to make peak amplitude equal to unity
218+
scale = 1 / (protected_rise * np.exp(-1))
219+
220+
power = -dt / protected_rise
221+
power = np.where(power > 100, 100, power)
222+
result = amplitude * scale * dt * np.exp(power)
223+
224+
return result
225+
226+
@staticmethod
227+
def initial_guesses(t, m, sigma, band):
228+
229+
A = np.ptp(m)
230+
231+
# Compute points after or before maximum
232+
peak_time = t[np.argmax(m)]
233+
after = t[-1] - peak_time
234+
before = peak_time - t[0]
235+
236+
# Peak position as weighted centroid of everything above zero
237+
idx = m > np.median(m)
238+
# Weighted centroid sigma
239+
dt = np.sqrt(np.sum((t[idx] - peak_time) ** 2 * m[idx] / sigma[idx]) / np.sum(m[idx] / sigma[idx]))
240+
# Empirical conversion of sigma to rise/rise times
241+
rise_time = dt / 2
242+
rise_time = rise_time if before >= after else -rise_time
243+
244+
initial = {}
245+
# Reference of linexp correspond to the moment where flux == 0
246+
initial["reference_time"] = peak_time + rise_time
247+
initial["amplitude"] = A
248+
initial["rise_time"] = rise_time
249+
250+
return initial
251+
252+
@staticmethod
253+
def limits(t, m, sigma, band):
254+
t_amplitude = np.ptp(t)
255+
m_amplitude = np.ptp(m)
256+
257+
limits = {}
258+
limits["reference_time"] = (np.min(t) - 10 * t_amplitude, np.max(t) + 10 * t_amplitude)
259+
limits["amplitude"] = (0, 10 * m_amplitude)
260+
limits["rise_time"] = (-10 * t_amplitude, 10 * t_amplitude)
261+
262+
return limits
263+
264+
@staticmethod
265+
def peak_time(t0, amplitude, rise_time):
266+
return t0 - rise_time
267+
268+
269+
@dataclass()
270+
class DoublexpBolometricTerm(BaseBolometricTerm):
271+
"""Doublexp function generated using Multi-view Symbolic Regression on ZTF SNIa light curves
272+
Russeil et al. 2024, https://arxiv.org/abs/2402.04298"""
273+
274+
@staticmethod
275+
def parameter_names():
276+
return ["reference_time", "amplitude", "time1", "time2", "p"]
277+
278+
@staticmethod
279+
def parameter_scalings():
280+
return ["time", "flux", "timescale", "timescale", "None"]
281+
282+
@staticmethod
283+
def value(t, t0, amplitude, time1, time2, p):
284+
dt = t - t0
285+
286+
result = np.zeros_like(dt)
287+
288+
# To avoid numerical overflows
289+
maxp = 20
290+
A = -(dt / time1) * (p - np.exp(-(dt / time2)))
291+
A = np.where(A > maxp, maxp, A)
292+
293+
result = amplitude * np.exp(A)
294+
295+
return result
296+
297+
@staticmethod
298+
def initial_guesses(t, m, sigma, band):
299+
A = np.ptp(m)
300+
301+
# Naive peak position from the highest point
302+
t0 = t[np.argmax(m)]
303+
# Peak position as weighted centroid of everything above zero
304+
idx = m > np.median(m)
305+
# t0 = np.sum(t[idx] * m[idx] / sigma[idx]) / np.sum(m[idx] / sigma[idx])
306+
# Weighted centroid sigma
307+
dt = np.sqrt(np.sum((t[idx] - t0) ** 2 * m[idx] / sigma[idx]) / np.sum(m[idx] / sigma[idx]))
308+
309+
# Empirical conversion of sigma to rise/fall times
310+
time1 = 10 * dt
311+
time2 = 10 * dt
312+
313+
initial = {}
314+
initial["reference_time"] = t0
315+
initial["amplitude"] = A
316+
initial["time1"] = time1
317+
initial["time2"] = time2
318+
initial["p"] = 1
319+
320+
return initial
321+
322+
@staticmethod
323+
def limits(t, m, sigma, band):
324+
t_amplitude = np.ptp(t)
325+
m_amplitude = np.ptp(m)
326+
327+
limits = {}
328+
limits["reference_time"] = (np.min(t) - 10 * t_amplitude, np.max(t) + 10 * t_amplitude)
329+
limits["amplitude"] = (0.0, 10 * m_amplitude)
330+
limits["time1"] = (1e-1, 2 * t_amplitude)
331+
limits["time2"] = (1e-1, 2 * t_amplitude)
332+
limits["p"] = (0, 100)
333+
334+
return limits
335+
336+
@staticmethod
337+
def peak_time(t0, p):
338+
return t0 + np.real(-lambertw(p * np.exp(1)) + 1)
339+
340+
189341
bolometric_terms = {
190342
"sigmoid": SigmoidBolometricTerm,
191343
"bazin": BazinBolometricTerm,
344+
"linexp": LinexpBolometricTerm,
345+
"doublexp": DoublexpBolometricTerm,
192346
}

light-curve/light_curve/light_curve_py/features/rainbow/temperature.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -87,24 +87,24 @@ class SigmoidTemperatureTerm(BaseTemperatureTerm):
8787

8888
@staticmethod
8989
def parameter_names():
90-
return ["reference_time", "Tmin", "Tmax", "k_sig"]
90+
return ["reference_time", "Tmin", "Tmax", "t_color"]
9191

9292
@staticmethod
9393
def parameter_scalings():
9494
return ["time", None, None, "timescale"]
9595

9696
@staticmethod
97-
def value(t, t0, temp_min, temp_max, k_sig):
97+
def value(t, t0, temp_min, temp_max, t_color):
9898
dt = t - t0
9999

100100
# To avoid numerical overflows, let's only compute the exponent not too far from t0
101-
idx1 = dt <= -100 * k_sig
102-
idx2 = (dt > -100 * k_sig) & (dt < 100 * k_sig)
103-
idx3 = dt >= 100 * k_sig
101+
idx1 = dt <= -100 * t_color
102+
idx2 = (dt > -100 * t_color) & (dt < 100 * t_color)
103+
idx3 = dt >= 100 * t_color
104104

105105
result = np.zeros(len(dt))
106106
result[idx1] = temp_max
107-
result[idx2] = temp_min + (temp_max - temp_min) / (1.0 + np.exp(dt[idx2] / k_sig))
107+
result[idx2] = temp_min + (temp_max - temp_min) / (1.0 + np.exp(dt[idx2] / t_color))
108108
result[idx3] = temp_min
109109

110110
return result
@@ -114,7 +114,7 @@ def initial_guesses(t, m, sigma, band):
114114
initial = {}
115115
initial["Tmin"] = 7000.0
116116
initial["Tmax"] = 10000.0
117-
initial["k_sig"] = 1.0
117+
initial["t_color"] = 1.0
118118

119119
return initial
120120

@@ -125,7 +125,7 @@ def limits(t, m, sigma, band):
125125
limits = {}
126126
limits["Tmin"] = (1e3, 2e6) # K
127127
limits["Tmax"] = (1e3, 2e6) # K
128-
limits["k_sig"] = (1e-4, 10 * t_amplitude)
128+
limits["t_color"] = (1e-4, 10 * t_amplitude)
129129

130130
return limits
131131

@@ -136,24 +136,24 @@ class DelayedSigmoidTemperatureTerm(BaseTemperatureTerm):
136136

137137
@staticmethod
138138
def parameter_names():
139-
return ["reference_time", "Tmin", "Tmax", "k_sig", "t_delay"]
139+
return ["reference_time", "Tmin", "Tmax", "t_color", "t_delay"]
140140

141141
@staticmethod
142142
def parameter_scalings():
143143
return ["time", None, None, "timescale", "timescale"]
144144

145145
@staticmethod
146-
def value(t, t0, Tmin, Tmax, k_sig, t_delay):
146+
def value(t, t0, Tmin, Tmax, t_color, t_delay):
147147
dt = t - t0 - t_delay
148148

149149
# To avoid numerical overflows, let's only compute the exponent not too far from t0
150-
idx1 = dt <= -100 * k_sig
151-
idx2 = (dt > -100 * k_sig) & (dt < 100 * k_sig)
152-
idx3 = dt >= 100 * k_sig
150+
idx1 = dt <= -100 * t_color
151+
idx2 = (dt > -100 * t_color) & (dt < 100 * t_color)
152+
idx3 = dt >= 100 * t_color
153153

154154
result = np.zeros(len(dt))
155155
result[idx1] = Tmax
156-
result[idx2] = Tmin + (Tmax - Tmin) / (1.0 + np.exp(dt[idx2] / k_sig))
156+
result[idx2] = Tmin + (Tmax - Tmin) / (1.0 + np.exp(dt[idx2] / t_color))
157157
result[idx3] = Tmin
158158

159159
return result
@@ -163,7 +163,7 @@ def initial_guesses(t, m, sigma, band):
163163
initial = {}
164164
initial["Tmin"] = 7000.0
165165
initial["Tmax"] = 10000.0
166-
initial["k_sig"] = 1.0
166+
initial["t_color"] = 1.0
167167
initial["t_delay"] = 0.0
168168

169169
return initial
@@ -175,7 +175,7 @@ def limits(t, m, sigma, band):
175175
limits = {}
176176
limits["Tmin"] = (1e3, 2e6) # K
177177
limits["Tmax"] = (1e3, 2e6) # K
178-
limits["k_sig"] = (1e-4, 10 * t_amplitude)
178+
limits["t_color"] = (1e-4, 10 * t_amplitude)
179179
limits["t_delay"] = (-t_amplitude, t_amplitude)
180180

181181
return limits

0 commit comments

Comments
 (0)