|
884 | 884 | "\n", |
885 | 885 | "# root.mainloop()" |
886 | 886 | ] |
| 887 | + }, |
| 888 | + { |
| 889 | + "cell_type": "code", |
| 890 | + "execution_count": 1, |
| 891 | + "metadata": {}, |
| 892 | + "outputs": [], |
| 893 | + "source": [ |
| 894 | + "def apk(actual, predicted, k=10):\n", |
| 895 | + " \"\"\"\n", |
| 896 | + " Computes the average precision at k.\n", |
| 897 | + " This function computes the average precision at k between two lists of\n", |
| 898 | + " items.\n", |
| 899 | + " Parameters\n", |
| 900 | + " ----------\n", |
| 901 | + " actual : list\n", |
| 902 | + " A list of elements that are to be predicted (order doesn't matter)\n", |
| 903 | + " predicted : list\n", |
| 904 | + " A list of predicted elements (order does matter)\n", |
| 905 | + " k : int, optional\n", |
| 906 | + " The maximum number of predicted elements\n", |
| 907 | + " Returns\n", |
| 908 | + " -------\n", |
| 909 | + " score : double\n", |
| 910 | + " The average precision at k over the input lists\n", |
| 911 | + " \"\"\"\n", |
| 912 | + " if len(predicted) > k:\n", |
| 913 | + " predicted = predicted[:k]\n", |
| 914 | + "\n", |
| 915 | + " score = 0.0\n", |
| 916 | + " num_hits = 0.0\n", |
| 917 | + "\n", |
| 918 | + " for i, p in enumerate(predicted):\n", |
| 919 | + " if p in actual and p not in predicted[:i]:\n", |
| 920 | + " num_hits += 1.0\n", |
| 921 | + " score += num_hits / (i + 1.0)\n", |
| 922 | + "\n", |
| 923 | + " if not actual:\n", |
| 924 | + " return 0.0\n", |
| 925 | + "\n", |
| 926 | + " return score / min(len(actual), k)\n", |
| 927 | + "\n", |
| 928 | + "\n", |
| 929 | + "def mapk(actual, predicted, k=10):\n", |
| 930 | + " \"\"\"\n", |
| 931 | + " Computes the mean average precision at k.\n", |
| 932 | + " This function computes the mean average precision at k between two lists\n", |
| 933 | + " of lists of items.\n", |
| 934 | + " Parameters\n", |
| 935 | + " ----------\n", |
| 936 | + " actual : list\n", |
| 937 | + " A list of lists of elements that are to be predicted\n", |
| 938 | + " (order doesn't matter in the lists)\n", |
| 939 | + " predicted : list\n", |
| 940 | + " A list of lists of predicted elements\n", |
| 941 | + " (order matters in the lists)\n", |
| 942 | + " k : int, optional\n", |
| 943 | + " The maximum number of predicted elements\n", |
| 944 | + " Returns\n", |
| 945 | + " -------\n", |
| 946 | + " score : double\n", |
| 947 | + " The mean average precision at k over the input lists\n", |
| 948 | + " \"\"\"\n", |
| 949 | + " return np.mean([apk(a, p, k) for a, p in zip(actual, predicted)])\n", |
| 950 | + "\n", |
| 951 | + "\n", |
| 952 | + "def mapk_score(y_true, y_pred, k=3):\n", |
| 953 | + " \"\"\" Wrapper for mapk func using numpy arrays\n", |
| 954 | + " Args:\n", |
| 955 | + " y_true (np.array): array of true values\n", |
| 956 | + " y_pred (np.array): array of predicted values\n", |
| 957 | + " k (int): number of predictions\n", |
| 958 | + " Returns:\n", |
| 959 | + " score (float): mean average precision at k\n", |
| 960 | + " Examples:\n", |
| 961 | + " >>> y_true = np.array([0, 1, 2, 2])\n", |
| 962 | + " >>> y_pred = np.array([[0.5, 0.2, 0.2], # 0 is in top 2\n", |
| 963 | + " [0.3, 0.4, 0.2], # 1 is in top 2\n", |
| 964 | + " [0.2, 0.4, 0.3], # 2 is in top 2\n", |
| 965 | + " [0.7, 0.2, 0.1]]) # 2 isn't in top 2\n", |
| 966 | + " >>> mapk_score(y_true, y_pred, k=1)\n", |
| 967 | + " 0.33333333333333331\n", |
| 968 | + " >>> mapk_score(y_true, y_pred, k=2)\n", |
| 969 | + " 0.75\n", |
| 970 | + " >>> mapk_score(y_true, y_pred, k=3)\n", |
| 971 | + " 0.3611111111111111\n", |
| 972 | + " >>> mapk_score(y_true, y_pred, k=4)\n", |
| 973 | + " 0.34722222222222221\n", |
| 974 | + " >>> mapk_score(y_true, y_pred, k=5)\n", |
| 975 | + " 0.34722222222222221\n", |
| 976 | + " \"\"\"\n", |
| 977 | + " sorted_prediction_ids = np.argsort(-y_pred, axis=1)\n", |
| 978 | + " top_k_prediction_ids = sorted_prediction_ids[:, :k]\n", |
| 979 | + " score = mapk(y_true.reshape(-1, 1), top_k_prediction_ids, k=k)\n", |
| 980 | + " return score" |
| 981 | + ] |
| 982 | + }, |
| 983 | + { |
| 984 | + "cell_type": "code", |
| 985 | + "execution_count": 17, |
| 986 | + "metadata": {}, |
| 987 | + "outputs": [ |
| 988 | + { |
| 989 | + "data": { |
| 990 | + "text/plain": [ |
| 991 | + "0.75" |
| 992 | + ] |
| 993 | + }, |
| 994 | + "execution_count": 17, |
| 995 | + "metadata": {}, |
| 996 | + "output_type": "execute_result" |
| 997 | + } |
| 998 | + ], |
| 999 | + "source": [ |
| 1000 | + "import numpy as np\n", |
| 1001 | + "# rom spotPython.utils.metrics import mapk_score\n", |
| 1002 | + "from sklearn.metrics import top_k_accuracy_score\n", |
| 1003 | + "\n", |
| 1004 | + "y_true = np.array([0, 1, 2, 2])\n", |
| 1005 | + "y_pred = np.array([[0.5, 0.2, 0.2], # 0 is in top 2\n", |
| 1006 | + " [0.3, 0.4, 0.2], # 1 is in top 2\n", |
| 1007 | + " [0.2, 0.4, 0.3], # 2 is in top 2\n", |
| 1008 | + " [0.7, 0.2, 0.1]]) # 2 isn't in top 2\n", |
| 1009 | + "top_k_accuracy_score(y_true, y_pred, k=2)\n", |
| 1010 | + "#mapk_score(y_true, y_pred, k=2)\n" |
| 1011 | + ] |
| 1012 | + }, |
| 1013 | + { |
| 1014 | + "cell_type": "code", |
| 1015 | + "execution_count": 7, |
| 1016 | + "metadata": {}, |
| 1017 | + "outputs": [ |
| 1018 | + { |
| 1019 | + "data": { |
| 1020 | + "text/plain": [ |
| 1021 | + "0.4583333333333333" |
| 1022 | + ] |
| 1023 | + }, |
| 1024 | + "execution_count": 7, |
| 1025 | + "metadata": {}, |
| 1026 | + "output_type": "execute_result" |
| 1027 | + } |
| 1028 | + ], |
| 1029 | + "source": [ |
| 1030 | + "mapk_score(y_true, y_pred, k=5)" |
| 1031 | + ] |
| 1032 | + }, |
| 1033 | + { |
| 1034 | + "cell_type": "code", |
| 1035 | + "execution_count": 11, |
| 1036 | + "metadata": {}, |
| 1037 | + "outputs": [ |
| 1038 | + { |
| 1039 | + "data": { |
| 1040 | + "text/plain": [ |
| 1041 | + "(4,)" |
| 1042 | + ] |
| 1043 | + }, |
| 1044 | + "execution_count": 11, |
| 1045 | + "metadata": {}, |
| 1046 | + "output_type": "execute_result" |
| 1047 | + } |
| 1048 | + ], |
| 1049 | + "source": [ |
| 1050 | + "y_true.shape" |
| 1051 | + ] |
| 1052 | + }, |
| 1053 | + { |
| 1054 | + "cell_type": "code", |
| 1055 | + "execution_count": 12, |
| 1056 | + "metadata": {}, |
| 1057 | + "outputs": [ |
| 1058 | + { |
| 1059 | + "data": { |
| 1060 | + "text/plain": [ |
| 1061 | + "numpy.ndarray" |
| 1062 | + ] |
| 1063 | + }, |
| 1064 | + "execution_count": 12, |
| 1065 | + "metadata": {}, |
| 1066 | + "output_type": "execute_result" |
| 1067 | + } |
| 1068 | + ], |
| 1069 | + "source": [ |
| 1070 | + "type(y_true)" |
| 1071 | + ] |
| 1072 | + }, |
| 1073 | + { |
| 1074 | + "cell_type": "code", |
| 1075 | + "execution_count": 13, |
| 1076 | + "metadata": {}, |
| 1077 | + "outputs": [ |
| 1078 | + { |
| 1079 | + "data": { |
| 1080 | + "text/plain": [ |
| 1081 | + "array([[0],\n", |
| 1082 | + " [1],\n", |
| 1083 | + " [2],\n", |
| 1084 | + " [2]])" |
| 1085 | + ] |
| 1086 | + }, |
| 1087 | + "execution_count": 13, |
| 1088 | + "metadata": {}, |
| 1089 | + "output_type": "execute_result" |
| 1090 | + } |
| 1091 | + ], |
| 1092 | + "source": [ |
| 1093 | + "y_true.reshape(-1, 1)" |
| 1094 | + ] |
| 1095 | + }, |
| 1096 | + { |
| 1097 | + "cell_type": "code", |
| 1098 | + "execution_count": 15, |
| 1099 | + "metadata": {}, |
| 1100 | + "outputs": [ |
| 1101 | + { |
| 1102 | + "data": { |
| 1103 | + "text/plain": [ |
| 1104 | + "array([0, 1, 2, 2])" |
| 1105 | + ] |
| 1106 | + }, |
| 1107 | + "execution_count": 15, |
| 1108 | + "metadata": {}, |
| 1109 | + "output_type": "execute_result" |
| 1110 | + } |
| 1111 | + ], |
| 1112 | + "source": [ |
| 1113 | + "y_true" |
| 1114 | + ] |
| 1115 | + }, |
| 1116 | + { |
| 1117 | + "cell_type": "code", |
| 1118 | + "execution_count": 18, |
| 1119 | + "metadata": {}, |
| 1120 | + "outputs": [ |
| 1121 | + { |
| 1122 | + "data": { |
| 1123 | + "text/plain": [ |
| 1124 | + "array([[0, 1, 2],\n", |
| 1125 | + " [1, 0, 2],\n", |
| 1126 | + " [1, 2, 0],\n", |
| 1127 | + " [0, 1, 2]])" |
| 1128 | + ] |
| 1129 | + }, |
| 1130 | + "execution_count": 18, |
| 1131 | + "metadata": {}, |
| 1132 | + "output_type": "execute_result" |
| 1133 | + } |
| 1134 | + ], |
| 1135 | + "source": [ |
| 1136 | + "np.argsort(-y_pred, axis=1)" |
| 1137 | + ] |
| 1138 | + }, |
| 1139 | + { |
| 1140 | + "cell_type": "code", |
| 1141 | + "execution_count": null, |
| 1142 | + "metadata": {}, |
| 1143 | + "outputs": [], |
| 1144 | + "source": [] |
887 | 1145 | } |
888 | 1146 | ], |
889 | 1147 | "metadata": { |
|
0 commit comments