Skip to content

Commit 76d6623

Browse files
v0.1.4
1 parent 27aeb43 commit 76d6623

6 files changed

Lines changed: 494 additions & 817 deletions

File tree

notebooks/00_spot_doc.ipynb

Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -884,6 +884,264 @@
884884
"\n",
885885
"# root.mainloop()"
886886
]
887+
},
888+
{
889+
"cell_type": "code",
890+
"execution_count": 1,
891+
"metadata": {},
892+
"outputs": [],
893+
"source": [
894+
"def apk(actual, predicted, k=10):\n",
895+
" \"\"\"\n",
896+
" Computes the average precision at k.\n",
897+
" This function computes the average precision at k between two lists of\n",
898+
" items.\n",
899+
" Parameters\n",
900+
" ----------\n",
901+
" actual : list\n",
902+
" A list of elements that are to be predicted (order doesn't matter)\n",
903+
" predicted : list\n",
904+
" A list of predicted elements (order does matter)\n",
905+
" k : int, optional\n",
906+
" The maximum number of predicted elements\n",
907+
" Returns\n",
908+
" -------\n",
909+
" score : double\n",
910+
" The average precision at k over the input lists\n",
911+
" \"\"\"\n",
912+
" if len(predicted) > k:\n",
913+
" predicted = predicted[:k]\n",
914+
"\n",
915+
" score = 0.0\n",
916+
" num_hits = 0.0\n",
917+
"\n",
918+
" for i, p in enumerate(predicted):\n",
919+
" if p in actual and p not in predicted[:i]:\n",
920+
" num_hits += 1.0\n",
921+
" score += num_hits / (i + 1.0)\n",
922+
"\n",
923+
" if not actual:\n",
924+
" return 0.0\n",
925+
"\n",
926+
" return score / min(len(actual), k)\n",
927+
"\n",
928+
"\n",
929+
"def mapk(actual, predicted, k=10):\n",
930+
" \"\"\"\n",
931+
" Computes the mean average precision at k.\n",
932+
" This function computes the mean average precision at k between two lists\n",
933+
" of lists of items.\n",
934+
" Parameters\n",
935+
" ----------\n",
936+
" actual : list\n",
937+
" A list of lists of elements that are to be predicted\n",
938+
" (order doesn't matter in the lists)\n",
939+
" predicted : list\n",
940+
" A list of lists of predicted elements\n",
941+
" (order matters in the lists)\n",
942+
" k : int, optional\n",
943+
" The maximum number of predicted elements\n",
944+
" Returns\n",
945+
" -------\n",
946+
" score : double\n",
947+
" The mean average precision at k over the input lists\n",
948+
" \"\"\"\n",
949+
" return np.mean([apk(a, p, k) for a, p in zip(actual, predicted)])\n",
950+
"\n",
951+
"\n",
952+
"def mapk_score(y_true, y_pred, k=3):\n",
953+
" \"\"\" Wrapper for mapk func using numpy arrays\n",
954+
" Args:\n",
955+
" y_true (np.array): array of true values\n",
956+
" y_pred (np.array): array of predicted values\n",
957+
" k (int): number of predictions\n",
958+
" Returns:\n",
959+
" score (float): mean average precision at k\n",
960+
" Examples:\n",
961+
" >>> y_true = np.array([0, 1, 2, 2])\n",
962+
" >>> y_pred = np.array([[0.5, 0.2, 0.2], # 0 is in top 2\n",
963+
" [0.3, 0.4, 0.2], # 1 is in top 2\n",
964+
" [0.2, 0.4, 0.3], # 2 is in top 2\n",
965+
" [0.7, 0.2, 0.1]]) # 2 isn't in top 2\n",
966+
" >>> mapk_score(y_true, y_pred, k=1)\n",
967+
" 0.33333333333333331\n",
968+
" >>> mapk_score(y_true, y_pred, k=2)\n",
969+
" 0.75\n",
970+
" >>> mapk_score(y_true, y_pred, k=3)\n",
971+
" 0.3611111111111111\n",
972+
" >>> mapk_score(y_true, y_pred, k=4)\n",
973+
" 0.34722222222222221\n",
974+
" >>> mapk_score(y_true, y_pred, k=5)\n",
975+
" 0.34722222222222221\n",
976+
" \"\"\"\n",
977+
" sorted_prediction_ids = np.argsort(-y_pred, axis=1)\n",
978+
" top_k_prediction_ids = sorted_prediction_ids[:, :k]\n",
979+
" score = mapk(y_true.reshape(-1, 1), top_k_prediction_ids, k=k)\n",
980+
" return score"
981+
]
982+
},
983+
{
984+
"cell_type": "code",
985+
"execution_count": 17,
986+
"metadata": {},
987+
"outputs": [
988+
{
989+
"data": {
990+
"text/plain": [
991+
"0.75"
992+
]
993+
},
994+
"execution_count": 17,
995+
"metadata": {},
996+
"output_type": "execute_result"
997+
}
998+
],
999+
"source": [
1000+
"import numpy as np\n",
1001+
"# rom spotPython.utils.metrics import mapk_score\n",
1002+
"from sklearn.metrics import top_k_accuracy_score\n",
1003+
"\n",
1004+
"y_true = np.array([0, 1, 2, 2])\n",
1005+
"y_pred = np.array([[0.5, 0.2, 0.2], # 0 is in top 2\n",
1006+
" [0.3, 0.4, 0.2], # 1 is in top 2\n",
1007+
" [0.2, 0.4, 0.3], # 2 is in top 2\n",
1008+
" [0.7, 0.2, 0.1]]) # 2 isn't in top 2\n",
1009+
"top_k_accuracy_score(y_true, y_pred, k=2)\n",
1010+
"#mapk_score(y_true, y_pred, k=2)\n"
1011+
]
1012+
},
1013+
{
1014+
"cell_type": "code",
1015+
"execution_count": 7,
1016+
"metadata": {},
1017+
"outputs": [
1018+
{
1019+
"data": {
1020+
"text/plain": [
1021+
"0.4583333333333333"
1022+
]
1023+
},
1024+
"execution_count": 7,
1025+
"metadata": {},
1026+
"output_type": "execute_result"
1027+
}
1028+
],
1029+
"source": [
1030+
"mapk_score(y_true, y_pred, k=5)"
1031+
]
1032+
},
1033+
{
1034+
"cell_type": "code",
1035+
"execution_count": 11,
1036+
"metadata": {},
1037+
"outputs": [
1038+
{
1039+
"data": {
1040+
"text/plain": [
1041+
"(4,)"
1042+
]
1043+
},
1044+
"execution_count": 11,
1045+
"metadata": {},
1046+
"output_type": "execute_result"
1047+
}
1048+
],
1049+
"source": [
1050+
"y_true.shape"
1051+
]
1052+
},
1053+
{
1054+
"cell_type": "code",
1055+
"execution_count": 12,
1056+
"metadata": {},
1057+
"outputs": [
1058+
{
1059+
"data": {
1060+
"text/plain": [
1061+
"numpy.ndarray"
1062+
]
1063+
},
1064+
"execution_count": 12,
1065+
"metadata": {},
1066+
"output_type": "execute_result"
1067+
}
1068+
],
1069+
"source": [
1070+
"type(y_true)"
1071+
]
1072+
},
1073+
{
1074+
"cell_type": "code",
1075+
"execution_count": 13,
1076+
"metadata": {},
1077+
"outputs": [
1078+
{
1079+
"data": {
1080+
"text/plain": [
1081+
"array([[0],\n",
1082+
" [1],\n",
1083+
" [2],\n",
1084+
" [2]])"
1085+
]
1086+
},
1087+
"execution_count": 13,
1088+
"metadata": {},
1089+
"output_type": "execute_result"
1090+
}
1091+
],
1092+
"source": [
1093+
"y_true.reshape(-1, 1)"
1094+
]
1095+
},
1096+
{
1097+
"cell_type": "code",
1098+
"execution_count": 15,
1099+
"metadata": {},
1100+
"outputs": [
1101+
{
1102+
"data": {
1103+
"text/plain": [
1104+
"array([0, 1, 2, 2])"
1105+
]
1106+
},
1107+
"execution_count": 15,
1108+
"metadata": {},
1109+
"output_type": "execute_result"
1110+
}
1111+
],
1112+
"source": [
1113+
"y_true"
1114+
]
1115+
},
1116+
{
1117+
"cell_type": "code",
1118+
"execution_count": 18,
1119+
"metadata": {},
1120+
"outputs": [
1121+
{
1122+
"data": {
1123+
"text/plain": [
1124+
"array([[0, 1, 2],\n",
1125+
" [1, 0, 2],\n",
1126+
" [1, 2, 0],\n",
1127+
" [0, 1, 2]])"
1128+
]
1129+
},
1130+
"execution_count": 18,
1131+
"metadata": {},
1132+
"output_type": "execute_result"
1133+
}
1134+
],
1135+
"source": [
1136+
"np.argsort(-y_pred, axis=1)"
1137+
]
1138+
},
1139+
{
1140+
"cell_type": "code",
1141+
"execution_count": null,
1142+
"metadata": {},
1143+
"outputs": [],
1144+
"source": []
8871145
}
8881146
],
8891147
"metadata": {

0 commit comments

Comments
 (0)