Skip to content
This repository was archived by the owner on Apr 8, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions farm/evaluation/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,20 @@ def pearson_and_spearman(preds, labels):
"corr": (pearson_corr + spearman_corr) / 2,
}


def compute_metrics(metric, preds, labels):
"""
Calculate the named metric values for the list of predictions vs list of labels.

:param metric: The name of a predefined metric; a function that takes a prediction list and a label
list and returns a dict from metric names to values, or recursively a list of metrics.
Predefined metrics are: mcc, acc, acc_f1, pear_spear, seq_f1, f1_macro, squad, mse, r2,
top_n_accuracy, text_similarity_metric.
:type metric: Samples are truncated after this many tokens.
:param preds: list of predictions
:param labels: list of target labels
:return: a dictionary mapping metric names to values.
"""
assert len(preds) == len(labels)
if metric == "mcc":
return {"mcc": matthews_corrcoef(labels, preds)}
Expand All @@ -98,6 +111,11 @@ def compute_metrics(metric, preds, labels):
return text_similarity_metric(preds, labels)
# elif metric == "masked_accuracy":
# return simple_accuracy(preds, labels, ignore=-1)
elif isinstance(metric, list):
ret = {}
for m in metric:
ret.update(compute_metrics(m, preds, labels))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice one! recursive call covering all cases, even registered metrics or list of lists.

return ret
elif metric in registered_metrics:
metric_func = registered_metrics[metric]
return metric_func(preds, labels)
Expand Down
47 changes: 47 additions & 0 deletions test/test_evaluation_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import pytest
import math
from farm.evaluation.metrics import compute_metrics


def test_compute_metrics_basic():
# check we get some exception, may not always be the AssertionError we get now
with pytest.raises(Exception):
compute_metrics("acc", ["x"] * 10, [""] * 11)
ret = compute_metrics("acc", [], [])
assert isinstance(ret, dict)
assert "acc" in ret
assert math.isnan(ret["acc"])
with pytest.raises(Exception):
compute_metrics("asdfasdf", ["a"], ["b"])
ls = (["a"] * 5)
ls.extend(["b"] * 5)
ps = ["a"] * 10
ret = compute_metrics("acc", ps, ls)
assert ret["acc"] == 0.5
ret = compute_metrics("acc", ls, ps)
assert ret["acc"] == 0.5
ret = compute_metrics("f1_macro", ps, ls)
assert ret["f1_macro"] == 1/3
ret = compute_metrics("f1_macro", ls, ps)
assert ret["f1_macro"] == 1 / 3
ret = compute_metrics(["f1_macro", "acc"], ps, ls)
assert isinstance(ret, dict)
assert len(ret) == 2
assert "acc" in ret
assert "f1_macro" in ret
assert ret["f1_macro"] == 1/3
assert ret["acc"] == 0.5
ret = compute_metrics(["f1_macro", "acc", "acc"], ps, ls)
assert isinstance(ret, dict)
assert len(ret) == 2
assert "acc" in ret
assert "f1_macro" in ret
assert ret["f1_macro"] == 1/3
assert ret["acc"] == 0.5
ret = compute_metrics(["f1_macro", ["acc"]], ps, ls)
assert isinstance(ret, dict)
assert len(ret) == 2
assert "acc" in ret
assert "f1_macro" in ret
assert ret["f1_macro"] == 1/3
assert ret["acc"] == 0.5