Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 51 additions & 7 deletions src/semble/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import asyncio
import json
import logging
import time
from collections import OrderedDict
from collections.abc import Sequence
from pathlib import Path
Expand All @@ -11,7 +12,7 @@
from mcp.server.fastmcp import FastMCP
from pydantic import Field

from semble.cache import save_index_to_cache
from semble.cache import get_validated_cache, save_index_to_cache
from semble.index import SembleIndex
from semble.index.dense import load_model
from semble.types import ContentType
Expand All @@ -25,6 +26,7 @@
)

_CACHE_MAX_SIZE = 10 # Max number of cached indexes to keep in memory
_MIN_REVALIDATE_FACTOR = 3 # Don't recheck staleness sooner than this many times the last build's duration


async def _get_index(
Expand Down Expand Up @@ -169,6 +171,7 @@ def __init__(self, content: Sequence[ContentType] = (ContentType.CODE,)) -> None
self._model_ready = asyncio.Event()
self._content = content
self._tasks: OrderedDict[str, asyncio.Task[SembleIndex]] = OrderedDict() # ordered for LRU eviction
self._revalidate_after: dict[str, float] = {} # cache_key -> monotonic time, staleness check is gated until

async def _await_model(self) -> str:
"""Block until the model is installed; re-raise the load error if it failed."""
Expand Down Expand Up @@ -196,22 +199,63 @@ def _build_and_cache_index(self, source: str, ref: str | None, model_path: str,
logger.warning("Failed to save index cache for %r", cache_key, exc_info=True)
return index

async def _build_and_track(self, source: str, ref: str | None, model_path: str, cache_key: str) -> SembleIndex:
"""Build an index and, for local paths, record when its staleness cooldown ends.

The cooldown write happens after the await, i.e. back on the event loop thread,
regardless of which thread `_build_and_cache_index` itself ran on.
"""
start = time.monotonic()
index = await asyncio.to_thread(self._build_and_cache_index, source, ref, model_path, cache_key)
if not is_git_url(source):
finished = time.monotonic()
self._revalidate_after[cache_key] = finished + (finished - start) * _MIN_REVALIDATE_FACTOR
return index

def evict(self, source: str) -> None:
self._tasks.pop(self._compute_cache_key(source), None)
cache_key = self._compute_cache_key(source)
self._tasks.pop(cache_key, None)
self._revalidate_after.pop(cache_key, None)

async def _evict_if_stale(self, source: str, cache_key: str) -> None:
"""Evict a cached local-path entry whose on-disk cache no longer matches its files.

Skipped while inside the cooldown window so repos that are slow to build aren't
rebuilt faster than they can be served.
"""
cached = self._tasks.get(cache_key)
if (
cached is None
or is_git_url(source)
or not cached.done()
or cached.cancelled()
or cached.exception() is not None
):
return
if time.monotonic() < self._revalidate_after.get(cache_key, 0.0):
return
validated = await asyncio.to_thread(get_validated_cache, cache_key, self._model_path, self._content)
# Only evict if this entry hasn't already been replaced by a concurrent caller.
if validated is None and self._tasks.get(cache_key) is cached:
self.evict(source)

async def get(self, source: str, ref: str | None = None) -> SembleIndex:
"""Return an index for the requested source, building and caching it on first access."""
"""Return an index for the requested source, building and caching it on first access.

Local paths are revalidated against the on-disk cache on every call (subject to a
cooldown scaled by build time), so an entry is rebuilt once its files change.
"""
cache_key = self._compute_cache_key(source, ref)
await self._evict_if_stale(source, cache_key)

if cache_key not in self._tasks:
model_path = await self._await_model()
# Re-check after the await: another caller may have populated the entry.
if cache_key not in self._tasks:
if len(self._tasks) >= _CACHE_MAX_SIZE:
self._tasks.popitem(last=False)
self._tasks[cache_key] = asyncio.create_task(
asyncio.to_thread(self._build_and_cache_index, source, ref, model_path, cache_key)
)
evicted_key, _ = self._tasks.popitem(last=False)
self._revalidate_after.pop(evicted_key, None)
self._tasks[cache_key] = asyncio.create_task(self._build_and_track(source, ref, model_path, cache_key))
self._tasks.move_to_end(cache_key)
task = self._tasks[cache_key]
try:
Expand Down
2 changes: 1 addition & 1 deletion src/semble/version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__version_triple__ = (0, 4, 0)
__version_triple__ = (0, 4, 1)
__version__ = ".".join(map(str, __version_triple__))
85 changes: 85 additions & 0 deletions tests/test_mcp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import threading
import time
from pathlib import Path
from typing import Any
from unittest.mock import MagicMock, patch
Expand Down Expand Up @@ -135,6 +136,7 @@ async def test_index_cache_builds_and_caches(
with (
patch(f"semble.mcp.SembleIndex.{patch_target}", return_value=fake_index) as mock_build,
patch("semble.mcp.save_index_to_cache") as mock_save,
patch("semble.mcp.get_validated_cache", return_value=Path("/fake/cache")),
):
first = await cache.get(resolved_source)
second = await cache.get(resolved_source)
Expand All @@ -144,6 +146,89 @@ async def test_index_cache_builds_and_caches(
mock_save.assert_called_once_with(fake_index, cache._compute_cache_key(resolved_source))


@pytest.mark.anyio
@pytest.mark.parametrize(
("source", "patch_target", "expected_build_calls", "validate_called"),
[
("local_tmp_path", "from_path", 2, True),
("https://github.com/org/repo", "from_git", 1, False),
],
ids=["local_path_rebuilds_when_stale", "git_url_skips_revalidation"],
)
async def test_index_cache_staleness_check_scope(
cache: _IndexCache,
tmp_path: Path,
source: str,
patch_target: str,
expected_build_calls: int,
validate_called: bool,
) -> None:
"""Local paths are revalidated (and rebuilt when stale) on every get(); git URLs never are."""
resolved_source = str(tmp_path) if source == "local_tmp_path" else source
with (
patch(f"semble.mcp.SembleIndex.{patch_target}", return_value=MagicMock()) as mock_build,
patch("semble.mcp.save_index_to_cache"),
patch("semble.mcp.get_validated_cache", return_value=None) as mock_validate,
# Disable the cooldown: real build duration (here, just thread-dispatch overhead) would
# otherwise sometimes exceed the gap between the two get() calls below, flaking the test.
patch("semble.mcp._MIN_REVALIDATE_FACTOR", 0),
):
await cache.get(resolved_source)
await cache.get(resolved_source)
assert mock_build.call_count == expected_build_calls
assert mock_validate.called is validate_called


@pytest.mark.anyio
async def test_index_cache_skips_staleness_check_during_cooldown(cache: _IndexCache, tmp_path: Path) -> None:
"""A slow-to-build local path is not revalidated again until its cooldown elapses."""
cache_key = str(tmp_path.resolve())
cache._tasks[cache_key] = asyncio.create_task(_succeed())
await asyncio.sleep(0) # let the task finish
cache._revalidate_after[cache_key] = time.monotonic() + 30.0 # a build that took 10s, just finished
with patch("semble.mcp.get_validated_cache") as mock_validate:
await cache._evict_if_stale(str(tmp_path), cache_key)
mock_validate.assert_not_called()


async def _succeed() -> MagicMock:
return MagicMock()


@pytest.mark.anyio
async def test_index_cache_skips_staleness_check_for_failed_task(cache: _IndexCache, tmp_path: Path) -> None:
"""A cached entry that finished with an exception is not revalidated; it is left for the normal retry path."""

async def _raise() -> MagicMock:
raise RuntimeError("boom")

cache._tasks[str(tmp_path.resolve())] = asyncio.create_task(_raise())
await asyncio.sleep(0) # let the task finish
with patch("semble.mcp.get_validated_cache") as mock_validate:
await cache._evict_if_stale(str(tmp_path), str(tmp_path.resolve()))
mock_validate.assert_not_called()


@pytest.mark.anyio
async def test_index_cache_does_not_evict_entry_replaced_during_validation(cache: _IndexCache, tmp_path: Path) -> None:
"""If a concurrent caller already replaced a stale entry, _evict_if_stale must not evict the new one."""
cache_key = str(tmp_path.resolve())
cache._tasks[cache_key] = asyncio.create_task(_succeed())
await asyncio.sleep(0)
cache._revalidate_after[cache_key] = 0.0 # cooldown already elapsed

replacement_task = object()

def _replace_entry_then_report_stale(*args: object, **kwargs: object) -> None:
# Simulate a concurrent get() winning the race and installing a fresh task first.
cache._tasks[cache_key] = replacement_task # type: ignore[assignment]
return None

with patch("semble.mcp.get_validated_cache", side_effect=_replace_entry_then_report_stale):
await cache._evict_if_stale(str(tmp_path), cache_key)
assert cache._tasks.get(cache_key) is replacement_task


@pytest.mark.anyio
async def test_index_cache_evicts_on_failure(cache: _IndexCache, tmp_path: Path) -> None:
"""A failed build evicts the entry so the next call can retry."""
Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading