From 0bdab5d697bf1ee3f8cd28c3e4c6f91c0e562c92 Mon Sep 17 00:00:00 2001 From: Konstantin Malanchev Date: Tue, 30 Aug 2022 16:28:09 -0500 Subject: [PATCH 1/9] eval_no_ts_check initial implementation --- src/evaluator.rs | 55 +++++++++---------- src/extractor.rs | 5 +- src/features/amplitude.rs | 4 +- src/features/anderson_darling_normal.rs | 7 ++- src/features/bazin_fit.rs | 1 + src/features/beyond_n_std.rs | 4 +- src/features/bins.rs | 9 ++- src/features/cusum.rs | 6 +- src/features/duration.rs | 4 +- src/features/eta.rs | 6 +- src/features/eta_e.rs | 6 +- src/features/excess_variance.rs | 4 +- src/features/inter_percentile_range.rs | 4 +- src/features/kurtosis.rs | 6 +- src/features/linear_fit.rs | 4 +- src/features/linear_trend.rs | 4 +- src/features/magnitude_percentage_ratio.rs | 10 +--- src/features/maximum_slope.rs | 4 +- src/features/maximum_time_interval.rs | 4 +- src/features/mean.rs | 3 +- src/features/mean_variance.rs | 4 +- src/features/median.rs | 4 +- src/features/median_absolute_deviation.rs | 4 +- .../median_buffer_range_percentage.rs | 4 +- src/features/minimum_time_interval.rs | 4 +- src/features/observation_count.rs | 4 +- src/features/otsu_split.rs | 45 +++++++++------ src/features/percent_amplitude.rs | 4 +- ...percent_difference_magnitude_percentile.rs | 4 +- src/features/periodogram.rs | 5 +- src/features/reduced_chi2.rs | 4 +- src/features/skew.rs | 6 +- src/features/standard_deviation.rs | 4 +- src/features/stetson_k.rs | 6 +- src/features/time_mean.rs | 4 +- src/features/time_standard_deviation.rs | 4 +- src/features/transformed.rs | 11 ++-- src/features/villar_fit.rs | 1 + src/features/weighted_mean.rs | 4 +- src/macros.rs | 12 ++-- src/tests.rs | 39 +++++++++++++ 41 files changed, 196 insertions(+), 131 deletions(-) diff --git a/src/evaluator.rs b/src/evaluator.rs index aa83d6db..de531e0c 100644 --- a/src/evaluator.rs +++ b/src/evaluator.rs @@ -20,6 +20,7 @@ pub struct EvaluatorInfo { pub m_required: bool, pub w_required: bool, pub sorting_required: bool, + pub variability_required: bool, } #[derive(Clone, Debug)] @@ -69,6 +70,11 @@ pub trait EvaluatorInfoTrait { fn is_sorting_required(&self) -> bool { self.get_info().sorting_required } + + /// If feature requires magnitude array elements to be different + fn is_variability_required(&self) -> bool { + self.get_info().variability_required + } } // impl

