@@ -38,17 +38,17 @@ def test_config_fbeta():
3838 assert fbeta_obj2 .dtype == tf .float32
3939
4040
41- def _test_tf (avg , beta , act , pred , threshold ):
41+ def _test_tf (avg , beta , act , pred , sample_weights , threshold ):
4242 act = tf .constant (act , tf .float32 )
4343 pred = tf .constant (pred , tf .float32 )
4444
4545 fbeta = FBetaScore (3 , avg , beta , threshold )
46- fbeta .update_state (act , pred )
46+ fbeta .update_state (act , pred , sample_weights )
4747 return fbeta .result ().numpy ()
4848
4949
50- def _test_fbeta_score (actuals , preds , avg , beta_val , result , threshold ):
51- tf_score = _test_tf (avg , beta_val , actuals , preds , threshold )
50+ def _test_fbeta_score (actuals , preds , sample_weights , avg , beta_val , result , threshold ):
51+ tf_score = _test_tf (avg , beta_val , actuals , preds , sample_weights , threshold )
5252 np .testing .assert_allclose (tf_score , result , atol = 1e-7 , rtol = 1e-6 )
5353
5454
@@ -58,7 +58,7 @@ def test_fbeta_perfect_score():
5858
5959 for avg_val in ["micro" , "macro" , "weighted" ]:
6060 for beta in [0.5 , 1.0 , 2.0 ]:
61- _test_fbeta_score (actuals , preds , avg_val , beta , 1.0 , 0.66 )
61+ _test_fbeta_score (actuals , preds , None , avg_val , beta , 1.0 , 0.66 )
6262
6363
6464def test_fbeta_worst_score ():
@@ -67,7 +67,7 @@ def test_fbeta_worst_score():
6767
6868 for avg_val in ["micro" , "macro" , "weighted" ]:
6969 for beta in [0.5 , 1.0 , 2.0 ]:
70- _test_fbeta_score (actuals , preds , avg_val , beta , 0.0 , 0.66 )
70+ _test_fbeta_score (actuals , preds , None , avg_val , beta , 0.0 , 0.66 )
7171
7272
7373@pytest .mark .parametrize (
@@ -90,7 +90,7 @@ def test_fbeta_worst_score():
9090def test_fbeta_random_score (avg_val , beta , result ):
9191 preds = [[0.7 , 0.7 , 0.7 ], [1 , 0 , 0 ], [0.9 , 0.8 , 0 ]]
9292 actuals = [[0 , 0 , 1 ], [1 , 1 , 0 ], [1 , 1 , 1 ]]
93- _test_fbeta_score (actuals , preds , avg_val , beta , result , 0.66 )
93+ _test_fbeta_score (actuals , preds , None , avg_val , beta , result , 0.66 )
9494
9595
9696@pytest .mark .parametrize (
@@ -120,7 +120,61 @@ def test_fbeta_random_score_none(avg_val, beta, result):
120120 [0 , 0 , 1 ],
121121 ]
122122 actuals = [[1 , 0 , 0 ], [0 , 1 , 0 ], [0 , 0 , 1 ], [1 , 0 , 0 ], [1 , 0 , 0 ], [0 , 0 , 1 ]]
123- _test_fbeta_score (actuals , preds , avg_val , beta , result , None )
123+ _test_fbeta_score (actuals , preds , None , avg_val , beta , result , None )
124+
125+
126+ @pytest .mark .parametrize (
127+ "avg_val, beta, sample_weights, result" ,
128+ [
129+ (None , 0.5 , [1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ], [0.909091 , 0.555556 , 1.0 ]),
130+ (None , 0.5 , [1.0 , 0.0 , 1.0 , 1.0 , 0.0 , 1.0 ], [1.0 , 0.0 , 1.0 ]),
131+ (None , 0.5 , [0.5 , 1.0 , 1.0 , 1.0 , 0.5 , 1.0 ], [0.9375 , 0.714286 , 1.0 ]),
132+ (None , 1.0 , [1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ], [0.8 , 0.666667 , 1.0 ]),
133+ (None , 1.0 , [1.0 , 0.0 , 1.0 , 1.0 , 0.0 , 1.0 ], [1.0 , 0.0 , 1.0 ]),
134+ (None , 1.0 , [0.5 , 1.0 , 1.0 , 1.0 , 0.5 , 1.0 ], [0.857143 , 0.8 , 1.0 ]),
135+ (None , 2.0 , [1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ], [0.714286 , 0.833333 , 1.0 ]),
136+ (None , 2.0 , [1.0 , 0.0 , 1.0 , 1.0 , 0.0 , 1.0 ], [1.0 , 0.0 , 1.0 ]),
137+ (None , 2.0 , [0.5 , 1.0 , 1.0 , 1.0 , 0.5 , 1.0 ], [0.789474 , 0.909091 , 1.0 ]),
138+ ("micro" , 0.5 , [1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ], 0.833333 ),
139+ ("micro" , 0.5 , [1.0 , 0.0 , 1.0 , 1.0 , 0.0 , 1.0 ], 1.0 ),
140+ ("micro" , 0.5 , [0.5 , 1.0 , 1.0 , 1.0 , 0.5 , 1.0 ], 0.9 ),
141+ ("micro" , 1.0 , [1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ], 0.833333 ),
142+ ("micro" , 1.0 , [1.0 , 0.0 , 1.0 , 1.0 , 0.0 , 1.0 ], 1.0 ),
143+ ("micro" , 1.0 , [0.5 , 1.0 , 1.0 , 1.0 , 0.5 , 1.0 ], 0.9 ),
144+ ("micro" , 2.0 , [1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ], 0.833333 ),
145+ ("micro" , 2.0 , [1.0 , 0.0 , 1.0 , 1.0 , 0.0 , 1.0 ], 1.0 ),
146+ ("micro" , 2.0 , [0.5 , 1.0 , 1.0 , 1.0 , 0.5 , 1.0 ], 0.9 ),
147+ ("macro" , 0.5 , [1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ], 0.821549 ),
148+ ("macro" , 0.5 , [1.0 , 0.0 , 1.0 , 1.0 , 0.0 , 1.0 ], 0.666667 ),
149+ ("macro" , 0.5 , [0.5 , 1.0 , 1.0 , 1.0 , 0.5 , 1.0 ], 0.883929 ),
150+ ("macro" , 1.0 , [1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ], 0.822222 ),
151+ ("macro" , 1.0 , [1.0 , 0.0 , 1.0 , 1.0 , 0.0 , 1.0 ], 0.666667 ),
152+ ("macro" , 1.0 , [0.5 , 1.0 , 1.0 , 1.0 , 0.5 , 1.0 ], 0.885714 ),
153+ ("macro" , 2.0 , [1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ], 0.849206 ),
154+ ("macro" , 2.0 , [1.0 , 0.0 , 1.0 , 1.0 , 0.0 , 1.0 ], 0.666667 ),
155+ ("macro" , 2.0 , [0.5 , 1.0 , 1.0 , 1.0 , 0.5 , 1.0 ], 0.899522 ),
156+ ("weighted" , 0.5 , [1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ], 0.880471 ),
157+ ("weighted" , 0.5 , [1.0 , 0.0 , 1.0 , 1.0 , 0.0 , 1.0 ], 1.0 ),
158+ ("weighted" , 0.5 , [0.5 , 1.0 , 1.0 , 1.0 , 0.5 , 1.0 ], 0.917857 ),
159+ ("weighted" , 1.0 , [1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ], 0.844444 ),
160+ ("weighted" , 1.0 , [1.0 , 0.0 , 1.0 , 1.0 , 0.0 , 1.0 ], 1.0 ),
161+ ("weighted" , 1.0 , [0.5 , 1.0 , 1.0 , 1.0 , 0.5 , 1.0 ], 0.902857 ),
162+ ("weighted" , 2.0 , [1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ], 0.829365 ),
163+ ("weighted" , 2.0 , [1.0 , 0.0 , 1.0 , 1.0 , 0.0 , 1.0 ], 1.0 ),
164+ ("weighted" , 2.0 , [0.5 , 1.0 , 1.0 , 1.0 , 0.5 , 1.0 ], 0.897608 ),
165+ ],
166+ )
167+ def test_fbeta_weighted_random_score_none (avg_val , beta , sample_weights , result ):
168+ preds = [
169+ [0.9 , 0.1 , 0 ],
170+ [0.2 , 0.6 , 0.2 ],
171+ [0 , 0 , 1 ],
172+ [0.4 , 0.3 , 0.3 ],
173+ [0 , 0.9 , 0.1 ],
174+ [0 , 0 , 1 ],
175+ ]
176+ actuals = [[1 , 0 , 0 ], [0 , 1 , 0 ], [0 , 0 , 1 ], [1 , 0 , 0 ], [1 , 0 , 0 ], [0 , 0 , 1 ]]
177+ _test_fbeta_score (actuals , preds , sample_weights , avg_val , beta , result , None )
124178
125179
126180def test_keras_model ():
@@ -147,6 +201,26 @@ def test_eq():
147201 np .testing .assert_allclose (fbeta .result ().numpy (), f1 .result ().numpy ())
148202
149203
204+ def test_sample_eq ():
205+ f1 = F1Score (3 )
206+ f1_weighted = F1Score (3 )
207+
208+ preds = [
209+ [0.9 , 0.1 , 0 ],
210+ [0.2 , 0.6 , 0.2 ],
211+ [0 , 0 , 1 ],
212+ [0.4 , 0.3 , 0.3 ],
213+ [0 , 0.9 , 0.1 ],
214+ [0 , 0 , 1 ],
215+ ]
216+ actuals = [[1 , 0 , 0 ], [0 , 1 , 0 ], [0 , 0 , 1 ], [1 , 0 , 0 ], [1 , 0 , 0 ], [0 , 0 , 1 ]]
217+ sample_weights = [1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ]
218+
219+ f1 .update_state (actuals , preds )
220+ f1_weighted (actuals , preds , sample_weights )
221+ np .testing .assert_allclose (f1 .result ().numpy (), f1_weighted .result ().numpy ())
222+
223+
150224def test_keras_model_f1 ():
151225 f1 = F1Score (5 )
152226 utils ._get_model (f1 , 5 )
0 commit comments