Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 61 additions & 13 deletions agentex/src/domain/services/task_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

from fastapi import Depends

from src.adapters.crud_store.exceptions import ItemDoesNotExist
from src.adapters.streams.adapter_redis import DRedisStreamRepository
from src.api.schemas.authorization_types import AgentexResource
from src.domain.entities.agents import ACPType, AgentEntity
from src.domain.entities.events import EventEntity
from src.domain.entities.task_message_updates import TaskMessageUpdateEntity
Expand All @@ -14,6 +16,7 @@
from src.domain.repositories.task_repository import DTaskRepository
from src.domain.repositories.task_state_repository import DTaskStateRepository
from src.domain.services.agent_acp_service import DAgentACPService
from src.domain.services.authorization_service import DAuthorizationService
from src.utils.ids import orm_id
from src.utils.logging import make_logger
from src.utils.stream_topics import get_task_event_stream_topic
Expand All @@ -33,12 +36,14 @@ def __init__(
task_repository: DTaskRepository,
event_repository: DEventRepository,
stream_repository: DRedisStreamRepository,
authorization_service: DAuthorizationService,
):
self.acp_client = acp_client
self.task_state_repository = task_state_repository
self.task_repository = task_repository
self.event_repository = event_repository
self.stream_repository = stream_repository
self.authorization_service = authorization_service