EvaluatorInfoTrait for P @@ -124,8 +130,14 @@ pub trait FeatureEvaluator: + DeserializeOwned + JsonSchema { + /// Version of [FeatureEvaluator::eval] which can panic for incorrect input + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError>; + /// Vector of feature values or `EvaluatorError` - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError>; + fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { + self.check_ts(ts)?; + self.eval_no_ts_check(ts) + } /// Returns vector of feature values and fill invalid components with given value fn eval_or_fill(&self, ts: &mut TimeSeries, fill_value: T) -> Vec { @@ -135,8 +147,13 @@ pub trait FeatureEvaluator: } } + fn check_ts(&self, ts: &mut TimeSeries) -> Result<(), EvaluatorError> { + self.check_ts_length(ts)?; + self.check_ts_variability(ts) + } + /// Checks if [TimeSeries] has enough points to evaluate the feature - fn check_ts_length(&self, ts: &TimeSeries) -> Result { + fn check_ts_length(&self, ts: &TimeSeries) -> Result<(), EvaluatorError> { let length = ts.lenu(); if length < self.min_ts_length() { Err(EvaluatorError::ShortTimeSeries { @@ -144,35 +161,17 @@ pub trait FeatureEvaluator: minimum: self.min_ts_length(), }) } else { - Ok(length) + Ok(()) } } -} -pub fn get_nonzero_m_std(ts: &mut TimeSeries) -> Result { - let std = ts.m.get_std(); - if std.is_zero() || ts.is_plateau() { - Err(EvaluatorError::FlatTimeSeries) - } else { - Ok(std) - } -} - -pub fn get_nonzero_m_std2(ts: &mut TimeSeries) -> Result { - let std2 = ts.m.get_std2(); - if std2.is_zero() || ts.is_plateau() { - Err(EvaluatorError::FlatTimeSeries) - } else { - Ok(std2) - } -} - -pub fn get_nonzero_reduced_chi2(ts: &mut TimeSeries) -> Result { - let reduced_chi2 = ts.get_m_reduced_chi2(); - if reduced_chi2.is_zero() || ts.is_plateau() { - Err(EvaluatorError::FlatTimeSeries) - } else { - Ok(reduced_chi2) + /// Checks if [TimeSeries] meets variability requirement + fn check_ts_variability(&self, ts: &mut TimeSeries) -> Result<(), EvaluatorError> { + if self.is_variability_required() && ts.is_plateau() { + Err(EvaluatorError::FlatTimeSeries) + } else { + Ok(()) + } } } diff --git a/src/extractor.rs b/src/extractor.rs index e7e5a1a1..8830d503 100644 --- a/src/extractor.rs +++ b/src/extractor.rs @@ -46,6 +46,7 @@ where m_required: features.iter().any(|x| x.is_m_required()), w_required: features.iter().any(|x| x.is_w_required()), sorting_required: features.iter().any(|x| x.is_sorting_required()), + variability_required: features.iter().any(|x| x.is_variability_required()), } .into(); Self { @@ -118,10 +119,10 @@ where T: Float, F: FeatureEvaluator, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { let mut vec = Vec::with_capacity(self.size_hint()); for x in &self.features { - vec.extend(x.eval(ts)?); + vec.extend(x.eval_no_ts_check(ts)?); } Ok(vec) } diff --git a/src/features/amplitude.rs b/src/features/amplitude.rs index cad0f5ad..ced5bc9f 100644 --- a/src/features/amplitude.rs +++ b/src/features/amplitude.rs @@ -47,6 +47,7 @@ lazy_info!( m_required: true, w_required: false, sorting_required: false, + variability_required: false, ); impl FeatureNamesDescriptionsTrait for Amplitude { @@ -63,8 +64,7 @@ impl FeatureEvaluator for Amplitude where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { Ok(vec![T::half() * (ts.m.get_max() - ts.m.get_min())]) } } diff --git a/src/features/anderson_darling_normal.rs b/src/features/anderson_darling_normal.rs index 189ea786..8bf947a9 100644 --- a/src/features/anderson_darling_normal.rs +++ b/src/features/anderson_darling_normal.rs @@ -46,6 +46,7 @@ lazy_info!( m_required: true, w_required: false, sorting_required: false, + variability_required: true, ); impl FeatureNamesDescriptionsTrait for AndersonDarlingNormal { @@ -62,10 +63,10 @@ impl FeatureEvaluator for AndersonDarlingNormal where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - let size = self.check_ts_length(ts)?; - let m_std = get_nonzero_m_std(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { + let size = ts.lenu(); let m_mean = ts.m.get_mean(); + let m_std = ts.m.get_std(); let sum: f64 = ts.m.get_sorted() .as_ref() diff --git a/src/features/bazin_fit.rs b/src/features/bazin_fit.rs index e5ce4b9d..4f8ec064 100644 --- a/src/features/bazin_fit.rs +++ b/src/features/bazin_fit.rs @@ -109,6 +109,7 @@ lazy_info!( m_required: true, w_required: true, sorting_required: true, // improve reproducibility + variability_required: false, ); struct Params<'a, T> { diff --git a/src/features/beyond_n_std.rs b/src/features/beyond_n_std.rs index 6f55e9ab..5c05c971 100644 --- a/src/features/beyond_n_std.rs +++ b/src/features/beyond_n_std.rs @@ -94,6 +94,7 @@ lazy_info!( m_required: true, w_required: false, sorting_required: false, + variability_required: false, ); impl Default for BeyondNStd @@ -122,8 +123,7 @@ impl FeatureEvaluator for BeyondNStd where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { let m_mean = ts.m.get_mean(); let threshold = ts.m.get_std() * self.nstd; let count_beyond = ts.m.sample.fold(0, |count, &m| { diff --git a/src/features/bins.rs b/src/features/bins.rs index 42c9e1da..b4a1b442 100644 --- a/src/features/bins.rs +++ b/src/features/bins.rs @@ -62,6 +62,7 @@ where m_required: true, w_required: true, sorting_required: true, + variability_required: false, }; Self { properties: EvaluatorProperties { @@ -94,6 +95,7 @@ where self.properties.info.size += feature.size_hint(); self.properties.info.min_ts_length = usize::max(self.properties.info.min_ts_length, feature.min_ts_length()); + self.properties.info.variability_required |= feature.is_variability_required(); self.properties.names.extend( feature .get_names() @@ -135,7 +137,12 @@ where } fn transform_ts(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; + if ts.lenu() < self.min_ts_length() { + return Err(EvaluatorError::ShortTimeSeries { + actual: ts.lenu(), + minimum: self.min_ts_length(), + }); + } let (t, m, w): (Vec<_>, Vec<_>, Vec<_>) = ts.t.as_slice() .iter() diff --git a/src/features/cusum.rs b/src/features/cusum.rs index 6c5adcf2..0ece8c23 100644 --- a/src/features/cusum.rs +++ b/src/features/cusum.rs @@ -46,6 +46,7 @@ lazy_info!( m_required: true, w_required: false, sorting_required: true, + variability_required: true, ); impl FeatureNamesDescriptionsTrait for Cusum { @@ -62,10 +63,9 @@ impl FeatureEvaluator for Cusum where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; - let m_std = get_nonzero_m_std(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { let m_mean = ts.m.get_mean(); + let m_std = ts.m.get_std(); let (_last_cusum, min_cusum, max_cusum) = ts.m.as_slice().iter().fold( (T::zero(), T::infinity(), -T::infinity()), |(mut cusum, min_cusum, max_cusum), &m| { diff --git a/src/features/duration.rs b/src/features/duration.rs index d26a8187..1c4e473c 100644 --- a/src/features/duration.rs +++ b/src/features/duration.rs @@ -39,6 +39,7 @@ lazy_info!( m_required: false, w_required: false, sorting_required: true, + variability_required: false, ); impl FeatureNamesDescriptionsTrait for Duration { @@ -55,8 +56,7 @@ impl FeatureEvaluator for Duration where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { Ok(vec![ts.t.sample[ts.lenu() - 1] - ts.t.sample[0]]) } } diff --git a/src/features/eta.rs b/src/features/eta.rs index 6c424ca9..92f8a03b 100644 --- a/src/features/eta.rs +++ b/src/features/eta.rs @@ -42,6 +42,7 @@ lazy_info!( m_required: true, w_required: false, sorting_required: true, + variability_required: true, ); impl FeatureNamesDescriptionsTrait for Eta { @@ -58,9 +59,8 @@ impl FeatureEvaluator for Eta where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; - let m_std2 = get_nonzero_m_std2(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { + let m_std2 = ts.m.get_std2(); let value = ts.m.as_slice() .iter() diff --git a/src/features/eta_e.rs b/src/features/eta_e.rs index b4dac558..d19f6168 100644 --- a/src/features/eta_e.rs +++ b/src/features/eta_e.rs @@ -47,6 +47,7 @@ lazy_info!( m_required: true, w_required: false, sorting_required: true, + variability_required: true, ); impl FeatureNamesDescriptionsTrait for EtaE { @@ -63,9 +64,8 @@ impl FeatureEvaluator for EtaE where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; - let m_std2 = get_nonzero_m_std2(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { + let m_std2 = ts.m.get_std2(); let sq_slope_sum = ts.t.as_slice() .iter() diff --git a/src/features/excess_variance.rs b/src/features/excess_variance.rs index ef9a861a..c2ecf68b 100644 --- a/src/features/excess_variance.rs +++ b/src/features/excess_variance.rs @@ -32,6 +32,7 @@ lazy_info!( m_required: true, w_required: true, sorting_required: false, + variability_required: false, ); impl ExcessVariance { @@ -58,8 +59,7 @@ impl FeatureEvaluator for ExcessVariance where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { let mean_error2 = ts.w.sample.fold(T::zero(), |sum, w| sum + w.recip()) / ts.lenf(); Ok(vec![ (ts.m.get_std2() - mean_error2) / ts.m.get_mean().powi(2), diff --git a/src/features/inter_percentile_range.rs b/src/features/inter_percentile_range.rs index b43969d5..4f96e4b1 100644 --- a/src/features/inter_percentile_range.rs +++ b/src/features/inter_percentile_range.rs @@ -42,6 +42,7 @@ lazy_info!( m_required: true, w_required: false, sorting_required: false, + variability_required: false, ); impl InterPercentileRange { @@ -91,8 +92,7 @@ impl FeatureEvaluator for InterPercentileRange where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { let ppf_low = ts.m.get_sorted().ppf(self.quantile); let ppf_high = ts.m.get_sorted().ppf(1.0 - self.quantile); let value = ppf_high - ppf_low; diff --git a/src/features/kurtosis.rs b/src/features/kurtosis.rs index c52286ec..707daf70 100644 --- a/src/features/kurtosis.rs +++ b/src/features/kurtosis.rs @@ -43,6 +43,7 @@ lazy_info!( m_required: true, w_required: false, sorting_required: false, + variability_required: true, ); impl FeatureNamesDescriptionsTrait for Kurtosis { @@ -59,10 +60,9 @@ impl FeatureEvaluator for Kurtosis where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; - let m_std2 = get_nonzero_m_std2(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { let m_mean = ts.m.get_mean(); + let m_std2 = ts.m.get_std2(); let n = ts.lenf(); let n1 = n + T::one(); let n_1 = n - T::one(); diff --git a/src/features/linear_fit.rs b/src/features/linear_fit.rs index 11831a4f..3e87b7db 100644 --- a/src/features/linear_fit.rs +++ b/src/features/linear_fit.rs @@ -45,6 +45,7 @@ lazy_info!( m_required: true, w_required: true, sorting_required: true, + variability_required: false, ); impl FeatureNamesDescriptionsTrait for LinearFit { @@ -69,8 +70,7 @@ impl FeatureEvaluator for LinearFit where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { let result = fit_straight_line(ts, true); Ok(vec![ result.slope, diff --git a/src/features/linear_trend.rs b/src/features/linear_trend.rs index d0916347..42dc57e6 100644 --- a/src/features/linear_trend.rs +++ b/src/features/linear_trend.rs @@ -43,6 +43,7 @@ lazy_info!( m_required: true, w_required: false, sorting_required: true, + variability_required: false, ); impl FeatureNamesDescriptionsTrait for LinearTrend { @@ -63,8 +64,7 @@ impl FeatureEvaluator for LinearTrend where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { let result = fit_straight_line(ts, false); Ok(vec![ result.slope, diff --git a/src/features/magnitude_percentage_ratio.rs b/src/features/magnitude_percentage_ratio.rs index cde34cbd..11546c2a 100644 --- a/src/features/magnitude_percentage_ratio.rs +++ b/src/features/magnitude_percentage_ratio.rs @@ -41,6 +41,7 @@ lazy_info!( m_required: true, w_required: false, sorting_required: false, + variability_required: true, ); impl MagnitudePercentageRatio { @@ -112,18 +113,13 @@ impl FeatureEvaluator for MagnitudePercentageRatio where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { let m_sorted = ts.m.get_sorted(); let numerator = m_sorted.ppf(1.0 - self.quantile_numerator) - m_sorted.ppf(self.quantile_numerator); let denumerator = m_sorted.ppf(1.0 - self.quantile_denominator) - m_sorted.ppf(self.quantile_denominator); - if numerator.is_zero() & denumerator.is_zero() { - Err(EvaluatorError::FlatTimeSeries) - } else { - Ok(vec![numerator / denumerator]) - } + Ok(vec![numerator / denumerator]) } } diff --git a/src/features/maximum_slope.rs b/src/features/maximum_slope.rs index b39ac491..84a5cd84 100644 --- a/src/features/maximum_slope.rs +++ b/src/features/maximum_slope.rs @@ -33,6 +33,7 @@ lazy_info!( m_required: true, w_required: false, sorting_required: true, + variability_required: false, ); impl MaximumSlope { @@ -57,8 +58,7 @@ impl FeatureEvaluator for MaximumSlope where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { let result = ts.t.as_slice() .iter() diff --git a/src/features/maximum_time_interval.rs b/src/features/maximum_time_interval.rs index 8bece763..1a30c00f 100644 --- a/src/features/maximum_time_interval.rs +++ b/src/features/maximum_time_interval.rs @@ -40,6 +40,7 @@ lazy_info!( m_required: false, w_required: false, sorting_required: true, + variability_required: false, ); impl FeatureNamesDescriptionsTrait for MaximumTimeInterval { @@ -56,8 +57,7 @@ impl FeatureEvaluator for MaximumTimeInterval where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { let dt = ts.t.as_slice() .iter() diff --git a/src/features/mean.rs b/src/features/mean.rs index d95bf695..ba1b792b 100644 --- a/src/features/mean.rs +++ b/src/features/mean.rs @@ -28,6 +28,7 @@ lazy_info!( m_required: true, w_required: false, sorting_required: false, + variability_required: false, ); impl Mean { @@ -54,7 +55,7 @@ impl FeatureEvaluator for Mean where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { self.check_ts_length(ts)?; Ok(vec![ts.m.get_mean()]) } diff --git a/src/features/mean_variance.rs b/src/features/mean_variance.rs index 2712e21b..625118ab 100644 --- a/src/features/mean_variance.rs +++ b/src/features/mean_variance.rs @@ -27,6 +27,7 @@ lazy_info!( m_required: true, w_required: false, sorting_required: false, + variability_required: false, ); impl MeanVariance { @@ -53,8 +54,7 @@ impl FeatureEvaluator for MeanVariance where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { Ok(vec![ts.m.get_std() / ts.m.get_mean()]) } } diff --git a/src/features/median.rs b/src/features/median.rs index ed71471a..451c2405 100644 --- a/src/features/median.rs +++ b/src/features/median.rs @@ -27,6 +27,7 @@ lazy_info!( m_required: true, w_required: false, sorting_required: false, + variability_required: false, ); impl Median { @@ -53,8 +54,7 @@ impl FeatureEvaluator for Median where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { Ok(vec![ts.m.get_median()]) } } diff --git a/src/features/median_absolute_deviation.rs b/src/features/median_absolute_deviation.rs index d4b6e442..dad9df04 100644 --- a/src/features/median_absolute_deviation.rs +++ b/src/features/median_absolute_deviation.rs @@ -30,6 +30,7 @@ lazy_info!( m_required: true, w_required: false, sorting_required: false, + variability_required: false, ); impl MedianAbsoluteDeviation { @@ -56,8 +57,7 @@ impl FeatureEvaluator for MedianAbsoluteDeviation where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { let m_median = ts.m.get_median(); let sorted_deviation: SortedArray<_> = ts.m.sample diff --git a/src/features/median_buffer_range_percentage.rs b/src/features/median_buffer_range_percentage.rs index 6e8c9d6f..16b39efd 100644 --- a/src/features/median_buffer_range_percentage.rs +++ b/src/features/median_buffer_range_percentage.rs @@ -38,6 +38,7 @@ lazy_info!( m_required: true, w_required: false, sorting_required: false, + variability_required: false, ); impl MedianBufferRangePercentage @@ -102,8 +103,7 @@ impl FeatureEvaluator for MedianBufferRangePercentage where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { let m_median = ts.m.get_median(); let amplitude = T::half() * (ts.m.get_max() - ts.m.get_min()); let threshold = self.quantile * amplitude; diff --git a/src/features/minimum_time_interval.rs b/src/features/minimum_time_interval.rs index 24298c4e..a7a6da6c 100644 --- a/src/features/minimum_time_interval.rs +++ b/src/features/minimum_time_interval.rs @@ -40,6 +40,7 @@ lazy_info!( m_required: false, w_required: false, sorting_required: true, + variability_required: false, ); impl FeatureNamesDescriptionsTrait for MinimumTimeInterval { @@ -56,8 +57,7 @@ impl FeatureEvaluator for MinimumTimeInterval where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { let dt = ts.t.as_slice() .iter() diff --git a/src/features/observation_count.rs b/src/features/observation_count.rs index c04c8c30..ab2ba097 100644 --- a/src/features/observation_count.rs +++ b/src/features/observation_count.rs @@ -39,6 +39,7 @@ lazy_info!( m_required: false, w_required: false, sorting_required: false, + variability_required: false, ); impl FeatureNamesDescriptionsTrait for ObservationCount { @@ -55,8 +56,7 @@ impl FeatureEvaluator for ObservationCount where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { Ok(vec![ts.lenf()]) } } diff --git a/src/features/otsu_split.rs b/src/features/otsu_split.rs index b0edbcf5..8c3bb73a 100644 --- a/src/features/otsu_split.rs +++ b/src/features/otsu_split.rs @@ -36,6 +36,7 @@ lazy_info!( m_required: true, w_required: false, sorting_required: false, + variability_required: true, ); impl OtsuSplit { @@ -47,28 +48,17 @@ impl OtsuSplit { DOC } - pub fn threshold<'a, 'b, T>( + fn threshold_no_ds_check<'a, 'b, T>( ds: &'b mut DataSample<'a, T>, - ) -> Result<(T, ArrayView1<'b, T>, ArrayView1<'b, T>), EvaluatorError> + ) -> (T, ArrayView1<'b, T>, ArrayView1<'b, T>) where 'a: 'b, T: Float, { - if ds.sample.len() < 2 { - return Err(EvaluatorError::ShortTimeSeries { - actual: ds.sample.len(), - minimum: 2, - }); - } - let count = ds.sample.len(); let countf = count.approx().unwrap(); let sorted = ds.get_sorted(); - if sorted.minimum() == sorted.maximum() { - return Err(EvaluatorError::FlatTimeSeries); - } - // size is (count - 1) let cumsum1: Array1<_> = sorted .iter() @@ -110,7 +100,30 @@ impl OtsuSplit { let index = inter_class_variance.argmax().unwrap(); let (lower, upper) = sorted.0.view().split_at(Axis(0), index + 1); - Ok((sorted.0[index + 1], lower, upper)) + (sorted.0[index + 1], lower, upper) + } + + pub fn threshold<'a, 'b, T>( + ds: &'b mut DataSample<'a, T>, + ) -> Result<(T, ArrayView1<'b, T>, ArrayView1<'b, T>), EvaluatorError> + where + 'a: 'b, + T: Float, + { + if ds.sample.len() < 2 { + return Err(EvaluatorError::ShortTimeSeries { + actual: ds.sample.len(), + minimum: 2, + }); + } + + // Sorted array will be cached inside ds, we will reuse it in threshold_no_ds_check + let sorted = ds.get_sorted(); + if sorted.minimum() == sorted.maximum() { + return Err(EvaluatorError::FlatTimeSeries); + } + + return Ok(Self::threshold_no_ds_check(ds)); } } @@ -136,9 +149,7 @@ impl FeatureEvaluator for OtsuSplit where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; - + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { let (_, lower, upper) = Self::threshold(&mut ts.m)?; let mut lower: DataSample<_> = lower.into(); let mut upper: DataSample<_> = upper.into(); diff --git a/src/features/percent_amplitude.rs b/src/features/percent_amplitude.rs index cb4f1156..899fb468 100644 --- a/src/features/percent_amplitude.rs +++ b/src/features/percent_amplitude.rs @@ -30,6 +30,7 @@ lazy_info!( m_required: true, w_required: false, sorting_required: false, + variability_required: false, ); impl PercentAmplitude { @@ -56,8 +57,7 @@ impl FeatureEvaluator for PercentAmplitude where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { let m_min = ts.m.get_min(); let m_max = ts.m.get_max(); let m_median = ts.m.get_median(); diff --git a/src/features/percent_difference_magnitude_percentile.rs b/src/features/percent_difference_magnitude_percentile.rs index 6b4ec505..61b839f5 100644 --- a/src/features/percent_difference_magnitude_percentile.rs +++ b/src/features/percent_difference_magnitude_percentile.rs @@ -38,6 +38,7 @@ lazy_info!( m_required: true, w_required: false, sorting_required: false, + variability_required: false, ); impl PercentDifferenceMagnitudePercentile { @@ -94,8 +95,7 @@ impl FeatureEvaluator for PercentDifferenceMagnitudePercentile where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { let nominator = ts.m.get_sorted().ppf(1.0 - self.quantile) - ts.m.get_sorted().ppf(self.quantile); let denominator = ts.m.get_median(); diff --git a/src/features/periodogram.rs b/src/features/periodogram.rs index 4a2a0210..9d3b8979 100644 --- a/src/features/periodogram.rs +++ b/src/features/periodogram.rs @@ -49,6 +49,7 @@ impl PeriodogramPeaks { m_required: true, w_required: false, sorting_required: true, + variability_required: false, }; let names = (0..peaks) .flat_map(|i| vec![format!("period_{}", i), format!("period_s_to_n_{}", i)]) @@ -120,8 +121,7 @@ impl FeatureEvaluator for PeriodogramPeaks where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { let peak_indices = peak_indices_reverse_sorted(&ts.m.sample); Ok(peak_indices .iter() @@ -310,6 +310,7 @@ where m_required: true, w_required: false, sorting_required: true, + variability_required: false, }; Self { properties: EvaluatorProperties { diff --git a/src/features/reduced_chi2.rs b/src/features/reduced_chi2.rs index 42076043..66eeef98 100644 --- a/src/features/reduced_chi2.rs +++ b/src/features/reduced_chi2.rs @@ -33,6 +33,7 @@ lazy_info!( m_required: true, w_required: true, sorting_required: false, + variability_required: false, ); impl ReducedChi2 { @@ -59,8 +60,7 @@ impl FeatureEvaluator for ReducedChi2 where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { Ok(vec![ts.get_m_reduced_chi2()]) } } diff --git a/src/features/skew.rs b/src/features/skew.rs index c5f41a7e..ab7f2a16 100644 --- a/src/features/skew.rs +++ b/src/features/skew.rs @@ -32,6 +32,7 @@ lazy_info!( m_required: true, w_required: false, sorting_required: false, + variability_required: true, ); impl Skew { @@ -58,10 +59,9 @@ impl FeatureEvaluator for Skew where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; - let m_std = get_nonzero_m_std(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { let m_mean = ts.m.get_mean(); + let m_std = ts.m.get_std(); let n = ts.lenf(); let n_1 = n - T::one(); let n_2 = n_1 - T::one(); diff --git a/src/features/standard_deviation.rs b/src/features/standard_deviation.rs index 55196f6c..88f5d3fb 100644 --- a/src/features/standard_deviation.rs +++ b/src/features/standard_deviation.rs @@ -32,6 +32,7 @@ lazy_info!( m_required: true, w_required: false, sorting_required: false, + variability_required: false, ); impl StandardDeviation { @@ -58,8 +59,7 @@ impl FeatureEvaluator for StandardDeviation where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { Ok(vec![ts.m.get_std()]) } } diff --git a/src/features/stetson_k.rs b/src/features/stetson_k.rs index e70b7138..a28d3c06 100644 --- a/src/features/stetson_k.rs +++ b/src/features/stetson_k.rs @@ -34,6 +34,7 @@ lazy_info!( m_required: true, w_required: true, sorting_required: false, + variability_required: true, ); impl StetsonK { @@ -60,9 +61,8 @@ impl FeatureEvaluator for StetsonK where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; - let chi2 = get_nonzero_reduced_chi2(ts)? * (ts.lenf() - T::one()); + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { + let chi2 = ts.get_m_reduced_chi2() * (ts.lenf() - T::one()); let mean = ts.get_m_weighted_mean(); let value = Zip::from(&ts.m.sample) .and(&ts.w.sample) diff --git a/src/features/time_mean.rs b/src/features/time_mean.rs index 36e0bbf2..d8fc3c2d 100644 --- a/src/features/time_mean.rs +++ b/src/features/time_mean.rs @@ -39,6 +39,7 @@ lazy_info!( m_required: false, w_required: false, sorting_required: false, + variability_required: false, ); impl FeatureNamesDescriptionsTrait for TimeMean { fn get_names(&self) -> Vec<&str> { @@ -53,8 +54,7 @@ impl FeatureEvaluator for TimeMean where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { Ok(vec![ts.t.get_mean()]) } } diff --git a/src/features/time_standard_deviation.rs b/src/features/time_standard_deviation.rs index ece52bc8..3b733419 100644 --- a/src/features/time_standard_deviation.rs +++ b/src/features/time_standard_deviation.rs @@ -39,6 +39,7 @@ lazy_info!( m_required: false, w_required: false, sorting_required: false, + variability_required: false, ); impl FeatureNamesDescriptionsTrait for TimeStandardDeviation { @@ -55,8 +56,7 @@ impl FeatureEvaluator for TimeStandardDeviation where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { Ok(vec![ts.t.get_std()]) } } diff --git a/src/features/transformed.rs b/src/features/transformed.rs index b33ed241..3054c76f 100644 --- a/src/features/transformed.rs +++ b/src/features/transformed.rs @@ -52,6 +52,7 @@ where m_required: feature.is_m_required(), w_required: feature.is_w_required(), sorting_required: feature.is_sorting_required(), + variability_required: feature.is_variability_required(), }; let names = transformer.names(&feature.get_names()); let descriptions = transformer.descriptions(&feature.get_descriptions()); @@ -110,15 +111,17 @@ where F: FeatureEvaluator, Tr: TransformerTrait, { + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { + Ok(self + .transformer + .transform(self.feature.eval_no_ts_check(ts)?)) + } + fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { Ok(self.transformer.transform(self.feature.eval(ts)?)) } // We keep default implementation of eval_or_fill - - fn check_ts_length(&self, ts: &TimeSeries) -> Result { - self.feature.check_ts_length(ts) - } } #[derive(Serialize, Deserialize, JsonSchema)] diff --git a/src/features/villar_fit.rs b/src/features/villar_fit.rs index c9b794a8..36b3990f 100644 --- a/src/features/villar_fit.rs +++ b/src/features/villar_fit.rs @@ -127,6 +127,7 @@ lazy_info!( m_required: true, w_required: true, sorting_required: true, // improve reproducibility + variability_required: false, ); impl FitModelTrait for VillarFit diff --git a/src/features/weighted_mean.rs b/src/features/weighted_mean.rs index f728ffef..dd2bcb86 100644 --- a/src/features/weighted_mean.rs +++ b/src/features/weighted_mean.rs @@ -28,6 +28,7 @@ lazy_info!( m_required: true, w_required: true, sorting_required: false, + variability_required: false, ); impl WeightedMean { @@ -54,8 +55,7 @@ impl FeatureEvaluator for WeightedMean where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { Ok(vec![ts.get_m_weighted_mean()]) } } diff --git a/src/macros.rs b/src/macros.rs index 443279aa..603ba302 100644 --- a/src/macros.rs +++ b/src/macros.rs @@ -8,6 +8,7 @@ macro_rules! lazy_info { m_required: $m: expr, w_required: $w: expr, sorting_required: $sort: expr, + variability_required: $var: expr, ) => { lazy_static! { static ref $name: EvaluatorInfo = EvaluatorInfo { @@ -17,6 +18,7 @@ macro_rules! lazy_info { m_required: $m, w_required: $w, sorting_required: $sort, + variability_required: $var, }; } }; @@ -29,6 +31,7 @@ macro_rules! lazy_info { m_required: $m: expr, w_required: $w: expr, sorting_required: $sort: expr, + variability_required: $var: expr, ) => { lazy_info!( $name, @@ -38,6 +41,7 @@ macro_rules! lazy_info { m_required: $m, w_required: $w, sorting_required: $sort, + variability_required: $var, ); impl EvaluatorInfoTrait for $feature { @@ -56,6 +60,7 @@ macro_rules! lazy_info { m_required: $m: expr, w_required: $w: expr, sorting_required: $sort: expr, + variability_required: $var: expr, ) => { lazy_info!( $name, @@ -65,6 +70,7 @@ macro_rules! lazy_info { m_required: $m, w_required: $w, sorting_required: $sort, + variability_required: $var, ); impl EvaluatorInfoTrait for $feature { @@ -80,7 +86,7 @@ macro_rules! lazy_info { /// - `transform_ts(&self, ts: &mut TimeSeries) -> Result, EvaluatorError>` macro_rules! transformer_eval { () => { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { let arrays = self.transform_ts(ts)?; let mut new_ts = arrays.ts(); self.feature_extractor.eval(&mut new_ts) @@ -121,9 +127,7 @@ macro_rules! json_schema { /// - declare `const NPARAMS: usize` in your code macro_rules! fit_eval { () => { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; - + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { let norm_data = NormalizedData::::from_ts(ts); let (x0, lower, upper) = { diff --git a/src/tests.rs b/src/tests.rs index 544a767a..cd987524 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -119,6 +119,8 @@ pub fn eval_info_tests( .as_ref() .map(check_size); } + + eval_info_variability_required_test(&eval, &t_sorted, &w, &mut rng); } fn eval_info_ts_length_test( @@ -265,6 +267,43 @@ fn eval_info_sorting_required_test( Some(v) } +fn eval_info_variability_required_test( + eval: &Feature, + t: &[f64], + w: &[f64], + rng: &mut StdRng, +) { + assert!( + !eval.is_variability_required() || eval.is_m_required(), + "variability_required is treu, but m_required is false" + ); + + let m = vec![rng.sample::(StandardNormal).abs(); t.len()]; + let mut ts = TimeSeries::new(t, &m, w); + assert_eq!(eval.is_variability_required(), eval.eval(&mut ts).is_err()); + + match ( + std::panic::catch_unwind(|| eval.eval_no_ts_check(&mut TimeSeries::new(t, &m, w))), + eval.is_variability_required(), + ) { + (Ok(_result), true) => {} + // |-- This doesn't work sometimes because of float rounding issues + // v + // (Ok(result), true) => assert!(result + // .map(|v| assert!( + // !v.iter().copied().all(f64::is_finite), + // "{:?} are all finite", + // v + // )) + // .is_err()), + (Ok(result), false) => assert!(result + .map(|v| assert!(v.into_iter().all(f64::is_finite))) + .is_ok()), + (Err(_err), true) => {} + (Err(err), false) => panic!("{:?}", err), + } +} + #[macro_export] macro_rules! serialization_name_test { ($feature_type: ty, $feature_expr: expr) => { From df608c5130f7feb94bd5140e6cc848338f9d33f1 Mon Sep 17 00:00:00 2001 From: Konstantin Malanchev Date: Fri, 12 Aug 2022 12:00:32 -0500 Subject: [PATCH 2/9] Initial implementation of multicolor --- src/error.rs | 35 ++ src/evaluator.rs | 4 +- src/features/periodogram.rs | 1 + src/lib.rs | 3 + src/multicolor.rs | 638 ++++++++++++++++++++++++++++++++++++ 5 files changed, 679 insertions(+), 2 deletions(-) create mode 100644 src/multicolor.rs diff --git a/src/error.rs b/src/error.rs index 38bb58ce..15c6f0a1 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,3 +1,7 @@ +use crate::PassbandTrait; + +use std::collections::BTreeSet; + /// Error returned from [crate::FeatureEvaluator] #[derive(Debug, thiserror::Error, PartialEq, Eq)] pub enum EvaluatorError { @@ -11,10 +15,41 @@ pub enum EvaluatorError { ZeroDivision(&'static str), } +#[derive(Debug, thiserror::Error, PartialEq, Eq)] +pub enum MultiColorEvaluatorError { + #[error("Passband {passband} time-series caused error: {error:?}")] + MonochromeEvaluatorError { + passband: String, + error: EvaluatorError, + }, + + #[error("Wrong passbands {actual:?}, {desired:?} are desired")] + WrongPassbandsError { + actual: BTreeSet, + desired: BTreeSet, + }, +} + +impl MultiColorEvaluatorError { + pub fn wrong_passbands_error<'a, P>( + actual: impl Iterator, + desired: impl Iterator, + ) -> Self + where + P: PassbandTrait + 'a, + { + Self::WrongPassbandsError { + actual: actual.map(|p| p.name().into()).collect(), + desired: desired.map(|p| p.name().into()).collect(), + } + } +} + #[derive(Debug, thiserror::Error, PartialEq, Eq)] pub enum SortedArrayError { #[error("SortedVec constructors accept sorted arrays only")] Unsorted, + #[error("SortedVec constructors accept contiguous arrays only")] NonContiguous, } diff --git a/src/evaluator.rs b/src/evaluator.rs index de531e0c..5c1078c3 100644 --- a/src/evaluator.rs +++ b/src/evaluator.rs @@ -12,7 +12,7 @@ use serde::de::DeserializeOwned; pub use serde::{Deserialize, Serialize}; pub use std::fmt::Debug; -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema)] pub struct EvaluatorInfo { pub size: usize, pub min_ts_length: usize, @@ -23,7 +23,7 @@ pub struct EvaluatorInfo { pub variability_required: bool, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] pub struct EvaluatorProperties { pub info: EvaluatorInfo, pub names: Vec, diff --git a/src/features/periodogram.rs b/src/features/periodogram.rs index 9d3b8979..407dd93f 100644 --- a/src/features/periodogram.rs +++ b/src/features/periodogram.rs @@ -103,6 +103,7 @@ impl EvaluatorInfoTrait for PeriodogramPeaks { &self.properties.info } } + impl FeatureNamesDescriptionsTrait for PeriodogramPeaks { fn get_names(&self) -> Vec<&str> { self.properties.names.iter().map(String::as_str).collect() diff --git a/src/lib.rs b/src/lib.rs index 1c6cf709..d77b787d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -27,6 +27,9 @@ pub use float_trait::Float; mod lnerfc; +mod multicolor; +pub use multicolor::*; + mod nl_fit; pub use nl_fit::evaluator::FitFeatureEvaluatorGettersTrait; #[cfg(any(feature = "ceres-source", feature = "ceres-system"))] diff --git a/src/multicolor.rs b/src/multicolor.rs new file mode 100644 index 00000000..fc9a3657 --- /dev/null +++ b/src/multicolor.rs @@ -0,0 +1,638 @@ +use crate::error::MultiColorEvaluatorError; +use crate::evaluator::{ + EvaluatorError, EvaluatorInfo, EvaluatorInfoTrait, EvaluatorProperties, FeatureEvaluator, + FeatureNamesDescriptionsTrait, +}; +use crate::feature::Feature; +use crate::float_trait::Float; +use crate::time_series::TimeSeries; + +use enum_dispatch::enum_dispatch; +use itertools::Itertools; +pub use lazy_static::lazy_static; +pub use schemars::JsonSchema; +pub use serde::{Deserialize, Serialize}; +use std::cmp::Ordering; +use std::collections::{BTreeMap, BTreeSet}; +use std::fmt::Debug; +use std::marker::PhantomData; + +pub trait PassbandTrait: Debug + Clone + Send + Sync + Ord + Serialize + JsonSchema { + fn name(&self) -> &str; +} + +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +pub struct MonochromePassband<'a, T> { + pub name: &'a str, + pub wavelength: T, +} + +impl<'a, T> MonochromePassband<'a, T> +where + T: Float, +{ + pub fn new(wavelength: T, name: &'a str) -> Self { + assert!( + wavelength.is_normal(), + "wavelength must be a positive normal number" + ); + assert!( + wavelength.is_sign_positive(), + "wavelength must be a positive normal number" + ); + Self { wavelength, name } + } +} + +impl<'a, T> PartialEq for MonochromePassband<'a, T> +where + T: Float, +{ + fn eq(&self, other: &Self) -> bool { + self.wavelength.eq(&other.wavelength) + } +} + +impl<'a, T> Eq for MonochromePassband<'a, T> where T: Float {} + +impl<'a, T> PartialOrd for MonochromePassband<'a, T> +where + T: Float, +{ + fn partial_cmp(&self, other: &Self) -> Option { + (self.wavelength).partial_cmp(&other.wavelength) + } +} + +impl<'a, T> Ord for MonochromePassband<'a, T> +where + T: Float, +{ + fn cmp(&self, other: &Self) -> Ordering { + self.partial_cmp(other).unwrap() + } +} + +impl<'a, T> PassbandTrait for MonochromePassband<'a, T> +where + T: Float, +{ + fn name(&self) -> &str { + self.name + } +} + +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, JsonSchema)] +pub struct NoPassband {} + +impl PassbandTrait for NoPassband { + fn name(&self) -> &str { + "" + } +} + +pub struct MultiColorTimeSeries<'a, P: PassbandTrait, T: Float>(BTreeMap>); + +#[enum_dispatch] +pub trait MultiColorPassbandSetTrait

+where + P: PassbandTrait, +{ + fn get_passband_set(&self) -> &PassbandSet

; +} + +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +#[serde(bound(deserialize = "P: PassbandTrait + Deserialize<'de>"))] +#[non_exhaustive] +pub enum PassbandSet

+where + P: Ord, +{ + FixedSet(BTreeSet

), + AllAvailable, +} + +impl

From> for PassbandSet

+where + P: Ord, +{ + fn from(value: BTreeSet

) -> Self { + Self::FixedSet(value) + } +} + +#[enum_dispatch] +pub trait MultiColorEvaluator: + FeatureNamesDescriptionsTrait + + EvaluatorInfoTrait + + MultiColorPassbandSetTrait

+ + Clone + + Serialize +where + P: PassbandTrait, + T: Float, +{ + /// Vector of feature values or `EvaluatorError` + fn eval_multicolor( + &self, + mcts: &mut MultiColorTimeSeries, + ) -> Result, MultiColorEvaluatorError>; + + /// Returns vector of feature values and fill invalid components with given value + fn eval_or_fill_multicolor( + &self, + mcts: &mut MultiColorTimeSeries, + fill_value: T, + ) -> Result, MultiColorEvaluatorError> { + self.check_mcts_shape(mcts)?; + Ok(match self.eval_multicolor(mcts) { + Ok(v) => v, + Err(_) => vec![fill_value; self.size_hint()], + }) + } + + fn check_mcts_shape( + &self, + mcts: &MultiColorTimeSeries, + ) -> Result, MultiColorEvaluatorError> { + self.check_mcts_passabands(mcts)?; + self.check_every_ts_length(mcts) + } + + fn check_mcts_passabands( + &self, + mcts: &MultiColorTimeSeries, + ) -> Result<(), MultiColorEvaluatorError> { + match self.get_passband_set() { + PassbandSet::AllAvailable => Ok(()), + PassbandSet::FixedSet(self_passbands) => { + if mcts + .0 + .keys() + .all(|mcts_passband| self_passbands.contains(mcts_passband)) + { + Ok(()) + } else { + Err(MultiColorEvaluatorError::wrong_passbands_error( + mcts.0.keys(), + self_passbands.iter(), + )) + } + } + } + } + + /// Checks if each component of [MultiColorTimeSeries] has enough points to evaluate the feature + fn check_every_ts_length( + &self, + mcts: &MultiColorTimeSeries, + ) -> Result, MultiColorEvaluatorError> { + // Use try_reduce when stabilizes + // https://github.com/rust-lang/rust/issues/87053 + mcts.0 + .iter() + .map(|(passband, ts)| { + let length = ts.lenu(); + if length < self.min_ts_length() { + Err(MultiColorEvaluatorError::MonochromeEvaluatorError { + error: EvaluatorError::ShortTimeSeries { + actual: length, + minimum: self.min_ts_length(), + }, + passband: passband.name().into(), + }) + } else { + Ok((passband.clone(), length)) + } + }) + .collect() + } +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde( + into = "MultiColorExtractorParameters", + from = "MultiColorExtractorParameters", + bound( + serialize = "P: PassbandTrait, T: Float, MCF: MultiColorEvaluator", + deserialize = "P: PassbandTrait + Deserialize<'de>, T: Float, MCF: MultiColorEvaluator + Deserialize<'de>" + ) +)] +pub struct MultiColorExtractor +where + P: Ord, +{ + features: Vec, + info: Box, + passband_set: PassbandSet

, + phantom: PhantomData, +} + +impl MultiColorExtractor +where + P: PassbandTrait, + T: Float, + MCF: MultiColorEvaluator, +{ + pub fn new(features: Vec) -> Self { + let passband_set = { + let set: BTreeSet<_> = features + .iter() + .filter_map(|f| match f.get_passband_set() { + PassbandSet::AllAvailable => None, + PassbandSet::FixedSet(set) => Some(set), + }) + .flatten() + .cloned() + .collect(); + if set.is_empty() { + PassbandSet::AllAvailable + } else { + PassbandSet::FixedSet(set) + } + }; + + let info = EvaluatorInfo { + size: features.iter().map(|x| x.size_hint()).sum(), + min_ts_length: features + .iter() + .map(|x| x.min_ts_length()) + .max() + .unwrap_or(0), + t_required: features.iter().any(|x| x.is_t_required()), + m_required: features.iter().any(|x| x.is_m_required()), + w_required: features.iter().any(|x| x.is_w_required()), + sorting_required: features.iter().any(|x| x.is_sorting_required()), + } + .into(); + + Self { + features, + passband_set, + info, + phantom: PhantomData, + } + } +} + +impl FeatureNamesDescriptionsTrait for MultiColorExtractor +where + P: Ord, + MCF: FeatureNamesDescriptionsTrait, +{ + /// Get feature names + fn get_names(&self) -> Vec<&str> { + self.features.iter().flat_map(|x| x.get_names()).collect() + } + + /// Get feature descriptions + fn get_descriptions(&self) -> Vec<&str> { + self.features + .iter() + .flat_map(|x| x.get_descriptions()) + .collect() + } +} + +impl EvaluatorInfoTrait for MultiColorExtractor +where + P: Ord, +{ + fn get_info(&self) -> &EvaluatorInfo { + &self.info + } +} + +impl MultiColorPassbandSetTrait

for MultiColorExtractor +where + P: PassbandTrait, +{ + fn get_passband_set(&self) -> &PassbandSet

{ + &self.passband_set + } +} + +impl MultiColorEvaluator for MultiColorExtractor +where + P: PassbandTrait, + T: Float, + MCF: MultiColorEvaluator, +{ + fn eval_multicolor( + &self, + mcts: &mut MultiColorTimeSeries, + ) -> Result, MultiColorEvaluatorError> { + self.check_mcts_passabands(mcts)?; + let mut vec = Vec::with_capacity(self.size_hint()); + for x in &self.features { + vec.extend(x.eval_multicolor(mcts)?); + } + Ok(vec) + } + + fn eval_or_fill_multicolor( + &self, + mcts: &mut MultiColorTimeSeries, + fill_value: T, + ) -> Result, MultiColorEvaluatorError> { + self.check_mcts_passabands(mcts)?; + self.features + .iter() + .map(|x| x.eval_or_fill_multicolor(mcts, fill_value)) + .flatten_ok() + .collect() + } +} + +#[derive(Serialize, Deserialize, JsonSchema)] +#[serde(rename = "MultiColorExtractor")] +struct MultiColorExtractorParameters { + features: Vec, +} + +impl From> for MultiColorExtractorParameters +where + P: PassbandTrait, + T: Float, + MCF: MultiColorEvaluator, +{ + fn from(f: MultiColorExtractor) -> Self { + Self { + features: f.features, + } + } +} + +impl From> for MultiColorExtractor +where + P: PassbandTrait, + T: Float, + MCF: MultiColorEvaluator, +{ + fn from(p: MultiColorExtractorParameters) -> Self { + Self::new(p.features) + } +} + +impl JsonSchema for MultiColorExtractor +where + P: PassbandTrait, + T: Float, + MCF: JsonSchema, +{ + json_schema!(MultiColorExtractorParameters, true); +} + +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +#[serde(bound( + deserialize = "P: PassbandTrait + Deserialize<'de>, T: Float, F: FeatureEvaluator" +))] +pub struct MonochromeFeature +where + P: Ord, +{ + feature: F, + passband_set: PassbandSet

, + properties: Box, + phantom: PhantomData, +} + +impl MonochromeFeature +where + P: PassbandTrait, + T: Float, + F: FeatureEvaluator, +{ + pub fn new(feature: F, passband_set: BTreeSet

) -> Self { + let names = passband_set + .iter() + .cartesian_product(feature.get_names()) + .map(|(passband, name)| format!("{}_{}", name, passband.name())) + .collect(); + let descriptions = passband_set + .iter() + .cartesian_product(feature.get_descriptions()) + .map(|(passband, description)| format!("{}, passband {}", description, passband.name())) + .collect(); + let info = { + let mut info = feature.get_info().clone(); + info.size *= passband_set.len(); + info + }; + Self { + properties: EvaluatorProperties { + info, + names, + descriptions, + } + .into(), + feature, + passband_set: passband_set.into(), + phantom: PhantomData, + } + } +} + +impl FeatureNamesDescriptionsTrait for MonochromeFeature +where + P: Ord, +{ + fn get_names(&self) -> Vec<&str> { + self.properties.names.iter().map(String::as_str).collect() + } + + fn get_descriptions(&self) -> Vec<&str> { + self.properties + .descriptions + .iter() + .map(String::as_str) + .collect() + } +} + +impl EvaluatorInfoTrait for MonochromeFeature +where + P: Ord, +{ + fn get_info(&self) -> &EvaluatorInfo { + &self.properties.info + } +} + +impl MultiColorPassbandSetTrait

for MonochromeFeature +where + P: PassbandTrait, +{ + fn get_passband_set(&self) -> &PassbandSet

{ + &self.passband_set + } +} + +impl MultiColorEvaluator for MonochromeFeature +where + P: PassbandTrait, + T: Float, + F: FeatureEvaluator, +{ + fn eval_multicolor( + &self, + mcts: &mut MultiColorTimeSeries, + ) -> Result, MultiColorEvaluatorError> { + self.check_mcts_passabands(mcts)?; + match &self.passband_set { + PassbandSet::FixedSet(set) => set + .iter() + .map(|passband| { + self.feature.eval(mcts.0.get_mut(passband).expect( + "we checked all needed passbands are in mcts, but we still cannot find one", + )).map_err(|error| MultiColorEvaluatorError::MonochromeEvaluatorError { + passband: passband.name().into(), + error, + }) + }) + .flatten_ok() + .collect(), + PassbandSet::AllAvailable => panic!("passband_set must be FixedSet variant here"), + } + } +} + +#[enum_dispatch(MultiColorEvaluator, FeatureNamesDescriptionsTrait, EvaluatorInfoTrait, MultiColorPassbandSetTrait

)] +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +#[serde(bound(deserialize = "P: PassbandTrait + Deserialize<'de>, T: Float"))] +#[non_exhaustive] +pub enum MultiColorFeature +where + P: PassbandTrait, + T: Float, +{ + // Extractor + MultiColorExtractor(MultiColorExtractor>), + // Monochrome Features + MonochromeFeature(MonochromeFeature>), + // Features + ColorOfMedian(color_median::ColorOfMedian

), +} + +impl MultiColorFeature +where + P: PassbandTrait, + T: Float, +{ + pub fn from_monochrome_feature(feature: F, passband_set: BTreeSet

) -> Self + where + F: Into>, + { + MonochromeFeature::new(feature.into(), passband_set).into() + } +} + +/// Example of multicolor light-curve feature evaluator +mod color_median { + use super::*; + use crate::{FeatureEvaluator, Median}; + + #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] + #[serde(bound(deserialize = "P: PassbandTrait + Deserialize<'de>"))] + pub struct ColorOfMedian

