forked from bostxavier/Serial-Speakers
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplot_xp_edition.py
More file actions
213 lines (186 loc) · 7.28 KB
/
plot_xp_edition.py
File metadata and controls
213 lines (186 loc) · 7.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
from typing import Union
import re, argparse, json
from collections import defaultdict
import pathlib as pl
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import scienceplots
import pandas as pd
from novelties_bookshare.experiments.plot_utils import (
STRAT_COLOR_HINTS,
EDITION_COLOR_HINTS,
)
def get_params(metric_key: str) -> tuple[str, dict[str, str]]:
# s=strat.e=novel,edition.error_nb
m = re.match(r"s=([^\.]+)\.e=([^\,]+),([^\.]+)\.(.*)", metric_key)
if m is None:
return "", {}
strat, novel, edition, metric_name = m.groups()
return metric_name, {
"strategy": strat,
"novel": novel,
"edition": edition,
}
def get_params_mlm(metric_key: str) -> tuple[str, dict[str, str]]:
# w=window.e=novel,edition.metric_name
m = re.match(r"w=([^\.]+)\.e=([^\,]+),([^\.]+)\.(.*)", metric_key)
if m is None:
return "", {}
window, novel, edition, metric_name = m.groups()
return metric_name, {"window": window, "novel": novel, "edition": edition}
def get_params_retokenize(metric_key: str) -> tuple[str, dict[str, str]]:
# t=max_token_len.s=max_split_nb.e=edition.metric_name
m = re.match(r"t=([^\.]+)\.s=([^\.]+)\.e=([^\,]+),([^\.]+)\.(.*)", metric_key)
if m is None:
return "", {}
max_token_len, max_split_nb, novel, edition, metric_name = m.groups()
return metric_name, {
"max_token_len": max_token_len,
"max_split_nb": max_split_nb,
"novel": novel,
"edition": edition,
}
def get_params_propagate(metric_key: str) -> tuple[str, dict[str, str]]:
# p=pipeline.e=edition
m = re.match(r"p=([^\.]+)\.e=([^\,]+),([^\.]+)\.(.*)", metric_key)
if m is None:
return "", {}
pipeline, novel, edition, metric_name = m.groups()
return metric_name, {"pipeline": pipeline, "edition": edition, "novel": novel}
def format_bar_height(bar_value: Union[int, float]) -> str:
if isinstance(bar_value, float):
return f"{bar_value:.2f}"
return str(bar_value)
METRIC_TO_YLABEL = {
"errors_nb": "Number of errors",
"precision_errors_nb": "Number of precision errors",
"duration_s": "Duration in seconds",
"errors_percent": "Percentage of errors",
"entity_errors_nb": "Number of entity errors",
"entity_errors_percent": "Percentage of entity errors",
"entity_errors_percent_lenient": "Percentage of entity errors",
"entity_errors_percent_strict": "Percentage of entity errors",
}
METRIC_TO_YFORMATTER = {
"errors_percent": mtick.PercentFormatter(1.0),
"entity_errors_percent": mtick.PercentFormatter(1.0),
"entity_errors_percent_lenient": mtick.PercentFormatter(1.0),
"entity_errors_percent_strict": mtick.PercentFormatter(1.0),
}
XP_PARAMS_KEY = {
"xp_edition": ["strategy", "edition"],
"xp_edition_ner_novelties": ["strategy", "edition"],
"xp_edition_mlm_params": ["window", "edition"],
"xp_edition_split_params": [
"max_token_len",
"max_split_nb",
"edition",
], # deprecated for retokenize
"xp_edition_retokenize_params": ["max_token_len", "max_split_nb", "edition"],
"xp_edition_propagate_order": ["pipeline", "edition"],
}
XP_GET_PARAMS_FN = {
"xp_edition": get_params,
"xp_edition_ner_novelties": get_params,
"xp_edition_mlm_params": get_params_mlm,
"xp_edition_split_params": get_params_retokenize, # deprecated for retokenize
"xp_edition_retokenize_params": get_params_retokenize,
"xp_edition_propagate_order": get_params_propagate,
}
def load_xp(path: pl.Path) -> tuple[str, pd.DataFrame]:
with open(path / "run.json") as f:
run_data = json.load(f)
xp_name = run_data["experiment"]["name"]
df_dict = defaultdict(list)
with open(path / "metrics.json") as f:
data = json.load(f)
lines = defaultdict(dict)
for key, metric_dict in data.items():
metric_name, params = XP_GET_PARAMS_FN[xp_name](key)
params_key = tuple(params[k] for k in XP_PARAMS_KEY[xp_name])
lines[params_key][metric_name] = metric_dict["values"][0]
for params, metric_dict in lines.items():
for param_name, param_value in zip(XP_PARAMS_KEY[xp_name], params):
df_dict[param_name].append(param_value)
for key, value in metric_dict.items():
df_dict[key].append(value)
df = pd.DataFrame(df_dict)
return xp_name, df
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-r",
"--runs",
nargs="*",
type=pl.Path,
help="A list of runs to plot. They must be of same nature (i.e. obtained with the same experiment script).",
)
parser.add_argument(
"-m",
"--metric",
type=str,
help="one of: 'errors_nb', 'duration_s', 'errors_percent', 'entity_errors_nb', 'entity_errors_percent_lenient', 'entity_errors_percent_strict'",
)
parser.add_argument("-l", "--log-scale", action="store_true")
parser.add_argument("-a", "--annotate-values", action="store_true")
parser.add_argument("-e", "--exclude-strategies", nargs="*", type=str, default=None)
parser.add_argument("-o", "--output-file", type=pl.Path, default=None)
args = parser.parse_args()
assert args.metric
assert len(args.runs) > 0
xp_name, df = load_xp(args.runs[0])
for run in args.runs[1:]:
run_xp_name, run_df = load_xp(run)
df = pd.concat([df, run_df])
if args.exclude_strategies:
for excluded_strat in args.exclude_strategies:
df = df[df["strategy"] != excluded_strat]
print(f"{xp_name=}")
print(df)
df = df.pivot(
index="edition",
columns=[k for k in XP_PARAMS_KEY[xp_name] if k != "edition"],
values=args.metric,
)
df = df.reset_index().set_index("edition")
df = df[df.mean().sort_values(ascending=False).index]
try: # if possible, sort columns in ascending order
df = df.sort_index(axis=1, key=lambda x: x.astype(int))
except ValueError:
pass
plt.style.use("science")
plt.rcParams.update({"font.size": 10})
# when columns are strategies, we use STRAT_COLOR_HINTS to provide
# the correct color for each bar. Otherwise, we simply use the
# default.
if all(col in STRAT_COLOR_HINTS for col in df.columns):
ax = df.plot.bar(color=[STRAT_COLOR_HINTS[strat] for strat in df.columns])
else:
ax = df.plot.bar()
# apply EDITION_COLOR_HINTS
for label in ax.get_xticklabels():
if label.get_text() in EDITION_COLOR_HINTS:
label.set_color(EDITION_COLOR_HINTS[label.get_text()])
ax.legend(ncols=3)
if args.annotate_values:
for p in ax.patches:
ax.annotate(
format_bar_height(p.get_height()),
(p.get_x() * 1.005, p.get_height() * 1.005),
fontsize=8,
)
ax.set_xlabel("Edition")
ax.set_ylabel(
METRIC_TO_YLABEL[args.metric] + "(log scale)" if args.log_scale else ""
)
if args.metric in METRIC_TO_YFORMATTER:
ax.yaxis.set_major_formatter(METRIC_TO_YFORMATTER[args.metric])
ax.grid()
if args.log_scale:
ax.set_yscale("log")
fig = plt.gcf()
fig.set_size_inches(8, 4)
if not args.output_file is None:
plt.savefig(args.output_file)
else:
plt.show()