|
3558 | 3558 | "outputs": [], |
3559 | 3559 | "source": [] |
3560 | 3560 | }, |
| 3561 | + { |
| 3562 | + "cell_type": "markdown", |
| 3563 | + "metadata": {}, |
| 3564 | + "source": [ |
| 3565 | + "# Subset Select" |
| 3566 | + ] |
| 3567 | + }, |
| 3568 | + { |
| 3569 | + "cell_type": "code", |
| 3570 | + "execution_count": 3, |
| 3571 | + "metadata": {}, |
| 3572 | + "outputs": [], |
| 3573 | + "source": [ |
| 3574 | + "import numpy as np\n", |
| 3575 | + "from sklearn.cluster import KMeans\n", |
| 3576 | + "\n", |
| 3577 | + "def select_distant_points(X, y, k):\n", |
| 3578 | + " \"\"\"\n", |
| 3579 | + " Selects k points that are distant from each other using a clustering approach.\n", |
| 3580 | + " \n", |
| 3581 | + " :param X: np.array of shape (n, k), with n points in k-dimensional space.\n", |
| 3582 | + " :param y: np.array of length n, with values corresponding to each point in X.\n", |
| 3583 | + " :param k: The number of distant points to select.\n", |
| 3584 | + " :return: Selected k points from X and their corresponding y values.\n", |
| 3585 | + " \"\"\"\n", |
| 3586 | + " # Perform k-means clustering to find k clusters\n", |
| 3587 | + " kmeans = KMeans(n_clusters=k, random_state=0, n_init=\"auto\").fit(X)\n", |
| 3588 | + " \n", |
| 3589 | + " # Find the closest point in X to each cluster center\n", |
| 3590 | + " selected_points = np.array([X[np.argmin(np.linalg.norm(X - center, axis=1))] for center in kmeans.cluster_centers_])\n", |
| 3591 | + " \n", |
| 3592 | + " # Find indices of the selected points in the original X array\n", |
| 3593 | + " indices = np.array([np.where(np.all(X==point, axis=1))[0][0] for point in selected_points])\n", |
| 3594 | + " \n", |
| 3595 | + " # Select the corresponding y values\n", |
| 3596 | + " selected_y = y[indices]\n", |
| 3597 | + " \n", |
| 3598 | + " return selected_points, selected_y\n" |
| 3599 | + ] |
| 3600 | + }, |
| 3601 | + { |
| 3602 | + "cell_type": "code", |
| 3603 | + "execution_count": 4, |
| 3604 | + "metadata": {}, |
| 3605 | + "outputs": [ |
| 3606 | + { |
| 3607 | + "name": "stdout", |
| 3608 | + "output_type": "stream", |
| 3609 | + "text": [ |
| 3610 | + "Selected Points: [[0.77482755 0.11776665]\n", |
| 3611 | + " [0.1600672 0.5466571 ]\n", |
| 3612 | + " [0.87752562 0.66913902]\n", |
| 3613 | + " [0.37216814 0.33013892]\n", |
| 3614 | + " [0.37977024 0.83643457]]\n", |
| 3615 | + "Corresponding y values: [0.79945132 0.63677214 0.17382713 0.97910053 0.26962361]\n" |
| 3616 | + ] |
| 3617 | + } |
| 3618 | + ], |
| 3619 | + "source": [ |
| 3620 | + "X = np.random.rand(100, 2) # Generate some random points\n", |
| 3621 | + "y = np.random.rand(100) # Random corresponding y values\n", |
| 3622 | + "k = 5\n", |
| 3623 | + "\n", |
| 3624 | + "selected_points, selected_y = select_distant_points(X, y, k)\n", |
| 3625 | + "print(\"Selected Points:\", selected_points)\n", |
| 3626 | + "print(\"Corresponding y values:\", selected_y)" |
| 3627 | + ] |
| 3628 | + }, |
3561 | 3629 | { |
3562 | 3630 | "cell_type": "code", |
3563 | 3631 | "execution_count": null, |
|
0 commit comments