+ where + P: Ord, + { + passband_set: PassbandSet

, + passbands: [P; 2], + median: Median, + name: String, + description: String, + } + + impl

ColorOfMedian

+ where + P: PassbandTrait, + { + pub fn new(passbands: [P; 2]) -> Self { + let set: BTreeSet<_> = passbands.clone().into(); + Self { + passband_set: set.into(), + name: format!( + "color_median_{}_{}", + passbands[0].name(), + passbands[1].name() + ), + description: format!( + "difference of median magnitudes {}-{}", + passbands[0].name(), + passbands[1].name() + ), + passbands, + median: Median {}, + } + } + } + + lazy_info!( + COLOR_MEDIAN_INFO, + size: 1, + min_ts_length: 1, + t_required: false, + m_required: true, + w_required: false, + sorting_required: false, + ); + + impl

EvaluatorInfoTrait for ColorOfMedian

+ where + P: Ord, + { + fn get_info(&self) -> &EvaluatorInfo { + &COLOR_MEDIAN_INFO + } + } + + impl

FeatureNamesDescriptionsTrait for ColorOfMedian

+ where + P: Ord, + { + fn get_names(&self) -> Vec<&str> { + vec![self.name.as_str()] + } + + fn get_descriptions(&self) -> Vec<&str> { + vec![self.description.as_str()] + } + } + + impl

MultiColorPassbandSetTrait

for ColorOfMedian

+ where + P: PassbandTrait, + { + fn get_passband_set(&self) -> &PassbandSet

{ + &self.passband_set + } + } + + impl MultiColorEvaluator for ColorOfMedian

+ where + P: PassbandTrait, + T: Float, + { + fn eval_multicolor( + &self, + mcts: &mut MultiColorTimeSeries, + ) -> Result, MultiColorEvaluatorError> { + self.check_mcts_passabands(mcts)?; + let mut medians = [T::zero(); 2]; + for (median, passband) in medians.iter_mut().zip(self.passbands.iter()) { + *median = self + .median + .eval(mcts.0.get_mut(passband).expect( + "we checked all needed passbands are in mcts, but we still cannot find one", + )) + .map_err(|error| MultiColorEvaluatorError::MonochromeEvaluatorError { + passband: passband.name().into(), + error, + })?[0] + } + Ok(vec![medians[0] - medians[1]]) + } + } +} From 383e94008f85ab56a69f536fa9faae29fd3216e9 Mon Sep 17 00:00:00 2001 From: Konstantin Malanchev Date: Sun, 28 Aug 2022 16:07:53 -0500 Subject: [PATCH 3/9] Create data module --- src/data/data_sample.rs | 314 +++++++++++++ src/data/mod.rs | 12 + src/data/multi_color_time_series.rs | 30 ++ src/{ => data}/sorted_array.rs | 1 + src/data/time_series.rs | 252 +++++++++++ src/evaluator.rs | 2 +- src/extractor.rs | 2 +- src/feature.rs | 2 +- src/features/median_absolute_deviation.rs | 2 +- src/features/otsu_split.rs | 2 +- src/lib.rs | 8 +- src/multicolor.rs | 14 +- src/nl_fit/data.rs | 2 +- src/nl_fit/evaluator.rs | 2 +- src/periodogram/freq.rs | 2 +- src/periodogram/mod.rs | 4 +- src/periodogram/power_direct.rs | 2 +- src/periodogram/power_fft.rs | 2 +- src/periodogram/power_trait.rs | 2 +- src/prelude.rs | 2 +- src/straight_line_fit.rs | 2 +- src/tests.rs | 2 +- src/time_series.rs | 520 ---------------------- 23 files changed, 633 insertions(+), 550 deletions(-) create mode 100644 src/data/data_sample.rs create mode 100644 src/data/mod.rs create mode 100644 src/data/multi_color_time_series.rs rename src/{ => data}/sorted_array.rs (99%) create mode 100644 src/data/time_series.rs delete mode 100644 src/time_series.rs diff --git a/src/data/data_sample.rs b/src/data/data_sample.rs new file mode 100644 index 00000000..44e5073f --- /dev/null +++ b/src/data/data_sample.rs @@ -0,0 +1,314 @@ +use crate::data::sorted_array::SortedArray; +use crate::float_trait::Float; +use crate::types::CowArray1; + +use conv::prelude::*; +use ndarray::{s, Array1, ArrayView1, Zip}; + +/// A [`TimeSeries`] component +#[derive(Clone, Debug)] +pub struct DataSample<'a, T> +where + T: Float, +{ + pub sample: CowArray1<'a, T>, + sorted: Option>, + min: Option, + max: Option, + mean: Option, + median: Option, + std: Option, + std2: Option, +} + +macro_rules! data_sample_getter { + ($attr: ident, $getter: ident, $func: expr, $method_sorted: ident) => { + // This lint is false-positive in macros + // https://github.com/rust-lang/rust-clippy/issues/1553 + #[allow(clippy::redundant_closure_call)] + pub fn $getter(&mut self) -> T { + match self.$attr { + Some(x) => x, + None => { + self.$attr = Some(match self.sorted.as_ref() { + Some(sorted) => sorted.$method_sorted(), + None => $func(self), + }); + self.$attr.unwrap() + } + } + } + }; + ($attr: ident, $getter: ident, $func: expr) => { + // This lint is false-positive in macros + // https://github.com/rust-lang/rust-clippy/issues/1553 + #[allow(clippy::redundant_closure_call)] + pub fn $getter(&mut self) -> T { + match self.$attr { + Some(x) => x, + None => { + self.$attr = Some($func(self)); + self.$attr.unwrap() + } + } + } + }; +} + +impl<'a, T> DataSample<'a, T> +where + T: Float, +{ + pub fn new(sample: CowArray1<'a, T>) -> Self { + Self { + sample, + sorted: None, + min: None, + max: None, + mean: None, + median: None, + std: None, + std2: None, + } + } + + pub fn as_slice(&mut self) -> &[T] { + if !self.sample.is_standard_layout() { + let owned: Array1<_> = self.sample.iter().copied().collect::>().into(); + self.sample = owned.into(); + } + self.sample.as_slice().unwrap() + } + + pub fn get_sorted(&mut self) -> &SortedArray { + if self.sorted.is_none() { + self.sorted = Some(self.sample.to_vec().into()); + } + self.sorted.as_ref().unwrap() + } + + fn set_min_max(&mut self) { + let (min, max) = + self.sample + .slice(s![1..]) + .fold((self.sample[0], self.sample[0]), |(min, max), &x| { + if x > max { + (min, x) + } else if x < min { + (x, max) + } else { + (min, max) + } + }); + self.min = Some(min); + self.max = Some(max); + } + + data_sample_getter!( + min, + get_min, + |ds: &mut DataSample<'a, T>| { + ds.set_min_max(); + ds.min.unwrap() + }, + minimum + ); + data_sample_getter!( + max, + get_max, + |ds: &mut DataSample<'a, T>| { + ds.set_min_max(); + ds.max.unwrap() + }, + maximum + ); + data_sample_getter!(mean, get_mean, |ds: &mut DataSample<'a, T>| { + ds.sample.mean().expect("time series must be non-empty") + }); + data_sample_getter!(median, get_median, |ds: &mut DataSample<'a, T>| { + ds.get_sorted().median() + }); + data_sample_getter!(std, get_std, |ds: &mut DataSample<'a, T>| { + ds.get_std2().sqrt() + }); + data_sample_getter!(std2, get_std2, |ds: &mut DataSample<'a, T>| { + // Benchmarks show that it is faster than `ndarray::ArrayBase::var(T::one)` + let mean = ds.get_mean(); + ds.sample + .fold(T::zero(), |sum, &x| sum + (x - mean).powi(2)) + / (ds.sample.len() - 1).approx().unwrap() + }); + + pub fn signal_to_noise(&mut self, value: T) -> T { + if self.get_std().is_zero() { + T::zero() + } else { + (value - self.get_mean()) / self.get_std() + } + } + + /// Returns true if all values are equal. Always true for zero- or one- length + pub fn is_all_same(&self) -> bool { + if self.sample.is_empty() { + return true; + } + if self.max.is_some() && self.max == self.min { + return true; + } + if self.std2 == Some(T::zero()) { + return true; + } + if let Some(sorted) = &self.sorted { + return sorted[0] == sorted[sorted.len() - 1]; + } + let x0 = self.sample[0]; + // all() returns true for the empty slice, i.e. single-point time series + Zip::from(self.sample.slice(s![1..])).all(|&x| x == x0) + } +} + +impl<'a, T> From> for DataSample<'a, T> +where + T: Float, +{ + fn from(sorted: SortedArray) -> Self { + let sample = sorted.0.clone().into(); + Self { + sample, + sorted: Some(sorted), + min: None, + max: None, + median: None, + mean: None, + std: None, + std2: None, + } + } +} + +impl<'a, T, Slice: ?Sized> From<&'a Slice> for DataSample<'a, T> +where + T: Float, + Slice: AsRef<[T]>, +{ + fn from(s: &'a Slice) -> Self { + ArrayView1::from(s).into() + } +} + +impl<'a, T> From> for DataSample<'a, T> +where + T: Float, +{ + fn from(v: Vec) -> Self { + Array1::from(v).into() + } +} + +impl<'a, T> From> for DataSample<'a, T> +where + T: Float, +{ + fn from(a: ArrayView1<'a, T>) -> Self { + Self::new(a.into()) + } +} + +impl<'a, T> From> for DataSample<'a, T> +where + T: Float, +{ + fn from(a: Array1) -> Self { + Self::new(a.into()) + } +} + +impl<'a, T> From> for DataSample<'a, T> +where + T: Float, +{ + fn from(a: CowArray1<'a, T>) -> Self { + Self::new(a) + } +} + +#[cfg(test)] +#[allow(clippy::unreadable_literal)] +#[allow(clippy::excessive_precision)] +mod tests { + use super::*; + + use approx::assert_relative_eq; + + macro_rules! data_sample_test { + ($name: ident, $method: ident, $desired: literal, $x: tt $(,)?) => { + #[test] + fn $name() { + let x = $x; + let desired = $desired; + + let mut ds: DataSample<_> = DataSample::from(&x); + assert_relative_eq!(ds.$method(), desired, epsilon = 1e-6); + assert_relative_eq!(ds.$method(), desired, epsilon = 1e-6); + + let mut ds: DataSample<_> = DataSample::from(&x); + ds.get_sorted(); + assert_relative_eq!(ds.$method(), desired, epsilon = 1e-6); + assert_relative_eq!(ds.$method(), desired, epsilon = 1e-6); + } + }; + } + + data_sample_test!( + data_sample_min, + get_min, + -7.79420906, + [3.92948846, 3.28436964, 6.73375373, -7.79420906, -7.23407407], + ); + + data_sample_test!( + data_sample_max, + get_max, + 6.73375373, + [3.92948846, 3.28436964, 6.73375373, -7.79420906, -7.23407407], + ); + + data_sample_test!( + data_sample_mean, + get_mean, + -0.21613426, + [3.92948846, 3.28436964, 6.73375373, -7.79420906, -7.23407407], + ); + + data_sample_test!( + data_sample_median_odd, + get_median, + 3.28436964, + [3.92948846, 3.28436964, 6.73375373, -7.79420906, -7.23407407], + ); + + data_sample_test!( + data_sample_median_even, + get_median, + 5.655794743124782, + [9.47981408, 3.86815751, 9.90299294, -2.986894, 7.44343197, 1.52751816], + ); + + data_sample_test!( + data_sample_std, + get_std, + 6.7900544035968435, + [3.92948846, 3.28436964, 6.73375373, -7.79420906, -7.23407407], + ); + + /// https://github.com/light-curve/light-curve-feature/issues/95 + #[test] + fn std2_overflow() { + const N: usize = (1 << 24) + 2; + // Such a large integer cannot be represented as a float32 + let x = Array1::linspace(0.0_f32, 1.0, N); + let mut ds = DataSample::new(x.into()); + // This should not panic + let _std2 = ds.get_std2(); + } +} diff --git a/src/data/mod.rs b/src/data/mod.rs new file mode 100644 index 00000000..aafaeb99 --- /dev/null +++ b/src/data/mod.rs @@ -0,0 +1,12 @@ +mod data_sample; +pub use data_sample::DataSample; + +mod multi_color_time_series; +pub use multi_color_time_series::MultiColorTimeSeries; + +mod sorted_array; +pub use sorted_array::SortedArray; + +mod time_series; + +pub use time_series::TimeSeries; diff --git a/src/data/multi_color_time_series.rs b/src/data/multi_color_time_series.rs new file mode 100644 index 00000000..3cd14a5a --- /dev/null +++ b/src/data/multi_color_time_series.rs @@ -0,0 +1,30 @@ +use crate::data::TimeSeries; +use crate::float_trait::Float; +use crate::multicolor::PassbandTrait; + +use std::collections::BTreeMap; +use std::ops::{Deref, DerefMut}; + +pub struct MultiColorTimeSeries<'a, P: PassbandTrait, T: Float>(BTreeMap>); + +impl<'a, P: PassbandTrait, T: Float> Deref for MultiColorTimeSeries<'a, P, T> { + type Target = BTreeMap>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl<'a, P: PassbandTrait, T: Float> DerefMut for MultiColorTimeSeries<'a, P, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl<'a, P: PassbandTrait, T: Float> FromIterator<(P, TimeSeries<'a, T>)> + for MultiColorTimeSeries<'a, P, T> +{ + fn from_iter)>>(iter: I) -> Self { + Self(iter.into_iter().collect()) + } +} diff --git a/src/sorted_array.rs b/src/data/sorted_array.rs similarity index 99% rename from src/sorted_array.rs rename to src/data/sorted_array.rs index 917fcf84..3686ff82 100644 --- a/src/sorted_array.rs +++ b/src/data/sorted_array.rs @@ -1,5 +1,6 @@ use crate::error::SortedArrayError; use crate::float_trait::Float; + use conv::prelude::*; use itertools::Itertools; use ndarray::Array1; diff --git a/src/data/time_series.rs b/src/data/time_series.rs new file mode 100644 index 00000000..855e7819 --- /dev/null +++ b/src/data/time_series.rs @@ -0,0 +1,252 @@ +use crate::data::data_sample::DataSample; +use crate::float_trait::Float; + +use conv::prelude::*; +use itertools::Itertools; +#[cfg(test)] +use ndarray::Array1; +use ndarray::Zip; +use ndarray_stats::SummaryStatisticsExt; + +/// Time series object to be put into [Feature](crate::Feature) +/// +/// This struct caches it's properties, like mean magnitude value, etc., that's why mutable +/// reference is required fot feature evaluation +#[derive(Clone, Debug)] +pub struct TimeSeries<'a, T> +where + T: Float, +{ + pub t: DataSample<'a, T>, + pub m: DataSample<'a, T>, + pub w: DataSample<'a, T>, + m_weighted_mean: Option, + m_reduced_chi2: Option, + t_max_m: Option, + t_min_m: Option, + plateau: Option, +} + +macro_rules! time_series_getter { + ($t: ty, $attr: ident, $getter: ident, $func: expr) => { + // This lint is false-positive in macros + // https://github.com/rust-lang/rust-clippy/issues/1553 + #[allow(clippy::redundant_closure_call)] + pub fn $getter(&mut self) -> $t { + match self.$attr { + Some(x) => x, + None => { + self.$attr = Some($func(self)); + self.$attr.unwrap() + } + } + } + }; + + ($attr: ident, $getter: ident, $func: expr) => { + time_series_getter!(T, $attr, $getter, $func); + }; +} + +impl<'a, T> TimeSeries<'a, T> +where + T: Float, +{ + /// Construct `TimeSeries` from array-like objects + /// + /// `t` is time, `m` is magnitude (or flux), `w` is weights. + /// + /// All arrays must have the same length, `t` must increase monotonically. Input arrays could be + /// [`ndarray::Array1`], [`ndarray::ArrayView1`], 1-D [`ndarray::CowArray`], or `&[T]`. Several + /// features assumes that `w` array corresponds to inverse square errors of `m`. + pub fn new( + t: impl Into>, + m: impl Into>, + w: impl Into>, + ) -> Self { + let t = t.into(); + let m = m.into(); + let w = w.into(); + + assert_eq!( + t.sample.len(), + m.sample.len(), + "t and m should have the same size" + ); + assert_eq!( + m.sample.len(), + w.sample.len(), + "m and err should have the same size" + ); + + Self { + t, + m, + w, + m_weighted_mean: None, + m_reduced_chi2: None, + t_max_m: None, + t_min_m: None, + plateau: None, + } + } + + /// Construct [`TimeSeries`] from time and magnitude (flux) + /// + /// It is the same as [`TimeSeries::new`], but sets unity weights. It doesn't recommended to use + /// it for features dependent on weights / observation errors like [`crate::StetsonK`] or + /// [`crate::LinearFit`]. + pub fn new_without_weight( + t: impl Into>, + m: impl Into>, + ) -> Self { + let t = t.into(); + let m = m.into(); + + assert_eq!( + t.sample.len(), + m.sample.len(), + "t and m should have the same size" + ); + + let w = T::array0_unity().broadcast(t.sample.len()).unwrap().into(); + + Self { + t, + m, + w, + m_weighted_mean: None, + m_reduced_chi2: None, + t_max_m: None, + t_min_m: None, + plateau: None, + } + } + + /// Time series length + #[inline] + pub fn lenu(&self) -> usize { + self.t.sample.len() + } + + /// Float approximating time series length + pub fn lenf(&self) -> T { + self.lenu().approx().unwrap() + } + + time_series_getter!( + m_weighted_mean, + get_m_weighted_mean, + |ts: &mut TimeSeries| { ts.m.sample.weighted_mean(&ts.w.sample).unwrap() } + ); + + time_series_getter!(m_reduced_chi2, get_m_reduced_chi2, |ts: &mut TimeSeries< + T, + >| { + let m_weighed_mean = ts.get_m_weighted_mean(); + let m_reduced_chi2 = Zip::from(&ts.m.sample) + .and(&ts.w.sample) + .fold(T::zero(), |chi2, &m, &w| { + chi2 + (m - m_weighed_mean).powi(2) * w + }) + / (ts.lenf() - T::one()); + if m_reduced_chi2.is_zero() { + ts.plateau = Some(true); + } + m_reduced_chi2 + }); + + time_series_getter!(bool, plateau, is_plateau, |ts: &mut TimeSeries| { + ts.m.is_all_same() + }); + + fn set_t_min_max_m(&mut self) { + let (i_min, i_max) = self + .m + .as_slice() + .iter() + .position_minmax() + .into_option() + .expect("time series must be non-empty"); + self.t_min_m = Some(self.t.sample[i_min]); + self.t_max_m = Some(self.t.sample[i_max]); + } + + pub fn get_t_min_m(&mut self) -> T { + if self.t_min_m.is_none() { + self.set_t_min_max_m(); + } + self.t_min_m.unwrap() + } + + pub fn get_t_max_m(&mut self) -> T { + if self.t_max_m.is_none() { + self.set_t_min_max_m(); + } + self.t_max_m.unwrap() + } +} + +// We really don't want it to be public, it is a private helper for test-data functions +#[cfg(test)] +impl<'a, T, D> From<(D, D, D)> for TimeSeries<'a, T> +where + T: Float, + D: Into>, +{ + fn from(v: (D, D, D)) -> Self { + Self::new(v.0, v.1, v.2) + } +} + +#[cfg(test)] +impl<'a, T> From<&'a (Array1, Array1, Array1)> for TimeSeries<'a, T> +where + T: Float, +{ + fn from(v: &'a (Array1, Array1, Array1)) -> Self { + Self::new(v.0.view(), v.1.view(), v.2.view()) + } +} + +#[cfg(test)] +#[allow(clippy::unreadable_literal)] +#[allow(clippy::excessive_precision)] +mod tests { + use super::*; + + use approx::assert_relative_eq; + + #[test] + fn time_series_m_weighted_mean() { + let t: Vec<_> = (0..5).map(|i| i as f64).collect(); + let m = [ + 12.77883145, + 18.89988406, + 17.55633632, + 18.36073996, + 11.83854198, + ]; + let w = [0.1282489, 0.10576467, 0.32102692, 0.12962352, 0.10746144]; + let mut ts = TimeSeries::new(&t, &m, &w); + // np.average(m, weights=w) + let desired = 16.31817047752941; + assert_relative_eq!(ts.get_m_weighted_mean(), desired, epsilon = 1e-6); + } + + #[test] + fn time_series_m_reduced_chi2() { + let t: Vec<_> = (0..5).map(|i| i as f64).collect(); + let m = [ + 12.77883145, + 18.89988406, + 17.55633632, + 18.36073996, + 11.83854198, + ]; + let w = [0.1282489, 0.10576467, 0.32102692, 0.12962352, 0.10746144]; + let mut ts = TimeSeries::new(&t, &m, &w); + let desired = 1.3752251301435465; + assert_relative_eq!(ts.get_m_reduced_chi2(), desired, epsilon = 1e-6); + } +} diff --git a/src/evaluator.rs b/src/evaluator.rs index 5c1078c3..fd3a0c23 100644 --- a/src/evaluator.rs +++ b/src/evaluator.rs @@ -1,6 +1,6 @@ +pub use crate::data::TimeSeries; pub use crate::error::EvaluatorError; pub use crate::float_trait::Float; -pub use crate::time_series::TimeSeries; pub use conv::errors::GeneralError; use enum_dispatch::enum_dispatch; diff --git a/src/extractor.rs b/src/extractor.rs index 8830d503..0eedee08 100644 --- a/src/extractor.rs +++ b/src/extractor.rs @@ -1,8 +1,8 @@ +use crate::data::TimeSeries; use crate::error::EvaluatorError; use crate::evaluator::*; use crate::feature::Feature; use crate::float_trait::Float; -use crate::time_series::TimeSeries; use std::marker::PhantomData; diff --git a/src/feature.rs b/src/feature.rs index 018b7b91..73772698 100644 --- a/src/feature.rs +++ b/src/feature.rs @@ -1,8 +1,8 @@ +use crate::data::TimeSeries; use crate::evaluator::*; use crate::extractor::FeatureExtractor; use crate::features::*; use crate::float_trait::Float; -use crate::time_series::TimeSeries; use crate::transformers::Transformer; use enum_dispatch::enum_dispatch; diff --git a/src/features/median_absolute_deviation.rs b/src/features/median_absolute_deviation.rs index dad9df04..2dae9af9 100644 --- a/src/features/median_absolute_deviation.rs +++ b/src/features/median_absolute_deviation.rs @@ -1,5 +1,5 @@ +use crate::data::SortedArray; use crate::evaluator::*; -use crate::sorted_array::SortedArray; macro_const! { const DOC: &'static str = r" diff --git a/src/features/otsu_split.rs b/src/features/otsu_split.rs index 8c3bb73a..165d72a8 100644 --- a/src/features/otsu_split.rs +++ b/src/features/otsu_split.rs @@ -1,5 +1,5 @@ +use crate::data::DataSample; use crate::evaluator::*; -use crate::time_series::DataSample; use conv::prelude::*; use ndarray::{s, Array1, ArrayView1, Axis, Zip}; use ndarray_stats::QuantileExt; diff --git a/src/lib.rs b/src/lib.rs index d77b787d..ccd8f5bb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,6 +7,9 @@ mod tests; #[macro_use] mod macros; +mod data; +pub use data::{DataSample, TimeSeries}; + mod evaluator; pub use evaluator::{EvaluatorInfoTrait, FeatureEvaluator, FeatureNamesDescriptionsTrait}; @@ -49,8 +52,6 @@ pub use periodogram::{ pub mod prelude; -mod sorted_array; - mod straight_line_fit; #[doc(hidden)] pub use straight_line_fit::fit_straight_line; @@ -62,9 +63,6 @@ mod peak_indices; #[doc(hidden)] pub use peak_indices::peak_indices; -mod time_series; -pub use time_series::{DataSample, TimeSeries}; - mod types; pub use ndarray; diff --git a/src/multicolor.rs b/src/multicolor.rs index fc9a3657..9b069c27 100644 --- a/src/multicolor.rs +++ b/src/multicolor.rs @@ -1,3 +1,4 @@ +use crate::data::MultiColorTimeSeries; use crate::error::MultiColorEvaluatorError; use crate::evaluator::{ EvaluatorError, EvaluatorInfo, EvaluatorInfoTrait, EvaluatorProperties, FeatureEvaluator, @@ -5,7 +6,6 @@ use crate::evaluator::{ }; use crate::feature::Feature; use crate::float_trait::Float; -use crate::time_series::TimeSeries; use enum_dispatch::enum_dispatch; use itertools::Itertools; @@ -91,8 +91,6 @@ impl PassbandTrait for NoPassband { } } -pub struct MultiColorTimeSeries<'a, P: PassbandTrait, T: Float>(BTreeMap>); - #[enum_dispatch] pub trait MultiColorPassbandSetTrait

