diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 00000000..f33a02cd --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,12 @@ +# To get started with Dependabot version updates, you'll need to specify which +# package ecosystems to update and where the package manifests are located. +# Please see the documentation for more information: +# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates +# https://containers.dev/guide/dependabot + +version: 2 +updates: + - package-ecosystem: "devcontainers" + directory: "/" + schedule: + interval: weekly diff --git a/roboflow/models/keypoint_detection.py b/roboflow/models/keypoint_detection.py index a0e86561..ee2a7699 100644 --- a/roboflow/models/keypoint_detection.py +++ b/roboflow/models/keypoint_detection.py @@ -26,6 +26,7 @@ def __init__( id: str, name: Optional[str] = None, version: Optional[str] = None, + confidence: Optional[int] = 10, local: Optional[str] = None, ): """ @@ -37,6 +38,7 @@ def __init__( name (str): is the name of the project version (str): version number local (str): localhost address and port if pointing towards local inference engine + confidence (int): A threshold for the returned predictions on a scale of 0-100. colors (dict): colors to use for the image preprocessing (dict): preprocessing to use for the image @@ -48,6 +50,7 @@ def __init__( self.__api_key = api_key self.id = id self.name = name + self.confidence = confidence self.version = version self.base_url = "https://detect.roboflow.com/" @@ -150,6 +153,7 @@ def __generate_url(self): self.base_url + without_workspace + "/" + str(version), "?api_key=" + self.__api_key, "&name=YOUR_IMAGE.jpg", + f"&confidence={self.confidence}", ] ) diff --git a/tests/annotations/keypoint-detection-annotations/MM2A_46_R_T_predictions.json b/tests/annotations/keypoint-detection-annotations/MM2A_46_R_T_predictions.json new file mode 100644 index 00000000..cec83f9b --- /dev/null +++ b/tests/annotations/keypoint-detection-annotations/MM2A_46_R_T_predictions.json @@ -0,0 +1,426 @@ +{ + "inference_id": "4b39e84f-88ce-4d27-880c-57bf949029e7", + "time": 0.05072031899999274, + "image": { + "width": 142, + "height": 327 + }, + "predictions": [ + { + "x": 59.5, + "y": 233.5, + "width": 25.0, + "height": 11.0, + "confidence": 0.763361394405365, + "class": "vertebra", + "class_id": 0, + "detection_id": "500623ad-1ca9-4604-a217-6e057cf3f588", + "keypoints": [ + { + "x": 47.0, + "y": 240.0, + "confidence": 0.9998906850814819, + "class_id": 0, + "class_name": "start" + }, + { + "x": 72.0, + "y": 227.0, + "confidence": 0.9996753931045532, + "class_id": 1, + "class_name": "end" + } + ] + }, + { + "x": 48.5, + "y": 210.0, + "width": 25.0, + "height": 10.0, + "confidence": 0.7600339651107788, + "class": "vertebra", + "class_id": 0, + "detection_id": "71b9fcd9-4351-47a2-a583-ad9d5e96b604", + "keypoints": [ + { + "x": 36.0, + "y": 215.0, + "confidence": 0.9999080896377563, + "class_id": 0, + "class_name": "start" + }, + { + "x": 61.0, + "y": 205.0, + "confidence": 0.9991416931152344, + "class_id": 1, + "class_name": "end" + } + ] + }, + { + "x": 41.0, + "y": 187.5, + "width": 24.0, + "height": 9.0, + "confidence": 0.742439866065979, + "class": "vertebra", + "class_id": 0, + "detection_id": "b04105e9-767e-4e76-9ee9-aadae1326658", + "keypoints": [ + { + "x": 29.0, + "y": 192.0, + "confidence": 0.9993617534637451, + "class_id": 0, + "class_name": "start" + }, + { + "x": 54.0, + "y": 183.0, + "confidence": 0.9988169074058533, + "class_id": 1, + "class_name": "end" + } + ] + }, + { + "x": 56.0, + "y": 80.5, + "width": 20.0, + "height": 7.0, + "confidence": 0.6737987995147705, + "class": "vertebra", + "class_id": 0, + "detection_id": "d519e047-8703-4000-95c7-3f29ffb7c233", + "keypoints": [ + { + "x": 46.0, + "y": 77.0, + "confidence": 0.9988997578620911, + "class_id": 0, + "class_name": "start" + }, + { + "x": 66.0, + "y": 85.0, + "confidence": 0.9990716576576233, + "class_id": 1, + "class_name": "end" + } + ] + }, + { + "x": 49.0, + "y": 98.0, + "width": 20.0, + "height": 10.0, + "confidence": 0.6587967872619629, + "class": "vertebra", + "class_id": 0, + "detection_id": "b9e46234-d571-4141-8e92-18a9c61cb888", + "keypoints": [ + { + "x": 40.0, + "y": 93.0, + "confidence": 0.9998961687088013, + "class_id": 0, + "class_name": "start" + }, + { + "x": 59.0, + "y": 102.0, + "confidence": 0.9997531175613403, + "class_id": 1, + "class_name": "end" + } + ] + }, + { + "x": 69.5, + "y": 259.0, + "width": 31.0, + "height": 8.0, + "confidence": 0.5930185914039612, + "class": "vertebra", + "class_id": 0, + "detection_id": "edf5846c-8858-4dbc-9a86-128340b4ecfd", + "keypoints": [ + { + "x": 55.0, + "y": 262.0, + "confidence": 0.9956279993057251, + "class_id": 0, + "class_name": "start" + }, + { + "x": 85.0, + "y": 256.0, + "confidence": 0.9995435476303101, + "class_id": 1, + "class_name": "end" + } + ] + }, + { + "x": 41.0, + "y": 113.0, + "width": 22.0, + "height": 6.0, + "confidence": 0.5826466083526611, + "class": "vertebra", + "class_id": 0, + "detection_id": "e5a09d46-1dda-4957-aa8c-2051febde9dc", + "keypoints": [ + { + "x": 30.0, + "y": 110.0, + "confidence": 0.9999384880065918, + "class_id": 0, + "class_name": "start" + }, + { + "x": 52.0, + "y": 117.0, + "confidence": 0.9998559951782227, + "class_id": 1, + "class_name": "end" + } + ] + }, + { + "x": 70.0, + "y": 45.5, + "width": 18.0, + "height": 5.0, + "confidence": 0.49985209107398987, + "class": "vertebra", + "class_id": 0, + "detection_id": "5373c45d-6ab5-474c-bf1b-0f80398e4f50", + "keypoints": [ + { + "x": 61.0, + "y": 43.0, + "confidence": 0.9988017082214355, + "class_id": 0, + "class_name": "start" + }, + { + "x": 80.0, + "y": 48.0, + "confidence": 0.9974247813224792, + "class_id": 1, + "class_name": "end" + } + ] + }, + { + "x": 62.5, + "y": 63.5, + "width": 19.0, + "height": 7.0, + "confidence": 0.46164435148239136, + "class": "vertebra", + "class_id": 0, + "detection_id": "93961590-3596-4f86-86ef-3115f27af571", + "keypoints": [ + { + "x": 53.0, + "y": 60.0, + "confidence": 0.9995067715644836, + "class_id": 0, + "class_name": "start" + }, + { + "x": 72.0, + "y": 67.0, + "confidence": 0.9983217716217041, + "class_id": 1, + "class_name": "end" + } + ] + }, + { + "x": 35.5, + "y": 168.0, + "width": 25.0, + "height": 6.0, + "confidence": 0.4455893933773041, + "class": "vertebra", + "class_id": 0, + "detection_id": "02949522-1446-4678-b580-37397a6e3544", + "keypoints": [ + { + "x": 23.0, + "y": 171.0, + "confidence": 0.9996205568313599, + "class_id": 0, + "class_name": "start" + }, + { + "x": 48.0, + "y": 165.0, + "confidence": 0.9966169595718384, + "class_id": 1, + "class_name": "end" + } + ] + }, + { + "x": 78.0, + "y": 284.0, + "width": 32.0, + "height": 8.0, + "confidence": 0.44538000226020813, + "class": "vertebra", + "class_id": 0, + "detection_id": "2ad879d8-901e-4647-aa58-52d3de28d5fa", + "keypoints": [ + { + "x": 62.0, + "y": 288.0, + "confidence": 0.9988986253738403, + "class_id": 0, + "class_name": "start" + }, + { + "x": 93.0, + "y": 282.0, + "confidence": 0.9989535808563232, + "class_id": 1, + "class_name": "end" + } + ] + }, + { + "x": 33.0, + "y": 150.0, + "width": 24.0, + "height": 2.0, + "confidence": 0.28537100553512573, + "class": "vertebra", + "class_id": 0, + "detection_id": "c522b624-ff97-46d7-b90f-4ea04e5ddbbd", + "keypoints": [ + { + "x": 21.0, + "y": 151.0, + "confidence": 0.9995453357696533, + "class_id": 0, + "class_name": "start" + }, + { + "x": 45.0, + "y": 150.0, + "confidence": 0.9993085861206055, + "class_id": 1, + "class_name": "end" + } + ] + }, + { + "x": 82.0, + "y": 313.0, + "width": 34.0, + "height": 6.0, + "confidence": 0.2552550435066223, + "class": "vertebra", + "class_id": 0, + "detection_id": "a420b97c-d316-41a6-895e-cd342795af4d", + "keypoints": [ + { + "x": 64.0, + "y": 316.0, + "confidence": 0.9955296516418457, + "class_id": 0, + "class_name": "start" + }, + { + "x": 99.0, + "y": 311.0, + "confidence": 0.9899979829788208, + "class_id": 1, + "class_name": "end" + } + ] + }, + { + "x": 37.0, + "y": 126.0, + "width": 24.0, + "height": 6.0, + "confidence": 0.2176252007484436, + "class": "vertebra", + "class_id": 0, + "detection_id": "8600c8bf-c3f6-46c1-a5a6-a602637d0d05", + "keypoints": [ + { + "x": 25.0, + "y": 124.0, + "confidence": 0.9993969798088074, + "class_id": 0, + "class_name": "start" + }, + { + "x": 49.0, + "y": 127.0, + "confidence": 0.9985653758049011, + "class_id": 1, + "class_name": "end" + } + ] + }, + { + "x": 35.5, + "y": 132.0, + "width": 23.0, + "height": 4.0, + "confidence": 0.14819690585136414, + "class": "vertebra", + "class_id": 0, + "detection_id": "8bda8ccd-a834-41b3-a40e-848c1fbd4de2", + "keypoints": [ + { + "x": 24.0, + "y": 130.0, + "confidence": 0.9997155666351318, + "class_id": 0, + "class_name": "start" + }, + { + "x": 47.0, + "y": 134.0, + "confidence": 0.9994645118713379, + "class_id": 1, + "class_name": "end" + } + ] + }, + { + "x": 74.0, + "y": 18.0, + "width": 24.0, + "height": 2.0, + "confidence": 0.14375203847885132, + "class": "vertebra", + "class_id": 0, + "detection_id": "fd657847-2461-40f0-8219-8c2c33580153", + "keypoints": [ + { + "x": 62.0, + "y": 18.0, + "confidence": 0.9981837272644043, + "class_id": 0, + "class_name": "start" + }, + { + "x": 85.0, + "y": 19.0, + "confidence": 0.996793806552887, + "class_id": 1, + "class_name": "end" + } + ] + } + ] +} diff --git a/tests/images/MM2A_46_R_T.png b/tests/images/MM2A_46_R_T.png new file mode 100644 index 00000000..8674705d Binary files /dev/null and b/tests/images/MM2A_46_R_T.png differ diff --git a/tests/models/test_keypoint_detection.py b/tests/models/test_keypoint_detection.py new file mode 100644 index 00000000..57bfaee5 --- /dev/null +++ b/tests/models/test_keypoint_detection.py @@ -0,0 +1,68 @@ +import json +import os +import unittest + +import responses +from requests.exceptions import HTTPError + +from roboflow.models.keypoint_detection import KeypointDetectionModel +from roboflow.util.prediction import PredictionGroup + +with open(os.path.join("tests", "annotations", "keypoint-detection-annotations", "MM2A_46_R_T_predictions.json")) as f: + MOCK_RESPONSE = json.load(f) + + +class TestKeypointDetection(unittest.TestCase): + api_key = "my-api-key" + workspace = "roboflow" + dataset_id = "test-123" + version = "23" + + api_url = f"https://detect.roboflow.com/{dataset_id}/{version}" + + _default_params = {"api_key": api_key, "confidence": "10", "name": "YOUR_IMAGE.jpg"} + + def setUp(self): + super().setUp() + self.version_id = f"{self.workspace}/{self.dataset_id}/{self.version}" + + def test_init_sets_attributes(self): + instance = KeypointDetectionModel(self.api_key, self.version_id, version=self.version) + + self.assertEqual(instance.id, self.version_id) + self.assertEqual(instance.api_key, self.api_key) + self.assertEqual(instance.version, self.version) + self.assertEqual(instance.base_url, "https://detect.roboflow.com/") + + @responses.activate + def test_predict_local_image(self): + instance = KeypointDetectionModel(self.api_key, self.version_id, version=self.version) + + responses.add(responses.POST, self.api_url, json=MOCK_RESPONSE, status=200) + + result = instance.predict("tests/images/MM2A_46_R_T.jpg") + + self.assertIsInstance(result, PredictionGroup) + self.assertEqual(len(result.predictions), 1) + self.assertEqual(result.predictions[0].confidence, 0.544) + + @responses.activate + def test_predict_with_confidence(self): + instance = KeypointDetectionModel(self.api_key, self.version_id, version=self.version) + + responses.add(responses.POST, self.api_url, json=MOCK_RESPONSE, status=200) + + result = instance.predict("tests/images/MM2A_46_R_T.jpg", confidence=30) + + self.assertIsInstance(result, PredictionGroup) + request = responses.calls[0].request + self.assertEqual(request.params["confidence"], "30") + + @responses.activate + def test_predict_error_response(self): + instance = KeypointDetectionModel(self.api_key, self.version_id, version=self.version) + + responses.add(responses.POST, self.api_url, json={"error": "Invalid API key"}, status=401) + + with self.assertRaises(HTTPError): + instance.predict("tests/images/MM2A_46_R_T.jpg")