diff --git a/django/applications/catmaid/control/skeleton.py b/django/applications/catmaid/control/skeleton.py index cc31a517bd..882de99ccf 100644 --- a/django/applications/catmaid/control/skeleton.py +++ b/django/applications/catmaid/control/skeleton.py @@ -16,7 +16,7 @@ from django.http import HttpRequest, HttpResponse, HttpResponseBadRequest, Http404, \ JsonResponse, StreamingHttpResponse from django.shortcuts import get_object_or_404 -from django.db import connection +from django.db import connection, transaction from django.db.models import Q from django.views.decorators.cache import never_cache from django.utils.decorators import method_decorator @@ -26,7 +26,9 @@ from rest_framework.response import Response from rest_framework.views import APIView -from catmaid.history import add_log_entry +from catmaid import locks +from catmaid.history import add_log_entry, Transaction, \ + find_latest_deleted_skeleton_transaction, undelete_neuron from catmaid.control import tracing from catmaid.models import (Project, UserRole, Class, ClassInstance, Review, ClassInstanceClassInstance, Relation, Sampler, Treenode, @@ -5152,6 +5154,58 @@ def post(self, request:Request, project_id, skeleton_id) -> JsonResponse: }) +RESTORABLE_SKELETON_DELETE_LABEL = 'skeletons.remove' + + +@api_view(['POST']) +@requires_user_role(UserRole.Annotate) +def restore_historic_skeleton(request:HttpRequest, project_id, skeleton_id): + """Restore the latest deleted historic version of a single skeleton.""" + project_id = int(project_id) + skeleton_id = int(skeleton_id) + + with transaction.atomic(): + cursor = connection.cursor() + cursor.execute(""" + SELECT pg_advisory_xact_lock(%(lock_id)s::bigint) + """, { + 'lock_id': locks.skeleton_restore_lock_id(skeleton_id), + }) + cursor.execute("SET LOCAL catmaid.user_id=%(user_id)s", { + 'user_id': request.user.id, + }) + + if ClassInstance.objects.filter(pk=skeleton_id).exists(): + raise ValueError(f"An object with ID {skeleton_id} already exists") + + restore_info = find_latest_deleted_skeleton_transaction( + project_id, skeleton_id) + if not restore_info: + raise ValueError( + f"No single-skeleton deleted historic skeleton found for " + f"skeleton {skeleton_id}") + + source_label = restore_info['label'] + if source_label != RESTORABLE_SKELETON_DELETE_LABEL: + raise ValueError( + f"Latest historic transaction for skeleton {skeleton_id} has " + f"missing or unsupported label {source_label}; expected " + f"{RESTORABLE_SKELETON_DELETE_LABEL}") + + tx = Transaction(restore_info['transaction_id'], + restore_info['execution_time']) + + undelete_neuron(project_id, tx, user_id=request.user.id) + + return JsonResponse({ + 'skeleton_id': skeleton_id, + 'transaction_id': restore_info['transaction_id'], + 'execution_time': restore_info['execution_time'], + 'source_label': source_label, + 'success': f"Restored skeleton {skeleton_id} from history.", + }) + + @api_view(['POST']) @requires_user_role(UserRole.Annotate) def delete_skeleton(request:HttpRequest, project_id, skeleton_id): diff --git a/django/applications/catmaid/control/transaction.py b/django/applications/catmaid/control/transaction.py index f54bb6538e..2fd3eaae5e 100644 --- a/django/applications/catmaid/control/transaction.py +++ b/django/applications/catmaid/control/transaction.py @@ -404,6 +404,8 @@ def get(self): ORDER BY t.edition_time DESC LIMIT 1; """), + 'skeletons.remove': QueryRef(location_queries, "neurons.remove"), + 'skeletons.restore': QueryRef(location_queries, "nodes.update_location"), 'textlabels.create': HistoryQuery(""" SELECT t.location_x, t.location_y, t.location_z FROM textlabel{history} t @@ -537,6 +539,8 @@ def get(self): WHERE so.{txid} = %s AND t.skeleton_id = so.skeleton_id ORDER BY t.edition_time DESC """), + 'skeletons.remove': QueryRef(skeleton_queries, "neurons.remove"), + 'skeletons.restore': QueryRef(skeleton_queries, "nodes.update_location"), 'textlabels.create': HistoryQuery(""" SELECT t.skeleton_id FROM textlabel{history} t diff --git a/django/applications/catmaid/history.py b/django/applications/catmaid/history.py index 1feb83fa78..774957d6da 100644 --- a/django/applications/catmaid/history.py +++ b/django/applications/catmaid/history.py @@ -153,6 +153,82 @@ def __str__(self): return "TX {} @ {}".format(self.id, self.time) +def find_latest_deleted_skeleton_transaction(project_id, skeleton_id): + """Find the newest transaction that removed the passed in skeleton. + + Only single-skeleton delete candidates are returned. Callers still have to + decide whether the transaction label is safe for their restore use case. + """ + cursor = connection.cursor() + cursor.execute(""" + WITH skeleton_class AS ( + SELECT id + FROM class + WHERE project_id = %(project_id)s + AND class_name = 'skeleton' + ), + candidates AS ( + SELECT ci.exec_transaction_id AS transaction_id, + upper(ci.sys_period) AS execution_time + FROM class_instance__history ci + JOIN skeleton_class sc + ON sc.id = ci.class_id + WHERE ci.project_id = %(project_id)s + AND ci.id = %(skeleton_id)s + AND ci.sys_period IS NOT NULL + AND NOT isempty(ci.sys_period) + AND NOT upper_inf(ci.sys_period) + GROUP BY ci.exec_transaction_id, upper(ci.sys_period) + ), + latest AS ( + SELECT transaction_id, execution_time + FROM candidates + ORDER BY execution_time DESC + LIMIT 1 + ), + affected_skeleton AS ( + SELECT ci.id AS skeleton_id + FROM class_instance__history ci + JOIN skeleton_class sc + ON sc.id = ci.class_id + JOIN latest + ON latest.transaction_id = ci.exec_transaction_id + WHERE ci.project_id = %(project_id)s + AND ci.sys_period IS NOT NULL + AND NOT isempty(ci.sys_period) + AND NOT upper_inf(ci.sys_period) + AND upper(ci.sys_period) >= latest.execution_time + ), + affected_summary AS ( + SELECT COUNT(DISTINCT skeleton_id) AS skeleton_count, + BOOL_OR(skeleton_id = %(skeleton_id)s) AS includes_requested + FROM affected_skeleton + ) + SELECT latest.transaction_id, + latest.execution_time::text, + cti.label + FROM latest + JOIN affected_summary affected + ON affected.skeleton_count = 1 + AND affected.includes_requested + LEFT JOIN catmaid_transaction_info cti + ON cti.transaction_id = latest.transaction_id + AND cti.execution_time = latest.execution_time + """, { + 'project_id': project_id, + 'skeleton_id': skeleton_id, + }) + result = cursor.fetchone() + if not result: + return None + + return { + 'transaction_id': result[0], + 'execution_time': result[1], + 'label': result[2], + } + + def get_historic_row_count_affected_by_tx(tx): """Counts how many historic rows reference the passed in transaction. Returned is a list of tuples (table_name, count). diff --git a/django/applications/catmaid/locks.py b/django/applications/catmaid/locks.py index 4663f2fdff..811652b45a 100644 --- a/django/applications/catmaid/locks.py +++ b/django/applications/catmaid/locks.py @@ -1,3 +1,6 @@ +import hashlib + + # The base lock is formed from the multiplication of all characters of "catmaid" # as ASCII: 99 * 97 * 116 * 109 * 97 * 105 * 100. base_lock_id = 123666608142000 @@ -6,3 +9,16 @@ spatial_update_event_lock = base_lock_id + 1 # Postgres advisory lock ID to update history update even handling history_update_event_lock = base_lock_id + 2 +# Postgres advisory lock namespace for historic skeleton restores +skeleton_restore_lock_namespace = base_lock_id + 3 + + +def skeleton_restore_lock_id(skeleton_id): + """Return a stable signed 64-bit advisory lock ID for a skeleton restore.""" + lock_key = f'{skeleton_restore_lock_namespace}:{int(skeleton_id)}'.encode( + 'ascii') + unsigned_lock_id = int.from_bytes( + hashlib.blake2b(lock_key, digest_size=8).digest(), 'big') + if unsigned_lock_id >= 2 ** 63: + return unsigned_lock_id - 2 ** 64 + return unsigned_lock_id diff --git a/django/applications/catmaid/tests/apis/test_skeletons.py b/django/applications/catmaid/tests/apis/test_skeletons.py index 3b4d4a3661..22799f98a9 100644 --- a/django/applications/catmaid/tests/apis/test_skeletons.py +++ b/django/applications/catmaid/tests/apis/test_skeletons.py @@ -15,9 +15,9 @@ from catmaid.control.annotation import _annotate_entities, annotations_for_skeleton from catmaid.control.skeleton import _get_neuronname_from_skeletonid from catmaid.models import ( - ClassInstance, ClassInstanceClassInstance, Log, Review, TreenodeConnector, - ReviewerWhitelist, Treenode, User, ClientDatastore, ClientData, - TreenodeClassInstance + Class, ClassInstance, ClassInstanceClassInstance, Log, Relation, Review, + SkeletonSummary, TreenodeConnector, ReviewerWhitelist, Treenode, User, + ClientDatastore, ClientData, TreenodeClassInstance ) from .common import CatmaidApiTestCase, CatmaidApiTransactionTestCase @@ -1411,6 +1411,32 @@ def test_skeleton_id_change(self): class SkeletonsApiTransactionTests(CatmaidApiTransactionTestCase): + def create_extra_skeleton_for_neuron(self, neuron_id): + skeleton_class = Class.objects.get(project_id=self.test_project_id, + class_name='skeleton') + model_of = Relation.objects.get(project_id=self.test_project_id, + relation_name='model_of') + skeleton = ClassInstance.objects.create(user=self.test_user, + project=self.test_project, class_column=skeleton_class, + name='extra test skeleton') + Treenode.objects.create(user=self.test_user, editor=self.test_user, + project=self.test_project, location_x=1, location_y=2, + location_z=3, parent=None, radius=-1, confidence=5, + skeleton=skeleton) + ClassInstanceClassInstance.objects.create(user=self.test_user, + project=self.test_project, relation=model_of, + class_instance_a=skeleton, class_instance_b_id=neuron_id) + return skeleton + + def transaction_label_count(self, label): + cursor = connection.cursor() + cursor.execute(""" + SELECT COUNT(*) + FROM catmaid_transaction_info + WHERE project_id = %s + AND label = %s + """, (self.test_project_id, label)) + return cursor.fetchone()[0] def test_import_skeleton(self): self.fake_authentication() @@ -2003,3 +2029,154 @@ def test_skeleton_deletion(self): self.assertEqual(0, TreenodeClassInstance.objects.filter(id=353).count()) self.assertEqual(log_count + 1, count_logs()) + + def test_restore_historic_skeleton(self): + self.fake_authentication() + skeleton_id = 1 + neuron_id = 2 + n_treenodes = Treenode.objects.filter(skeleton_id=skeleton_id).count() + skeleton_remove_count = self.transaction_label_count('skeletons.remove') + skeleton_restore_count = self.transaction_label_count('skeletons.restore') + + response = self.client.post( + '/%d/skeletons/%s/delete' % (self.test_project_id, skeleton_id)) + self.assertStatus(response) + self.assertEqual(skeleton_remove_count + 1, + self.transaction_label_count('skeletons.remove')) + + response = self.client.post( + '/%d/skeletons/%s/restore' % (self.test_project_id, skeleton_id)) + self.assertStatus(response) + parsed_response = json.loads(response.content.decode('utf-8')) + + self.assertEqual(skeleton_id, parsed_response['skeleton_id']) + self.assertNotIn('restored_skeleton_ids', parsed_response) + self.assertEqual('skeletons.remove', parsed_response['source_label']) + + self.assertEqual(n_treenodes, + Treenode.objects.filter(skeleton_id=skeleton_id).count()) + self.assertTrue(ClassInstance.objects.filter(id=skeleton_id).exists()) + self.assertTrue(ClassInstance.objects.filter(id=neuron_id).exists()) + self.assertTrue(ClassInstanceClassInstance.objects.filter( + class_instance_a=skeleton_id, class_instance_b=neuron_id, + relation__relation_name='model_of').exists()) + self.assertEqual(n_treenodes, + SkeletonSummary.objects.get(skeleton_id=skeleton_id).num_nodes) + + cursor = connection.cursor() + cursor.execute(""" + SELECT COUNT(*) + FROM treenode_edge te + JOIN treenode t + ON t.id = te.id + WHERE t.skeleton_id = %s + """, (skeleton_id,)) + self.assertEqual(n_treenodes, cursor.fetchone()[0]) + self.assertEqual(skeleton_restore_count + 1, + self.transaction_label_count('skeletons.restore')) + + def test_restore_historic_skeleton_after_split_delete(self): + self.fake_authentication() + + response = self.client.post( + '/%d/skeleton/split' % (self.test_project_id,), + { + 'treenode_id': 2394, + 'upstream_annotation_map': '{}', + 'downstream_annotation_map': '{}', + }) + self.assertStatus(response) + parsed_response = json.loads(response.content.decode('utf-8')) + skeleton_id = parsed_response['new_skeleton_id'] + n_treenodes = Treenode.objects.filter(skeleton_id=skeleton_id).count() + + response = self.client.post( + '/%d/skeletons/%s/delete' % (self.test_project_id, skeleton_id), + {'delete_multi_skeleton_neurons': 'false'}) + self.assertStatus(response) + self.assertFalse(ClassInstance.objects.filter(id=skeleton_id).exists()) + + response = self.client.post( + '/%d/skeletons/%s/restore' % (self.test_project_id, skeleton_id)) + self.assertStatus(response) + parsed_response = json.loads(response.content.decode('utf-8')) + + self.assertEqual(skeleton_id, parsed_response['skeleton_id']) + self.assertEqual('skeletons.remove', parsed_response['source_label']) + self.assertTrue(ClassInstance.objects.filter(id=skeleton_id).exists()) + self.assertEqual(n_treenodes, + Treenode.objects.filter(skeleton_id=skeleton_id).count()) + + def test_restore_historic_skeleton_rejects_multi_skeleton_transaction(self): + self.fake_authentication() + skeleton_id = 1 + neuron_id = 2 + extra_skeleton = self.create_extra_skeleton_for_neuron(neuron_id) + + response = self.client.post( + '/%d/neuron/%s/delete' % (self.test_project_id, neuron_id)) + self.assertStatus(response) + + response = self.client.post( + '/%d/skeletons/%s/restore' % (self.test_project_id, skeleton_id)) + self.assertStatus(response, 400) + self.assertFalse(ClassInstance.objects.filter(id=skeleton_id).exists()) + self.assertFalse(ClassInstance.objects.filter(id=extra_skeleton.id).exists()) + + def test_restore_historic_skeleton_rejects_neuron_delete_transaction(self): + self.fake_authentication() + skeleton_id = 1 + neuron_id = 2 + + response = self.client.post( + '/%d/neuron/%s/delete' % (self.test_project_id, neuron_id)) + self.assertStatus(response) + + response = self.client.post( + '/%d/skeletons/%s/restore' % (self.test_project_id, skeleton_id)) + self.assertStatus(response, 400) + self.assertFalse(ClassInstance.objects.filter(id=skeleton_id).exists()) + + def test_restore_historic_skeleton_rejects_unlabeled_transaction(self): + self.fake_authentication() + skeleton_id = 1 + + response = self.client.post( + '/%d/skeletons/%s/delete' % (self.test_project_id, skeleton_id)) + self.assertStatus(response) + + cursor = connection.cursor() + cursor.execute(""" + WITH latest AS ( + SELECT ci.exec_transaction_id AS transaction_id, + upper(ci.sys_period) AS execution_time + FROM class_instance__history ci + JOIN class c + ON c.id = ci.class_id + AND c.project_id = ci.project_id + AND c.class_name = 'skeleton' + WHERE ci.project_id = %s + AND ci.id = %s + AND ci.sys_period IS NOT NULL + AND NOT upper_inf(ci.sys_period) + ORDER BY upper(ci.sys_period) DESC + LIMIT 1 + ) + DELETE FROM catmaid_transaction_info cti + USING latest + WHERE cti.transaction_id = latest.transaction_id + AND cti.execution_time = latest.execution_time + """, (self.test_project_id, skeleton_id)) + + response = self.client.post( + '/%d/skeletons/%s/restore' % (self.test_project_id, skeleton_id)) + self.assertStatus(response, 400) + self.assertFalse(ClassInstance.objects.filter(id=skeleton_id).exists()) + + def test_restore_historic_skeleton_rejects_live_skeleton(self): + self.fake_authentication() + skeleton_id = 1 + + response = self.client.post( + '/%d/skeletons/%s/restore' % (self.test_project_id, skeleton_id)) + self.assertStatus(response, 400) diff --git a/django/applications/catmaid/urls.py b/django/applications/catmaid/urls.py index c725a29620..79e7f1b2bc 100644 --- a/django/applications/catmaid/urls.py +++ b/django/applications/catmaid/urls.py @@ -345,7 +345,8 @@ re_path(r'^(?P\d+)/skeletons/(?P\d+)/sampler-count$', skeleton.sampler_count), re_path(r'^(?P\d+)/skeletons/(?P\d+)/cable-length$', skeleton.cable_length), re_path(r'^(?P\d+)/skeletons/(?P\d+)/neuron-details$', skeleton.neurondetails), - re_path(r'^(?P\d+)/skeletons/(?P\d+)/delete$', skeleton.delete_skeleton), + re_path(r'^(?P\d+)/skeletons/(?P\d+)/delete$', record_view("skeletons.remove")(skeleton.delete_skeleton)), + re_path(r'^(?P\d+)/skeletons/(?P\d+)/restore$', record_view("skeletons.restore")(skeleton.restore_historic_skeleton)), re_path(r'^(?P\d+)/skeleton/split$', record_view("skeletons.split")(skeleton.split_skeleton)), re_path(r'^(?P\d+)/skeleton/ancestry$', skeleton.skeleton_ancestry), re_path(r'^(?P\d+)/skeleton/join$', record_view("skeletons.merge")(skeleton.join_skeleton)),