where @@ -167,14 +165,13 @@ where PassbandSet::AllAvailable => Ok(()), PassbandSet::FixedSet(self_passbands) => { if mcts - .0 .keys() .all(|mcts_passband| self_passbands.contains(mcts_passband)) { Ok(()) } else { Err(MultiColorEvaluatorError::wrong_passbands_error( - mcts.0.keys(), + mcts.keys(), self_passbands.iter(), )) } @@ -189,8 +186,7 @@ where ) -> Result, MultiColorEvaluatorError> { // Use try_reduce when stabilizes // https://github.com/rust-lang/rust/issues/87053 - mcts.0 - .iter() + mcts.iter() .map(|(passband, ts)| { let length = ts.lenu(); if length < self.min_ts_length() { @@ -483,7 +479,7 @@ where PassbandSet::FixedSet(set) => set .iter() .map(|passband| { - self.feature.eval(mcts.0.get_mut(passband).expect( + self.feature.eval(mcts.get_mut(passband).expect( "we checked all needed passbands are in mcts, but we still cannot find one", )).map_err(|error| MultiColorEvaluatorError::MonochromeEvaluatorError { passband: passband.name().into(), @@ -624,7 +620,7 @@ mod color_median { for (median, passband) in medians.iter_mut().zip(self.passbands.iter()) { *median = self .median - .eval(mcts.0.get_mut(passband).expect( + .eval(mcts.get_mut(passband).expect( "we checked all needed passbands are in mcts, but we still cannot find one", )) .map_err(|error| MultiColorEvaluatorError::MonochromeEvaluatorError { diff --git a/src/nl_fit/data.rs b/src/nl_fit/data.rs index c5fba037..8cd6952b 100644 --- a/src/nl_fit/data.rs +++ b/src/nl_fit/data.rs @@ -1,5 +1,5 @@ +use crate::data::{DataSample, TimeSeries}; use crate::float_trait::Float; -use crate::time_series::{DataSample, TimeSeries}; use conv::ConvUtil; use ndarray::Array1; diff --git a/src/nl_fit/evaluator.rs b/src/nl_fit/evaluator.rs index 3000ebb0..1b83e741 100644 --- a/src/nl_fit/evaluator.rs +++ b/src/nl_fit/evaluator.rs @@ -1,6 +1,6 @@ +use crate::data::TimeSeries; use crate::float_trait::Float; use crate::nl_fit::{data::NormalizedData, CurveFitAlgorithm, LikeFloat, LnPrior}; -use crate::time_series::TimeSeries; use schemars::JsonSchema; use serde::de::DeserializeOwned; diff --git a/src/periodogram/freq.rs b/src/periodogram/freq.rs index d5b81662..177f3f60 100644 --- a/src/periodogram/freq.rs +++ b/src/periodogram/freq.rs @@ -1,5 +1,5 @@ +use crate::data::SortedArray; use crate::float_trait::Float; -use crate::sorted_array::SortedArray; use conv::{ConvAsUtil, ConvUtil, RoundToNearest}; use enum_dispatch::enum_dispatch; diff --git a/src/periodogram/mod.rs b/src/periodogram/mod.rs index af7a29df..6bc477ea 100644 --- a/src/periodogram/mod.rs +++ b/src/periodogram/mod.rs @@ -1,7 +1,7 @@ //! Periodogram-related stuff +use crate::data::TimeSeries; use crate::float_trait::Float; -use crate::time_series::TimeSeries; use conv::ConvAsUtil; use enum_dispatch::enum_dispatch; @@ -107,8 +107,8 @@ where mod tests { use super::*; + use crate::data::SortedArray; use crate::peak_indices::peak_indices_reverse_sorted; - use crate::sorted_array::SortedArray; use light_curve_common::{all_close, linspace}; use rand::prelude::*; diff --git a/src/periodogram/power_direct.rs b/src/periodogram/power_direct.rs index 1b659f27..a8020318 100644 --- a/src/periodogram/power_direct.rs +++ b/src/periodogram/power_direct.rs @@ -1,8 +1,8 @@ +use crate::data::TimeSeries; use crate::float_trait::Float; use crate::periodogram::freq::FreqGrid; use crate::periodogram::power_trait::*; use crate::periodogram::recurrent_sin_cos::*; -use crate::time_series::TimeSeries; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; diff --git a/src/periodogram/power_fft.rs b/src/periodogram/power_fft.rs index cb8d8647..40526df9 100644 --- a/src/periodogram/power_fft.rs +++ b/src/periodogram/power_fft.rs @@ -1,8 +1,8 @@ +use crate::data::TimeSeries; use crate::float_trait::Float; use crate::periodogram::fft::*; use crate::periodogram::freq::FreqGrid; use crate::periodogram::power_trait::*; -use crate::time_series::TimeSeries; use conv::{ConvAsUtil, RoundToNearest}; use schemars::JsonSchema; diff --git a/src/periodogram/power_trait.rs b/src/periodogram/power_trait.rs index 0ea6b723..46c271be 100644 --- a/src/periodogram/power_trait.rs +++ b/src/periodogram/power_trait.rs @@ -1,6 +1,6 @@ +use crate::data::TimeSeries; use crate::float_trait::Float; use crate::periodogram::freq::FreqGrid; -use crate::time_series::TimeSeries; use enum_dispatch::enum_dispatch; use std::fmt::Debug; diff --git a/src/prelude.rs b/src/prelude.rs index 704e1d47..0ecb0672 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -1,3 +1,4 @@ +pub use crate::data::TimeSeries; pub use crate::error::EvaluatorError; pub use crate::evaluator::{EvaluatorInfoTrait, FeatureEvaluator, FeatureNamesDescriptionsTrait}; pub use crate::extractor::FeatureExtractor; @@ -5,4 +6,3 @@ pub use crate::feature::Feature; pub use crate::features::*; pub use crate::float_trait::Float; pub use crate::nl_fit::evaluator::*; -pub use crate::time_series::TimeSeries; diff --git a/src/straight_line_fit.rs b/src/straight_line_fit.rs index 7b68cba3..e582dea7 100644 --- a/src/straight_line_fit.rs +++ b/src/straight_line_fit.rs @@ -1,5 +1,5 @@ +use crate::data::TimeSeries; use crate::float_trait::Float; -use crate::time_series::TimeSeries; use ndarray::Zip; diff --git a/src/tests.rs b/src/tests.rs index cd987524..b271391d 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -1,8 +1,8 @@ +pub use crate::data::TimeSeries; pub use crate::evaluator::*; pub use crate::extractor::FeatureExtractor; pub use crate::feature::Feature; pub use crate::float_trait::Float; -pub use crate::time_series::TimeSeries; pub use light_curve_common::{all_close, linspace}; pub use ndarray::{Array1, ArrayView1}; diff --git a/src/time_series.rs b/src/time_series.rs deleted file mode 100644 index b2a99e4d..00000000 --- a/src/time_series.rs +++ /dev/null @@ -1,520 +0,0 @@ -use crate::float_trait::Float; -use crate::sorted_array::SortedArray; -use crate::types::CowArray1; - -use conv::prelude::*; -use itertools::Itertools; -use ndarray::{s, Array1, ArrayView1, Zip}; -use ndarray_stats::SummaryStatisticsExt; - -/// A [`TimeSeries`] component -#[derive(Clone, Debug)] -pub struct DataSample<'a, T> -where - T: Float, -{ - pub sample: CowArray1<'a, T>, - sorted: Option>, - min: Option, - max: Option, - mean: Option, - median: Option, - std: Option, - std2: Option, -} - -macro_rules! data_sample_getter { - ($attr: ident, $getter: ident, $func: expr, $method_sorted: ident) => { - // This lint is false-positive in macros - // https://github.com/rust-lang/rust-clippy/issues/1553 - #[allow(clippy::redundant_closure_call)] - pub fn $getter(&mut self) -> T { - match self.$attr { - Some(x) => x, - None => { - self.$attr = Some(match self.sorted.as_ref() { - Some(sorted) => sorted.$method_sorted(), - None => $func(self), - }); - self.$attr.unwrap() - } - } - } - }; - ($attr: ident, $getter: ident, $func: expr) => { - // This lint is false-positive in macros - // https://github.com/rust-lang/rust-clippy/issues/1553 - #[allow(clippy::redundant_closure_call)] - pub fn $getter(&mut self) -> T { - match self.$attr { - Some(x) => x, - None => { - self.$attr = Some($func(self)); - self.$attr.unwrap() - } - } - } - }; -} - -impl<'a, T> DataSample<'a, T> -where - T: Float, -{ - pub fn new(sample: CowArray1<'a, T>) -> Self { - Self { - sample, - sorted: None, - min: None, - max: None, - mean: None, - median: None, - std: None, - std2: None, - } - } - - pub fn as_slice(&mut self) -> &[T] { - if !self.sample.is_standard_layout() { - let owned: Array1<_> = self.sample.iter().copied().collect::>().into(); - self.sample = owned.into(); - } - self.sample.as_slice().unwrap() - } - - pub fn get_sorted(&mut self) -> &SortedArray { - if self.sorted.is_none() { - self.sorted = Some(self.sample.to_vec().into()); - } - self.sorted.as_ref().unwrap() - } - - fn set_min_max(&mut self) { - let (min, max) = - self.sample - .slice(s![1..]) - .fold((self.sample[0], self.sample[0]), |(min, max), &x| { - if x > max { - (min, x) - } else if x < min { - (x, max) - } else { - (min, max) - } - }); - self.min = Some(min); - self.max = Some(max); - } - - data_sample_getter!( - min, - get_min, - |ds: &mut DataSample<'a, T>| { - ds.set_min_max(); - ds.min.unwrap() - }, - minimum - ); - data_sample_getter!( - max, - get_max, - |ds: &mut DataSample<'a, T>| { - ds.set_min_max(); - ds.max.unwrap() - }, - maximum - ); - data_sample_getter!(mean, get_mean, |ds: &mut DataSample<'a, T>| { - ds.sample.mean().expect("time series must be non-empty") - }); - data_sample_getter!(median, get_median, |ds: &mut DataSample<'a, T>| { - ds.get_sorted().median() - }); - data_sample_getter!(std, get_std, |ds: &mut DataSample<'a, T>| { - ds.get_std2().sqrt() - }); - data_sample_getter!(std2, get_std2, |ds: &mut DataSample<'a, T>| { - // Benchmarks show that it is faster than `ndarray::ArrayBase::var(T::one)` - let mean = ds.get_mean(); - ds.sample - .fold(T::zero(), |sum, &x| sum + (x - mean).powi(2)) - / (ds.sample.len() - 1).approx().unwrap() - }); - - pub fn signal_to_noise(&mut self, value: T) -> T { - if self.get_std().is_zero() { - T::zero() - } else { - (value - self.get_mean()) / self.get_std() - } - } -} - -impl<'a, T, Slice: ?Sized> From<&'a Slice> for DataSample<'a, T> -where - T: Float, - Slice: AsRef<[T]>, -{ - fn from(s: &'a Slice) -> Self { - ArrayView1::from(s).into() - } -} - -impl<'a, T> From> for DataSample<'a, T> -where - T: Float, -{ - fn from(v: Vec) -> Self { - Array1::from(v).into() - } -} - -impl<'a, T> From> for DataSample<'a, T> -where - T: Float, -{ - fn from(a: ArrayView1<'a, T>) -> Self { - Self::new(a.into()) - } -} - -impl<'a, T> From> for DataSample<'a, T> -where - T: Float, -{ - fn from(a: Array1) -> Self { - Self::new(a.into()) - } -} - -impl<'a, T> From> for DataSample<'a, T> -where - T: Float, -{ - fn from(a: CowArray1<'a, T>) -> Self { - Self::new(a) - } -} - -/// Time series object to be put into [Feature](crate::Feature) -/// -/// This struct caches it's properties, like mean magnitude value, etc., that's why mutable -/// reference is required fot feature evaluation -#[derive(Clone, Debug)] -pub struct TimeSeries<'a, T> -where - T: Float, -{ - pub t: DataSample<'a, T>, - pub m: DataSample<'a, T>, - pub w: DataSample<'a, T>, - m_weighted_mean: Option, - m_reduced_chi2: Option, - t_max_m: Option, - t_min_m: Option, - plateau: Option, -} - -macro_rules! time_series_getter { - ($t: ty, $attr: ident, $getter: ident, $func: expr) => { - // This lint is false-positive in macros - // https://github.com/rust-lang/rust-clippy/issues/1553 - #[allow(clippy::redundant_closure_call)] - pub fn $getter(&mut self) -> $t { - match self.$attr { - Some(x) => x, - None => { - self.$attr = Some($func(self)); - self.$attr.unwrap() - } - } - } - }; - - ($attr: ident, $getter: ident, $func: expr) => { - time_series_getter!(T, $attr, $getter, $func); - }; -} - -impl<'a, T> TimeSeries<'a, T> -where - T: Float, -{ - /// Construct `TimeSeries` from array-like objects - /// - /// `t` is time, `m` is magnitude (or flux), `w` is weights. - /// - /// All arrays must have the same length, `t` must increase monotonically. Input arrays could be - /// [`ndarray::Array1`], [`ndarray::ArrayView1`], 1-D [`ndarray::CowArray`], or `&[T]`. Several - /// features assumes that `w` array corresponds to inverse square errors of `m`. - pub fn new( - t: impl Into>, - m: impl Into>, - w: impl Into>, - ) -> Self { - let t = t.into(); - let m = m.into(); - let w = w.into(); - - assert_eq!( - t.sample.len(), - m.sample.len(), - "t and m should have the same size" - ); - assert_eq!( - m.sample.len(), - w.sample.len(), - "m and err should have the same size" - ); - - Self { - t, - m, - w, - m_weighted_mean: None, - m_reduced_chi2: None, - t_max_m: None, - t_min_m: None, - plateau: None, - } - } - - /// Construct [`TimeSeries`] from time and magnitude (flux) - /// - /// It is the same as [`TimeSeries::new`], but sets unity weights. It doesn't recommended to use - /// it for features dependent on weights / observation errors like [`crate::StetsonK`] or - /// [`crate::LinearFit`]. - pub fn new_without_weight( - t: impl Into>, - m: impl Into>, - ) -> Self { - let t = t.into(); - let m = m.into(); - - assert_eq!( - t.sample.len(), - m.sample.len(), - "t and m should have the same size" - ); - - let w = T::array0_unity().broadcast(t.sample.len()).unwrap().into(); - - Self { - t, - m, - w, - m_weighted_mean: None, - m_reduced_chi2: None, - t_max_m: None, - t_min_m: None, - plateau: None, - } - } - - /// Time series length - #[inline] - pub fn lenu(&self) -> usize { - self.t.sample.len() - } - - /// Float approximating time series length - pub fn lenf(&self) -> T { - self.lenu().approx().unwrap() - } - - time_series_getter!( - m_weighted_mean, - get_m_weighted_mean, - |ts: &mut TimeSeries| { ts.m.sample.weighted_mean(&ts.w.sample).unwrap() } - ); - - time_series_getter!(m_reduced_chi2, get_m_reduced_chi2, |ts: &mut TimeSeries< - T, - >| { - let m_weighed_mean = ts.get_m_weighted_mean(); - let m_reduced_chi2 = Zip::from(&ts.m.sample) - .and(&ts.w.sample) - .fold(T::zero(), |chi2, &m, &w| { - chi2 + (m - m_weighed_mean).powi(2) * w - }) - / (ts.lenf() - T::one()); - if m_reduced_chi2.is_zero() { - ts.plateau = Some(true); - } - m_reduced_chi2 - }); - - time_series_getter!(bool, plateau, is_plateau, |ts: &mut TimeSeries| { - if ts.m.max.is_some() && ts.m.max == ts.m.min { - return true; - } - if ts.m.std2 == Some(T::zero()) { - return true; - } - let m0 = ts.m.sample[0]; - // all() returns true for the empty slice, i.e. one-point time series - Zip::from(ts.m.sample.slice(s![1..])).all(|&m| m == m0) - }); - - fn set_t_min_max_m(&mut self) { - let (i_min, i_max) = self - .m - .as_slice() - .iter() - .position_minmax() - .into_option() - .expect("time series must be non-empty"); - self.t_min_m = Some(self.t.sample[i_min]); - self.t_max_m = Some(self.t.sample[i_max]); - } - - pub fn get_t_min_m(&mut self) -> T { - if self.t_min_m.is_none() { - self.set_t_min_max_m(); - } - self.t_min_m.unwrap() - } - - pub fn get_t_max_m(&mut self) -> T { - if self.t_max_m.is_none() { - self.set_t_min_max_m(); - } - self.t_max_m.unwrap() - } -} - -// We really don't want it to be public, it is a private helper for test-util functions -#[cfg(test)] -impl<'a, T, D> From<(D, D, D)> for TimeSeries<'a, T> -where - T: Float, - D: Into>, -{ - fn from(v: (D, D, D)) -> Self { - Self::new(v.0, v.1, v.2) - } -} - -#[cfg(test)] -impl<'a, T> From<&'a (Array1, Array1, Array1)> for TimeSeries<'a, T> -where - T: Float, -{ - fn from(v: &'a (Array1, Array1, Array1)) -> Self { - Self::new(v.0.view(), v.1.view(), v.2.view()) - } -} - -#[cfg(test)] -#[allow(clippy::unreadable_literal)] -#[allow(clippy::excessive_precision)] -mod tests { - use super::*; - - use light_curve_common::all_close; - - macro_rules! data_sample_test { - ($name: ident, $method: ident, $desired: tt, $x: tt $(,)?) => { - #[test] - fn $name() { - let x = $x; - let desired = $desired; - - let mut ds: DataSample<_> = DataSample::from(&x); - all_close(&[ds.$method()], &desired[..], 1e-6); - all_close(&[ds.$method()], &desired[..], 1e-6); - - let mut ds: DataSample<_> = DataSample::from(&x); - ds.get_sorted(); - all_close(&[ds.$method()], &desired[..], 1e-6); - all_close(&[ds.$method()], &desired[..], 1e-6); - } - }; - } - - data_sample_test!( - data_sample_min, - get_min, - [-7.79420906], - [3.92948846, 3.28436964, 6.73375373, -7.79420906, -7.23407407], - ); - - data_sample_test!( - data_sample_max, - get_max, - [6.73375373], - [3.92948846, 3.28436964, 6.73375373, -7.79420906, -7.23407407], - ); - - data_sample_test!( - data_sample_mean, - get_mean, - [-0.21613426], - [3.92948846, 3.28436964, 6.73375373, -7.79420906, -7.23407407], - ); - - data_sample_test!( - data_sample_median_odd, - get_median, - [3.28436964], - [3.92948846, 3.28436964, 6.73375373, -7.79420906, -7.23407407], - ); - - data_sample_test!( - data_sample_median_even, - get_median, - [5.655794743124782], - [9.47981408, 3.86815751, 9.90299294, -2.986894, 7.44343197, 1.52751816], - ); - - data_sample_test!( - data_sample_std, - get_std, - [6.7900544035968435], - [3.92948846, 3.28436964, 6.73375373, -7.79420906, -7.23407407], - ); - - #[test] - fn time_series_m_weighted_mean() { - let t: Vec<_> = (0..5).map(|i| i as f64).collect(); - let m = [ - 12.77883145, - 18.89988406, - 17.55633632, - 18.36073996, - 11.83854198, - ]; - let w = [0.1282489, 0.10576467, 0.32102692, 0.12962352, 0.10746144]; - let mut ts = TimeSeries::new(&t, &m, &w); - // np.average(m, weights=w) - let desired = [16.31817047752941]; - all_close(&[ts.get_m_weighted_mean()], &desired[..], 1e-6); - } - - #[test] - fn time_series_m_reduced_chi2() { - let t: Vec<_> = (0..5).map(|i| i as f64).collect(); - let m = [ - 12.77883145, - 18.89988406, - 17.55633632, - 18.36073996, - 11.83854198, - ]; - let w = [0.1282489, 0.10576467, 0.32102692, 0.12962352, 0.10746144]; - let mut ts = TimeSeries::new(&t, &m, &w); - let desired = [1.3752251301435465]; - all_close(&[ts.get_m_reduced_chi2()], &desired[..], 1e-6); - } - - /// https://github.com/light-curve/light-curve-feature/issues/95 - #[test] - fn time_series_std2_overflow() { - const N: usize = (1 << 24) + 2; - // Such a large integer cannot be represented as a float32 - let x = Array1::linspace(0.0_f32, 1.0, N); - let mut ds = DataSample::new(x.into()); - // This should not panic - let _std2 = ds.get_std2(); - } -} From dc03212161a096b30cd62b7474233510a756f920 Mon Sep 17 00:00:00 2001 From: Konstantin Malanchev Date: Tue, 30 Aug 2022 13:09:57 -0500 Subject: [PATCH 4/9] Refactor multicolor.rs into multiple modules --- src/multicolor.rs | 634 ------------------ src/multicolor/features/color_of_median.rs | 120 ++++ src/multicolor/features/mod.rs | 2 + src/multicolor/mod.rs | 16 + src/multicolor/monochrome_feature.rs | 130 ++++ src/multicolor/multicolor_evaluator.rs | 130 ++++ src/multicolor/multicolor_extractor.rs | 188 ++++++ src/multicolor/multicolor_feature.rs | 45 ++ src/multicolor/passband/dump_passband.rs | 14 + src/multicolor/passband/mod.rs | 8 + .../passband/monochrome_passband.rs | 69 ++ src/multicolor/passband/passband_trait.rs | 7 + 12 files changed, 729 insertions(+), 634 deletions(-) delete mode 100644 src/multicolor.rs create mode 100644 src/multicolor/features/color_of_median.rs create mode 100644 src/multicolor/features/mod.rs create mode 100644 src/multicolor/mod.rs create mode 100644 src/multicolor/monochrome_feature.rs create mode 100644 src/multicolor/multicolor_evaluator.rs create mode 100644 src/multicolor/multicolor_extractor.rs create mode 100644 src/multicolor/multicolor_feature.rs create mode 100644 src/multicolor/passband/dump_passband.rs create mode 100644 src/multicolor/passband/mod.rs create mode 100644 src/multicolor/passband/monochrome_passband.rs create mode 100644 src/multicolor/passband/passband_trait.rs diff --git a/src/multicolor.rs b/src/multicolor.rs deleted file mode 100644 index 9b069c27..00000000 --- a/src/multicolor.rs +++ /dev/null @@ -1,634 +0,0 @@ -use crate::data::MultiColorTimeSeries; -use crate::error::MultiColorEvaluatorError; -use crate::evaluator::{ - EvaluatorError, EvaluatorInfo, EvaluatorInfoTrait, EvaluatorProperties, FeatureEvaluator, - FeatureNamesDescriptionsTrait, -}; -use crate::feature::Feature; -use crate::float_trait::Float; - -use enum_dispatch::enum_dispatch; -use itertools::Itertools; -pub use lazy_static::lazy_static; -pub use schemars::JsonSchema; -pub use serde::{Deserialize, Serialize}; -use std::cmp::Ordering; -use std::collections::{BTreeMap, BTreeSet}; -use std::fmt::Debug; -use std::marker::PhantomData; - -pub trait PassbandTrait: Debug + Clone + Send + Sync + Ord + Serialize + JsonSchema { - fn name(&self) -> &str; -} - -#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] -pub struct MonochromePassband<'a, T> { - pub name: &'a str, - pub wavelength: T, -} - -impl<'a, T> MonochromePassband<'a, T> -where - T: Float, -{ - pub fn new(wavelength: T, name: &'a str) -> Self { - assert!( - wavelength.is_normal(), - "wavelength must be a positive normal number" - ); - assert!( - wavelength.is_sign_positive(), - "wavelength must be a positive normal number" - ); - Self { wavelength, name } - } -} - -impl<'a, T> PartialEq for MonochromePassband<'a, T> -where - T: Float, -{ - fn eq(&self, other: &Self) -> bool { - self.wavelength.eq(&other.wavelength) - } -} - -impl<'a, T> Eq for MonochromePassband<'a, T> where T: Float {} - -impl<'a, T> PartialOrd for MonochromePassband<'a, T> -where - T: Float, -{ - fn partial_cmp(&self, other: &Self) -> Option { - (self.wavelength).partial_cmp(&other.wavelength) - } -} - -impl<'a, T> Ord for MonochromePassband<'a, T> -where - T: Float, -{ - fn cmp(&self, other: &Self) -> Ordering { - self.partial_cmp(other).unwrap() - } -} - -impl<'a, T> PassbandTrait for MonochromePassband<'a, T> -where - T: Float, -{ - fn name(&self) -> &str { - self.name - } -} - -#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, JsonSchema)] -pub struct NoPassband {} - -impl PassbandTrait for NoPassband { - fn name(&self) -> &str { - "" - } -} - -#[enum_dispatch] -pub trait MultiColorPassbandSetTrait

-where - P: PassbandTrait, -{ - fn get_passband_set(&self) -> &PassbandSet

; -} - -#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] -#[serde(bound(deserialize = "P: PassbandTrait + Deserialize<'de>"))] -#[non_exhaustive] -pub enum PassbandSet

-where - P: Ord, -{ - FixedSet(BTreeSet

), - AllAvailable, -} - -impl

From> for PassbandSet

-where - P: Ord, -{ - fn from(value: BTreeSet

) -> Self { - Self::FixedSet(value) - } -} - -#[enum_dispatch] -pub trait MultiColorEvaluator: - FeatureNamesDescriptionsTrait - + EvaluatorInfoTrait - + MultiColorPassbandSetTrait

- + Clone - + Serialize -where - P: PassbandTrait, - T: Float, -{ - /// Vector of feature values or `EvaluatorError` - fn eval_multicolor( - &self, - mcts: &mut MultiColorTimeSeries, - ) -> Result, MultiColorEvaluatorError>; - - /// Returns vector of feature values and fill invalid components with given value - fn eval_or_fill_multicolor( - &self, - mcts: &mut MultiColorTimeSeries, - fill_value: T, - ) -> Result, MultiColorEvaluatorError> { - self.check_mcts_shape(mcts)?; - Ok(match self.eval_multicolor(mcts) { - Ok(v) => v, - Err(_) => vec![fill_value; self.size_hint()], - }) - } - - fn check_mcts_shape( - &self, - mcts: &MultiColorTimeSeries, - ) -> Result, MultiColorEvaluatorError> { - self.check_mcts_passabands(mcts)?; - self.check_every_ts_length(mcts) - } - - fn check_mcts_passabands( - &self, - mcts: &MultiColorTimeSeries, - ) -> Result<(), MultiColorEvaluatorError> { - match self.get_passband_set() { - PassbandSet::AllAvailable => Ok(()), - PassbandSet::FixedSet(self_passbands) => { - if mcts - .keys() - .all(|mcts_passband| self_passbands.contains(mcts_passband)) - { - Ok(()) - } else { - Err(MultiColorEvaluatorError::wrong_passbands_error( - mcts.keys(), - self_passbands.iter(), - )) - } - } - } - } - - /// Checks if each component of [MultiColorTimeSeries] has enough points to evaluate the feature - fn check_every_ts_length( - &self, - mcts: &MultiColorTimeSeries, - ) -> Result, MultiColorEvaluatorError> { - // Use try_reduce when stabilizes - // https://github.com/rust-lang/rust/issues/87053 - mcts.iter() - .map(|(passband, ts)| { - let length = ts.lenu(); - if length < self.min_ts_length() { - Err(MultiColorEvaluatorError::MonochromeEvaluatorError { - error: EvaluatorError::ShortTimeSeries { - actual: length, - minimum: self.min_ts_length(), - }, - passband: passband.name().into(), - }) - } else { - Ok((passband.clone(), length)) - } - }) - .collect() - } -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -#[serde( - into = "MultiColorExtractorParameters", - from = "MultiColorExtractorParameters", - bound( - serialize = "P: PassbandTrait, T: Float, MCF: MultiColorEvaluator", - deserialize = "P: PassbandTrait + Deserialize<'de>, T: Float, MCF: MultiColorEvaluator + Deserialize<'de>" - ) -)] -pub struct MultiColorExtractor -where - P: Ord, -{ - features: Vec, - info: Box, - passband_set: PassbandSet

, - phantom: PhantomData, -} - -impl MultiColorExtractor -where - P: PassbandTrait, - T: Float, - MCF: MultiColorEvaluator, -{ - pub fn new(features: Vec) -> Self { - let passband_set = { - let set: BTreeSet<_> = features - .iter() - .filter_map(|f| match f.get_passband_set() { - PassbandSet::AllAvailable => None, - PassbandSet::FixedSet(set) => Some(set), - }) - .flatten() - .cloned() - .collect(); - if set.is_empty() { - PassbandSet::AllAvailable - } else { - PassbandSet::FixedSet(set) - } - }; - - let info = EvaluatorInfo { - size: features.iter().map(|x| x.size_hint()).sum(), - min_ts_length: features - .iter() - .map(|x| x.min_ts_length()) - .max() - .unwrap_or(0), - t_required: features.iter().any(|x| x.is_t_required()), - m_required: features.iter().any(|x| x.is_m_required()), - w_required: features.iter().any(|x| x.is_w_required()), - sorting_required: features.iter().any(|x| x.is_sorting_required()), - } - .into(); - - Self { - features, - passband_set, - info, - phantom: PhantomData, - } - } -} - -impl FeatureNamesDescriptionsTrait for MultiColorExtractor -where - P: Ord, - MCF: FeatureNamesDescriptionsTrait, -{ - /// Get feature names - fn get_names(&self) -> Vec<&str> { - self.features.iter().flat_map(|x| x.get_names()).collect() - } - - /// Get feature descriptions - fn get_descriptions(&self) -> Vec<&str> { - self.features - .iter() - .flat_map(|x| x.get_descriptions()) - .collect() - } -} - -impl EvaluatorInfoTrait for MultiColorExtractor -where - P: Ord, -{ - fn get_info(&self) -> &EvaluatorInfo { - &self.info - } -} - -impl MultiColorPassbandSetTrait

for MultiColorExtractor -where - P: PassbandTrait, -{ - fn get_passband_set(&self) -> &PassbandSet

{ - &self.passband_set - } -} - -impl MultiColorEvaluator for MultiColorExtractor -where - P: PassbandTrait, - T: Float, - MCF: MultiColorEvaluator, -{ - fn eval_multicolor( - &self, - mcts: &mut MultiColorTimeSeries, - ) -> Result, MultiColorEvaluatorError> { - self.check_mcts_passabands(mcts)?; - let mut vec = Vec::with_capacity(self.size_hint()); - for x in &self.features { - vec.extend(x.eval_multicolor(mcts)?); - } - Ok(vec) - } - - fn eval_or_fill_multicolor( - &self, - mcts: &mut MultiColorTimeSeries, - fill_value: T, - ) -> Result, MultiColorEvaluatorError> { - self.check_mcts_passabands(mcts)?; - self.features - .iter() - .map(|x| x.eval_or_fill_multicolor(mcts, fill_value)) - .flatten_ok() - .collect() - } -} - -#[derive(Serialize, Deserialize, JsonSchema)] -#[serde(rename = "MultiColorExtractor")] -struct MultiColorExtractorParameters { - features: Vec, -} - -impl From> for MultiColorExtractorParameters -where - P: PassbandTrait, - T: Float, - MCF: MultiColorEvaluator, -{ - fn from(f: MultiColorExtractor) -> Self { - Self { - features: f.features, - } - } -} - -impl From> for MultiColorExtractor -where - P: PassbandTrait, - T: Float, - MCF: MultiColorEvaluator, -{ - fn from(p: MultiColorExtractorParameters) -> Self { - Self::new(p.features) - } -} - -impl JsonSchema for MultiColorExtractor -where - P: PassbandTrait, - T: Float, - MCF: JsonSchema, -{ - json_schema!(MultiColorExtractorParameters, true); -} - -#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] -#[serde(bound( - deserialize = "P: PassbandTrait + Deserialize<'de>, T: Float, F: FeatureEvaluator" -))] -pub struct MonochromeFeature -where - P: Ord, -{ - feature: F, - passband_set: PassbandSet

, - properties: Box, - phantom: PhantomData, -} - -impl MonochromeFeature -where - P: PassbandTrait, - T: Float, - F: FeatureEvaluator, -{ - pub fn new(feature: F, passband_set: BTreeSet

) -> Self { - let names = passband_set - .iter() - .cartesian_product(feature.get_names()) - .map(|(passband, name)| format!("{}_{}", name, passband.name())) - .collect(); - let descriptions = passband_set - .iter() - .cartesian_product(feature.get_descriptions()) - .map(|(passband, description)| format!("{}, passband {}", description, passband.name())) - .collect(); - let info = { - let mut info = feature.get_info().clone(); - info.size *= passband_set.len(); - info - }; - Self { - properties: EvaluatorProperties { - info, - names, - descriptions, - } - .into(), - feature, - passband_set: passband_set.into(), - phantom: PhantomData, - } - } -} - -impl FeatureNamesDescriptionsTrait for MonochromeFeature -where - P: Ord, -{ - fn get_names(&self) -> Vec<&str> { - self.properties.names.iter().map(String::as_str).collect() - } - - fn get_descriptions(&self) -> Vec<&str> { - self.properties - .descriptions - .iter() - .map(String::as_str) - .collect() - } -} - -impl EvaluatorInfoTrait for MonochromeFeature -where - P: Ord, -{ - fn get_info(&self) -> &EvaluatorInfo { - &self.properties.info - } -} - -impl MultiColorPassbandSetTrait

for MonochromeFeature -where - P: PassbandTrait, -{ - fn get_passband_set(&self) -> &PassbandSet

{ - &self.passband_set - } -} - -impl MultiColorEvaluator for MonochromeFeature -where - P: PassbandTrait, - T: Float, - F: FeatureEvaluator, -{ - fn eval_multicolor( - &self, - mcts: &mut MultiColorTimeSeries, - ) -> Result, MultiColorEvaluatorError> { - self.check_mcts_passabands(mcts)?; - match &self.passband_set { - PassbandSet::FixedSet(set) => set - .iter() - .map(|passband| { - self.feature.eval(mcts.get_mut(passband).expect( - "we checked all needed passbands are in mcts, but we still cannot find one", - )).map_err(|error| MultiColorEvaluatorError::MonochromeEvaluatorError { - passband: passband.name().into(), - error, - }) - }) - .flatten_ok() - .collect(), - PassbandSet::AllAvailable => panic!("passband_set must be FixedSet variant here"), - } - } -} - -#[enum_dispatch(MultiColorEvaluator, FeatureNamesDescriptionsTrait, EvaluatorInfoTrait, MultiColorPassbandSetTrait

)] -#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] -#[serde(bound(deserialize = "P: PassbandTrait + Deserialize<'de>, T: Float"))] -#[non_exhaustive] -pub enum MultiColorFeature -where - P: PassbandTrait, - T: Float, -{ - // Extractor - MultiColorExtractor(MultiColorExtractor>), - // Monochrome Features - MonochromeFeature(MonochromeFeature>), - // Features - ColorOfMedian(color_median::ColorOfMedian

), -} - -impl MultiColorFeature -where - P: PassbandTrait, - T: Float, -{ - pub fn from_monochrome_feature(feature: F, passband_set: BTreeSet

) -> Self - where - F: Into>, - { - MonochromeFeature::new(feature.into(), passband_set).into() - } -} - -/// Example of multicolor light-curve feature evaluator -mod color_median { - use super::*; - use crate::{FeatureEvaluator, Median}; - - #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] - #[serde(bound(deserialize = "P: PassbandTrait + Deserialize<'de>"))] - pub struct ColorOfMedian

