Skip to content

Commit 89aed9c

Browse files
authored
add sample_weight support to FScore metrics (#1816)
* add sample_weight support to FScore metrics, generate test data with https://colab.research.google.com/drive/1ymB5iOj9YCeBQ-g-8eMGpQhiHr7QUBM4
1 parent 379fe62 commit 89aed9c

File tree

2 files changed

+96
-19
lines changed

2 files changed

+96
-19
lines changed

tensorflow_addons/metrics/f_scores.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,6 @@ def _zero_wt_init(name):
120120
self.false_negatives = _zero_wt_init("false_negatives")
121121
self.weights_intermediate = _zero_wt_init("weights_intermediate")
122122

123-
# TODO: Add sample_weight support, currently it is
124-
# ignored during calculations.
125123
def update_state(self, y_true, y_pred, sample_weight=None):
126124
if self.threshold is None:
127125
threshold = tf.reduce_max(y_pred, axis=-1, keepdims=True)
@@ -131,17 +129,22 @@ def update_state(self, y_true, y_pred, sample_weight=None):
131129
else:
132130
y_pred = y_pred > self.threshold
133131

134-
y_true = tf.cast(y_true, tf.int32)
135-
y_pred = tf.cast(y_pred, tf.int32)
132+
y_true = tf.cast(y_true, self.dtype)
133+
y_pred = tf.cast(y_pred, self.dtype)
136134

137-
def _count_non_zero(val):
138-
non_zeros = tf.math.count_nonzero(val, axis=self.axis)
139-
return tf.cast(non_zeros, self.dtype)
135+
def _weighted_sum(val, sample_weight):
136+
if sample_weight is not None:
137+
val = tf.math.multiply(val, tf.expand_dims(sample_weight, 1))
138+
return tf.reduce_sum(val, axis=self.axis)
140139

141-
self.true_positives.assign_add(_count_non_zero(y_pred * y_true))
142-
self.false_positives.assign_add(_count_non_zero(y_pred * (y_true - 1)))
143-
self.false_negatives.assign_add(_count_non_zero((y_pred - 1) * y_true))
144-
self.weights_intermediate.assign_add(_count_non_zero(y_true))
140+
self.true_positives.assign_add(_weighted_sum(y_pred * y_true, sample_weight))
141+
self.false_positives.assign_add(
142+
_weighted_sum(y_pred * (1 - y_true), sample_weight)
143+
)
144+
self.false_negatives.assign_add(
145+
_weighted_sum((1 - y_pred) * y_true, sample_weight)
146+
)
147+
self.weights_intermediate.assign_add(_weighted_sum(y_true, sample_weight))
145148

146149
def result(self):
147150
precision = tf.math.divide_no_nan(

tensorflow_addons/metrics/tests/f_scores_test.py

Lines changed: 82 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,17 +38,17 @@ def test_config_fbeta():
3838
assert fbeta_obj2.dtype == tf.float32
3939

4040

41-
def _test_tf(avg, beta, act, pred, threshold):
41+
def _test_tf(avg, beta, act, pred, sample_weights, threshold):
4242
act = tf.constant(act, tf.float32)
4343
pred = tf.constant(pred, tf.float32)
4444

4545
fbeta = FBetaScore(3, avg, beta, threshold)
46-
fbeta.update_state(act, pred)
46+
fbeta.update_state(act, pred, sample_weights)
4747
return fbeta.result().numpy()
4848

4949

50-
def _test_fbeta_score(actuals, preds, avg, beta_val, result, threshold):
51-
tf_score = _test_tf(avg, beta_val, actuals, preds, threshold)
50+
def _test_fbeta_score(actuals, preds, sample_weights, avg, beta_val, result, threshold):
51+
tf_score = _test_tf(avg, beta_val, actuals, preds, sample_weights, threshold)
5252
np.testing.assert_allclose(tf_score, result, atol=1e-7, rtol=1e-6)
5353

5454

@@ -58,7 +58,7 @@ def test_fbeta_perfect_score():
5858

5959
for avg_val in ["micro", "macro", "weighted"]:
6060
for beta in [0.5, 1.0, 2.0]:
61-
_test_fbeta_score(actuals, preds, avg_val, beta, 1.0, 0.66)
61+
_test_fbeta_score(actuals, preds, None, avg_val, beta, 1.0, 0.66)
6262

6363

6464
def test_fbeta_worst_score():
@@ -67,7 +67,7 @@ def test_fbeta_worst_score():
6767

6868
for avg_val in ["micro", "macro", "weighted"]:
6969
for beta in [0.5, 1.0, 2.0]:
70-
_test_fbeta_score(actuals, preds, avg_val, beta, 0.0, 0.66)
70+
_test_fbeta_score(actuals, preds, None, avg_val, beta, 0.0, 0.66)
7171

7272

7373
@pytest.mark.parametrize(
@@ -90,7 +90,7 @@ def test_fbeta_worst_score():
9090
def test_fbeta_random_score(avg_val, beta, result):
9191
preds = [[0.7, 0.7, 0.7], [1, 0, 0], [0.9, 0.8, 0]]
9292
actuals = [[0, 0, 1], [1, 1, 0], [1, 1, 1]]
93-
_test_fbeta_score(actuals, preds, avg_val, beta, result, 0.66)
93+
_test_fbeta_score(actuals, preds, None, avg_val, beta, result, 0.66)
9494

9595

9696
@pytest.mark.parametrize(
@@ -120,7 +120,61 @@ def test_fbeta_random_score_none(avg_val, beta, result):
120120
[0, 0, 1],
121121
]
122122
actuals = [[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [1, 0, 0], [0, 0, 1]]
123-
_test_fbeta_score(actuals, preds, avg_val, beta, result, None)
123+
_test_fbeta_score(actuals, preds, None, avg_val, beta, result, None)
124+
125+
126+
@pytest.mark.parametrize(
127+
"avg_val, beta, sample_weights, result",
128+
[
129+
(None, 0.5, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [0.909091, 0.555556, 1.0]),
130+
(None, 0.5, [1.0, 0.0, 1.0, 1.0, 0.0, 1.0], [1.0, 0.0, 1.0]),
131+
(None, 0.5, [0.5, 1.0, 1.0, 1.0, 0.5, 1.0], [0.9375, 0.714286, 1.0]),
132+
(None, 1.0, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [0.8, 0.666667, 1.0]),
133+
(None, 1.0, [1.0, 0.0, 1.0, 1.0, 0.0, 1.0], [1.0, 0.0, 1.0]),
134+
(None, 1.0, [0.5, 1.0, 1.0, 1.0, 0.5, 1.0], [0.857143, 0.8, 1.0]),
135+
(None, 2.0, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [0.714286, 0.833333, 1.0]),
136+
(None, 2.0, [1.0, 0.0, 1.0, 1.0, 0.0, 1.0], [1.0, 0.0, 1.0]),
137+
(None, 2.0, [0.5, 1.0, 1.0, 1.0, 0.5, 1.0], [0.789474, 0.909091, 1.0]),
138+
("micro", 0.5, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 0.833333),
139+
("micro", 0.5, [1.0, 0.0, 1.0, 1.0, 0.0, 1.0], 1.0),
140+
("micro", 0.5, [0.5, 1.0, 1.0, 1.0, 0.5, 1.0], 0.9),
141+
("micro", 1.0, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 0.833333),
142+
("micro", 1.0, [1.0, 0.0, 1.0, 1.0, 0.0, 1.0], 1.0),
143+
("micro", 1.0, [0.5, 1.0, 1.0, 1.0, 0.5, 1.0], 0.9),
144+
("micro", 2.0, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 0.833333),
145+
("micro", 2.0, [1.0, 0.0, 1.0, 1.0, 0.0, 1.0], 1.0),
146+
("micro", 2.0, [0.5, 1.0, 1.0, 1.0, 0.5, 1.0], 0.9),
147+
("macro", 0.5, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 0.821549),
148+
("macro", 0.5, [1.0, 0.0, 1.0, 1.0, 0.0, 1.0], 0.666667),
149+
("macro", 0.5, [0.5, 1.0, 1.0, 1.0, 0.5, 1.0], 0.883929),
150+
("macro", 1.0, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 0.822222),
151+
("macro", 1.0, [1.0, 0.0, 1.0, 1.0, 0.0, 1.0], 0.666667),
152+
("macro", 1.0, [0.5, 1.0, 1.0, 1.0, 0.5, 1.0], 0.885714),
153+
("macro", 2.0, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 0.849206),
154+
("macro", 2.0, [1.0, 0.0, 1.0, 1.0, 0.0, 1.0], 0.666667),
155+
("macro", 2.0, [0.5, 1.0, 1.0, 1.0, 0.5, 1.0], 0.899522),
156+
("weighted", 0.5, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 0.880471),
157+
("weighted", 0.5, [1.0, 0.0, 1.0, 1.0, 0.0, 1.0], 1.0),
158+
("weighted", 0.5, [0.5, 1.0, 1.0, 1.0, 0.5, 1.0], 0.917857),
159+
("weighted", 1.0, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 0.844444),
160+
("weighted", 1.0, [1.0, 0.0, 1.0, 1.0, 0.0, 1.0], 1.0),
161+
("weighted", 1.0, [0.5, 1.0, 1.0, 1.0, 0.5, 1.0], 0.902857),
162+
("weighted", 2.0, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 0.829365),
163+
("weighted", 2.0, [1.0, 0.0, 1.0, 1.0, 0.0, 1.0], 1.0),
164+
("weighted", 2.0, [0.5, 1.0, 1.0, 1.0, 0.5, 1.0], 0.897608),
165+
],
166+
)
167+
def test_fbeta_weighted_random_score_none(avg_val, beta, sample_weights, result):
168+
preds = [
169+
[0.9, 0.1, 0],
170+
[0.2, 0.6, 0.2],
171+
[0, 0, 1],
172+
[0.4, 0.3, 0.3],
173+
[0, 0.9, 0.1],
174+
[0, 0, 1],
175+
]
176+
actuals = [[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [1, 0, 0], [0, 0, 1]]
177+
_test_fbeta_score(actuals, preds, sample_weights, avg_val, beta, result, None)
124178

125179

126180
def test_keras_model():
@@ -147,6 +201,26 @@ def test_eq():
147201
np.testing.assert_allclose(fbeta.result().numpy(), f1.result().numpy())
148202

149203

204+
def test_sample_eq():
205+
f1 = F1Score(3)
206+
f1_weighted = F1Score(3)
207+
208+
preds = [
209+
[0.9, 0.1, 0],
210+
[0.2, 0.6, 0.2],
211+
[0, 0, 1],
212+
[0.4, 0.3, 0.3],
213+
[0, 0.9, 0.1],
214+
[0, 0, 1],
215+
]
216+
actuals = [[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [1, 0, 0], [0, 0, 1]]
217+
sample_weights = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
218+
219+
f1.update_state(actuals, preds)
220+
f1_weighted(actuals, preds, sample_weights)
221+
np.testing.assert_allclose(f1.result().numpy(), f1_weighted.result().numpy())
222+
223+
150224
def test_keras_model_f1():
151225
f1 = F1Score(5)
152226
utils._get_model(f1, 5)

0 commit comments

Comments
 (0)