diff --git a/src/core/embeddings.rs b/src/core/embeddings.rs index be77c5b..dfbdc9f 100644 --- a/src/core/embeddings.rs +++ b/src/core/embeddings.rs @@ -131,10 +131,15 @@ impl EmbeddingEngine { } } -fn normalize(input: &[f32]) -> Vec { +pub(crate) fn normalize(input: &[f32]) -> Vec { + // Check for NaN or Infinity in input + if input.iter().any(|x| x.is_nan() || x.is_infinite()) { + return vec![0.0; input.len()]; + } + let magnitude: f32 = input.iter().map(|x| x * x).sum::().sqrt(); - if magnitude == 0.0 { - return input.to_vec(); + if magnitude == 0.0 || magnitude.is_nan() || magnitude.is_infinite() { + return vec![0.0; input.len()]; } input.iter().map(|x| x / magnitude).collect() }