- where - P: Ord, - { - passband_set: PassbandSet

, - passbands: [P; 2], - median: Median, - name: String, - description: String, - } - - impl

ColorOfMedian

- where - P: PassbandTrait, - { - pub fn new(passbands: [P; 2]) -> Self { - let set: BTreeSet<_> = passbands.clone().into(); - Self { - passband_set: set.into(), - name: format!( - "color_median_{}_{}", - passbands[0].name(), - passbands[1].name() - ), - description: format!( - "difference of median magnitudes {}-{}", - passbands[0].name(), - passbands[1].name() - ), - passbands, - median: Median {}, - } - } - } - - lazy_info!( - COLOR_MEDIAN_INFO, - size: 1, - min_ts_length: 1, - t_required: false, - m_required: true, - w_required: false, - sorting_required: false, - ); - - impl

EvaluatorInfoTrait for ColorOfMedian

- where - P: Ord, - { - fn get_info(&self) -> &EvaluatorInfo { - &COLOR_MEDIAN_INFO - } - } - - impl

FeatureNamesDescriptionsTrait for ColorOfMedian

- where - P: Ord, - { - fn get_names(&self) -> Vec<&str> { - vec![self.name.as_str()] - } - - fn get_descriptions(&self) -> Vec<&str> { - vec![self.description.as_str()] - } - } - - impl

MultiColorPassbandSetTrait

for ColorOfMedian

- where - P: PassbandTrait, - { - fn get_passband_set(&self) -> &PassbandSet

{ - &self.passband_set - } - } - - impl MultiColorEvaluator for ColorOfMedian

- where - P: PassbandTrait, - T: Float, - { - fn eval_multicolor( - &self, - mcts: &mut MultiColorTimeSeries, - ) -> Result, MultiColorEvaluatorError> { - self.check_mcts_passabands(mcts)?; - let mut medians = [T::zero(); 2]; - for (median, passband) in medians.iter_mut().zip(self.passbands.iter()) { - *median = self - .median - .eval(mcts.get_mut(passband).expect( - "we checked all needed passbands are in mcts, but we still cannot find one", - )) - .map_err(|error| MultiColorEvaluatorError::MonochromeEvaluatorError { - passband: passband.name().into(), - error, - })?[0] - } - Ok(vec![medians[0] - medians[1]]) - } - } -} diff --git a/src/multicolor/features/color_of_median.rs b/src/multicolor/features/color_of_median.rs new file mode 100644 index 00000000..8c7c7457 --- /dev/null +++ b/src/multicolor/features/color_of_median.rs @@ -0,0 +1,120 @@ +use crate::data::MultiColorTimeSeries; +use crate::error::MultiColorEvaluatorError; +use crate::evaluator::{ + EvaluatorInfo, EvaluatorInfoTrait, FeatureEvaluator, FeatureNamesDescriptionsTrait, +}; +use crate::features::Median; +use crate::float_trait::Float; +use crate::multicolor::multicolor_evaluator::*; +use crate::multicolor::{PassbandSet, PassbandTrait}; + +pub use lazy_static::lazy_static; +pub use schemars::JsonSchema; +pub use serde::{Deserialize, Serialize}; +use std::collections::BTreeSet; +use std::fmt::Debug; + +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +#[serde(bound(deserialize = "P: PassbandTrait + Deserialize<'de>"))] +pub struct ColorOfMedian

+where + P: Ord, +{ + passband_set: PassbandSet

, + passbands: [P; 2], + median: Median, + name: String, + description: String, +} + +impl

ColorOfMedian

+where + P: PassbandTrait, +{ + pub fn new(passbands: [P; 2]) -> Self { + let set: BTreeSet<_> = passbands.clone().into(); + Self { + passband_set: set.into(), + name: format!( + "color_median_{}_{}", + passbands[0].name(), + passbands[1].name() + ), + description: format!( + "difference of median magnitudes {}-{}", + passbands[0].name(), + passbands[1].name() + ), + passbands, + median: Median {}, + } + } +} + +lazy_info!( + COLOR_MEDIAN_INFO, + size: 1, + min_ts_length: 1, + t_required: false, + m_required: true, + w_required: false, + sorting_required: false, + variability_required: false, +); + +impl

EvaluatorInfoTrait for ColorOfMedian

+where + P: Ord, +{ + fn get_info(&self) -> &EvaluatorInfo { + &COLOR_MEDIAN_INFO + } +} + +impl

FeatureNamesDescriptionsTrait for ColorOfMedian

+where + P: Ord, +{ + fn get_names(&self) -> Vec<&str> { + vec![self.name.as_str()] + } + + fn get_descriptions(&self) -> Vec<&str> { + vec![self.description.as_str()] + } +} + +impl

MultiColorPassbandSetTrait

for ColorOfMedian

+where + P: PassbandTrait, +{ + fn get_passband_set(&self) -> &PassbandSet

{ + &self.passband_set + } +} + +impl MultiColorEvaluator for ColorOfMedian

+where + P: PassbandTrait, + T: Float, +{ + fn eval_multicolor( + &self, + mcts: &mut MultiColorTimeSeries, + ) -> Result, MultiColorEvaluatorError> { + self.check_mcts_passabands(mcts)?; + let mut medians = [T::zero(); 2]; + for (median, passband) in medians.iter_mut().zip(self.passbands.iter()) { + *median = self + .median + .eval(mcts.get_mut(passband).expect( + "we checked all needed passbands are in mcts, but we still cannot find one", + )) + .map_err(|error| MultiColorEvaluatorError::MonochromeEvaluatorError { + passband: passband.name().into(), + error, + })?[0] + } + Ok(vec![medians[0] - medians[1]]) + } +} diff --git a/src/multicolor/features/mod.rs b/src/multicolor/features/mod.rs new file mode 100644 index 00000000..8b634dfb --- /dev/null +++ b/src/multicolor/features/mod.rs @@ -0,0 +1,2 @@ +mod color_of_median; +pub use color_of_median::ColorOfMedian; diff --git a/src/multicolor/mod.rs b/src/multicolor/mod.rs new file mode 100644 index 00000000..fb09f714 --- /dev/null +++ b/src/multicolor/mod.rs @@ -0,0 +1,16 @@ +mod features; + +mod monochrome_feature; +pub use monochrome_feature::MonochromeFeature; + +mod multicolor_evaluator; +pub use multicolor_evaluator::{MultiColorEvaluator, MultiColorPassbandSetTrait, PassbandSet}; + +mod multicolor_extractor; +pub use multicolor_extractor::MultiColorExtractor; + +mod multicolor_feature; +pub use multicolor_feature::MultiColorFeature; + +mod passband; +pub use passband::*; diff --git a/src/multicolor/monochrome_feature.rs b/src/multicolor/monochrome_feature.rs new file mode 100644 index 00000000..5aa5ffeb --- /dev/null +++ b/src/multicolor/monochrome_feature.rs @@ -0,0 +1,130 @@ +use crate::data::MultiColorTimeSeries; +use crate::error::MultiColorEvaluatorError; +use crate::evaluator::{ + EvaluatorInfo, EvaluatorInfoTrait, EvaluatorProperties, FeatureEvaluator, + FeatureNamesDescriptionsTrait, +}; +use crate::float_trait::Float; +use crate::multicolor::multicolor_evaluator::*; + +use itertools::Itertools; +pub use lazy_static::lazy_static; +pub use schemars::JsonSchema; +pub use serde::{Deserialize, Serialize}; +use std::collections::BTreeSet; +use std::fmt::Debug; +use std::marker::PhantomData; + +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +#[serde(bound( + deserialize = "P: PassbandTrait + Deserialize<'de>, T: Float, F: FeatureEvaluator" +))] +pub struct MonochromeFeature +where + P: Ord, +{ + feature: F, + passband_set: PassbandSet

, + properties: Box, + phantom: PhantomData, +} + +impl MonochromeFeature +where + P: PassbandTrait, + T: Float, + F: FeatureEvaluator, +{ + pub fn new(feature: F, passband_set: BTreeSet

) -> Self { + let names = passband_set + .iter() + .cartesian_product(feature.get_names()) + .map(|(passband, name)| format!("{}_{}", name, passband.name())) + .collect(); + let descriptions = passband_set + .iter() + .cartesian_product(feature.get_descriptions()) + .map(|(passband, description)| format!("{}, passband {}", description, passband.name())) + .collect(); + let info = { + let mut info = feature.get_info().clone(); + info.size *= passband_set.len(); + info + }; + Self { + properties: EvaluatorProperties { + info, + names, + descriptions, + } + .into(), + feature, + passband_set: passband_set.into(), + phantom: PhantomData, + } + } +} + +impl FeatureNamesDescriptionsTrait for MonochromeFeature +where + P: Ord, +{ + fn get_names(&self) -> Vec<&str> { + self.properties.names.iter().map(String::as_str).collect() + } + + fn get_descriptions(&self) -> Vec<&str> { + self.properties + .descriptions + .iter() + .map(String::as_str) + .collect() + } +} + +impl EvaluatorInfoTrait for MonochromeFeature +where + P: Ord, +{ + fn get_info(&self) -> &EvaluatorInfo { + &self.properties.info + } +} + +impl MultiColorPassbandSetTrait

for MonochromeFeature +where + P: PassbandTrait, +{ + fn get_passband_set(&self) -> &PassbandSet

{ + &self.passband_set + } +} + +impl MultiColorEvaluator for MonochromeFeature +where + P: PassbandTrait, + T: Float, + F: FeatureEvaluator, +{ + fn eval_multicolor( + &self, + mcts: &mut MultiColorTimeSeries, + ) -> Result, MultiColorEvaluatorError> { + self.check_mcts_passabands(mcts)?; + match &self.passband_set { + PassbandSet::FixedSet(set) => set + .iter() + .map(|passband| { + self.feature.eval(mcts.get_mut(passband).expect( + "we checked all needed passbands are in mcts, but we still cannot find one", + )).map_err(|error| MultiColorEvaluatorError::MonochromeEvaluatorError { + passband: passband.name().into(), + error, + }) + }) + .flatten_ok() + .collect(), + PassbandSet::AllAvailable => panic!("passband_set must be FixedSet variant here"), + } + } +} diff --git a/src/multicolor/multicolor_evaluator.rs b/src/multicolor/multicolor_evaluator.rs new file mode 100644 index 00000000..55a10502 --- /dev/null +++ b/src/multicolor/multicolor_evaluator.rs @@ -0,0 +1,130 @@ +pub use crate::data::MultiColorTimeSeries; +pub use crate::error::MultiColorEvaluatorError; +pub use crate::evaluator::{ + EvaluatorError, EvaluatorInfo, EvaluatorInfoTrait, EvaluatorProperties, FeatureEvaluator, + FeatureNamesDescriptionsTrait, +}; +pub use crate::feature::Feature; +pub use crate::float_trait::Float; +pub use crate::multicolor::PassbandTrait; + +use enum_dispatch::enum_dispatch; +pub use lazy_static::lazy_static; +pub use schemars::JsonSchema; +pub use serde::{Deserialize, Serialize}; +use std::collections::{BTreeMap, BTreeSet}; +use std::fmt::Debug; + +#[enum_dispatch] +pub trait MultiColorPassbandSetTrait

+where + P: PassbandTrait, +{ + fn get_passband_set(&self) -> &PassbandSet

; +} + +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +#[serde(bound(deserialize = "P: PassbandTrait + Deserialize<'de>"))] +#[non_exhaustive] +pub enum PassbandSet

+where + P: Ord, +{ + FixedSet(BTreeSet

), + AllAvailable, +} + +impl

From> for PassbandSet

+where + P: Ord, +{ + fn from(value: BTreeSet

) -> Self { + Self::FixedSet(value) + } +} + +#[enum_dispatch] +pub trait MultiColorEvaluator: + FeatureNamesDescriptionsTrait + + EvaluatorInfoTrait + + MultiColorPassbandSetTrait

