Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions requirements-test.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
boto3
pytest
moto[sqs]
28 changes: 14 additions & 14 deletions sqs_launcher/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,14 @@ def __init__(self, queue=None, queue_url=None, create_queue=False, visibility_ti
if not any([queue, queue_url]):
raise ValueError('Either `queue` or `queue_url` should be provided.')

if (
not os.environ.get('AWS_ACCOUNT_ID', None) and
not (boto3.Session().get_credentials().method in ['iam-role', 'assume-role', 'assume-role-with-web-identity'])
):
raise EnvironmentError('Environment variable `AWS_ACCOUNT_ID` not set and no role found.')
# Accept credentials from any provider boto3 can resolve (env vars,
# shared credentials file, config file, container/instance roles, SSO,
# assumed roles, ...). AWS_ACCOUNT_ID is only needed as a fallback when
# no credentials are discoverable. See issue #62.
if not os.environ.get('AWS_ACCOUNT_ID', None) and boto3.Session().get_credentials() is None:
raise EnvironmentError(
'No AWS credentials found and environment variable `AWS_ACCOUNT_ID` is not set.'
)

# new session for each instantiation
self._session = boto3.session.Session()
Expand All @@ -56,15 +59,12 @@ def __init__(self, queue=None, queue_url=None, create_queue=False, visibility_ti
self._serializer = serializer

if not queue_url:
queues = self._client.list_queues(QueueNamePrefix=self._queue_name)
exists = False
for q in queues.get('QueueUrls', []):
qname = q.split('/')[-1]
if qname == self._queue_name:
exists = True
self._queue_url = q

if not exists:
# Resolve the queue by its exact name instead of a name prefix, so a
# queue is not confused with another whose name it is a prefix of.
# See issues #43 and #61.
try:
self._queue_url = self._client.get_queue_url(QueueName=self._queue_name)['QueueUrl']
except self._client.exceptions.QueueDoesNotExist:
if create_queue:
q = self._client.create_queue(
QueueName=self._queue_name,
Expand Down
110 changes: 51 additions & 59 deletions sqs_listener/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,22 @@ def __init__(self, queue, **kwargs):
)
else:
boto3_session = None
if (
not os.environ.get('AWS_ACCOUNT_ID', None) and
not (boto3.Session().get_credentials().method in ['sso', 'iam-role', 'assume-role', 'assume-role-with-web-identity'])
):
raise EnvironmentError('Environment variable `AWS_ACCOUNT_ID` not set and no role found.')
# Accept credentials from any provider boto3 can resolve (env vars,
# shared credentials file, config file, container/instance roles,
# SSO, assumed roles, ...). AWS_ACCOUNT_ID is only needed as a
# fallback when no credentials are discoverable. See issue #62.
if not os.environ.get('AWS_ACCOUNT_ID', None) and boto3.Session().get_credentials() is None:
raise EnvironmentError(
'No AWS credentials found and environment variable `AWS_ACCOUNT_ID` is not set.'
)

self._queue_name = queue
self._poll_interval = kwargs.get("interval", 60)
self._queue_visibility_timeout = kwargs.get('visibility_timeout', '600')
self._error_queue_name = kwargs.get('error_queue', None)
self._error_queue_visibility_timeout = kwargs.get('error_visibility_timeout', '600')
self._queue_url = kwargs.get('queue_url', None)
self._error_queue_url = None
self._message_attribute_names = kwargs.get('message_attribute_names', [])
self._attribute_names = kwargs.get('attribute_names', [])
self._force_delete = kwargs.get('force_delete', False)
Expand All @@ -83,65 +87,50 @@ def _initialize_client(self):
ssl = False

sqs = self._session.client('sqs', region_name=self._region_name, endpoint_url=self._endpoint_name, use_ssl=ssl)
try:
queues = sqs.list_queues(QueueNamePrefix=self._queue_name)
except SSOTokenLoadError:
raise EnvironmentError('Error loading SSO Token. Reauthenticate via aws sso login.')

main_queue_exists = False
error_queue_exists = False
if 'QueueUrls' in queues:
for q in queues['QueueUrls']:
qname = q.split('/')[-1]
if qname == self._queue_name:
main_queue_exists = True
if self._error_queue_name and qname == self._error_queue_name:
error_queue_exists = True

# create queue if necessary.
# creation is idempotent, no harm in calling on a queue if it already exists.
# Resolve each queue by its exact name rather than by a shared name
# prefix. The old prefix-based lookup misidentified error queues that
# don't share the main queue's prefix (and CloudFormation-generated
# names), causing spurious queue creation and QueueNameExists errors.
# See issues #43 and #61.
if self._queue_url is None:
if not main_queue_exists:
sqs_logger.warning("main queue not found, creating now")

# is this a fifo queue?
if self._queue_name.endswith(".fifo"):
fifo_queue = "true"
q = sqs.create_queue(
QueueName=self._queue_name,
Attributes={
'VisibilityTimeout': self._queue_visibility_timeout, # 10 minutes
'FifoQueue': fifo_queue
}
)
else:
# need to avoid FifoQueue property for normal non-fifo queues
q = sqs.create_queue(
QueueName=self._queue_name,
Attributes={
'VisibilityTimeout': self._queue_visibility_timeout, # 10 minutes
}
)
self._queue_url = q['QueueUrl']

if self._error_queue_name and not error_queue_exists:
sqs_logger.warning("error queue not found, creating now")
q = sqs.create_queue(
QueueName=self._error_queue_name,
Attributes={
'VisibilityTimeout': self._queue_visibility_timeout # 10 minutes
}
self._queue_url = self._get_or_create_queue_url(
sqs, self._queue_name, self._queue_visibility_timeout
)

if self._error_queue_name:
self._error_queue_url = self._get_or_create_queue_url(
sqs, self._error_queue_name, self._error_queue_visibility_timeout
)

if self._queue_url is None:
if os.environ.get('AWS_ACCOUNT_ID', None):
qs = sqs.get_queue_url(QueueName=self._queue_name,
QueueOwnerAWSAccountId=os.environ.get('AWS_ACCOUNT_ID', None))
else:
qs = sqs.get_queue_url(QueueName=self._queue_name)
self._queue_url = qs['QueueUrl']
return sqs

def _get_or_create_queue_url(self, sqs, queue_name, visibility_timeout):
"""Return the URL of ``queue_name``, creating the queue only if absent.

Creation is skipped when the queue already exists, so pre-created
queues (including ones whose names don't share a common prefix) are
used as-is instead of triggering a QueueNameExists error.
"""
try:
account_id = os.environ.get('AWS_ACCOUNT_ID', None)
if account_id:
response = sqs.get_queue_url(QueueName=queue_name, QueueOwnerAWSAccountId=account_id)
else:
response = sqs.get_queue_url(QueueName=queue_name)
return response['QueueUrl']
except sqs.exceptions.QueueDoesNotExist:
sqs_logger.warning("queue %s not found, creating now", queue_name)
except SSOTokenLoadError:
raise EnvironmentError('Error loading SSO Token. Reauthenticate via aws sso login.')

attributes = {'VisibilityTimeout': visibility_timeout}
# FIFO queues must be created with the FifoQueue attribute; it must be
# omitted for standard queues.
if queue_name.endswith(".fifo"):
attributes['FifoQueue'] = "true"
return sqs.create_queue(QueueName=queue_name, Attributes=attributes)['QueueUrl']

def _start_listening(self):
# TODO consider incorporating output processing from here: https://github.com/debrouwere/sqs-antenna/blob/master/antenna/__init__.py
while True:
Expand Down Expand Up @@ -193,7 +182,10 @@ def _start_listening(self):
exc_type, exc_obj, exc_tb = sys.exc_info()

sqs_logger.info("Pushing exception to error queue")
error_launcher = SqsLauncher(queue=self._error_queue_name, create_queue=True)
# The error queue was already resolved/created during
# initialization, so reuse its URL instead of looking
# it up again on every failure.
error_launcher = SqsLauncher(queue_url=self._error_queue_url)
error_launcher.launch_message(
{
'exception_type': str(exc_type),
Expand Down
Empty file added tests/__init__.py
Empty file.
22 changes: 22 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""Shared pytest fixtures for the SQS listener/launcher test-suite."""

import os

import pytest


@pytest.fixture(autouse=True)
def aws_environment(monkeypatch):
"""Provide a clean, predictable AWS environment for every test.

- A default region so boto3 clients can be created under moto.
- Dummy static credentials so ``get_credentials()`` resolves.
- ``AWS_ACCOUNT_ID`` removed so the credential gate is actually exercised
(several issues are specifically about behaviour when it is unset).
"""
monkeypatch.setenv("AWS_DEFAULT_REGION", "us-east-1")
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "testing")
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "testing")
monkeypatch.setenv("AWS_SECURITY_TOKEN", "testing")
monkeypatch.setenv("AWS_SESSION_TOKEN", "testing")
monkeypatch.delenv("AWS_ACCOUNT_ID", raising=False)
67 changes: 67 additions & 0 deletions tests/test_credentials.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""Tests for the credential / AWS_ACCOUNT_ID gate.

Regression coverage for issue #62 (support of shared-credentials-file):
credentials resolved from a shared credentials file (or any other provider)
must be accepted even when ``AWS_ACCOUNT_ID`` is not set.
"""

import boto3
import pytest
from moto import mock_aws

from sqs_launcher import SqsLauncher
from sqs_listener import SqsListener


class _Listener(SqsListener):
def handle_message(self, body, attributes, messages_attributes):
pass


@mock_aws
def test_listener_accepts_shared_credentials_file(monkeypatch):
"""A non-role credential method (e.g. shared-credentials-file) must work."""
real_get_credentials = boto3.Session.get_credentials

def fake_get_credentials(self):
creds = real_get_credentials(self)
if creds is not None:
creds.method = "shared-credentials-file"
return creds

monkeypatch.setattr(boto3.Session, "get_credentials", fake_get_credentials)

# Should not raise EnvironmentError just because the method isn't a role.
listener = _Listener("some-queue", region_name="us-east-1")
assert listener._queue_url is not None


@mock_aws
def test_launcher_accepts_shared_credentials_file(monkeypatch):
real_get_credentials = boto3.Session.get_credentials

def fake_get_credentials(self):
creds = real_get_credentials(self)
if creds is not None:
creds.method = "shared-credentials-file"
return creds

monkeypatch.setattr(boto3.Session, "get_credentials", fake_get_credentials)

launcher = SqsLauncher("some-queue", create_queue=True)
assert launcher._queue_url is not None


def test_listener_raises_when_no_credentials_and_no_account_id(monkeypatch):
"""The safety net stays: no credentials AND no account id is still an error."""
monkeypatch.setattr(boto3.Session, "get_credentials", lambda self: None)

with pytest.raises(EnvironmentError):
_Listener("some-queue", region_name="us-east-1")


def test_launcher_raises_when_no_credentials_and_no_account_id(monkeypatch):
monkeypatch.setattr(boto3.Session, "get_credentials", lambda self: None)

with pytest.raises(EnvironmentError):
SqsLauncher("some-queue", create_queue=True)
55 changes: 55 additions & 0 deletions tests/test_error_queue_push.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""When a handler raises, the failing message must be forwarded to the error
queue that was resolved during initialization."""

import json

import boto3
from moto import mock_aws

from sqs_listener import SqsListener


class _Stop(Exception):
"""Sentinel used to break out of the otherwise-infinite poll loop."""


@mock_aws
def test_handler_exception_is_pushed_to_error_queue():
sqs = boto3.client("sqs", region_name="us-east-1")
sqs.create_queue(QueueName="work-queue")
error_url = sqs.create_queue(QueueName="dead-letters")["QueueUrl"]

class _Listener(SqsListener):
def handle_message(self, body, attributes, messages_attributes):
raise ValueError("boom")

listener = _Listener(
"work-queue", error_queue="dead-letters", region_name="us-east-1"
)

# First poll yields one message; the second poll aborts the loop so the
# test doesn't spin forever.
calls = {"n": 0}
messages = [
{"Messages": [{"ReceiptHandle": "rh", "Body": json.dumps({"hi": 1})}]},
]

def fake_receive(**kwargs):
if calls["n"] < len(messages):
msg = messages[calls["n"]]
calls["n"] += 1
return msg
raise _Stop()

listener._client.receive_message = fake_receive

try:
listener._start_listening()
except _Stop:
pass

received = sqs.receive_message(QueueUrl=error_url, WaitTimeSeconds=0)
assert "Messages" in received
payload = json.loads(received["Messages"][0]["Body"])
assert "boom" in payload["error_message"]
assert "ValueError" in payload["exception_type"]
Loading