Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ dependencies = [
[project.optional-dependencies]
mcp = [
"mcp>=1.0,<2.0",
"watchfiles>=0.21",
]
benchmark = [
"sentence-transformers>=3.0",
Expand Down
9 changes: 1 addition & 8 deletions src/semble/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,6 @@ def _mcp_main() -> None:
prog="semble",
description="Instant local code search for agents.",
)
parser.add_argument(
"path",
nargs="?",
default=None,
help="Local directory or git URL to pre-index at startup (optional).",
)
parser.add_argument("--ref", default=None, help="Branch or tag to check out (git URLs only).")
_add_content_args(parser)
args = parser.parse_args()
if any(find_spec(dep) is None for dep in get_package_extras("semble", "mcp")):
Expand All @@ -87,7 +80,7 @@ def _mcp_main() -> None:
from semble.mcp import serve

content = _resolve_content(args.content, args.include_text_files)
asyncio.run(serve(args.path, ref=args.ref, content=content))
asyncio.run(serve(content))


def _resolve_content(content: list[str], include_text_files: bool) -> list[ContentType]:
Expand Down
63 changes: 14 additions & 49 deletions src/semble/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from pathlib import Path
from typing import Annotated

import watchfiles
from mcp.server.fastmcp import FastMCP
from pydantic import Field

Expand All @@ -21,35 +20,27 @@
logger = logging.getLogger(__name__)

_REPO_DESCRIPTION = (
"https:// or http:// git URL (e.g. https://github.com/org/repo) or local directory path to index and search. "
"Required when no default index was configured at startup. "
"The index is cached after the first call, so repeat queries are fast."
"A local directory path or https:// or http:// git URL (e.g. https://github.com/org/repo) to index and "
"search. The index is cached after the first call, so repeat queries are fast."
)
Comment thread
stephantul marked this conversation as resolved.

_CACHE_MAX_SIZE = 10 # Max number of cached indexes to keep in memory


async def _get_index(
repo: str | None,
default_source: str | None,
repo: str,
cache: _IndexCache,
) -> SembleIndex:
"""Return a cached index for a repo, rejecting unsafe git transport schemes."""
if repo is not None and is_git_url(repo) and not repo.startswith(("https://", "http://")):
if is_git_url(repo) and not repo.startswith(("https://", "http://")):
raise ValueError(f"Only https://, http://, or local directory paths are accepted as `repo`. Got: {repo!r}")
source = repo or default_source
if not source:
raise ValueError(
"No repo specified and no default index. "
"Pass an https:// or http:// git URL or local directory path as `repo`."
)
try:
return await cache.get(source)
return await cache.get(repo)
except Exception as exc:
raise ValueError(f"Failed to index {source!r}: {exc}") from exc
raise ValueError(f"Failed to index {repo!r}: {exc}") from exc


def create_server(cache: _IndexCache, default_source: str | None = None) -> FastMCP:
def create_server(cache: _IndexCache) -> FastMCP:
"""Build and return a configured FastMCP server backed by the given cache."""
server = FastMCP(
"semble",
Expand All @@ -66,7 +57,7 @@ def create_server(cache: _IndexCache, default_source: str | None = None) -> Fast
@server.tool()
async def search(
query: Annotated[str, Field(description="Natural language or code query.")],
repo: Annotated[str | None, Field(description=_REPO_DESCRIPTION)] = None,
repo: Annotated[str, Field(description=_REPO_DESCRIPTION)],
top_k: Annotated[int, Field(description="Number of results to return.", ge=1)] = 5,
max_snippet_lines: Annotated[
int | None,
Expand All @@ -89,7 +80,7 @@ async def search(
Pass a git URL or local path as `repo`; indexes are cached for the session.
"""
try:
index = await _get_index(repo, default_source, cache)
index = await _get_index(repo, cache)
except ValueError as exc:
return str(exc)
results = index.search(query, top_k=top_k, max_snippet_lines=max_snippet_lines)
Expand All @@ -104,7 +95,7 @@ async def find_related(
Field(description="Path to the file as stored in the index (use file_path from a search result)."),
],
line: Annotated[int, Field(description="Line number (1-indexed).")],
repo: Annotated[str | None, Field(description=_REPO_DESCRIPTION)] = None,
repo: Annotated[str, Field(description=_REPO_DESCRIPTION)],
top_k: Annotated[int, Field(description="Number of similar chunks to return.", ge=1)] = 5,
max_snippet_lines: Annotated[
int | None,
Expand All @@ -124,7 +115,7 @@ async def find_related(
Pass `file_path` and `line` from a prior search result.
"""
try:
index = await _get_index(repo, default_source, cache)
index = await _get_index(repo, cache)
except ValueError as exc:
return str(exc)
chunk = resolve_chunk(index.chunks, file_path, line)
Expand All @@ -143,15 +134,13 @@ async def find_related(


async def serve(
path: str | None = None,
ref: str | None = None,
content: Sequence[ContentType] = (ContentType.CODE,),
) -> None:
"""Start an MCP stdio server, optionally pre-indexing a default source."""
"""Start an MCP stdio server."""
cache = _IndexCache(content=content)

async def _load_and_prewarm() -> None:
"""Pre-load the model and optionally pre-index the default source in parallel with starting the server."""
"""Pre-load the embedding model in parallel with starting the server."""
try:
_, cache._model_path = await asyncio.to_thread(load_model)
except Exception as exc:
Expand All @@ -160,16 +149,9 @@ async def _load_and_prewarm() -> None:
return
finally:
cache._model_ready.set()
if path:
try:
await cache.get(path, ref=ref)
except Exception:
logger.warning("Failed to pre-index %r at startup", path, exc_info=True)
if not is_git_url(path):
await cache.start_watcher(path)

init_task = asyncio.create_task(_load_and_prewarm())
server = create_server(cache, default_source=path)
server = create_server(cache)
try:
await server.run_stdio_async()
finally:
Expand All @@ -187,7 +169,6 @@ def __init__(self, content: Sequence[ContentType] = (ContentType.CODE,)) -> None
self._model_ready = asyncio.Event()
self._content = content
self._tasks: OrderedDict[str, asyncio.Task[SembleIndex]] = OrderedDict() # ordered for LRU eviction
self._watcher_task: asyncio.Task[None] | None = None

async def _await_model(self) -> str:
"""Block until the model is installed; re-raise the load error if it failed."""
Expand Down Expand Up @@ -218,22 +199,6 @@ def _build_and_cache_index(self, source: str, ref: str | None, model_path: str,
def evict(self, source: str) -> None:
self._tasks.pop(self._compute_cache_key(source), None)

async def start_watcher(self, path: str) -> None:

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I'm missing something, but by removing this the MCP index never updates anymore I think?

"""Start a background task that re-indexes the path whenever files change."""
self._watcher_task = asyncio.create_task(self._watch_loop(path))

async def _watch_loop(self, path: str) -> None:
"""Watch the given path for changes and evict the cache entry on changes."""
try:
async for _ in watchfiles.awatch(path):
self.evict(path)
try:
await self.get(path)
except Exception:
logger.warning("Failed to rebuild index for %r after file change", path, exc_info=True)
except Exception:
pass

async def get(self, source: str, ref: str | None = None) -> SembleIndex:
"""Return an index for the requested source, building and caching it on first access."""
cache_key = self._compute_cache_key(source, ref)
Expand Down
7 changes: 0 additions & 7 deletions src/semble/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,6 @@ class SearchResult:
chunk: Chunk
score: float

def to_dict(self) -> dict[str, Any]:

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was somehow a pre-existing piece of dead code. Strange!

"""Dump a search result to a dict."""
return {
"chunk": self.chunk.to_dict(),
"score": self.score,
}


@dataclass(frozen=True, slots=True)
class IndexStats:
Expand Down
1 change: 0 additions & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
@pytest.mark.parametrize(
"argv",
[
["semble", "/some/path", "--ref", "main"],

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left this as a parameterize, even though there's only a single option now , seemed sensible given that we want to expand this again.

["semble"],
],
)
Expand Down
69 changes: 15 additions & 54 deletions tests/test_mcp.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import asyncio
import threading
from pathlib import Path
from typing import Any, AsyncGenerator
from unittest.mock import AsyncMock, MagicMock, patch
from typing import Any
from unittest.mock import MagicMock, patch

import pytest
from model2vec import StaticModel
Expand All @@ -26,15 +26,14 @@ async def _call_tool(
index_method: str,
index_return: list[SearchResult],
index_chunks: list[Chunk] | None = None,
default_source: str | None = "/some/path",
) -> str:
"""Patch SembleIndex.from_path with a fake index and invoke the tool, returning the text."""
fake_index = MagicMock()
getattr(fake_index, index_method).return_value = index_return
if index_chunks is not None:
fake_index.chunks = index_chunks
with patch("semble.mcp.SembleIndex.from_path", return_value=fake_index):
server = create_server(cache, default_source=default_source)
server = create_server(cache)
result = await server.call_tool(tool, args)
return _tool_text(result)

Expand Down Expand Up @@ -176,21 +175,6 @@ async def test_index_cache_ignores_cache_save_failure(cache: _IndexCache, tmp_pa
assert await cache.get(str(tmp_path)) is fake_index


@pytest.mark.anyio
@pytest.mark.parametrize(
("tool", "args"),
[
("search", {"query": "foo"}),
("find_related", {"file_path": "src/foo.py", "line": 10}),
],
)
async def test_tool_no_repo_no_default(cache: _IndexCache, tool: str, args: dict[str, object]) -> None:
"""Both tools return an error message when no repo and no default source are given."""
server = create_server(cache, default_source=None)
result = await server.call_tool(tool, args)
assert "No repo specified" in _tool_text(result)


@pytest.mark.anyio
@pytest.mark.parametrize(
("tool", "args"),
Expand All @@ -215,7 +199,7 @@ async def test_tool_index_failure(cache: _IndexCache, tool: str, args: dict[str,
[
pytest.param(
"search",
{"query": "bar"},
{"query": "bar", "repo": "/some/path"},
"search",
[SearchResult(chunk=make_chunk("def bar(): pass", "src/bar.py"), score=0.9)],
None,
Expand All @@ -224,7 +208,7 @@ async def test_tool_index_failure(cache: _IndexCache, tool: str, args: dict[str,
),
pytest.param(
"search",
{"query": "nothing"},
{"query": "nothing", "repo": "/some/path"},
"search",
[],
None,
Expand All @@ -233,7 +217,7 @@ async def test_tool_index_failure(cache: _IndexCache, tool: str, args: dict[str,
),
pytest.param(
"find_related",
{"file_path": "src/foo.py", "line": 1},
{"file_path": "src/foo.py", "line": 1, "repo": "/some/path"},
"find_related",
[SearchResult(chunk=make_chunk("class Foo: pass", "src/foo.py"), score=0.8)],
[make_chunk("class Foo: pass", "src/foo.py")],
Expand All @@ -242,7 +226,7 @@ async def test_tool_index_failure(cache: _IndexCache, tool: str, args: dict[str,
),
pytest.param(
"find_related",
{"file_path": "src/foo.py", "line": 1},
{"file_path": "src/foo.py", "line": 1, "repo": "/some/path"},
"find_related",
[],
[make_chunk("class Foo: pass", "src/foo.py")],
Expand All @@ -251,7 +235,7 @@ async def test_tool_index_failure(cache: _IndexCache, tool: str, args: dict[str,
),
pytest.param(
"find_related",
{"file_path": "src/unknown.py", "line": 1},
{"file_path": "src/unknown.py", "line": 1, "repo": "/some/path"},
"find_related",
[],
[],
Expand All @@ -277,21 +261,16 @@ async def test_tool_output(

@pytest.mark.anyio
@pytest.mark.parametrize(
("with_path", "load_err", "from_path_err", "stdio_yields"),
("load_err", "stdio_yields"),
[
(True, None, None, True),
(False, None, None, True),
(False, RuntimeError("boom"), None, True),
(True, None, RuntimeError("boom"), True),
(False, None, None, False),
(None, True),
(RuntimeError("boom"), True),
(None, False),
],
ids=["pre_index", "no_path", "model_load_fails", "prewarm_fails", "cancel_pending_init"],
ids=["model_loads", "model_load_fails", "cancel_pending_init"],
)
async def test_serve_runs_stdio(
tmp_path: Path,
with_path: bool,
load_err: Exception | None,
from_path_err: Exception | None,
stdio_yields: bool,
) -> None:
"""serve() runs stdio and handles all background init outcomes without raising."""
Expand All @@ -303,14 +282,11 @@ async def fake_stdio() -> None:
load_kwargs = (
{"side_effect": load_err} if load_err else {"return_value": (MagicMock(spec=StaticModel), "/fake/model")}
)
fp_kwargs = {"side_effect": from_path_err} if from_path_err else {"return_value": MagicMock()}
with (
patch("semble.mcp.load_model", **load_kwargs),
patch("semble.mcp.SembleIndex.from_path", **fp_kwargs),
patch.object(_IndexCache, "start_watcher", new_callable=AsyncMock),
patch("mcp.server.fastmcp.FastMCP.run_stdio_async", side_effect=fake_stdio) as mock_run,
):
await (serve(str(tmp_path)) if with_path else serve())
await serve()

mock_run.assert_called_once()

Expand Down Expand Up @@ -379,7 +355,7 @@ async def test_tool_rejects_unsafe_repo(
cache: _IndexCache, repo: str, tool: str, extra_args: dict[str, object]
) -> None:
"""Both tools reject unsafe git transport schemes (ssh://, file://, SCP-form) supplied as repo."""
server = create_server(cache, default_source=None)
server = create_server(cache)
result = await server.call_tool(tool, {**extra_args, "repo": repo})
assert "Only https://" in _tool_text(result)

Expand Down Expand Up @@ -411,18 +387,3 @@ def test_cache_evict(cache: _IndexCache, tmp_path: Path) -> None:
def test_cache_evict_missing(cache: _IndexCache, tmp_path: Path) -> None:
"""evict() on an unknown path is a no-op."""
cache.evict(str(tmp_path)) # should not raise


@pytest.mark.anyio
async def test_watch_loop(cache: _IndexCache, tmp_path: Path) -> None:
"""_watch_loop rebuilds on change (inner errors swallowed) and exits cleanly on watcher error."""

async def fake_awatch(_path: str) -> AsyncGenerator:
yield set()
raise RuntimeError("watcher died")

with patch("semble.mcp.watchfiles.awatch", fake_awatch):
with patch("semble.mcp.SembleIndex.from_path", side_effect=RuntimeError("build failed")):
await cache.start_watcher(str(tmp_path))
assert cache._watcher_task is not None
await cache._watcher_task
Loading