+ + Clone + + Serialize +where + P: PassbandTrait, + T: Float, +{ + /// Vector of feature values or `EvaluatorError` + fn eval_multicolor( + &self, + mcts: &mut MultiColorTimeSeries, + ) -> Result, MultiColorEvaluatorError>; + + /// Returns vector of feature values and fill invalid components with given value + fn eval_or_fill_multicolor( + &self, + mcts: &mut MultiColorTimeSeries, + fill_value: T, + ) -> Result, MultiColorEvaluatorError> { + self.check_mcts_shape(mcts)?; + Ok(match self.eval_multicolor(mcts) { + Ok(v) => v, + Err(_) => vec![fill_value; self.size_hint()], + }) + } + + fn check_mcts_shape( + &self, + mcts: &MultiColorTimeSeries, + ) -> Result, MultiColorEvaluatorError> { + self.check_mcts_passabands(mcts)?; + self.check_every_ts_length(mcts) + } + + fn check_mcts_passabands( + &self, + mcts: &MultiColorTimeSeries, + ) -> Result<(), MultiColorEvaluatorError> { + match self.get_passband_set() { + PassbandSet::AllAvailable => Ok(()), + PassbandSet::FixedSet(self_passbands) => { + if mcts + .keys() + .all(|mcts_passband| self_passbands.contains(mcts_passband)) + { + Ok(()) + } else { + Err(MultiColorEvaluatorError::wrong_passbands_error( + mcts.keys(), + self_passbands.iter(), + )) + } + } + } + } + + /// Checks if each component of [MultiColorTimeSeries] has enough points to evaluate the feature + fn check_every_ts_length( + &self, + mcts: &MultiColorTimeSeries, + ) -> Result, MultiColorEvaluatorError> { + // Use try_reduce when stabilizes + // https://github.com/rust-lang/rust/issues/87053 + mcts.iter() + .map(|(passband, ts)| { + let length = ts.lenu(); + if length < self.min_ts_length() { + Err(MultiColorEvaluatorError::MonochromeEvaluatorError { + error: EvaluatorError::ShortTimeSeries { + actual: length, + minimum: self.min_ts_length(), + }, + passband: passband.name().into(), + }) + } else { + Ok((passband.clone(), length)) + } + }) + .collect() + } +} diff --git a/src/multicolor/multicolor_extractor.rs b/src/multicolor/multicolor_extractor.rs new file mode 100644 index 00000000..5fdff815 --- /dev/null +++ b/src/multicolor/multicolor_extractor.rs @@ -0,0 +1,188 @@ +use crate::data::MultiColorTimeSeries; +use crate::error::MultiColorEvaluatorError; +use crate::evaluator::{EvaluatorInfo, EvaluatorInfoTrait, FeatureNamesDescriptionsTrait}; +use crate::float_trait::Float; +use crate::multicolor::multicolor_evaluator::*; + +use itertools::Itertools; +pub use lazy_static::lazy_static; +pub use schemars::JsonSchema; +pub use serde::{Deserialize, Serialize}; +use std::collections::BTreeSet; +use std::fmt::Debug; +use std::marker::PhantomData; + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde( + into = "MultiColorExtractorParameters", + from = "MultiColorExtractorParameters", + bound( + serialize = "P: PassbandTrait, T: Float, MCF: MultiColorEvaluator", + deserialize = "P: PassbandTrait + Deserialize<'de>, T: Float, MCF: MultiColorEvaluator + Deserialize<'de>" + ) +)] +pub struct MultiColorExtractor +where + P: Ord, +{ + features: Vec, + info: Box, + passband_set: PassbandSet

, + phantom: PhantomData, +} + +impl MultiColorExtractor +where + P: PassbandTrait, + T: Float, + MCF: MultiColorEvaluator, +{ + pub fn new(features: Vec) -> Self { + let passband_set = { + let set: BTreeSet<_> = features + .iter() + .filter_map(|f| match f.get_passband_set() { + PassbandSet::AllAvailable => None, + PassbandSet::FixedSet(set) => Some(set), + }) + .flatten() + .cloned() + .collect(); + if set.is_empty() { + PassbandSet::AllAvailable + } else { + PassbandSet::FixedSet(set) + } + }; + + let info = EvaluatorInfo { + size: features.iter().map(|x| x.size_hint()).sum(), + min_ts_length: features + .iter() + .map(|x| x.min_ts_length()) + .max() + .unwrap_or(0), + t_required: features.iter().any(|x| x.is_t_required()), + m_required: features.iter().any(|x| x.is_m_required()), + w_required: features.iter().any(|x| x.is_w_required()), + sorting_required: features.iter().any(|x| x.is_sorting_required()), + variability_required: features.iter().any(|x| x.is_variability_required()), + } + .into(); + + Self { + features, + passband_set, + info, + phantom: PhantomData, + } + } +} + +impl FeatureNamesDescriptionsTrait for MultiColorExtractor +where + P: Ord, + MCF: FeatureNamesDescriptionsTrait, +{ + /// Get feature names + fn get_names(&self) -> Vec<&str> { + self.features.iter().flat_map(|x| x.get_names()).collect() + } + + /// Get feature descriptions + fn get_descriptions(&self) -> Vec<&str> { + self.features + .iter() + .flat_map(|x| x.get_descriptions()) + .collect() + } +} + +impl EvaluatorInfoTrait for MultiColorExtractor +where + P: Ord, +{ + fn get_info(&self) -> &EvaluatorInfo { + &self.info + } +} + +impl MultiColorPassbandSetTrait

for MultiColorExtractor +where + P: PassbandTrait, +{ + fn get_passband_set(&self) -> &PassbandSet

{ + &self.passband_set + } +} + +impl MultiColorEvaluator for MultiColorExtractor +where + P: PassbandTrait, + T: Float, + MCF: MultiColorEvaluator, +{ + fn eval_multicolor( + &self, + mcts: &mut MultiColorTimeSeries, + ) -> Result, MultiColorEvaluatorError> { + self.check_mcts_passabands(mcts)?; + let mut vec = Vec::with_capacity(self.size_hint()); + for x in &self.features { + vec.extend(x.eval_multicolor(mcts)?); + } + Ok(vec) + } + + fn eval_or_fill_multicolor( + &self, + mcts: &mut MultiColorTimeSeries, + fill_value: T, + ) -> Result, MultiColorEvaluatorError> { + self.check_mcts_passabands(mcts)?; + self.features + .iter() + .map(|x| x.eval_or_fill_multicolor(mcts, fill_value)) + .flatten_ok() + .collect() + } +} + +#[derive(Serialize, Deserialize, JsonSchema)] +#[serde(rename = "MultiColorExtractor")] +struct MultiColorExtractorParameters { + features: Vec, +} + +impl From> for MultiColorExtractorParameters +where + P: PassbandTrait, + T: Float, + MCF: MultiColorEvaluator, +{ + fn from(f: MultiColorExtractor) -> Self { + Self { + features: f.features, + } + } +} + +impl From> for MultiColorExtractor +where + P: PassbandTrait, + T: Float, + MCF: MultiColorEvaluator, +{ + fn from(p: MultiColorExtractorParameters) -> Self { + Self::new(p.features) + } +} + +impl JsonSchema for MultiColorExtractor +where + P: PassbandTrait, + T: Float, + MCF: JsonSchema, +{ + json_schema!(MultiColorExtractorParameters, true); +} diff --git a/src/multicolor/multicolor_feature.rs b/src/multicolor/multicolor_feature.rs new file mode 100644 index 00000000..3afb8405 --- /dev/null +++ b/src/multicolor/multicolor_feature.rs @@ -0,0 +1,45 @@ +use crate::data::MultiColorTimeSeries; +use crate::error::MultiColorEvaluatorError; +use crate::evaluator::{EvaluatorInfo, EvaluatorInfoTrait, FeatureNamesDescriptionsTrait}; +use crate::feature::Feature; +use crate::float_trait::Float; +use crate::multicolor::features::ColorOfMedian; +use crate::multicolor::multicolor_evaluator::*; +use crate::multicolor::{MonochromeFeature, MultiColorExtractor}; + +use enum_dispatch::enum_dispatch; +pub use lazy_static::lazy_static; +pub use schemars::JsonSchema; +pub use serde::{Deserialize, Serialize}; +use std::collections::{BTreeMap, BTreeSet}; +use std::fmt::Debug; + +#[enum_dispatch(MultiColorEvaluator, FeatureNamesDescriptionsTrait, EvaluatorInfoTrait, MultiColorPassbandSetTrait

)] +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +#[serde(bound(deserialize = "P: PassbandTrait + Deserialize<'de>, T: Float"))] +#[non_exhaustive] +pub enum MultiColorFeature +where + P: PassbandTrait, + T: Float, +{ + // Extractor + MultiColorExtractor(MultiColorExtractor>), + // Monochrome Features + MonochromeFeature(MonochromeFeature>), + // Features + ColorOfMedian(ColorOfMedian

), +} + +impl MultiColorFeature +where + P: PassbandTrait, + T: Float, +{ + pub fn from_monochrome_feature(feature: F, passband_set: BTreeSet

) -> Self + where + F: Into>, + { + MonochromeFeature::new(feature.into(), passband_set).into() + } +} diff --git a/src/multicolor/passband/dump_passband.rs b/src/multicolor/passband/dump_passband.rs new file mode 100644 index 00000000..ffd73cde --- /dev/null +++ b/src/multicolor/passband/dump_passband.rs @@ -0,0 +1,14 @@ +use crate::PassbandTrait; + +pub use schemars::JsonSchema; +pub use serde::{Deserialize, Serialize}; +use std::fmt::Debug; + +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, JsonSchema)] +pub struct DumpPassband {} + +impl PassbandTrait for DumpPassband { + fn name(&self) -> &str { + "" + } +} diff --git a/src/multicolor/passband/mod.rs b/src/multicolor/passband/mod.rs new file mode 100644 index 00000000..1fbef55e --- /dev/null +++ b/src/multicolor/passband/mod.rs @@ -0,0 +1,8 @@ +mod monochrome_passband; +pub use monochrome_passband::MonochromePassband; + +mod dump_passband; +pub use dump_passband::DumpPassband; + +mod passband_trait; +pub use passband_trait::PassbandTrait; diff --git a/src/multicolor/passband/monochrome_passband.rs b/src/multicolor/passband/monochrome_passband.rs new file mode 100644 index 00000000..987bc79a --- /dev/null +++ b/src/multicolor/passband/monochrome_passband.rs @@ -0,0 +1,69 @@ +use crate::float_trait::Float; +use crate::multicolor::PassbandTrait; + +pub use lazy_static::lazy_static; +pub use schemars::JsonSchema; +pub use serde::{Deserialize, Serialize}; +use std::cmp::Ordering; +use std::fmt::Debug; + +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +pub struct MonochromePassband<'a, T> { + pub name: &'a str, + pub wavelength: T, +} + +impl<'a, T> MonochromePassband<'a, T> +where + T: Float, +{ + pub fn new(wavelength: T, name: &'a str) -> Self { + assert!( + wavelength.is_normal(), + "wavelength must be a positive normal number" + ); + assert!( + wavelength.is_sign_positive(), + "wavelength must be a positive normal number" + ); + Self { wavelength, name } + } +} + +impl<'a, T> PartialEq for MonochromePassband<'a, T> +where + T: Float, +{ + fn eq(&self, other: &Self) -> bool { + self.wavelength.eq(&other.wavelength) + } +} + +impl<'a, T> Eq for MonochromePassband<'a, T> where T: Float {} + +impl<'a, T> PartialOrd for MonochromePassband<'a, T> +where + T: Float, +{ + fn partial_cmp(&self, other: &Self) -> Option { + (self.wavelength).partial_cmp(&other.wavelength) + } +} + +impl<'a, T> Ord for MonochromePassband<'a, T> +where + T: Float, +{ + fn cmp(&self, other: &Self) -> Ordering { + self.partial_cmp(other).unwrap() + } +} + +impl<'a, T> PassbandTrait for MonochromePassband<'a, T> +where + T: Float, +{ + fn name(&self) -> &str { + self.name + } +} diff --git a/src/multicolor/passband/passband_trait.rs b/src/multicolor/passband/passband_trait.rs new file mode 100644 index 00000000..011c972a --- /dev/null +++ b/src/multicolor/passband/passband_trait.rs @@ -0,0 +1,7 @@ +pub use schemars::JsonSchema; +pub use serde::{Deserialize, Serialize}; +use std::fmt::Debug; + +pub trait PassbandTrait: Debug + Clone + Send + Sync + Ord + Serialize + JsonSchema { + fn name(&self) -> &str; +} From ef38a399f1ffd6685f2c9802cf27a1125251b0f8 Mon Sep 17 00:00:00 2001 From: Konstantin Malanchev Date: Mon, 5 Sep 2022 12:45:53 -0500 Subject: [PATCH 5/9] Refactor mcts checking --- src/data/multi_color_time_series.rs | 57 ++++++ src/evaluator.rs | 63 +++--- src/lib.rs | 2 + src/multicolor/features/color_of_median.rs | 7 +- src/multicolor/monochrome_feature.rs | 3 +- src/multicolor/multicolor_evaluator.rs | 222 ++++++++++++++++----- src/multicolor/multicolor_extractor.rs | 4 +- src/multicolor/multicolor_feature.rs | 4 +- 8 files changed, 274 insertions(+), 88 deletions(-) diff --git a/src/data/multi_color_time_series.rs b/src/data/multi_color_time_series.rs index 3cd14a5a..96b28e70 100644 --- a/src/data/multi_color_time_series.rs +++ b/src/data/multi_color_time_series.rs @@ -1,12 +1,69 @@ use crate::data::TimeSeries; use crate::float_trait::Float; use crate::multicolor::PassbandTrait; +use crate::PassbandSet; +use itertools::Either; +use itertools::EitherOrBoth; +use itertools::Itertools; use std::collections::BTreeMap; use std::ops::{Deref, DerefMut}; pub struct MultiColorTimeSeries<'a, P: PassbandTrait, T: Float>(BTreeMap>); +impl<'a, P, T> MultiColorTimeSeries<'a, P, T> +where + P: PassbandTrait, + T: Float, +{ + pub fn new(map: impl Into>>) -> Self { + Self(map.into()) + } + + pub fn iter_passband_set<'slf, 'ps>( + &'slf self, + passband_set: &'ps PassbandSet

, + ) -> impl Iterator>)> + 'ps + where + 'a: 'ps, + 'slf: 'ps, + 'ps: 'slf, + { + match passband_set { + PassbandSet::AllAvailable => Either::Left(self.0.iter().map(|(p, ts)| (p, Some(ts)))), + PassbandSet::FixedSet(set) => Either::Right(set.iter().map(|p| (p, self.0.get(p)))), + } + } + + pub fn iter_passband_set_mut<'slf, 'ps>( + &'slf mut self, + passband_set: &'ps PassbandSet

, + ) -> impl Iterator>)> + 'ps + where + 'a: 'ps, + 'slf: 'ps, + 'ps: 'slf, + { + match passband_set { + PassbandSet::AllAvailable => { + Either::Left(self.0.iter_mut().map(|(p, ts)| (p, Some(ts)))) + } + PassbandSet::FixedSet(set) => Either::Right( + set.iter() + .merge_join_by(self.0.iter_mut(), |p1, (p2, _ts)| p1.cmp(p2)) + .filter_map(|either_or_both| match either_or_both { + // mcts misses required passband + EitherOrBoth::Left(p) => Some((p, None)), + // mcts has some passban passband_set doesn't require + EitherOrBoth::Right(_) => None, + // passbands match + EitherOrBoth::Both(p, (_, ts)) => Some((p, Some(ts))), + }), + ), + } + } +} + impl<'a, P: PassbandTrait, T: Float> Deref for MultiColorTimeSeries<'a, P, T> { type Target = BTreeMap>; diff --git a/src/evaluator.rs b/src/evaluator.rs index fd3a0c23..e21f891f 100644 --- a/src/evaluator.rs +++ b/src/evaluator.rs @@ -75,6 +75,42 @@ pub trait EvaluatorInfoTrait { fn is_variability_required(&self) -> bool { self.get_info().variability_required } + + fn check_ts(&self, ts: &mut TimeSeries) -> Result<(), EvaluatorError> + where + F: Float, + { + self.check_ts_length(ts)?; + self.check_ts_variability(ts) + } + + /// Checks if [TimeSeries] has enough points to evaluate the feature + fn check_ts_length(&self, ts: &TimeSeries) -> Result<(), EvaluatorError> + where + F: Float, + { + let length = ts.lenu(); + if length < self.min_ts_length() { + Err(EvaluatorError::ShortTimeSeries { + actual: length, + minimum: self.min_ts_length(), + }) + } else { + Ok(()) + } + } + + /// Checks if [TimeSeries] meets variability requirement + fn check_ts_variability(&self, ts: &mut TimeSeries) -> Result<(), EvaluatorError> + where + F: Float, + { + if self.is_variability_required() && ts.is_plateau() { + Err(EvaluatorError::FlatTimeSeries) + } else { + Ok(()) + } + } } // impl

EvaluatorInfoTrait for P @@ -146,33 +182,6 @@ pub trait FeatureEvaluator: Err(_) => vec![fill_value; self.size_hint()], } } - - fn check_ts(&self, ts: &mut TimeSeries) -> Result<(), EvaluatorError> { - self.check_ts_length(ts)?; - self.check_ts_variability(ts) - } - - /// Checks if [TimeSeries] has enough points to evaluate the feature - fn check_ts_length(&self, ts: &TimeSeries) -> Result<(), EvaluatorError> { - let length = ts.lenu(); - if length < self.min_ts_length() { - Err(EvaluatorError::ShortTimeSeries { - actual: length, - minimum: self.min_ts_length(), - }) - } else { - Ok(()) - } - } - - /// Checks if [TimeSeries] meets variability requirement - fn check_ts_variability(&self, ts: &mut TimeSeries) -> Result<(), EvaluatorError> { - if self.is_variability_required() && ts.is_plateau() { - Err(EvaluatorError::FlatTimeSeries) - } else { - Ok(()) - } - } } pub trait OwnedArrays diff --git a/src/lib.rs b/src/lib.rs index ccd8f5bb..36c6c79e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,7 @@ #![doc = include_str!("../README.md")] +extern crate core; + #[cfg(test)] #[macro_use] mod tests; diff --git a/src/multicolor/features/color_of_median.rs b/src/multicolor/features/color_of_median.rs index 8c7c7457..3a9e435a 100644 --- a/src/multicolor/features/color_of_median.rs +++ b/src/multicolor/features/color_of_median.rs @@ -52,7 +52,7 @@ where } lazy_info!( - COLOR_MEDIAN_INFO, + COLOR_OF_MEDIAN_INFO, size: 1, min_ts_length: 1, t_required: false, @@ -67,7 +67,7 @@ where P: Ord, { fn get_info(&self) -> &EvaluatorInfo { - &COLOR_MEDIAN_INFO + &COLOR_OF_MEDIAN_INFO } } @@ -98,11 +98,10 @@ where P: PassbandTrait, T: Float, { - fn eval_multicolor( + fn eval_multicolor_no_mcts_check( &self, mcts: &mut MultiColorTimeSeries, ) -> Result, MultiColorEvaluatorError> { - self.check_mcts_passabands(mcts)?; let mut medians = [T::zero(); 2]; for (median, passband) in medians.iter_mut().zip(self.passbands.iter()) { *median = self diff --git a/src/multicolor/monochrome_feature.rs b/src/multicolor/monochrome_feature.rs index 5aa5ffeb..93de2ed3 100644 --- a/src/multicolor/monochrome_feature.rs +++ b/src/multicolor/monochrome_feature.rs @@ -106,11 +106,10 @@ where T: Float, F: FeatureEvaluator, { - fn eval_multicolor( + fn eval_multicolor_no_mcts_check( &self, mcts: &mut MultiColorTimeSeries, ) -> Result, MultiColorEvaluatorError> { - self.check_mcts_passabands(mcts)?; match &self.passband_set { PassbandSet::FixedSet(set) => set .iter() diff --git a/src/multicolor/multicolor_evaluator.rs b/src/multicolor/multicolor_evaluator.rs index 55a10502..8406ebc2 100644 --- a/src/multicolor/multicolor_evaluator.rs +++ b/src/multicolor/multicolor_evaluator.rs @@ -9,10 +9,11 @@ pub use crate::float_trait::Float; pub use crate::multicolor::PassbandTrait; use enum_dispatch::enum_dispatch; +use itertools::Itertools; pub use lazy_static::lazy_static; pub use schemars::JsonSchema; pub use serde::{Deserialize, Serialize}; -use std::collections::{BTreeMap, BTreeSet}; +use std::collections::BTreeSet; use std::fmt::Debug; #[enum_dispatch] @@ -43,6 +44,38 @@ where } } +enum InternalMctsError { + MultiColorEvaluatorError(MultiColorEvaluatorError), + InternalWrongPassbandSet, +} + +impl InternalMctsError { + fn into_multi_color_evaluator_error( + self, + mcts: &MultiColorTimeSeries, + ps: &PassbandSet

