diff --git a/hashtopolis/__init__.py b/hashtopolis/__init__.py index f10fc27..c864c8c 100644 --- a/hashtopolis/__init__.py +++ b/hashtopolis/__init__.py @@ -10,11 +10,13 @@ ObjectDoesNotExist, MultipleObjectsReturned, ModelBase, - Model + Model, + Helper ) # models from .hashtopolis import ( + ApiToken, AccessGroup, Agent, AgentStat, diff --git a/hashtopolis/hashtopolis.py b/hashtopolis/hashtopolis.py index 8d5b223..a8b02e0 100644 --- a/hashtopolis/hashtopolis.py +++ b/hashtopolis/hashtopolis.py @@ -60,6 +60,16 @@ def __init__(self): self.username = self._cfg['username'] self.password = self._cfg['password'] + @classmethod + def with_credentials(cls, uri, username, password): + """Create a config with explicit credentials instead of reading from a config file.""" + config = cls.__new__(cls) + config._hashtopolis_uri = uri + config._api_endpoint = uri + '/api/v2' + config.username = username + config.password = password + return config + class HashtopolisResponseError(HashtopolisError): pass @@ -106,22 +116,26 @@ def __init__(self, model_uri, config): self._hashtopolis_uri = config._hashtopolis_uri self.config = config - def authenticate(self): - if self._api_endpoint not in HashtopolisConnector.token: - # Request access TOKEN, used throughout the test - - logger.info("Start authentication") + def authenticate(self, auth=None): + if auth is not None: + logger.info("Start authentication with provided credentials") auth_uri = self._api_endpoint + '/auth/token' - auth = (self.config.username, self.config.password) r = requests.post(auth_uri, auth=auth) self.validate_status_code(r, [201], "Authentication failed") - r_json = self.resp_to_json(r) - HashtopolisConnector.token[self._api_endpoint] = r_json['token'] - HashtopolisConnector.token_expires[self._api_endpoint] = r_json['token'] - - self._token = HashtopolisConnector.token[self._api_endpoint] - self._token_expires = HashtopolisConnector.token_expires[self._api_endpoint] + self._token = r_json['token'] + self._token_expires = r_json['token'] + else: + if self._api_endpoint not in HashtopolisConnector.token: + logger.info("Start authentication") + auth_uri = self._api_endpoint + '/auth/token' + r = requests.post(auth_uri, auth=(self.config.username, self.config.password)) + self.validate_status_code(r, [201], "Authentication failed") + r_json = self.resp_to_json(r) + HashtopolisConnector.token[self._api_endpoint] = r_json['token'] + HashtopolisConnector.token_expires[self._api_endpoint] = r_json['token'] + self._token = HashtopolisConnector.token[self._api_endpoint] + self._token_expires = HashtopolisConnector.token_expires[self._api_endpoint] self._headers = { 'Authorization': 'Bearer ' + self._token @@ -215,8 +229,8 @@ def get_single_page(self, page, filter): return response["data"] # todo refactor start_offset into page variable - def filter(self, include, ordering, filter, start_offset): - self.authenticate() + def filter(self, include, ordering, filter, start_offset, auth=None): + self.authenticate(auth=auth) headers = self._headers after_dict = {"primary": {"id": start_offset}} @@ -394,12 +408,13 @@ def count(self, filter): # Build Django ORM style django.query interface class QuerySet(): - def __init__(self, cls, include=None, ordering=None, filters=None, pages=None): + def __init__(self, cls, include=None, ordering=None, filters=None, pages=None, auth=None): self.cls = cls self.include = include self.ordering = ordering self.filters = filters self.pages = pages + self.auth = auth def __iter__(self): yield from self.__getitem__(slice(None, None, 1)) @@ -431,7 +446,7 @@ def filter_(self, start, stop, step): filters['id'] = filters['pk'] del filters['pk'] - filter_generator = self.cls.get_conn().filter(self.include, self.ordering, filters, start_offset=cursor) + filter_generator = self.cls.get_conn().filter(self.include, self.ordering, filters, start_offset=cursor, auth=self.auth) while index < stop: # Fetch new entries in chunks default to server @@ -469,6 +484,10 @@ def page(self, **pages): def all(self): # yield from self return self + + def authenticate(self, auth): + self.auth = auth + return self def get(self, **filters): if filters: @@ -760,6 +779,10 @@ def uri(self): ## # Begin of API objects # +class ApiToken(Model, uri="/ui/apiTokens"): + pass + + class AccessGroup(Model, uri="/ui/accessgroups"): pass