-
Notifications
You must be signed in to change notification settings - Fork 108
fix: implement update_message() for guardrail redaction support #388
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
c4e826d
2ea85c1
a6df5f4
d649bda
1968765
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -628,36 +628,107 @@ def read_message(self, session_id: str, agent_id: str, message_id: int, **kwargs | |
| Optional[SessionMessage]: The message if found, None otherwise. | ||
|
|
||
| Note: | ||
| This should not be called as (as of now) only the `update_message` method calls this method and | ||
| updating messages is not supported in AgentCore Memory. | ||
| This is primarily used internally by the `update_message` method to read | ||
| the original event before replacing it. | ||
| """ | ||
| result = self.memory_client.gmdp_client.get_event( | ||
| memoryId=self.config.memory_id, actorId=self.config.actor_id, sessionId=session_id, eventId=message_id | ||
| ) | ||
| return SessionMessage.from_dict(result) if result else None | ||
|
|
||
| def update_message(self, session_id: str, agent_id: str, session_message: SessionMessage, **kwargs: Any) -> None: | ||
| """Update message data. | ||
| """Update message data in AgentCore Memory. | ||
|
|
||
| Note: AgentCore Memory doesn't support updating events, | ||
| so this is primarily for validation and logging. | ||
| Since AgentCore Memory events are immutable, this method performs an update by | ||
| creating a new event with the updated content and deleting the old event. | ||
| This enables features like guardrail redaction via Strands' redact_latest_message(). | ||
|
|
||
| If the message has not yet been persisted (e.g., still in the message buffer when | ||
| batch_size > 1), the buffered message is replaced in-place instead. | ||
|
|
||
| Args: | ||
| session_id (str): The session ID containing the message. | ||
| agent_id (str): The agent ID associated with the message. | ||
| session_message (SessionMessage): The message to update. | ||
| session_message (SessionMessage): The message to update (with updated content | ||
| and the original message_id/eventId). | ||
| **kwargs (Any): Additional keyword arguments. | ||
|
|
||
| Raises: | ||
| SessionException: If session ID doesn't match configuration. | ||
| SessionException: If session ID doesn't match configuration or update fails. | ||
| """ | ||
| if session_id != self.config.session_id: | ||
| raise SessionException(f"Session ID mismatch: expected {self.config.session_id}, got {session_id}") | ||
|
|
||
| logger.debug( | ||
| "Message update requested for message: %s (AgentCore Memory doesn't support updates)", | ||
| {session_message.message_id}, | ||
| ) | ||
| old_message_id = session_message.message_id | ||
|
|
||
| # If message hasn't been persisted yet (still in buffer), update it there | ||
| if old_message_id is None: | ||
| if self._update_buffered_message(session_message): | ||
| logger.debug("Updated buffered message (not yet persisted to AgentCore Memory)") | ||
| return | ||
| logger.debug("Message has no event ID and was not found in buffer - skipping update") | ||
| return | ||
|
|
||
| # Create a new event with the updated message content | ||
| try: | ||
| updated_message = SessionMessage( | ||
|
notgitika marked this conversation as resolved.
|
||
| message=session_message.message, | ||
| message_id=0, | ||
| created_at=session_message.created_at, | ||
| ) | ||
| new_event = self.create_message(session_id, agent_id, updated_message) | ||
| except Exception as e: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. P1 — Potential data loss when Two scenarios:
Suggestion: Guard against both cases: new_event = self.create_message(session_id, agent_id, updated_message)
if not new_event or not new_event.get("eventId"):
logger.warning("create_message did not return an eventId — skipping delete of old event %s", old_message_id)
return |
||
| logger.error("Failed to update message in AgentCore Memory: %s", e) | ||
| raise SessionException(f"Failed to update message: {e}") from e | ||
|
|
||
| # Delete the old event; if this fails, roll back the newly created event | ||
| try: | ||
| self.memory_client.gmdp_client.delete_event( | ||
|
notgitika marked this conversation as resolved.
|
||
| memoryId=self.config.memory_id, | ||
| actorId=self.config.actor_id, | ||
| sessionId=session_id, | ||
| eventId=old_message_id, | ||
| ) | ||
| except Exception as delete_error: | ||
| logger.warning( | ||
| "Failed to delete old event %s after creating replacement: %s. Attempting rollback.", | ||
| old_message_id, | ||
| delete_error, | ||
| ) | ||
| new_event_id = new_event.get("eventId") if new_event else None | ||
| if new_event_id: | ||
| try: | ||
| self.memory_client.gmdp_client.delete_event( | ||
| memoryId=self.config.memory_id, | ||
| actorId=self.config.actor_id, | ||
| sessionId=session_id, | ||
| eventId=new_event_id, | ||
| ) | ||
| logger.info("Rolled back new event %s after failed delete of old event", new_event_id) | ||
| except Exception as rollback_error: | ||
| logger.error( | ||
| "Rollback failed: could not delete new event %s: %s. Both old (%s) and new events may exist.", | ||
| new_event_id, | ||
| rollback_error, | ||
| old_message_id, | ||
| ) | ||
| raise SessionException( | ||
| f"Failed to update message: could not delete old event: {delete_error}" | ||
| ) from delete_error | ||
|
|
||
| # Update _latest_agent_message so it doesn't hold a stale eventId | ||
| new_event_id = new_event.get("eventId") if new_event else None | ||
| latest_messages = getattr(self, "_latest_agent_message", None) | ||
| if new_event_id and latest_messages and agent_id in latest_messages: | ||
| old_latest = latest_messages[agent_id] | ||
| if old_latest.message_id == old_message_id: | ||
| self._latest_agent_message[agent_id] = SessionMessage( | ||
| message=session_message.message, | ||
| message_id=new_event_id, | ||
| created_at=session_message.created_at, | ||
| ) | ||
|
|
||
| logger.info("Updated message in AgentCore Memory: replaced event %s", old_message_id) | ||
|
|
||
| def list_messages( | ||
| self, | ||
|
|
@@ -857,6 +928,44 @@ def initialize(self, agent: "Agent", **kwargs: Any) -> None: | |
|
|
||
| # region Batching support | ||
|
|
||
| def _update_buffered_message(self, session_message: SessionMessage) -> bool: | ||
| """Attempt to update a message that is still in the send buffer. | ||
|
|
||
| When batch_size > 1, messages may not yet be persisted to AgentCore Memory. | ||
| This method finds the most recent buffered message matching the session_message's | ||
| content role and replaces it with the updated content. | ||
|
|
||
| Args: | ||
| session_message (SessionMessage): The message with updated content. | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. P2 — Race condition between buffer update and flush
The fix would be to hold |
||
| Returns: | ||
| bool: True if a buffered message was found and updated, False otherwise. | ||
| """ | ||
| updated_messages = self.converter.message_to_payload(session_message) | ||
| if not updated_messages: | ||
| return False | ||
|
|
||
| is_blob = self.converter.exceeds_conversational_limit(updated_messages[0]) | ||
|
|
||
| with self._message_lock: | ||
| # Search from the end (most recent) to find the message to update | ||
| for i in range(len(self._message_buffer) - 1, -1, -1): | ||
| buf = self._message_buffer[i] | ||
| if buf.session_id == self.config.session_id and buf.messages: | ||
| # Match by role - the most recent message with the same role | ||
| existing_role = buf.messages[0][1] if not buf.is_blob else None | ||
| new_role = updated_messages[0][1] if not is_blob else None | ||
| if existing_role == new_role: | ||
| self._message_buffer[i] = BufferedMessage( | ||
| session_id=buf.session_id, | ||
| messages=updated_messages, | ||
| is_blob=is_blob, | ||
| timestamp=buf.timestamp, | ||
| metadata=buf.metadata, | ||
| ) | ||
| return True | ||
| return False | ||
|
|
||
| def _flush_messages_only(self) -> list[dict[str, Any]]: | ||
| """Flush only buffered messages to AgentCore Memory. | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -496,20 +496,74 @@ def test_read_message_not_found(self, session_manager, mock_memory_client): | |
|
|
||
| assert result is None | ||
|
|
||
| def test_update_message(self, session_manager): | ||
| """Test updating a message.""" | ||
| message = SessionMessage(message={"role": "user", "content": [{"text": "Hello"}]}, message_id=1) | ||
| def test_update_message(self, session_manager, mock_memory_client): | ||
| """Test updating a persisted message creates new event and deletes old one.""" | ||
| mock_memory_client.create_event.return_value = {"eventId": "new_event_456"} | ||
|
|
||
| message = SessionMessage( | ||
| message={"role": "user", "content": [{"text": "redacted"}]}, | ||
| message_id="old_event_123", | ||
| created_at="2024-01-01T12:00:00Z", | ||
| ) | ||
|
|
||
| # Should not raise any exceptions | ||
| session_manager.update_message("test-session-456", "test-agent-123", message) | ||
|
|
||
| # Verify new event was created | ||
| mock_memory_client.create_event.assert_called_once() | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Test gap — content of created event not verified
Suggest adding: create_kwargs = mock_memory_client.create_event.call_args.kwargs
# or inspect the payload to verify it contains "redacted" |
||
| # Verify old event was deleted | ||
| mock_memory_client.gmdp_client.delete_event.assert_called_once() | ||
| delete_kwargs = mock_memory_client.gmdp_client.delete_event.call_args.kwargs | ||
| assert delete_kwargs["eventId"] == "old_event_123" | ||
| assert delete_kwargs["memoryId"] == "test-memory-123" | ||
| assert delete_kwargs["actorId"] == "test-actor-789" | ||
| assert delete_kwargs["sessionId"] == "test-session-456" | ||
|
|
||
| def test_update_message_wrong_session(self, session_manager): | ||
| """Test updating a message with wrong session ID.""" | ||
| message = SessionMessage(message={"role": "user", "content": [{"text": "Hello"}]}, message_id=1) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Test gap — The source code explicitly updates Suggest adding a test that:
|
||
|
|
||
| with pytest.raises(SessionException, match="Session ID mismatch"): | ||
| session_manager.update_message("wrong-session-id", "test-agent-123", message) | ||
|
|
||
| def test_update_message_no_message_id(self, session_manager): | ||
| """Test updating a message with no message_id (not yet persisted) skips gracefully.""" | ||
| message = SessionMessage( | ||
| message={"role": "user", "content": [{"text": "redacted"}]}, | ||
| message_id=None, | ||
| created_at="2024-01-01T12:00:00Z", | ||
| ) | ||
|
|
||
| # Should not raise - just skips since message isn't persisted and buffer is empty | ||
| session_manager.update_message("test-session-456", "test-agent-123", message) | ||
|
|
||
| def test_update_message_create_fails(self, session_manager, mock_memory_client): | ||
| """Test update_message raises SessionException when create fails.""" | ||
| mock_memory_client.create_event.side_effect = Exception("API Error") | ||
|
|
||
| message = SessionMessage( | ||
| message={"role": "user", "content": [{"text": "redacted"}]}, | ||
| message_id="old_event_123", | ||
| created_at="2024-01-01T12:00:00Z", | ||
| ) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Test gap — When mock_memory_client.gmdp_client.delete_event.assert_not_called() |
||
|
|
||
| with pytest.raises(SessionException, match="Failed to update message"): | ||
| session_manager.update_message("test-session-456", "test-agent-123", message) | ||
|
|
||
| def test_update_message_delete_fails(self, session_manager, mock_memory_client): | ||
| """Test update_message raises SessionException when delete fails.""" | ||
| mock_memory_client.create_event.return_value = {"eventId": "new_event_456"} | ||
| mock_memory_client.gmdp_client.delete_event.side_effect = Exception("Delete failed") | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Test gap — rollback not verified Setting Missing tests:
|
||
| message = SessionMessage( | ||
| message={"role": "user", "content": [{"text": "redacted"}]}, | ||
| message_id="old_event_123", | ||
| created_at="2024-01-01T12:00:00Z", | ||
| ) | ||
|
|
||
| with pytest.raises(SessionException, match="Failed to update message"): | ||
| session_manager.update_message("test-session-456", "test-agent-123", message) | ||
|
|
||
| def test_list_messages_with_limit(self, session_manager, mock_memory_client): | ||
| """Test listing messages with limit.""" | ||
| mock_memory_client.list_events.return_value = [ | ||
|
|
@@ -1366,6 +1420,31 @@ def test_pending_message_count_with_buffered_messages(self, batching_session_man | |
| # Verify no events were sent (still buffered) | ||
| mock_memory_client.create_event.assert_not_called() | ||
|
|
||
| def test_update_buffered_message(self, batching_session_manager, mock_memory_client): | ||
| """Test update_message replaces a buffered message in-place when message_id is None.""" | ||
| # Add a user message to buffer | ||
| message = SessionMessage( | ||
| message={"role": "user", "content": [{"text": "offensive content"}]}, | ||
| message_id=0, | ||
| created_at="2024-01-01T12:00:00Z", | ||
| ) | ||
| batching_session_manager.create_message("test-session-456", "test-agent", message) | ||
| assert batching_session_manager.pending_message_count() == 1 | ||
|
|
||
| # Update with redacted content (message_id=None simulates unbatched message) | ||
| redacted = SessionMessage( | ||
| message={"role": "user", "content": [{"text": "Message redacted by guardrail"}]}, | ||
| message_id=None, | ||
| created_at="2024-01-01T12:00:00Z", | ||
| ) | ||
| batching_session_manager.update_message("test-session-456", "test-agent", redacted) | ||
|
|
||
| # Buffer should still have 1 message but with updated content | ||
| assert batching_session_manager.pending_message_count() == 1 | ||
| # No API calls should have been made (still buffered) | ||
| mock_memory_client.create_event.assert_not_called() | ||
| mock_memory_client.gmdp_client.delete_event.assert_not_called() | ||
|
|
||
| def test_buffer_auto_flushes_at_batch_size(self, batching_session_manager, mock_memory_client): | ||
| """Test buffer automatically flushes when reaching batch_size.""" | ||
| mock_memory_client.gmdp_client.create_event.return_value = {"eventId": "event_123"} | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.