, + ) -> MultiColorEvaluatorError + where + P: PassbandTrait, + T: Float, + { + match self { + InternalMctsError::MultiColorEvaluatorError(e) => e, + InternalMctsError::InternalWrongPassbandSet => { + MultiColorEvaluatorError::wrong_passbands_error( + mcts.keys(), + match ps { + PassbandSet::FixedSet(ps) => ps.iter(), + PassbandSet::AllAvailable => { + panic!("PassbandSet cannot be ::AllAvailable here") + } + }, + ) + } + } + } +} + #[enum_dispatch] pub trait MultiColorEvaluator: FeatureNamesDescriptionsTrait @@ -54,11 +87,20 @@ where P: PassbandTrait, T: Float, { + /// Version of [MultiColorEvaluator::eval_multicolor] without basic [MultiColorTimeSeries] checks + fn eval_multicolor_no_mcts_check( + &self, + mcts: &mut MultiColorTimeSeries, + ) -> Result, MultiColorEvaluatorError>; + /// Vector of feature values or `EvaluatorError` fn eval_multicolor( &self, mcts: &mut MultiColorTimeSeries, - ) -> Result, MultiColorEvaluatorError>; + ) -> Result, MultiColorEvaluatorError> { + self.check_mcts(mcts)?; + self.eval_multicolor_no_mcts_check(mcts) + } /// Returns vector of feature values and fill invalid components with given value fn eval_or_fill_multicolor( @@ -66,65 +108,145 @@ where mcts: &mut MultiColorTimeSeries, fill_value: T, ) -> Result, MultiColorEvaluatorError> { - self.check_mcts_shape(mcts)?; Ok(match self.eval_multicolor(mcts) { Ok(v) => v, Err(_) => vec![fill_value; self.size_hint()], }) } - fn check_mcts_shape( + /// Check [MultiColorTimeSeries] to have required passbands and individual [TimeSeries] are valid + fn check_mcts( &self, - mcts: &MultiColorTimeSeries, - ) -> Result, MultiColorEvaluatorError> { - self.check_mcts_passabands(mcts)?; - self.check_every_ts_length(mcts) + mcts: &mut MultiColorTimeSeries, + ) -> Result<(), MultiColorEvaluatorError> { + mcts.iter_passband_set_mut(self.get_passband_set()) + .map(|(p, maybe_ts)| { + maybe_ts + .ok_or(InternalMctsError::InternalWrongPassbandSet) + .and_then(|ts| { + self.check_ts(ts).map_err(|error| { + InternalMctsError::MultiColorEvaluatorError( + MultiColorEvaluatorError::MonochromeEvaluatorError { + error, + passband: p.name().into(), + }, + ) + }) + }) + .map(|_| ()) + }) + .try_collect() + .map_err(|err| err.into_multi_color_evaluator_error(mcts, self.get_passband_set())) } +} - fn check_mcts_passabands( - &self, - mcts: &MultiColorTimeSeries, - ) -> Result<(), MultiColorEvaluatorError> { - match self.get_passband_set() { - PassbandSet::AllAvailable => Ok(()), - PassbandSet::FixedSet(self_passbands) => { - if mcts - .keys() - .all(|mcts_passband| self_passbands.contains(mcts_passband)) - { - Ok(()) - } else { - Err(MultiColorEvaluatorError::wrong_passbands_error( - mcts.keys(), - self_passbands.iter(), - )) - } - } +#[cfg(test)] +#[allow(clippy::unreadable_literal)] +#[allow(clippy::excessive_precision)] +mod tests { + use super::*; + use crate::data::TimeSeries; + use crate::multicolor::MonochromePassband; + + use std::collections::BTreeMap; + + #[derive(Clone, Debug, Serialize)] + struct TestTimeMultiColorFeature { + passband_set: PassbandSet>, + } + + lazy_info!( + TEST_TIME_FEATURE_INFO, + TestTimeMultiColorFeature, + size: 1, + min_ts_length: 1, + t_required: true, + m_required: false, + w_required: false, + sorting_required: true, + variability_required: false, + ); + + impl FeatureNamesDescriptionsTrait for TestTimeMultiColorFeature { + fn get_names(&self) -> Vec<&str> { + vec!["zero"] + } + + fn get_descriptions(&self) -> Vec<&str> { + vec!["zero"] } } - /// Checks if each component of [MultiColorTimeSeries] has enough points to evaluate the feature - fn check_every_ts_length( - &self, - mcts: &MultiColorTimeSeries, - ) -> Result, MultiColorEvaluatorError> { - // Use try_reduce when stabilizes - // https://github.com/rust-lang/rust/issues/87053 - mcts.iter() - .map(|(passband, ts)| { - let length = ts.lenu(); - if length < self.min_ts_length() { - Err(MultiColorEvaluatorError::MonochromeEvaluatorError { - error: EvaluatorError::ShortTimeSeries { - actual: length, - minimum: self.min_ts_length(), - }, - passband: passband.name().into(), - }) - } else { - Ok((passband.clone(), length)) - } - }) - .collect() + impl MultiColorPassbandSetTrait> for TestTimeMultiColorFeature { + fn get_passband_set(&self) -> &PassbandSet> { + &self.passband_set + } + } + + impl MultiColorEvaluator, T> for TestTimeMultiColorFeature + where + T: Float, + { + fn eval_multicolor_no_mcts_check( + &self, + _mcts: &mut MultiColorTimeSeries, T>, + ) -> Result, MultiColorEvaluatorError> { + Ok(vec![T::zero()]) + } + } + + #[test] + fn test_check_mcts_passbands() { + let t = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0]; + let m = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0]; + let passband_b_capital = MonochromePassband::new(4400e-8, "B"); + let passband_v_capital = MonochromePassband::new(5500e-8, "V"); + let passband_r_capital = MonochromePassband::new(6400e-8, "R"); + let mut mcts = { + let mut passbands = BTreeMap::new(); + passbands.insert( + passband_b_capital.clone(), + TimeSeries::new_without_weight(&t, &m), + ); + passbands.insert( + passband_v_capital.clone(), + TimeSeries::new_without_weight(&t, &m), + ); + MultiColorTimeSeries::new(passbands) + }; + + let feature = TestTimeMultiColorFeature { + passband_set: PassbandSet::AllAvailable, + }; + assert!(feature.eval_multicolor(&mut mcts).is_ok()); + + let feature = TestTimeMultiColorFeature { + passband_set: PassbandSet::FixedSet( + [passband_b_capital.clone(), passband_v_capital.clone()].into(), + ), + }; + assert!(feature.eval_multicolor(&mut mcts).is_ok()); + + let feature = TestTimeMultiColorFeature { + passband_set: PassbandSet::FixedSet([passband_b_capital.clone()].into()), + }; + assert!(feature.eval_multicolor(&mut mcts).is_ok()); + + let feature = TestTimeMultiColorFeature { + passband_set: PassbandSet::FixedSet([passband_r_capital.clone()].into()), + }; + assert!(feature.eval_multicolor(&mut mcts).is_err()); + + let feature = TestTimeMultiColorFeature { + passband_set: PassbandSet::FixedSet( + [ + passband_b_capital.clone(), + passband_r_capital.clone(), + passband_r_capital.clone(), + ] + .into(), + ), + }; + assert!(feature.eval_multicolor(&mut mcts).is_err()); } } diff --git a/src/multicolor/multicolor_extractor.rs b/src/multicolor/multicolor_extractor.rs index 5fdff815..6af4552c 100644 --- a/src/multicolor/multicolor_extractor.rs +++ b/src/multicolor/multicolor_extractor.rs @@ -122,11 +122,10 @@ where T: Float, MCF: MultiColorEvaluator, { - fn eval_multicolor( + fn eval_multicolor_no_mcts_check( &self, mcts: &mut MultiColorTimeSeries, ) -> Result, MultiColorEvaluatorError> { - self.check_mcts_passabands(mcts)?; let mut vec = Vec::with_capacity(self.size_hint()); for x in &self.features { vec.extend(x.eval_multicolor(mcts)?); @@ -139,7 +138,6 @@ where mcts: &mut MultiColorTimeSeries, fill_value: T, ) -> Result, MultiColorEvaluatorError> { - self.check_mcts_passabands(mcts)?; self.features .iter() .map(|x| x.eval_or_fill_multicolor(mcts, fill_value)) diff --git a/src/multicolor/multicolor_feature.rs b/src/multicolor/multicolor_feature.rs index 3afb8405..94245814 100644 --- a/src/multicolor/multicolor_feature.rs +++ b/src/multicolor/multicolor_feature.rs @@ -1,4 +1,4 @@ -use crate::data::MultiColorTimeSeries; +use crate::data::{MultiColorTimeSeries, TimeSeries}; use crate::error::MultiColorEvaluatorError; use crate::evaluator::{EvaluatorInfo, EvaluatorInfoTrait, FeatureNamesDescriptionsTrait}; use crate::feature::Feature; @@ -11,7 +11,7 @@ use enum_dispatch::enum_dispatch; pub use lazy_static::lazy_static; pub use schemars::JsonSchema; pub use serde::{Deserialize, Serialize}; -use std::collections::{BTreeMap, BTreeSet}; +use std::collections::BTreeSet; use std::fmt::Debug; #[enum_dispatch(MultiColorEvaluator, FeatureNamesDescriptionsTrait, EvaluatorInfoTrait, MultiColorPassbandSetTrait

)] From 8cb5eb3d19632212215933f74e6258e1c7d86e74 Mon Sep 17 00:00:00 2001 From: Konstantin Malanchev Date: Tue, 6 Sep 2022 11:03:23 -0500 Subject: [PATCH 6/9] Color features --- src/data/multi_color_time_series.rs | 44 +++++--- src/multicolor/features/color_of_maximum.rs | 106 ++++++++++++++++++++ src/multicolor/features/color_of_median.rs | 17 ++-- src/multicolor/features/color_of_minimum.rs | 106 ++++++++++++++++++++ src/multicolor/features/mod.rs | 6 ++ src/multicolor/multicolor_feature.rs | 4 +- 6 files changed, 259 insertions(+), 24 deletions(-) create mode 100644 src/multicolor/features/color_of_maximum.rs create mode 100644 src/multicolor/features/color_of_minimum.rs diff --git a/src/data/multi_color_time_series.rs b/src/data/multi_color_time_series.rs index 96b28e70..bec58eaf 100644 --- a/src/data/multi_color_time_series.rs +++ b/src/data/multi_color_time_series.rs @@ -11,9 +11,9 @@ use std::ops::{Deref, DerefMut}; pub struct MultiColorTimeSeries<'a, P: PassbandTrait, T: Float>(BTreeMap>); -impl<'a, P, T> MultiColorTimeSeries<'a, P, T> +impl<'a, 'p, P, T> MultiColorTimeSeries<'a, P, T> where - P: PassbandTrait, + P: PassbandTrait + 'p, T: Float, { pub fn new(map: impl Into>>) -> Self { @@ -31,7 +31,7 @@ where { match passband_set { PassbandSet::AllAvailable => Either::Left(self.0.iter().map(|(p, ts)| (p, Some(ts)))), - PassbandSet::FixedSet(set) => Either::Right(set.iter().map(|p| (p, self.0.get(p)))), + PassbandSet::FixedSet(set) => Either::Right(self.iter_matched_passbands(set.iter())), } } @@ -48,20 +48,34 @@ where PassbandSet::AllAvailable => { Either::Left(self.0.iter_mut().map(|(p, ts)| (p, Some(ts)))) } - PassbandSet::FixedSet(set) => Either::Right( - set.iter() - .merge_join_by(self.0.iter_mut(), |p1, (p2, _ts)| p1.cmp(p2)) - .filter_map(|either_or_both| match either_or_both { - // mcts misses required passband - EitherOrBoth::Left(p) => Some((p, None)), - // mcts has some passban passband_set doesn't require - EitherOrBoth::Right(_) => None, - // passbands match - EitherOrBoth::Both(p, (_, ts)) => Some((p, Some(ts))), - }), - ), + PassbandSet::FixedSet(set) => { + Either::Right(self.iter_matched_passbands_mut(set.iter())) + } } } + + pub fn iter_matched_passbands( + &self, + passband_it: impl Iterator, + ) -> impl Iterator>)> { + passband_it.map(|p| (p, self.0.get(p))) + } + + pub fn iter_matched_passbands_mut( + &mut self, + passband_it: impl Iterator, + ) -> impl Iterator>)> { + passband_it + .merge_join_by(self.0.iter_mut(), |p1, (p2, _ts)| p1.cmp(p2)) + .filter_map(|either_or_both| match either_or_both { + // mcts misses required passband + EitherOrBoth::Left(p) => Some((p, None)), + // mcts has some passban passband_set doesn't require + EitherOrBoth::Right(_) => None, + // passbands match + EitherOrBoth::Both(p, (_, ts)) => Some((p, Some(ts))), + }) + } } impl<'a, P: PassbandTrait, T: Float> Deref for MultiColorTimeSeries<'a, P, T> { diff --git a/src/multicolor/features/color_of_maximum.rs b/src/multicolor/features/color_of_maximum.rs new file mode 100644 index 00000000..1c0a0bc3 --- /dev/null +++ b/src/multicolor/features/color_of_maximum.rs @@ -0,0 +1,106 @@ +use crate::data::MultiColorTimeSeries; +use crate::error::MultiColorEvaluatorError; +use crate::evaluator::{EvaluatorInfo, EvaluatorInfoTrait, FeatureNamesDescriptionsTrait}; +use crate::float_trait::Float; +use crate::multicolor::multicolor_evaluator::*; +use crate::multicolor::{PassbandSet, PassbandTrait}; + +pub use lazy_static::lazy_static; +pub use schemars::JsonSchema; +pub use serde::{Deserialize, Serialize}; +use std::collections::BTreeSet; +use std::fmt::Debug; + +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +#[serde(bound(deserialize = "P: PassbandTrait + Deserialize<'de>"))] +pub struct ColorOfMaximum

+where + P: Ord, +{ + passband_set: PassbandSet

, + passbands: [P; 2], + name: String, + description: String, +} + +impl

ColorOfMaximum

+where + P: PassbandTrait, +{ + pub fn new(passbands: [P; 2]) -> Self { + let set: BTreeSet<_> = passbands.clone().into(); + Self { + passband_set: set.into(), + name: format!("color_max_{}_{}", passbands[0].name(), passbands[1].name()), + description: format!( + "difference of maximum value magnitudes {}-{}", + passbands[0].name(), + passbands[1].name() + ), + passbands, + } + } +} + +lazy_info!( + COLOR_OF_MAXIMUM_INFO, + size: 1, + min_ts_length: 1, + t_required: false, + m_required: true, + w_required: false, + sorting_required: false, + variability_required: false, +); + +impl

EvaluatorInfoTrait for ColorOfMaximum

+where + P: Ord, +{ + fn get_info(&self) -> &EvaluatorInfo { + &COLOR_OF_MAXIMUM_INFO + } +} + +impl

FeatureNamesDescriptionsTrait for ColorOfMaximum

+where + P: Ord, +{ + fn get_names(&self) -> Vec<&str> { + vec![self.name.as_str()] + } + + fn get_descriptions(&self) -> Vec<&str> { + vec![self.description.as_str()] + } +} + +impl

MultiColorPassbandSetTrait

for ColorOfMaximum

+where + P: PassbandTrait, +{ + fn get_passband_set(&self) -> &PassbandSet

{ + &self.passband_set + } +} + +impl MultiColorEvaluator for ColorOfMaximum