async def create_task(
self,
Expand All @@ -59,19 +64,33 @@ async def create_task(
Returns:
Task containing the created task info
"""

task_entity = await self.task_repository.create(
agent_id=agent.id,
task=TaskEntity(
id=orm_id(),
name=task_name,
status=TaskStatus.RUNNING,
status_reason="Task created, forwarding to ACP server",
params=task_params,
task_metadata=task_metadata,
),
# Register in the authorization service before persisting: a registration
# failure aborts the request with no orphaned row. If the persist fails
# after a successful registration, the compensating deregister_resource
# below prevents a dangling authorization entry. Both calls are no-ops
# when the authorization service is disabled for this account.
task_entity = TaskEntity(
id=orm_id(),
name=task_name,
status=TaskStatus.RUNNING,
status_reason="Task created, forwarding to ACP server",
params=task_params,
task_metadata=task_metadata,
)
await self.authorization_service.register_resource(
AgentexResource.task(task_entity.id),
parent=AgentexResource.agent(agent.id),
)
return task_entity
try:
return await self.task_repository.create(
agent_id=agent.id,
task=task_entity,
)
except Exception:
await self.authorization_service.deregister_resource(
AgentexResource.task(task_entity.id),
)
raise
Comment thread
greptile-apps[bot] marked this conversation as resolved.

async def create_task_and_forward_to_acp(
self,
Expand All @@ -91,7 +110,9 @@ async def create_task_and_forward_to_acp(
Task containing the created task info
"""
task_entity = await self.create_task(
agent=agent, task_name=task_name, task_params=task_params
agent=agent,
task_name=task_name,
task_params=task_params,
)

if agent.acp_type == ACPType.SYNC:
Expand Down Expand Up @@ -214,8 +235,35 @@ async def delete_task(self, id: str | None = None, name: str | None = None) -> N
"""
Delete a task from the repository.
"""
# Delete first (Postgres is the source of truth for existence), then
# deregister best-effort: a deregister failure is logged and swallowed
# rather than failing a delete that already succeeded.
# Resolve the id before the delete so we can pass it to deregister_resource;
# looking it up by name afterwards would race. If the name doesn't resolve,
# swallow ItemDoesNotExist and let delete() surface its own native error
# so the missing-task error contract is unchanged.
task_id_for_deregister: str | None = id
if task_id_for_deregister is None and name is not None:
try:
task = await self.task_repository.get(name=name)
task_id_for_deregister = task.id
except ItemDoesNotExist:
task_id_for_deregister = None

await self.task_repository.delete(id=id, name=name)

if task_id_for_deregister is not None:
try:
await self.authorization_service.deregister_resource(
AgentexResource.task(task_id_for_deregister),
)
except Exception:
logger.exception(
"task authorization deregister failed for task %s after successful delete; "
"the deregistration failure has been swallowed",
task_id_for_deregister,
)

async def list_tasks(
self,
*,
Expand Down
27 changes: 25 additions & 2 deletions agentex/tests/fixtures/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Provides factory functions and specific fixtures for creating services with test repositories.
"""

from unittest.mock import MagicMock, Mock
from unittest.mock import AsyncMock, MagicMock, Mock

import pytest

Expand All @@ -12,6 +12,24 @@
# =============================================================================


def make_noop_authorization_service() -> Mock:
"""Shared noop AuthorizationService mock for tests that don't exercise authz.

``principal_context`` is ``None``, and
``grant``/``revoke``/``register_resource``/``deregister_resource`` are async
no-ops returning ``None`` — matching the real service signature. Use this
anywhere a test just needs to construct ``AgentTaskService`` without caring
about authorization behavior.
"""
svc = Mock()
svc.principal_context = None
svc.grant = AsyncMock(return_value=None)
svc.revoke = AsyncMock(return_value=None)
svc.register_resource = AsyncMock(return_value=None)
svc.deregister_resource = AsyncMock(return_value=None)
return svc


def create_task_message_service(task_message_repository):
"""Factory function to create TaskMessageService with given repository"""
from src.domain.services.task_message_service import TaskMessageService
Expand Down Expand Up @@ -52,16 +70,21 @@ def create_task_service(
event_repository,
agent_acp_service,
redis_stream_repository,
authorization_service=None,
):
"""Factory function to create AgentTaskService with given repositories and services"""
"""Factory function to create AgentTaskService with given repositories and services."""
from src.domain.services.task_service import AgentTaskService

if authorization_service is None:
authorization_service = make_noop_authorization_service()

return AgentTaskService(
task_repository=task_repository,
task_state_repository=task_state_repository,
event_repository=event_repository,
acp_client=agent_acp_service,
stream_repository=redis_stream_repository,
authorization_service=authorization_service,
)


Expand Down
3 changes: 3 additions & 0 deletions agentex/tests/integration/fixtures/integration_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from src.config.dependencies import GlobalDependencies
from src.config.environment_variables import EnvironmentVariables

from tests.fixtures.services import make_noop_authorization_service


@pytest.fixture(scope="session")
def event_loop():
Expand Down Expand Up @@ -455,6 +457,7 @@ async def send_message(self, *args, **kwargs):
task_repository=isolated_repositories["task_repository"],
event_repository=isolated_repositories["event_repository"],
stream_repository=isolated_repositories["redis_stream_repository"],
authorization_service=make_noop_authorization_service(),
)

return TasksUseCase(task_service=task_service)
Expand Down
28 changes: 16 additions & 12 deletions agentex/tests/integration/test_task_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from src.domain.use_cases.tasks_use_case import TasksUseCase
from src.utils.ids import orm_id

from tests.fixtures.services import make_noop_authorization_service


@pytest.mark.asyncio
@pytest.mark.integration
Expand Down Expand Up @@ -76,6 +78,7 @@ async def send_message(self, *args, **kwargs):
task_repository=isolated_repositories["task_repository"],
event_repository=isolated_repositories["event_repository"],
stream_repository=isolated_repositories["redis_stream_repository"],
authorization_service=make_noop_authorization_service(),
)

return TasksUseCase(task_service=task_service)
Expand Down Expand Up @@ -103,6 +106,7 @@ async def send_message(self, *args, **kwargs):
task_repository=isolated_repositories["task_repository"],
event_repository=isolated_repositories["event_repository"],
stream_repository=isolated_repositories["redis_stream_repository"],
authorization_service=make_noop_authorization_service(),
)

environment_variables = EnvironmentVariables.refresh()
Expand Down Expand Up @@ -194,17 +198,17 @@ async def collect_stream_events():
pass

# Then - Verify the stream event was received
assert (
len(stream_events) >= 1
), f"Expected at least 1 stream event, got {len(stream_events)}"
assert len(stream_events) >= 1, (
f"Expected at least 1 stream event, got {len(stream_events)}"
)

# Find the task_updated event
task_updated_events = [
e for e in stream_events if e.get("type") == "task_updated"
]
assert (
len(task_updated_events) >= 1
), f"Expected task_updated event, got events: {[e.get('type') for e in stream_events]}"
assert len(task_updated_events) >= 1, (
f"Expected task_updated event, got events: {[e.get('type') for e in stream_events]}"
)

task_updated_event = task_updated_events[0]

Expand Down Expand Up @@ -389,9 +393,9 @@ async def collect_stream_events():
task_updated_events = [
e for e in stream_events if e.get("type") == "task_updated"
]
assert (
len(task_updated_events) >= 3
), f"Expected at least 3 task_updated events, got {len(task_updated_events)}"
assert len(task_updated_events) >= 3, (
f"Expected at least 3 task_updated events, got {len(task_updated_events)}"
)

# Verify each event has the correct metadata for its update
versions = [
Expand Down Expand Up @@ -599,8 +603,8 @@ async def collect_stream_data():
pass

# Then - Verify we received at least 2 pings
assert (
ping_count >= 2
), f"Expected at least 2 ping messages during idle period, got {ping_count}"
assert ping_count >= 2, (
f"Expected at least 2 ping messages during idle period, got {ping_count}"
)

print(f"✅ Stream sent {ping_count} keepalive pings during idle period")
Empty file.
Loading
Loading