11from abc import abstractmethod
2- from copy import deepcopy
32from dataclasses import dataclass
43from typing import Dict , List , Tuple
54
1110from light_curve .light_curve_py .features .rainbow ._parameters import create_parameters_class
1211from light_curve .light_curve_py .features .rainbow ._scaler import MultiBandScaler , Scaler
1312from light_curve .light_curve_py .minuit_lsq import LeastSquares
13+ from light_curve .light_curve_py .minuit_ml import MaximumLikelihood
1414
1515__all__ = ["BaseRainbowFit" ]
1616
@@ -121,6 +121,9 @@ def _check_iminuit():
121121 if LeastSquares is None :
122122 raise ImportError (IMINUIT_IMPORT_ERROR )
123123
124+ if MaximumLikelihood is None :
125+ raise ImportError (IMINUIT_IMPORT_ERROR )
126+
124127 try :
125128 try :
126129 from packaging .version import parse as parse_version
@@ -144,48 +147,65 @@ def temp_func(self, t, params):
144147 """Temperature evolution function."""
145148 return NotImplementedError
146149
147- @ abstractmethod
148- def _unscale_parameters ( self , params , t_scaler : Scaler , m_scaler : MultiBandScaler ) -> None :
149- """Unscale parameters from internal units, in-place.
150+ def _parameter_scalings ( self ) -> Dict [ str , str ]:
151+ """Rules for scaling/unscaling the parameters"""
152+ rules = {}
150153
151- No baseline parameters are needed to be unscaled.
152- """
153- return NotImplementedError
154+ if self .with_baseline :
155+ for band_name in self .bands .names :
156+ baseline_name = self .p .baseline_parameter_name (band_name )
157+ rules [baseline_name ] = "baseline"
154158
155- def _unscale_errors (self , errors , t_scaler : Scaler , m_scaler : MultiBandScaler ) -> None :
156- """Unscale parameter errors from internal units, in-place.
159+ return rules
157160
158- No baseline parameters are needed to be unscaled.
159- """
161+ def _parameter_scale (self , name : str , t_scaler : Scaler , m_scaler : MultiBandScaler ) -> float :
162+ """Return the scale factor to be applied to the parameter to unscale it"""
163+ scaling = self ._parameter_scalings ().get (name )
164+ if scaling == "time" or scaling == "timescale" :
165+ return t_scaler .scale
166+ elif scaling == "flux" :
167+ return m_scaler .scale
160168
161- # We need to modify original scalers to only apply the scale, not shifts, to the errors
162- # It should be re-implemented in subclasses for a cleaner way to unscale the errors
163- t_scaler = deepcopy (t_scaler )
164- m_scaler = deepcopy (m_scaler )
165- t_scaler .reset_shift ()
166- m_scaler .reset_shift ()
169+ return 1
167170
168- return self ._unscale_parameters (errors , t_scaler , m_scaler )
171+ def _unscale_parameters (self , params , t_scaler : Scaler , m_scaler : MultiBandScaler ) -> None :
172+ """Unscale parameters from internal units, in-place."""
173+ for name , scaling in self ._parameter_scalings ().items ():
174+ if scaling == "time" :
175+ params [self .p [name ]] = t_scaler .undo_shift_scale (params [self .p [name ]])
169176
170- def _unscale_baseline_parameters ( self , params , m_scaler : MultiBandScaler ) -> None :
171- """Unscale baseline parameters from internal units, in-place.
177+ elif scaling == "timescale" :
178+ params [ self . p [ name ]] = t_scaler . undo_scale ( params [ self . p [ name ]])
172179
173- Must be used only if `with_baseline` is True.
174- """
175- for band_name in self .bands .names :
176- baseline_name = self .p .baseline_parameter_name (band_name )
177- baseline = params [self .p [baseline_name ]]
178- params [self .p [baseline_name ]] = m_scaler .undo_shift_scale_band (baseline , band_name )
180+ elif scaling == "flux" :
181+ params [self .p [name ]] = m_scaler .undo_scale (params [self .p [name ]])
179182
180- def _unscale_baseline_errors (self , errors , m_scaler : MultiBandScaler ) -> None :
181- """Unscale baseline parameters from internal units, in-place.
183+ elif scaling == "baseline" :
184+ band_name = self .p .baseline_band_name (name )
185+ baseline = params [self .p [name ]]
186+ params [self .p [name ]] = m_scaler .undo_shift_scale_band (baseline , band_name )
182187
183- Must be used only if `with_baseline` is True.
184- """
185- for band_name in self .bands .names :
186- baseline_name = self .p .baseline_parameter_name (band_name )
187- baseline = errors [self .p [baseline_name ]]
188- errors [self .p [baseline_name ]] = m_scaler .undo_scale_band (baseline , band_name )
188+ pass
189+
190+ elif scaling is None or scaling .lower () == "none" :
191+ pass
192+
193+ else :
194+ raise ValueError ("Unsupported parameter scaling: " + scaling )
195+
196+ def _unscale_errors (self , errors , t_scaler : Scaler , m_scaler : MultiBandScaler ) -> None :
197+ """Unscale parameter errors from internal units, in-place."""
198+ for name in self .names :
199+ scale = self ._parameter_scale (name , t_scaler , m_scaler )
200+ errors [self .p [name ]] *= scale
201+
202+ def _unscale_covariance (self , cov , t_scaler : Scaler , m_scaler : MultiBandScaler ) -> None :
203+ """Unscale parameter covariance from internal units, in-place."""
204+ for name in self .names :
205+ scale = self ._parameter_scale (name , t_scaler , m_scaler )
206+ i = self .p [name ]
207+ cov [:, i ] *= scale
208+ cov [i , :] *= scale
189209
190210 @staticmethod
191211 def planck_nu (wave_cm , T ):
@@ -283,7 +303,19 @@ def _eval(self, *, t, m, sigma, band):
283303 def _eval_and_fill (self , * , t , m , sigma , band , fill_value ):
284304 return super ()._eval_and_fill (t = t , m = m , sigma = sigma , band = band , fill_value = fill_value )
285305
286- def _eval_and_get_errors (self , * , t , m , sigma , band , print_level = None , get_initial = False ):
306+ def _eval_and_get_errors (
307+ self ,
308+ * ,
309+ t ,
310+ m ,
311+ sigma ,
312+ band ,
313+ upper_mask = None ,
314+ get_initial = False ,
315+ return_covariance = False ,
316+ print_level = None ,
317+ debug = False ,
318+ ):
287319 # Initialize data scalers
288320 t_scaler = Scaler .from_time (t )
289321 m_scaler = MultiBandScaler .from_flux (m , band , with_baseline = self .with_baseline )
@@ -311,19 +343,51 @@ def _eval_and_get_errors(self, *, t, m, sigma, band, print_level=None, get_initi
311343 initial_guesses = self ._initial_guesses (t , m , sigma , band )
312344 limits = self ._limits (t , m , sigma , band )
313345
314- least_squares = LeastSquares (
346+ # least_squares = LeastSquares(
347+ cost_function = MaximumLikelihood (
315348 model = self ._lsq_model ,
316349 parameters = limits ,
317350 x = (t , band_idx , wave_cm ),
318351 y = m ,
319352 yerror = sigma ,
353+ upper_mask = upper_mask ,
320354 )
321- minuit = self .Minuit (least_squares , name = self .names , ** initial_guesses )
355+ minuit = self .Minuit (cost_function , name = self .names , ** initial_guesses )
322356 # TODO: expose these parameters through function arguments
323357 if print_level is not None :
324358 minuit .print_level = print_level
325- minuit .strategy = 2
326- minuit .migrad (ncall = 10000 , iterate = 10 )
359+ minuit .strategy = 0 # We will need to manually call .hesse() on convergence anyway
360+
361+ # Supposedly it is not the same as just setting iterate=10?..
362+ for i in range (10 ):
363+ minuit .migrad ()
364+
365+ if minuit .valid :
366+ minuit .hesse ()
367+ # hesse() may may drive it invalid
368+ if minuit .valid :
369+ break
370+ else :
371+ # That's what iterate is supposed to do?..
372+ minuit .simplex ()
373+ # FIXME: it may drive the fit valid, but we will not have Hesse run on last iteration
374+
375+ if debug :
376+ # Expose everything we have to outside, unscaled, for easier debugging
377+ self .minuit = minuit
378+ self .mparams = {
379+ "t" : t ,
380+ "band_idx" : band_idx ,
381+ "wave_cm" : wave_cm ,
382+ "m" : m ,
383+ "sigma" : sigma ,
384+ "limits" : limits ,
385+ "upper_mask" : upper_mask ,
386+ "initial_guesses" : initial_guesses ,
387+ "values" : minuit .values ,
388+ "errors" : minuit .errors ,
389+ "covariance" : minuit .covariance ,
390+ }
327391
328392 if not minuit .valid and self .fail_on_divergence and not get_initial :
329393 raise RuntimeError ("Fitting failed" )
@@ -338,15 +402,19 @@ def _eval_and_get_errors(self, *, t, m, sigma, band, print_level=None, get_initi
338402 errors = np .array (minuit .errors )
339403
340404 self ._unscale_parameters (params , t_scaler , m_scaler )
341- if self .with_baseline :
342- self ._unscale_baseline_parameters (params , m_scaler )
343405
344406 # Unscale errors
345407 self ._unscale_errors (errors , t_scaler , m_scaler )
346- if self .with_baseline :
347- self ._unscale_baseline_errors (errors , m_scaler )
348408
349- return np .r_ [params , reduced_chi2 ], errors
409+ return_values = np .r_ [params , reduced_chi2 ], errors
410+
411+ if return_covariance :
412+ # Unscale covaiance
413+ cov = np .array (minuit .covariance )
414+ self ._unscale_covariance (cov , t_scaler , m_scaler )
415+ return_values += (cov ,)
416+
417+ return return_values
350418
351419 def fit_and_get_errors (self , t , m , sigma , band , * , sorted = None , check = True , ** kwargs ):
352420 t , m , sigma , band = self ._normalize_input (t = t , m = m , sigma = sigma , band = band , sorted = sorted , check = check )
0 commit comments