+where + P: PassbandTrait, + T: Float, +{ + fn eval_multicolor_no_mcts_check( + &self, + mcts: &mut MultiColorTimeSeries, + ) -> Result, MultiColorEvaluatorError> { + let mut maxima = [T::zero(); 2]; + for ((_passband, mcts), maximum) in mcts + .iter_matched_passbands_mut(self.passbands.iter()) + .zip(maxima.iter_mut()) + { + let mcts = mcts.expect("MultiColorTimeSeries must have all required passbands"); + *maximum = mcts.m.get_max() + } + Ok(vec![maxima[0] - maxima[1]]) + } +} diff --git a/src/multicolor/features/color_of_median.rs b/src/multicolor/features/color_of_median.rs index 3a9e435a..9360bc6e 100644 --- a/src/multicolor/features/color_of_median.rs +++ b/src/multicolor/features/color_of_median.rs @@ -103,16 +103,17 @@ where mcts: &mut MultiColorTimeSeries, ) -> Result, MultiColorEvaluatorError> { let mut medians = [T::zero(); 2]; - for (median, passband) in medians.iter_mut().zip(self.passbands.iter()) { - *median = self - .median - .eval(mcts.get_mut(passband).expect( - "we checked all needed passbands are in mcts, but we still cannot find one", - )) - .map_err(|error| MultiColorEvaluatorError::MonochromeEvaluatorError { + for ((passband, mcts), median) in mcts + .iter_matched_passbands_mut(self.passbands.iter()) + .zip(medians.iter_mut()) + { + let mcts = mcts.expect("MultiColorTimeSeries must have all required passbands"); + *median = self.median.eval(mcts).map_err(|error| { + MultiColorEvaluatorError::MonochromeEvaluatorError { passband: passband.name().into(), error, - })?[0] + } + })?[0] } Ok(vec![medians[0] - medians[1]]) } diff --git a/src/multicolor/features/color_of_minimum.rs b/src/multicolor/features/color_of_minimum.rs new file mode 100644 index 00000000..cfa1c449 --- /dev/null +++ b/src/multicolor/features/color_of_minimum.rs @@ -0,0 +1,106 @@ +use crate::data::MultiColorTimeSeries; +use crate::error::MultiColorEvaluatorError; +use crate::evaluator::{EvaluatorInfo, EvaluatorInfoTrait, FeatureNamesDescriptionsTrait}; +use crate::float_trait::Float; +use crate::multicolor::multicolor_evaluator::*; +use crate::multicolor::{PassbandSet, PassbandTrait}; + +pub use lazy_static::lazy_static; +pub use schemars::JsonSchema; +pub use serde::{Deserialize, Serialize}; +use std::collections::BTreeSet; +use std::fmt::Debug; + +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +#[serde(bound(deserialize = "P: PassbandTrait + Deserialize<'de>"))] +pub struct ColorOfMinimum

+where + P: Ord, +{ + passband_set: PassbandSet

, + passbands: [P; 2], + name: String, + description: String, +} + +impl

ColorOfMinimum

+where + P: PassbandTrait, +{ + pub fn new(passbands: [P; 2]) -> Self { + let set: BTreeSet<_> = passbands.clone().into(); + Self { + passband_set: set.into(), + name: format!("color_min_{}_{}", passbands[0].name(), passbands[1].name()), + description: format!( + "difference of minimum value magnitudes {}-{}", + passbands[0].name(), + passbands[1].name() + ), + passbands, + } + } +} + +lazy_info!( + COLOR_OF_MINIMUM_INFO, + size: 1, + min_ts_length: 1, + t_required: false, + m_required: true, + w_required: false, + sorting_required: false, + variability_required: false, +); + +impl

EvaluatorInfoTrait for ColorOfMinimum

+where + P: Ord, +{ + fn get_info(&self) -> &EvaluatorInfo { + &COLOR_OF_MINIMUM_INFO + } +} + +impl

FeatureNamesDescriptionsTrait for ColorOfMinimum

+where + P: Ord, +{ + fn get_names(&self) -> Vec<&str> { + vec![self.name.as_str()] + } + + fn get_descriptions(&self) -> Vec<&str> { + vec![self.description.as_str()] + } +} + +impl

MultiColorPassbandSetTrait

for ColorOfMinimum

+where + P: PassbandTrait, +{ + fn get_passband_set(&self) -> &PassbandSet

{ + &self.passband_set + } +} + +impl MultiColorEvaluator for ColorOfMinimum

+where + P: PassbandTrait, + T: Float, +{ + fn eval_multicolor_no_mcts_check( + &self, + mcts: &mut MultiColorTimeSeries, + ) -> Result, MultiColorEvaluatorError> { + let mut minima = [T::zero(); 2]; + for ((_passband, mcts), minimum) in mcts + .iter_matched_passbands_mut(self.passbands.iter()) + .zip(minima.iter_mut()) + { + let mcts = mcts.expect("MultiColorTimeSeries must have all required passbands"); + *minimum = mcts.m.get_min() + } + Ok(vec![minima[0] - minima[1]]) + } +} diff --git a/src/multicolor/features/mod.rs b/src/multicolor/features/mod.rs index 8b634dfb..54b5a4c7 100644 --- a/src/multicolor/features/mod.rs +++ b/src/multicolor/features/mod.rs @@ -1,2 +1,8 @@ +mod color_of_maximum; +pub use color_of_maximum::ColorOfMaximum; + mod color_of_median; pub use color_of_median::ColorOfMedian; + +mod color_of_minimum; +pub use color_of_minimum::ColorOfMinimum; diff --git a/src/multicolor/multicolor_feature.rs b/src/multicolor/multicolor_feature.rs index 94245814..61a63e75 100644 --- a/src/multicolor/multicolor_feature.rs +++ b/src/multicolor/multicolor_feature.rs @@ -3,7 +3,7 @@ use crate::error::MultiColorEvaluatorError; use crate::evaluator::{EvaluatorInfo, EvaluatorInfoTrait, FeatureNamesDescriptionsTrait}; use crate::feature::Feature; use crate::float_trait::Float; -use crate::multicolor::features::ColorOfMedian; +use crate::multicolor::features::{ColorOfMaximum, ColorOfMedian, ColorOfMinimum}; use crate::multicolor::multicolor_evaluator::*; use crate::multicolor::{MonochromeFeature, MultiColorExtractor}; @@ -28,7 +28,9 @@ where // Monochrome Features MonochromeFeature(MonochromeFeature>), // Features + ColorOfMaximum(ColorOfMaximum

), ColorOfMedian(ColorOfMedian

), + ColorOfMinimum(ColorOfMinimum

), } impl MultiColorFeature From 25088686ce2a9efd56c43ff0dddbaaf50729e9e3 Mon Sep 17 00:00:00 2001 From: Konstantin Malanchev Date: Tue, 6 Sep 2022 17:01:48 -0500 Subject: [PATCH 7/9] Flatter MultiColorTimeSeries --- src/data/multi_color_time_series.rs | 101 ++++++++++++++------ src/multicolor/features/color_of_maximum.rs | 12 ++- src/multicolor/features/color_of_median.rs | 12 ++- src/multicolor/features/color_of_minimum.rs | 12 ++- src/multicolor/monochrome_feature.rs | 35 +++---- src/multicolor/multicolor_evaluator.rs | 72 +++++++++----- src/multicolor/multicolor_extractor.rs | 26 +++-- 7 files changed, 180 insertions(+), 90 deletions(-) diff --git a/src/data/multi_color_time_series.rs b/src/data/multi_color_time_series.rs index bec58eaf..be2fc58c 100644 --- a/src/data/multi_color_time_series.rs +++ b/src/data/multi_color_time_series.rs @@ -1,15 +1,17 @@ use crate::data::TimeSeries; use crate::float_trait::Float; use crate::multicolor::PassbandTrait; -use crate::PassbandSet; +use crate::{DataSample, PassbandSet}; use itertools::Either; use itertools::EitherOrBoth; use itertools::Itertools; use std::collections::BTreeMap; -use std::ops::{Deref, DerefMut}; -pub struct MultiColorTimeSeries<'a, P: PassbandTrait, T: Float>(BTreeMap>); +pub struct MultiColorTimeSeries<'a, P: PassbandTrait, T: Float> { + mapping: BTreeMap>, + flat: Option>, +} impl<'a, 'p, P, T> MultiColorTimeSeries<'a, P, T> where @@ -17,20 +19,38 @@ where T: Float, { pub fn new(map: impl Into>>) -> Self { - Self(map.into()) + Self { + mapping: map.into(), + flat: None, + } + } + + pub fn flatten(&mut self) -> &FlatMultiColorTimeSeries<'static, P, T> { + self.flat + .get_or_insert_with(|| FlatMultiColorTimeSeries::from_mapping(&mut self.mapping)) + } + + pub fn passbands<'slf>( + &'slf self, + ) -> std::collections::btree_map::Keys<'slf, P, TimeSeries<'a, T>> + where + 'a: 'slf, + { + self.mapping.keys() } pub fn iter_passband_set<'slf, 'ps>( &'slf self, passband_set: &'ps PassbandSet

, - ) -> impl Iterator>)> + 'ps + ) -> impl Iterator>)> + 'slf where - 'a: 'ps, - 'slf: 'ps, - 'ps: 'slf, + 'a: 'slf, + 'ps: 'a, { match passband_set { - PassbandSet::AllAvailable => Either::Left(self.0.iter().map(|(p, ts)| (p, Some(ts)))), + PassbandSet::AllAvailable => { + Either::Left(self.mapping.iter().map(|(p, ts)| (p, Some(ts)))) + } PassbandSet::FixedSet(set) => Either::Right(self.iter_matched_passbands(set.iter())), } } @@ -38,15 +58,14 @@ where pub fn iter_passband_set_mut<'slf, 'ps>( &'slf mut self, passband_set: &'ps PassbandSet

, - ) -> impl Iterator>)> + 'ps + ) -> impl Iterator>)> + 'slf where - 'a: 'ps, - 'slf: 'ps, - 'ps: 'slf, + 'a: 'slf, + 'ps: 'a, { match passband_set { PassbandSet::AllAvailable => { - Either::Left(self.0.iter_mut().map(|(p, ts)| (p, Some(ts)))) + Either::Left(self.mapping.iter_mut().map(|(p, ts)| (p, Some(ts)))) } PassbandSet::FixedSet(set) => { Either::Right(self.iter_matched_passbands_mut(set.iter())) @@ -58,7 +77,7 @@ where &self, passband_it: impl Iterator, ) -> impl Iterator>)> { - passband_it.map(|p| (p, self.0.get(p))) + passband_it.map(|p| (p, self.mapping.get(p))) } pub fn iter_matched_passbands_mut( @@ -66,7 +85,7 @@ where passband_it: impl Iterator, ) -> impl Iterator>)> { passband_it - .merge_join_by(self.0.iter_mut(), |p1, (p2, _ts)| p1.cmp(p2)) + .merge_join_by(self.mapping.iter_mut(), |p1, (p2, _ts)| p1.cmp(p2)) .filter_map(|either_or_both| match either_or_both { // mcts misses required passband EitherOrBoth::Left(p) => Some((p, None)), @@ -78,24 +97,48 @@ where } } -impl<'a, P: PassbandTrait, T: Float> Deref for MultiColorTimeSeries<'a, P, T> { - type Target = BTreeMap>; - - fn deref(&self) -> &Self::Target { - &self.0 +impl<'a, P: PassbandTrait, T: Float> FromIterator<(P, TimeSeries<'a, T>)> + for MultiColorTimeSeries<'a, P, T> +{ + fn from_iter)>>(iter: I) -> Self { + Self { + mapping: iter.into_iter().collect(), + flat: None, + } } } -impl<'a, P: PassbandTrait, T: Float> DerefMut for MultiColorTimeSeries<'a, P, T> { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } +pub struct FlatMultiColorTimeSeries<'a, P: PassbandTrait, T: Float> { + pub t: DataSample<'a, T>, + pub m: DataSample<'a, T>, + pub w: DataSample<'a, T>, + pub passbands: Vec

, } -impl<'a, P: PassbandTrait, T: Float> FromIterator<(P, TimeSeries<'a, T>)> - for MultiColorTimeSeries<'a, P, T> +impl FlatMultiColorTimeSeries<'static, P, T> +where + P: PassbandTrait, + T: Float, { - fn from_iter)>>(iter: I) -> Self { - Self(iter.into_iter().collect()) + pub fn from_mapping(mapping: &mut BTreeMap>) -> Self { + let (t, m, w, passbands): (Vec<_>, Vec<_>, Vec<_>, _) = mapping + .iter_mut() + .map(|(p, ts)| { + itertools::multizip(( + ts.t.as_slice().iter().copied(), + ts.m.as_slice().iter().copied(), + ts.w.as_slice().iter().copied(), + std::iter::repeat(p.clone()), + )) + }) + .kmerge_by(|(t1, _, _, _), (t2, _, _, _)| t1 <= t2) + .multiunzip(); + + Self { + t: t.into(), + m: m.into(), + w: w.into(), + passbands, + } } } diff --git a/src/multicolor/features/color_of_maximum.rs b/src/multicolor/features/color_of_maximum.rs index 1c0a0bc3..7bc51a52 100644 --- a/src/multicolor/features/color_of_maximum.rs +++ b/src/multicolor/features/color_of_maximum.rs @@ -89,10 +89,14 @@ where P: PassbandTrait, T: Float, { - fn eval_multicolor_no_mcts_check( - &self, - mcts: &mut MultiColorTimeSeries, - ) -> Result, MultiColorEvaluatorError> { + fn eval_multicolor_no_mcts_check<'slf, 'a, 'mcts>( + &'slf self, + mcts: &'mcts mut MultiColorTimeSeries<'a, P, T>, + ) -> Result, MultiColorEvaluatorError> + where + 'slf: 'a, + 'a: 'mcts, + { let mut maxima = [T::zero(); 2]; for ((_passband, mcts), maximum) in mcts .iter_matched_passbands_mut(self.passbands.iter()) diff --git a/src/multicolor/features/color_of_median.rs b/src/multicolor/features/color_of_median.rs index 9360bc6e..54789f65 100644 --- a/src/multicolor/features/color_of_median.rs +++ b/src/multicolor/features/color_of_median.rs @@ -98,10 +98,14 @@ where P: PassbandTrait, T: Float, { - fn eval_multicolor_no_mcts_check( - &self, - mcts: &mut MultiColorTimeSeries, - ) -> Result, MultiColorEvaluatorError> { + fn eval_multicolor_no_mcts_check<'slf, 'a, 'mcts>( + &'slf self, + mcts: &'mcts mut MultiColorTimeSeries<'a, P, T>, + ) -> Result, MultiColorEvaluatorError> + where + 'slf: 'a, + 'a: 'mcts, + { let mut medians = [T::zero(); 2]; for ((passband, mcts), median) in mcts .iter_matched_passbands_mut(self.passbands.iter()) diff --git a/src/multicolor/features/color_of_minimum.rs b/src/multicolor/features/color_of_minimum.rs index cfa1c449..79701334 100644 --- a/src/multicolor/features/color_of_minimum.rs +++ b/src/multicolor/features/color_of_minimum.rs @@ -89,10 +89,14 @@ where P: PassbandTrait, T: Float, { - fn eval_multicolor_no_mcts_check( - &self, - mcts: &mut MultiColorTimeSeries, - ) -> Result, MultiColorEvaluatorError> { + fn eval_multicolor_no_mcts_check<'slf, 'a, 'mcts>( + &'slf self, + mcts: &'mcts mut MultiColorTimeSeries<'a, P, T>, + ) -> Result, MultiColorEvaluatorError> + where + 'slf: 'a, + 'a: 'mcts, + { let mut minima = [T::zero(); 2]; for ((_passband, mcts), minimum) in mcts .iter_matched_passbands_mut(self.passbands.iter()) diff --git a/src/multicolor/monochrome_feature.rs b/src/multicolor/monochrome_feature.rs index 93de2ed3..d6a03430 100644 --- a/src/multicolor/monochrome_feature.rs +++ b/src/multicolor/monochrome_feature.rs @@ -106,23 +106,26 @@ where T: Float, F: FeatureEvaluator, { - fn eval_multicolor_no_mcts_check( - &self, - mcts: &mut MultiColorTimeSeries, - ) -> Result, MultiColorEvaluatorError> { + fn eval_multicolor_no_mcts_check<'slf, 'a, 'mcts>( + &'slf self, + mcts: &'mcts mut MultiColorTimeSeries<'a, P, T>, + ) -> Result, MultiColorEvaluatorError> + where + 'slf: 'a, + 'a: 'mcts, + { match &self.passband_set { - PassbandSet::FixedSet(set) => set - .iter() - .map(|passband| { - self.feature.eval(mcts.get_mut(passband).expect( - "we checked all needed passbands are in mcts, but we still cannot find one", - )).map_err(|error| MultiColorEvaluatorError::MonochromeEvaluatorError { - passband: passband.name().into(), - error, - }) - }) - .flatten_ok() - .collect(), + PassbandSet::FixedSet(set) => { + mcts.iter_matched_passbands_mut(set.iter()) + .map(|(passband, ts)| { + self.feature.eval_no_ts_check( + ts.expect("we checked all needed passbands are in mcts, but we still cannot find one") + ).map_err(|error| MultiColorEvaluatorError::MonochromeEvaluatorError { + passband: passband.name().into(), + error, + }) + }).flatten_ok().collect() + } PassbandSet::AllAvailable => panic!("passband_set must be FixedSet variant here"), } } diff --git a/src/multicolor/multicolor_evaluator.rs b/src/multicolor/multicolor_evaluator.rs index 8406ebc2..dccc4365 100644 --- a/src/multicolor/multicolor_evaluator.rs +++ b/src/multicolor/multicolor_evaluator.rs @@ -50,12 +50,14 @@ enum InternalMctsError { } impl InternalMctsError { - fn into_multi_color_evaluator_error( + fn into_multi_color_evaluator_error<'mcts, 'a, 'ps, P, T>( self, - mcts: &MultiColorTimeSeries, - ps: &PassbandSet

, + mcts: &'mcts MultiColorTimeSeries<'a, P, T>, + ps: &'ps PassbandSet

, ) -> MultiColorEvaluatorError where + 'ps: 'a, + 'a: 'mcts, P: PassbandTrait, T: Float, { @@ -63,7 +65,7 @@ impl InternalMctsError { InternalMctsError::MultiColorEvaluatorError(e) => e, InternalMctsError::InternalWrongPassbandSet => { MultiColorEvaluatorError::wrong_passbands_error( - mcts.keys(), + mcts.passbands(), match ps { PassbandSet::FixedSet(ps) => ps.iter(), PassbandSet::AllAvailable => { @@ -88,26 +90,39 @@ where T: Float, { /// Version of [MultiColorEvaluator::eval_multicolor] without basic [MultiColorTimeSeries] checks - fn eval_multicolor_no_mcts_check( - &self, - mcts: &mut MultiColorTimeSeries, - ) -> Result, MultiColorEvaluatorError>; + fn eval_multicolor_no_mcts_check<'slf, 'a, 'mcts>( + &'slf self, + mcts: &'mcts mut MultiColorTimeSeries<'a, P, T>, + ) -> Result, MultiColorEvaluatorError> + where + 'slf: 'a, + 'a: 'mcts; /// Vector of feature values or `EvaluatorError` - fn eval_multicolor( - &self, - mcts: &mut MultiColorTimeSeries, - ) -> Result, MultiColorEvaluatorError> { + fn eval_multicolor<'slf, 'a, 'mcts>( + &'slf self, + mcts: &'mcts mut MultiColorTimeSeries<'a, P, T>, + ) -> Result, MultiColorEvaluatorError> + where + 'slf: 'a, + 'a: 'mcts, + P: 'a, + { self.check_mcts(mcts)?; self.eval_multicolor_no_mcts_check(mcts) } /// Returns vector of feature values and fill invalid components with given value - fn eval_or_fill_multicolor( - &self, - mcts: &mut MultiColorTimeSeries, + fn eval_or_fill_multicolor<'slf, 'a, 'mcts>( + &'slf self, + mcts: &'mcts mut MultiColorTimeSeries<'a, P, T>, fill_value: T, - ) -> Result, MultiColorEvaluatorError> { + ) -> Result, MultiColorEvaluatorError> + where + 'slf: 'a, + 'a: 'mcts, + P: 'a, + { Ok(match self.eval_multicolor(mcts) { Ok(v) => v, Err(_) => vec![fill_value; self.size_hint()], @@ -115,10 +130,15 @@ where } /// Check [MultiColorTimeSeries] to have required passbands and individual [TimeSeries] are valid - fn check_mcts( - &self, - mcts: &mut MultiColorTimeSeries, - ) -> Result<(), MultiColorEvaluatorError> { + fn check_mcts<'slf, 'a, 'mcts>( + &'slf self, + mcts: &'mcts mut MultiColorTimeSeries<'a, P, T>, + ) -> Result<(), MultiColorEvaluatorError> + where + 'slf: 'a, + 'a: 'mcts, + P: 'a, + { mcts.iter_passband_set_mut(self.get_passband_set()) .map(|(p, maybe_ts)| { maybe_ts @@ -187,10 +207,14 @@ mod tests { where T: Float, { - fn eval_multicolor_no_mcts_check( - &self, - _mcts: &mut MultiColorTimeSeries, T>, - ) -> Result, MultiColorEvaluatorError> { + fn eval_multicolor_no_mcts_check<'slf, 'a, 'mcts>( + &'slf self, + _mcts: &'mcts mut MultiColorTimeSeries<'a, MonochromePassband<'static, f64>, T>, + ) -> Result, MultiColorEvaluatorError> + where + 'slf: 'a, + 'a: 'mcts, + { Ok(vec![T::zero()]) } } diff --git a/src/multicolor/multicolor_extractor.rs b/src/multicolor/multicolor_extractor.rs index 6af4552c..5fe4d70d 100644 --- a/src/multicolor/multicolor_extractor.rs +++ b/src/multicolor/multicolor_extractor.rs @@ -122,22 +122,30 @@ where T: Float, MCF: MultiColorEvaluator, { - fn eval_multicolor_no_mcts_check( - &self, - mcts: &mut MultiColorTimeSeries, - ) -> Result, MultiColorEvaluatorError> { + fn eval_multicolor_no_mcts_check<'slf, 'a, 'mcts>( + &'slf self, + mcts: &'mcts mut MultiColorTimeSeries<'a, P, T>, + ) -> Result, MultiColorEvaluatorError> + where + 'slf: 'a, + 'a: 'mcts, + { let mut vec = Vec::with_capacity(self.size_hint()); for x in &self.features { - vec.extend(x.eval_multicolor(mcts)?); + vec.extend(x.eval_multicolor_no_mcts_check(mcts)?); } Ok(vec) } - fn eval_or_fill_multicolor( - &self, - mcts: &mut MultiColorTimeSeries, + fn eval_or_fill_multicolor<'slf, 'a, 'mcts>( + &'slf self, + mcts: &'mcts mut MultiColorTimeSeries<'a, P, T>, fill_value: T, - ) -> Result, MultiColorEvaluatorError> { + ) -> Result, MultiColorEvaluatorError> + where + 'slf: 'a, + 'a: 'mcts, + { self.features .iter() .map(|x| x.eval_or_fill_multicolor(mcts, fill_value)) From 47a12cd910cfa3c496f54d259c405787bc3bd4ce Mon Sep 17 00:00:00 2001 From: Konstantin Malanchev Date: Wed, 28 Sep 2022 14:59:53 -0500 Subject: [PATCH 8/9] Make MultiColorTimeSeries an enum --- src/data/multi_color_time_series.rs | 218 +++++++++++++++++--- src/multicolor/features/color_of_maximum.rs | 1 + src/multicolor/features/color_of_median.rs | 1 + src/multicolor/features/color_of_minimum.rs | 1 + src/multicolor/monochrome_feature.rs | 2 +- src/multicolor/multicolor_evaluator.rs | 11 +- 6 files changed, 201 insertions(+), 33 deletions(-) diff --git a/src/data/multi_color_time_series.rs b/src/data/multi_color_time_series.rs index be2fc58c..535543aa 100644 --- a/src/data/multi_color_time_series.rs +++ b/src/data/multi_color_time_series.rs @@ -6,11 +6,16 @@ use crate::{DataSample, PassbandSet}; use itertools::Either; use itertools::EitherOrBoth; use itertools::Itertools; -use std::collections::BTreeMap; +use std::collections::{BTreeMap, BTreeSet}; +use std::ops::{Deref, DerefMut}; -pub struct MultiColorTimeSeries<'a, P: PassbandTrait, T: Float> { - mapping: BTreeMap>, - flat: Option>, +pub enum MultiColorTimeSeries<'a, P: PassbandTrait, T: Float> { + Mapping(MappedMultiColorTimeSeries<'a, P, T>), + Flat(FlatMultiColorTimeSeries<'a, P, T>), + MappingFlat { + mapping: MappedMultiColorTimeSeries<'a, P, T>, + flat: FlatMultiColorTimeSeries<'a, P, T>, + }, } impl<'a, 'p, P, T> MultiColorTimeSeries<'a, P, T> @@ -18,16 +23,129 @@ where P: PassbandTrait + 'p, T: Float, { - pub fn new(map: impl Into>>) -> Self { - Self { - mapping: map.into(), - flat: None, + pub fn from_map(map: impl Into>>) -> Self { + Self::Mapping(MappedMultiColorTimeSeries::new(map)) + } + + pub fn from_flat( + t: impl Into>, + m: impl Into>, + w: impl Into>, + passbands: impl Into>, + ) -> Self { + Self::Flat(FlatMultiColorTimeSeries::new(t, m, w, passbands)) + } + + pub fn mapping_mut(&mut self) -> &mut MappedMultiColorTimeSeries<'a, P, T> { + if matches!(self, MultiColorTimeSeries::Flat(_)) { + let dummy_self = Self::Mapping(MappedMultiColorTimeSeries::new(BTreeMap::new())); + *self = match std::mem::replace(self, dummy_self) { + Self::Flat(mut flat) => { + let mapping = MappedMultiColorTimeSeries::from_flat(&mut flat); + Self::MappingFlat { mapping, flat } + } + _ => unreachable!(), + } + } + match self { + Self::Mapping(mapping) => mapping, + Self::Flat(_flat) => { + unreachable!("::Flat variant is already transofrmed to ::MappingFlat") + } + Self::MappingFlat { mapping, .. } => mapping, + } + } + + pub fn mapping(&self) -> Option<&MappedMultiColorTimeSeries<'a, P, T>> { + match self { + Self::Mapping(mapping) => Some(mapping), + Self::Flat(_flat) => None, + Self::MappingFlat { mapping, .. } => Some(mapping), + } + } + + pub fn flat_mut(&mut self) -> &mut FlatMultiColorTimeSeries<'a, P, T> { + if matches!(self, MultiColorTimeSeries::Mapping(_)) { + let dummy_self = Self::Mapping(MappedMultiColorTimeSeries::new(BTreeMap::new())); + *self = match std::mem::replace(self, dummy_self) { + Self::Mapping(mut mapping) => { + let flat = FlatMultiColorTimeSeries::from_mapping(&mut mapping); + Self::MappingFlat { mapping, flat } + } + _ => unreachable!(), + } + } + match self { + Self::Mapping(_mapping) => { + unreachable!("::Mapping veriant is already transformed to ::MappingFlat") + } + Self::Flat(flat) => flat, + Self::MappingFlat { flat, .. } => flat, + } + } + + pub fn flat(&self) -> Option<&FlatMultiColorTimeSeries<'a, P, T>> { + match self { + Self::Mapping(_mapping) => None, + Self::Flat(flat) => Some(flat), + Self::MappingFlat { flat, .. } => Some(flat), + } + } + + pub fn passbands<'slf>( + &'slf self, + ) -> Either< + std::collections::btree_map::Keys<'slf, P, TimeSeries<'a, T>>, + std::collections::btree_set::Iter

, + > + where + 'a: 'slf, + { + match self { + Self::Mapping(mapping) => Either::Left(mapping.passbands()), + Self::Flat(flat) => Either::Right(flat.passband_set.iter()), + Self::MappingFlat { mapping, .. } => Either::Left(mapping.passbands()), } } +} + +pub struct MappedMultiColorTimeSeries<'a, P: PassbandTrait, T: Float>( + BTreeMap>, +); - pub fn flatten(&mut self) -> &FlatMultiColorTimeSeries<'static, P, T> { - self.flat - .get_or_insert_with(|| FlatMultiColorTimeSeries::from_mapping(&mut self.mapping)) +impl<'a, 'p, P, T> MappedMultiColorTimeSeries<'a, P, T> +where + P: PassbandTrait + 'p, + T: Float, +{ + pub fn new(map: impl Into>>) -> Self { + Self(map.into()) + } + + pub fn from_flat(flat: &mut FlatMultiColorTimeSeries) -> Self { + let mut map = BTreeMap::new(); + let groups = itertools::multizip(( + flat.t.as_slice().iter(), + flat.m.as_slice().iter(), + flat.w.as_slice().iter(), + flat.passbands.iter(), + )) + .group_by(|(_t, _m, _w, p)| (*p).clone()); + for (p, group) in &groups { + let (t_vec, m_vec, w_vec) = map + .entry(p.clone()) + .or_insert_with(|| (vec![], vec![], vec![])); + for (&t, &m, &w, _p) in group { + t_vec.push(t); + m_vec.push(m); + w_vec.push(w); + } + } + Self( + map.into_iter() + .map(|(p, (t, m, w))| (p, TimeSeries::new(t, m, w))) + .collect(), + ) } pub fn passbands<'slf>( @@ -36,7 +154,7 @@ where where 'a: 'slf, { - self.mapping.keys() + self.keys() } pub fn iter_passband_set<'slf, 'ps>( @@ -48,9 +166,7 @@ where 'ps: 'a, { match passband_set { - PassbandSet::AllAvailable => { - Either::Left(self.mapping.iter().map(|(p, ts)| (p, Some(ts)))) - } + PassbandSet::AllAvailable => Either::Left(self.iter().map(|(p, ts)| (p, Some(ts)))), PassbandSet::FixedSet(set) => Either::Right(self.iter_matched_passbands(set.iter())), } } @@ -64,9 +180,7 @@ where 'ps: 'a, { match passband_set { - PassbandSet::AllAvailable => { - Either::Left(self.mapping.iter_mut().map(|(p, ts)| (p, Some(ts)))) - } + PassbandSet::AllAvailable => Either::Left(self.iter_mut().map(|(p, ts)| (p, Some(ts)))), PassbandSet::FixedSet(set) => { Either::Right(self.iter_matched_passbands_mut(set.iter())) } @@ -77,7 +191,7 @@ where &self, passband_it: impl Iterator, ) -> impl Iterator>)> { - passband_it.map(|p| (p, self.mapping.get(p))) + passband_it.map(|p| (p, self.get(p))) } pub fn iter_matched_passbands_mut( @@ -85,7 +199,7 @@ where passband_it: impl Iterator, ) -> impl Iterator>)> { passband_it - .merge_join_by(self.mapping.iter_mut(), |p1, (p2, _ts)| p1.cmp(p2)) + .merge_join_by(self.iter_mut(), |p1, (p2, _ts)| p1.cmp(p2)) .filter_map(|either_or_both| match either_or_both { // mcts misses required passband EitherOrBoth::Left(p) => Some((p, None)), @@ -98,13 +212,24 @@ where } impl<'a, P: PassbandTrait, T: Float> FromIterator<(P, TimeSeries<'a, T>)> - for MultiColorTimeSeries<'a, P, T> + for MappedMultiColorTimeSeries<'a, P, T> { fn from_iter)>>(iter: I) -> Self { - Self { - mapping: iter.into_iter().collect(), - flat: None, - } + Self(iter.into_iter().collect()) + } +} + +impl<'a, P: PassbandTrait, T: Float> Deref for MappedMultiColorTimeSeries<'a, P, T> { + type Target = BTreeMap>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl<'a, P: PassbandTrait, T: Float> DerefMut for MappedMultiColorTimeSeries<'a, P, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 } } @@ -113,13 +238,51 @@ pub struct FlatMultiColorTimeSeries<'a, P: PassbandTrait, T: Float> { pub m: DataSample<'a, T>, pub w: DataSample<'a, T>, pub passbands: Vec

, + passband_set: BTreeSet

, } -impl FlatMultiColorTimeSeries<'static, P, T> +impl<'a, P, T> FlatMultiColorTimeSeries<'a, P, T> where P: PassbandTrait, T: Float, { + pub fn new( + t: impl Into>, + m: impl Into>, + w: impl Into>, + passbands: impl Into>, + ) -> Self { + let t = t.into(); + let m = m.into(); + let w = w.into(); + let passbands = passbands.into(); + let passband_set = passbands.iter().cloned().collect(); + + assert_eq!( + t.sample.len(), + m.sample.len(), + "t and m should have the same size" + ); + assert_eq!( + m.sample.len(), + w.sample.len(), + "m and err should have the same size" + ); + assert_eq!( + t.sample.len(), + passbands.len(), + "t and passbands should have the same size" + ); + + Self { + t, + m, + w, + passbands, + passband_set, + } + } + pub fn from_mapping(mapping: &mut BTreeMap>) -> Self { let (t, m, w, passbands): (Vec<_>, Vec<_>, Vec<_>, _) = mapping .iter_mut() @@ -131,7 +294,7 @@ where std::iter::repeat(p.clone()), )) }) - .kmerge_by(|(t1, _, _, _), (t2, _, _, _)| t1 <= t2) + .kmerge_by(|(t1, _m1, _w1, _p1), (t2, _m2, _w2, _p2)| t1 <= t2) .multiunzip(); Self { @@ -139,6 +302,7 @@ where m: m.into(), w: w.into(), passbands, + passband_set: mapping.keys().cloned().collect(), } } } diff --git a/src/multicolor/features/color_of_maximum.rs b/src/multicolor/features/color_of_maximum.rs index 7bc51a52..9639f875 100644 --- a/src/multicolor/features/color_of_maximum.rs +++ b/src/multicolor/features/color_of_maximum.rs @@ -99,6 +99,7 @@ where { let mut maxima = [T::zero(); 2]; for ((_passband, mcts), maximum) in mcts + .mapping_mut() .iter_matched_passbands_mut(self.passbands.iter()) .zip(maxima.iter_mut()) { diff --git a/src/multicolor/features/color_of_median.rs b/src/multicolor/features/color_of_median.rs index 54789f65..a6779e4b 100644 --- a/src/multicolor/features/color_of_median.rs +++ b/src/multicolor/features/color_of_median.rs @@ -108,6 +108,7 @@ where { let mut medians = [T::zero(); 2]; for ((passband, mcts), median) in mcts + .mapping_mut() .iter_matched_passbands_mut(self.passbands.iter()) .zip(medians.iter_mut()) { diff --git a/src/multicolor/features/color_of_minimum.rs b/src/multicolor/features/color_of_minimum.rs index 79701334..72de5c7d 100644 --- a/src/multicolor/features/color_of_minimum.rs +++ b/src/multicolor/features/color_of_minimum.rs @@ -99,6 +99,7 @@ where { let mut minima = [T::zero(); 2]; for ((_passband, mcts), minimum) in mcts + .mapping_mut() .iter_matched_passbands_mut(self.passbands.iter()) .zip(minima.iter_mut()) { diff --git a/src/multicolor/monochrome_feature.rs b/src/multicolor/monochrome_feature.rs index d6a03430..2b483d29 100644 --- a/src/multicolor/monochrome_feature.rs +++ b/src/multicolor/monochrome_feature.rs @@ -116,7 +116,7 @@ where { match &self.passband_set { PassbandSet::FixedSet(set) => { - mcts.iter_matched_passbands_mut(set.iter()) + mcts.mapping_mut().iter_matched_passbands_mut(set.iter()) .map(|(passband, ts)| { self.feature.eval_no_ts_check( ts.expect("we checked all needed passbands are in mcts, but we still cannot find one") diff --git a/src/multicolor/multicolor_evaluator.rs b/src/multicolor/multicolor_evaluator.rs index dccc4365..c60dddf7 100644 --- a/src/multicolor/multicolor_evaluator.rs +++ b/src/multicolor/multicolor_evaluator.rs @@ -139,7 +139,8 @@ where 'a: 'mcts, P: 'a, { - mcts.iter_passband_set_mut(self.get_passband_set()) + mcts.mapping_mut() + .iter_passband_set_mut(self.get_passband_set()) .map(|(p, maybe_ts)| { maybe_ts .ok_or(InternalMctsError::InternalWrongPassbandSet) @@ -227,16 +228,16 @@ mod tests { let passband_v_capital = MonochromePassband::new(5500e-8, "V"); let passband_r_capital = MonochromePassband::new(6400e-8, "R"); let mut mcts = { - let mut passbands = BTreeMap::new(); - passbands.insert( + let mut mapping = BTreeMap::new(); + mapping.insert( passband_b_capital.clone(), TimeSeries::new_without_weight(&t, &m), ); - passbands.insert( + mapping.insert( passband_v_capital.clone(), TimeSeries::new_without_weight(&t, &m), ); - MultiColorTimeSeries::new(passbands) + MultiColorTimeSeries::from_map(mapping) }; let feature = TestTimeMultiColorFeature { From 77919a573588bc4abcd9b695d71c11802dce3105 Mon Sep 17 00:00:00 2001 From: Konstantin Malanchev Date: Mon, 24 Apr 2023 19:17:44 -0500 Subject: [PATCH 9/9] Make *Fit do not accept flat time series (#113) * Fix #112 * Fix never-run tests --- CHANGELOG.md | 2 +- src/features/bazin_fit.rs | 17 +++++++++++++---- src/features/villar_fit.rs | 17 +++++++++++++---- 3 files changed, 27 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a93e02d2..6ab65c46 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,7 +27,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed --- +- **Breaking:** {Bazin,Villar}Fit requires variability now and do not accept flat time series anymore https://github.com/light-curve/light-curve-feature/issues/112 https://github.com/light-curve/light-curve-feature/pull/113 ### Security diff --git a/src/features/bazin_fit.rs b/src/features/bazin_fit.rs index 4f8ec064..9dc6bae8 100644 --- a/src/features/bazin_fit.rs +++ b/src/features/bazin_fit.rs @@ -108,8 +108,8 @@ lazy_info!( t_required: true, m_required: true, w_required: true, - sorting_required: true, // improve reproducibility - variability_required: false, + sorting_required: true, // improves reproducibility + variability_required: true, ); struct Params<'a, T> { @@ -419,13 +419,22 @@ mod tests { check_feature!(BazinFit); feature_test!( - bazin_fit_plateau, + bazin_fit_almost_plateau, [BazinFit::default()], [0.0, 0.0, 10.0, 5.0, 5.0, 0.0], // initial model parameters and zero chi2 linspace(0.0, 10.0, 11), - [0.0; 11], + linspace(0.0, 1e-100, 11), // make it a bit non-flat ); + #[test] + fn bazin_fit_plateau() { + let fe = BazinFit::default(); + let t = linspace(0.0, 10.0, 11); + let f = [0.0; 11]; + let mut ts = TimeSeries::new_without_weight(&t, &f); + assert!(fe.eval(&mut ts).is_err()); + } + fn bazin_fit_noisy(eval: BazinFit) { const N: usize = 50; diff --git a/src/features/villar_fit.rs b/src/features/villar_fit.rs index 36b3990f..ca1f8041 100644 --- a/src/features/villar_fit.rs +++ b/src/features/villar_fit.rs @@ -126,8 +126,8 @@ lazy_info!( t_required: true, m_required: true, w_required: true, - sorting_required: true, // improve reproducibility - variability_required: false, + sorting_required: true, // improves reproducibility + variability_required: true, ); impl FitModelTrait for VillarFit @@ -610,13 +610,22 @@ mod tests { check_feature!(VillarFit); feature_test!( - villar_fit_plateau, + villar_fit_almost_plateau, [VillarFit::default()], [0.0, 0.0, 10.0, 5.0, 5.0, 0.0, 1.0, 0.0], // initial model parameters and zero chi2 linspace(0.0, 10.0, 11), - [0.0; 11], + linspace(0.0, 1e-100, 11), // make it a bit non-flat ); + #[test] + fn villar_fit_plateau() { + let fe = VillarFit::default(); + let t = linspace(0.0, 10.0, 11); + let f = [0.0; 11]; + let mut ts = TimeSeries::new_without_weight(&t, &f); + assert!(fe.eval(&mut ts).is_err()); + } + #[cfg(any( feature = "gsl", any(feature = "ceres-source", feature = "ceres-system")