diff --git a/packages/markitdown/src/markitdown/__main__.py b/packages/markitdown/src/markitdown/__main__.py index ccb44b64b..802bf553d 100644 --- a/packages/markitdown/src/markitdown/__main__.py +++ b/packages/markitdown/src/markitdown/__main__.py @@ -119,6 +119,20 @@ def main(): help="Comma-separated list of file types to route to Content Understanding (e.g., pdf,jpeg,mp4). If omitted, all supported types are routed.", ) + parser.add_argument( + "--docintel-model-id", + type=str, + default=None, + help="Document Intelligence model ID (e.g., 'prebuilt-layout', 'prebuilt-invoice', or a custom model ID). Defaults to 'prebuilt-layout'.", + ) + + parser.add_argument( + "--docintel-query-fields", + type=str, + default=None, + help="Comma-separated list of field names to extract via the Document Intelligence queryFields add-on (OCR file types only).", + ) + parser.add_argument( "-p", "--use-plugins", @@ -208,9 +222,19 @@ def main(): elif args.filename is None: _exit_with_error("Filename is required when using Document Intelligence.") - markitdown = MarkItDown( - enable_plugins=args.use_plugins, docintel_endpoint=args.endpoint - ) + docintel_kwargs: Dict[str, Any] = { + "docintel_endpoint": args.endpoint, + } + if args.docintel_model_id: + docintel_kwargs["docintel_model_id"] = args.docintel_model_id + if args.docintel_query_fields: + fields = [ + f.strip() for f in args.docintel_query_fields.split(",") if f.strip() + ] + if fields: + docintel_kwargs["docintel_query_fields"] = fields + + markitdown = MarkItDown(enable_plugins=args.use_plugins, **docintel_kwargs) elif args.use_cu: if args.cu_endpoint is None: _exit_with_error( diff --git a/packages/markitdown/src/markitdown/_markitdown.py b/packages/markitdown/src/markitdown/_markitdown.py index f6aa4df0e..7ffad0996 100644 --- a/packages/markitdown/src/markitdown/_markitdown.py +++ b/packages/markitdown/src/markitdown/_markitdown.py @@ -222,6 +222,14 @@ def enable_builtins(self, **kwargs) -> None: if docintel_version is not None: docintel_args["api_version"] = docintel_version + docintel_model_id = kwargs.get("docintel_model_id") + if docintel_model_id is not None: + docintel_args["model_id"] = docintel_model_id + + docintel_query_fields = kwargs.get("docintel_query_fields") + if docintel_query_fields is not None: + docintel_args["query_fields"] = docintel_query_fields + self.register_converter( DocumentIntelligenceConverter(**docintel_args), ) diff --git a/packages/markitdown/src/markitdown/converters/_doc_intel_converter.py b/packages/markitdown/src/markitdown/converters/_doc_intel_converter.py index fd843f231..98aafcf46 100644 --- a/packages/markitdown/src/markitdown/converters/_doc_intel_converter.py +++ b/packages/markitdown/src/markitdown/converters/_doc_intel_converter.py @@ -1,12 +1,16 @@ import sys import re import os -from typing import BinaryIO, Any, List +from datetime import date, datetime, time +from typing import BinaryIO, Any, List, Optional from enum import Enum from .._base_converter import DocumentConverter, DocumentConverterResult from .._stream_info import StreamInfo from .._exceptions import MissingDependencyException +from .. import __version__ as _markitdown_version + +_USER_AGENT = f"markitdown-docintel/{_markitdown_version}" # Try loading optional (but in this case, required) dependencies # Save reporting of any exceptions for later @@ -127,6 +131,170 @@ def _get_file_extensions(types: List[DocumentIntelligenceFileType]) -> List[str] return extensions +def _field_value(field: Any) -> Any: + """ + Extract a serializable Python value from a Document Intelligence DocumentField. + + Returns the most specific typed value when available, falling back to the + raw ``content`` string. Returns ``None`` when nothing usable is present. + """ + if field is None: + return None + + # Typed scalar values (in rough order of specificity). + for attr in ( + "value_string", + "value_boolean", + "value_integer", + "value_number", + "value_date", + "value_time", + "value_phone_number", + "value_country_region", + "value_selection_mark", + "value_signature", + ): + v = getattr(field, attr, None) + if v is not None: + if isinstance(v, (date, datetime, time)): + return v.isoformat() + return v + + # Currency: { amount, currencySymbol, currencyCode } + cur = getattr(field, "value_currency", None) + if cur is not None: + amount = getattr(cur, "amount", None) + code = getattr(cur, "currency_code", None) or getattr( + cur, "currency_symbol", None + ) + if amount is not None and code: + return f"{amount} {code}" + if amount is not None: + return amount + + # Address: serialize to its content/string form. + addr = getattr(field, "value_address", None) + if addr is not None: + return getattr(field, "content", None) or str(addr) + + # Array of fields -> list of values. + arr = getattr(field, "value_array", None) + if arr is not None: + return [_field_value(item) for item in arr] + + # Object of fields -> dict of values. + obj = getattr(field, "value_object", None) + if obj is not None: + return {k: _field_value(v) for k, v in obj.items()} + + # Last resort: the raw extracted text. + return getattr(field, "content", None) + + +def _yaml_scalar(value: Any) -> str: + """Render a scalar value as a YAML string.""" + if value is None: + return "null" + if isinstance(value, bool): + return "true" if value else "false" + if isinstance(value, (int, float)): + return repr(value) + s = str(value) + # Quote when necessary: contains special chars, leading/trailing whitespace, + # or characters that would confuse a YAML parser. + if ( + s == "" + or s != s.strip() + or any(c in s for c in ":#&*!|>'\"%@`\n\r\t") + or s.lower() in ("null", "true", "false", "yes", "no", "~") + ): + # Escape backslashes and double quotes; collapse newlines. + escaped = ( + s.replace("\\", "\\\\") + .replace('"', '\\"') + .replace("\n", "\\n") + .replace("\r", "\\r") + .replace("\t", "\\t") + ) + return f'"{escaped}"' + return s + + +def _yaml_dump(value: Any, indent: int = 0) -> str: + """Minimal YAML emitter for scalars, lists, and dicts of scalars/lists/dicts.""" + pad = " " * indent + if isinstance(value, dict): + if not value: + return f"{pad}{{}}" + lines: List[str] = [] + for k, v in value.items(): + key = _yaml_scalar(k) + if isinstance(v, (dict, list)) and v: + lines.append(f"{pad}{key}:") + lines.append(_yaml_dump(v, indent + 1)) + else: + lines.append( + f"{pad}{key}: {_yaml_scalar(v) if not isinstance(v, (dict, list)) else ('{}' if isinstance(v, dict) else '[]')}" + ) + return "\n".join(lines) + if isinstance(value, list): + if not value: + return f"{pad}[]" + lines = [] + for item in value: + if isinstance(item, (dict, list)) and item: + lines.append(f"{pad}-") + lines.append(_yaml_dump(item, indent + 1)) + else: + lines.append( + f"{pad}- {_yaml_scalar(item) if not isinstance(item, (dict, list)) else ('{}' if isinstance(item, dict) else '[]')}" + ) + return "\n".join(lines) + return f"{pad}{_yaml_scalar(value)}" + + +def _fields_to_front_matter(documents: Any, model_id: Optional[str] = None) -> str: + """ + Build a YAML front matter block from ``AnalyzeResult.documents[*].fields``. + + Returns an empty string when there are no documents or no non-empty fields. + Multiple documents are merged into a single ``fields`` mapping; on duplicate + keys, the value from the later document wins. + + The shape mirrors the Content Understanding converter's front matter so that + downstream consumers (e.g., LLM pipelines) can parse both uniformly: + + --- + modelId: prebuilt-invoice + fields: + VendorName: Contoso Ltd. + InvoiceTotal: 1250.0 + --- + """ + if not documents: + return "" + + merged: dict = {} + for doc in documents: + fields = getattr(doc, "fields", None) or {} + for name, field in fields.items(): + value = _field_value(field) + if value is None or value == "" or value == [] or value == {}: + continue + merged[name] = value + + if not merged: + return "" + + payload: dict = {} + if model_id: + payload["modelId"] = model_id + payload["fields"] = merged + + body = _yaml_dump(payload) + return f"---\n{body}\n---\n\n" + + class DocumentIntelligenceConverter(DocumentConverter): """Specialized DocumentConverter that uses Document Intelligence to extract text from documents.""" @@ -134,8 +302,10 @@ def __init__( self, *, endpoint: str, - api_version: str = "2024-07-31-preview", + api_version: str = "2024-11-30", credential: AzureKeyCredential | TokenCredential | None = None, + model_id: str = "prebuilt-layout", + query_fields: Optional[List[str]] = None, file_types: List[DocumentIntelligenceFileType] = [ DocumentIntelligenceFileType.DOCX, DocumentIntelligenceFileType.PPTX, @@ -152,13 +322,19 @@ def __init__( Args: endpoint (str): The endpoint for the Document Intelligence service. - api_version (str): The API version to use. Defaults to "2024-07-31-preview". + api_version (str): The API version to use. Defaults to "2024-11-30" (GA). credential (AzureKeyCredential | TokenCredential | None): The credential to use for authentication. + model_id (str): The Document Intelligence model ID to use (e.g., "prebuilt-layout", + "prebuilt-invoice", "prebuilt-receipt", or a custom model ID). Defaults to "prebuilt-layout". + query_fields (List[str] | None): Optional list of field names to extract via the DI + ``queryFields`` add-on. Only applied to OCR-supported file types (PDF/images). file_types (List[DocumentIntelligenceFileType]): The file types to accept. Defaults to all supported file types. """ super().__init__() self._file_types = file_types + self._model_id = model_id + self._query_fields = list(query_fields) if query_fields else None # Raise an error if the dependencies are not available. # This is different than other converters since this one isn't even instantiated @@ -184,6 +360,7 @@ def __init__( endpoint=self.endpoint, api_version=self.api_version, credential=credential, + user_agent=_USER_AGENT, ) def accepts( @@ -228,11 +405,14 @@ def _analysis_features(self, stream_info: StreamInfo) -> List[str]: if mimetype.startswith(prefix): return [] - return [ + features = [ DocumentAnalysisFeature.FORMULAS, # enable formula extraction DocumentAnalysisFeature.OCR_HIGH_RESOLUTION, # enable high resolution OCR DocumentAnalysisFeature.STYLE_FONT, # enable font style extraction ] + if self._query_fields: + features.append(DocumentAnalysisFeature.QUERY_FIELDS) + return features def convert( self, @@ -240,15 +420,32 @@ def convert( stream_info: StreamInfo, **kwargs: Any, # Options to pass to the converter ) -> DocumentConverterResult: + # Build optional kwargs so that we only pass query_fields when the + # QUERY_FIELDS feature is actually enabled for this file type. + features = self._analysis_features(stream_info) + extra: dict = {} + if self._query_fields and DocumentAnalysisFeature.QUERY_FIELDS in features: + extra["query_fields"] = self._query_fields + # Extract the text using Azure Document Intelligence poller = self.doc_intel_client.begin_analyze_document( - model_id="prebuilt-layout", + model_id=self._model_id, body=AnalyzeDocumentRequest(bytes_source=file_stream.read()), - features=self._analysis_features(stream_info), + features=features, output_content_format=CONTENT_FORMAT, # TODO: replace with "ContentFormat.MARKDOWN" when the bug is fixed + **extra, ) result: AnalyzeResult = poller.result() # remove comments from the markdown content generated by Doc Intelligence and append to markdown string markdown_text = re.sub(r"", "", result.content, flags=re.DOTALL) + + # Prepend YAML front matter when DI returned structured fields (e.g., from + # prebuilt-invoice/-receipt, custom models, or queryFields). + front_matter = _fields_to_front_matter( + getattr(result, "documents", None), model_id=self._model_id + ) + if front_matter: + markdown_text = front_matter + markdown_text + return DocumentConverterResult(markdown=markdown_text) diff --git a/packages/markitdown/tests/test_docintel_converter.py b/packages/markitdown/tests/test_docintel_converter.py new file mode 100644 index 000000000..5b872f0df --- /dev/null +++ b/packages/markitdown/tests/test_docintel_converter.py @@ -0,0 +1,342 @@ +"""Unit tests for the DocumentIntelligenceConverter improvements. + +These tests exercise the converter without making any network calls. They use +``__new__`` to bypass ``__init__`` (which would construct a real +``DocumentIntelligenceClient``) and instead inject a mock client. +""" + +import io +from datetime import date +from types import SimpleNamespace +from unittest import mock + +import pytest + +from markitdown._stream_info import StreamInfo +from markitdown.converters import _doc_intel_converter as di_mod +from markitdown.converters._doc_intel_converter import ( + DocumentIntelligenceConverter, + DocumentIntelligenceFileType, + _USER_AGENT, + _field_value, + _fields_to_front_matter, + _yaml_dump, +) + +# --------- helpers --------------------------------------------------------- + + +def _bare_converter( + *, + file_types=None, + model_id="prebuilt-layout", + query_fields=None, + client=None, +): + """Build a converter without calling __init__ (no real DI client).""" + conv = DocumentIntelligenceConverter.__new__(DocumentIntelligenceConverter) + conv._file_types = file_types or [ + DocumentIntelligenceFileType.PDF, + DocumentIntelligenceFileType.DOCX, + ] + conv._model_id = model_id + conv._query_fields = list(query_fields) if query_fields else None + conv.endpoint = "https://example.cognitiveservices.azure.com/" + conv.api_version = "2024-11-30" + conv.doc_intel_client = client + return conv + + +def _mock_field(**kwargs): + """A SimpleNamespace with all DocumentField value_* attrs defaulting to None.""" + defaults = { + "value_string": None, + "value_boolean": None, + "value_integer": None, + "value_number": None, + "value_date": None, + "value_time": None, + "value_phone_number": None, + "value_country_region": None, + "value_selection_mark": None, + "value_signature": None, + "value_currency": None, + "value_address": None, + "value_array": None, + "value_object": None, + "content": None, + } + defaults.update(kwargs) + return SimpleNamespace(**defaults) + + +# --------- Phase 1: API version + user agent ------------------------------- + + +def test_default_api_version_is_2024_11_30(): + """The default api_version must be the GA value '2024-11-30'.""" + import inspect + + sig = inspect.signature(DocumentIntelligenceConverter.__init__) + assert sig.parameters["api_version"].default == "2024-11-30" + + +def test_user_agent_string_format(): + """User agent should start with 'markitdown-docintel/'.""" + assert _USER_AGENT.startswith("markitdown-docintel/") + assert len(_USER_AGENT) > len("markitdown-docintel/") + + +def test_client_constructed_with_user_agent_and_api_version(): + """__init__ should pass user_agent and api_version to DocumentIntelligenceClient.""" + fake_client = mock.MagicMock() + with mock.patch.object( + di_mod, "DocumentIntelligenceClient", return_value=fake_client + ) as ctor: + DocumentIntelligenceConverter( + endpoint="https://example.cognitiveservices.azure.com/", + credential=mock.MagicMock(), + ) + kwargs = ctor.call_args.kwargs + assert kwargs["api_version"] == "2024-11-30" + assert kwargs["user_agent"] == _USER_AGENT + + +# --------- Phase 2: configurable model_id ---------------------------------- + + +def test_default_model_id(): + """Default model_id preserves existing behavior.""" + import inspect + + sig = inspect.signature(DocumentIntelligenceConverter.__init__) + assert sig.parameters["model_id"].default == "prebuilt-layout" + + +def test_convert_uses_default_model_id(): + """Without overrides, convert() calls begin_analyze_document with prebuilt-layout.""" + fake_poller = mock.MagicMock() + fake_poller.result.return_value = SimpleNamespace(content="# hi", documents=None) + client = mock.MagicMock() + client.begin_analyze_document.return_value = fake_poller + + conv = _bare_converter(client=client) + conv.convert( + io.BytesIO(b"data"), StreamInfo(extension=".pdf", mimetype="application/pdf") + ) + + args, kwargs = client.begin_analyze_document.call_args + assert kwargs["model_id"] == "prebuilt-layout" + + +def test_convert_uses_overridden_model_id(): + fake_poller = mock.MagicMock() + fake_poller.result.return_value = SimpleNamespace(content="# hi", documents=None) + client = mock.MagicMock() + client.begin_analyze_document.return_value = fake_poller + + conv = _bare_converter(model_id="prebuilt-invoice", client=client) + conv.convert( + io.BytesIO(b"data"), StreamInfo(extension=".pdf", mimetype="application/pdf") + ) + + assert ( + client.begin_analyze_document.call_args.kwargs["model_id"] == "prebuilt-invoice" + ) + + +# --------- Phase 3: YAML front matter -------------------------------------- + + +def test_field_value_typed_scalars(): + assert _field_value(_mock_field(value_string="Contoso")) == "Contoso" + assert _field_value(_mock_field(value_integer=42)) == 42 + assert _field_value(_mock_field(value_number=12.5)) == 12.5 + assert _field_value(_mock_field(value_boolean=True)) is True + assert _field_value(_mock_field(value_date=date(2026, 3, 15))) == "2026-03-15" + + +def test_field_value_currency(): + cur = SimpleNamespace(amount=1250.0, currency_code="USD", currency_symbol="$") + assert _field_value(_mock_field(value_currency=cur)) == "1250.0 USD" + + +def test_field_value_falls_back_to_content(): + assert _field_value(_mock_field(content="raw text")) == "raw text" + + +def test_field_value_array_of_scalars(): + items = [_mock_field(value_string="A"), _mock_field(value_string="B")] + assert _field_value(_mock_field(value_array=items)) == ["A", "B"] + + +def test_fields_to_front_matter_empty_when_no_documents(): + assert _fields_to_front_matter(None) == "" + assert _fields_to_front_matter([]) == "" + + +def test_fields_to_front_matter_empty_when_no_fields(): + doc = SimpleNamespace(fields={}) + assert _fields_to_front_matter([doc]) == "" + + +def test_fields_to_front_matter_basic(): + doc = SimpleNamespace( + fields={ + "VendorName": _mock_field(value_string="Contoso Ltd."), + "InvoiceTotal": _mock_field(value_number=1250.0), + } + ) + fm = _fields_to_front_matter([doc], model_id="prebuilt-invoice") + assert fm.startswith("---\n") + assert fm.endswith("---\n\n") + assert "modelId: prebuilt-invoice" in fm + assert "fields:" in fm + assert " VendorName: Contoso Ltd." in fm + assert " InvoiceTotal: 1250.0" in fm + + +def test_fields_to_front_matter_omits_model_id_when_not_provided(): + doc = SimpleNamespace(fields={"X": _mock_field(value_string="y")}) + fm = _fields_to_front_matter([doc]) + assert "modelId:" not in fm + assert "fields:" in fm + + +def test_fields_with_special_chars_are_quoted(): + doc = SimpleNamespace( + fields={"Note": _mock_field(value_string="line1\nline2: with colon")} + ) + fm = _fields_to_front_matter([doc]) + # Value contains both \n and ':' so it must be quoted. + assert ' Note: "line1\\nline2: with colon"' in fm + + +def test_yaml_dump_nested_dict(): + out = _yaml_dump({"a": 1, "b": {"c": "x"}}) + assert "a: 1" in out + assert "b:" in out + assert " c: x" in out + + +def test_convert_prepends_front_matter_when_fields_present(): + doc = SimpleNamespace(fields={"VendorName": _mock_field(value_string="Contoso")}) + fake_poller = mock.MagicMock() + fake_poller.result.return_value = SimpleNamespace( + content="# Invoice\n\nbody", documents=[doc] + ) + client = mock.MagicMock() + client.begin_analyze_document.return_value = fake_poller + + conv = _bare_converter(model_id="prebuilt-invoice", client=client) + result = conv.convert( + io.BytesIO(b"data"), StreamInfo(extension=".pdf", mimetype="application/pdf") + ) + + assert result.markdown.startswith("---\n") + assert "modelId: prebuilt-invoice" in result.markdown + assert " VendorName: Contoso" in result.markdown + assert "# Invoice" in result.markdown + + +def test_convert_no_front_matter_when_no_documents(): + fake_poller = mock.MagicMock() + fake_poller.result.return_value = SimpleNamespace( + content="# Layout", documents=None + ) + client = mock.MagicMock() + client.begin_analyze_document.return_value = fake_poller + + conv = _bare_converter(client=client) + result = conv.convert( + io.BytesIO(b"data"), StreamInfo(extension=".pdf", mimetype="application/pdf") + ) + + assert not result.markdown.startswith("---") + assert result.markdown.startswith("# Layout") + + +# --------- Phase 4: query fields ------------------------------------------- + + +def test_query_fields_adds_feature_for_ocr_types(): + conv = _bare_converter(query_fields=["VendorName", "Total"]) + features = conv._analysis_features( + StreamInfo(extension=".pdf", mimetype="application/pdf") + ) + from azure.ai.documentintelligence.models import DocumentAnalysisFeature + + assert DocumentAnalysisFeature.QUERY_FIELDS in features + + +def test_query_fields_skipped_for_office_types(): + conv = _bare_converter(query_fields=["VendorName"]) + features = conv._analysis_features( + StreamInfo( + extension=".docx", + mimetype="application/vnd.openxmlformats-officedocument.wordprocessingml.document", + ) + ) + # Office types skip OCR features entirely. + assert features == [] + + +def test_query_fields_passed_to_begin_analyze_document_for_pdf(): + fake_poller = mock.MagicMock() + fake_poller.result.return_value = SimpleNamespace(content="x", documents=None) + client = mock.MagicMock() + client.begin_analyze_document.return_value = fake_poller + + conv = _bare_converter(query_fields=["A", "B"], client=client) + conv.convert( + io.BytesIO(b"data"), StreamInfo(extension=".pdf", mimetype="application/pdf") + ) + + assert client.begin_analyze_document.call_args.kwargs.get("query_fields") == [ + "A", + "B", + ] + + +def test_query_fields_not_passed_for_office_types(): + fake_poller = mock.MagicMock() + fake_poller.result.return_value = SimpleNamespace(content="x", documents=None) + client = mock.MagicMock() + client.begin_analyze_document.return_value = fake_poller + + conv = _bare_converter(query_fields=["A"], client=client) + conv.convert( + io.BytesIO(b"data"), + StreamInfo( + extension=".docx", + mimetype="application/vnd.openxmlformats-officedocument.wordprocessingml.document", + ), + ) + + assert "query_fields" not in client.begin_analyze_document.call_args.kwargs + + +# --------- _markitdown.py wiring ------------------------------------------- + + +def test_markitdown_forwards_docintel_kwargs(monkeypatch): + """MarkItDown(...) should forward docintel_model_id / docintel_query_fields.""" + from markitdown import _markitdown as md_mod + + captured = {} + + class _Fake: + def __init__(self, **kwargs): + captured.update(kwargs) + + monkeypatch.setattr(md_mod, "DocumentIntelligenceConverter", _Fake) + + md_mod.MarkItDown( + docintel_endpoint="https://example.cognitiveservices.azure.com/", + docintel_model_id="prebuilt-invoice", + docintel_query_fields=["A", "B"], + ) + + assert captured.get("endpoint") == "https://example.cognitiveservices.azure.com/" + assert captured.get("model_id") == "prebuilt-invoice" + assert captured.get("query_fields") == ["A", "B"]