diff --git a/paderbox/math/vector.py b/paderbox/math/vector.py index a84cc055..73f1cebd 100644 --- a/paderbox/math/vector.py +++ b/paderbox/math/vector.py @@ -19,13 +19,15 @@ def cos_similarity(A, B): def cos_distance(a, b): """ - cosine distance between vector a and b: 1/2*(1-a*b/|a|*|b|) + cosine distance between array A and B + Args: + A: array with shape (...,d) + B: array with shape (...,d) - :param a: vector a (1xN or Nx1 numpy array) - :param b: vector b (1xN or Nx1 numpy array) - :return: distance (scalar) + Returns: + cosine distance with shape (...) """ - return 0.5 * (1 - sum(a * b) / np.sqrt(sum(a ** 2) * sum(b ** 2))) + return 0.5 * (1 - cos_similarity(A, B)) def normalize_vector_to_unit_length(data):