diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 012266d..3d5fff4 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -29,10 +29,6 @@ jobs: os: ["ubuntu-24.04", "macos-14", "windows-2022"] python-version: ["3.10", "3.11", "3.12"] include: - - os: ubuntu-22.04 - python-version: 3.7 - - os: ubuntu-24.04 - python-version: 3.8 - os: ubuntu-24.04 python-version: 3.9 - os: ubuntu-24.04 diff --git a/docs/available_parameters.rst b/docs/available_parameters.rst index 9e34df0..0ac9df7 100644 --- a/docs/available_parameters.rst +++ b/docs/available_parameters.rst @@ -291,7 +291,9 @@ SW_OPER_VOBS_1M_2\_:SecularVariation ``SiteCode,B_SV,sigma_SV`` ``models`` ---------- -Models are evaluated along the satellite track at the positions of the time series that has been requested. These must be used together with one of the MAG collections, and one or both of the "F" and "B_NEC" measurements. This can yield either the model values together with the measurements, or the data-model residuals. +When requesting a MAG-type collection, geomagnetic models can be evaluated on-demand along the satellite track (i.e. at the same times and positions as the data). You have the choice of receiving either the model values together with the measurements (i.e. ``B_NEC`` & ``B_NEC_Model``), or simply the data-model residuals (i.e. ``B_NEC_res_Model``). + +To evaluate models at arbitrary coordinates (i.e. without a data request), see :py:meth:`viresclient.SwarmRequest.eval_model` and :py:meth:`viresclient.SwarmRequest.eval_model_for_cdf_file`. .. note:: @@ -299,7 +301,7 @@ Models are evaluated along the satellite track at the positions of the time seri ``models=["'CHAOS-full' = 'CHAOS-Core' + 'CHAOS-Static' + 'CHAOS-MMA-Primary' + 'CHAOS-MMA-Secondary'"]`` `(click for more info) `_ - This composed model can also be accessed by an alias: ``models=["CHAOS"]`` which represents the full CHAOS model + This composed model (core + crust + magnetosphere) can also be accessed by an alias: ``models=["CHAOS"]``. Note that this does not include the ionospheric part (``"CHAOS-MIO"``) which was added to the CHAOS series in `CHAOS-8 `_. See `Magnetic Earth `_ for an introduction to geomagnetic models. @@ -326,6 +328,7 @@ Models are evaluated along the satellite track at the positions of the time seri CHAOS-Core, # Core CHAOS-Static, # Lithosphere CHAOS-MMA-Primary, CHAOS-MMA-Secondary # Magnetosphere + CHAOS-MIO # Polar ionosphere # Other lithospheric models: MF7, LCS-1 diff --git a/docs/release_notes.rst b/docs/release_notes.rst index 4d1942a..9252e77 100644 --- a/docs/release_notes.rst +++ b/docs/release_notes.rst @@ -4,6 +4,17 @@ Release notes Change log ---------- +Changes from 0.13.0 to 0.14.0 +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Coinciding with `VirES server version 3.16 `_ + +- Added ``"CHAOS-MIO"`` magnetic model - the new ionospheric part of CHAOS. Note that the alias ``"CHAOS"`` does not include this ``"CHAOS-MIO"`` +- Support for new feature to evaluate models at arbitrary coordinates. See: + - :py:meth:`viresclient.SwarmRequest.eval_model` + - :py:meth:`viresclient.SwarmRequest.eval_model_for_cdf_file` +- Added support for spline interpolation of magnetic models when requesting ``MAGx_HR`` data. Can be disabled with the ``do_not_interpolate_models`` option in :py:meth:`viresclient.SwarmRequest.set_products` + Changes from 0.12.3 to 0.13.0 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/pyproject.toml b/pyproject.toml index 23eb1f7..2df2101 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,8 +26,6 @@ classifiers = [ "Operating System :: OS Independent", "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", @@ -38,6 +36,7 @@ classifiers = [ dynamic = ["version"] dependencies = [ "cdflib >= 0.3.9", + "h5py >= 3.12.1", "Jinja2 >= 2.10", "netCDF4 >= 1.5.3; python_version>='3.8'", "netCDF4 >= 1.5.3, <= 1.5.8; python_version<='3.7'", diff --git a/src/viresclient/__init__.py b/src/viresclient/__init__.py index 280b0c1..05f9662 100644 --- a/src/viresclient/__init__.py +++ b/src/viresclient/__init__.py @@ -35,4 +35,4 @@ from ._config import ClientConfig, set_token from ._data_handling import ReturnedData, ReturnedDataFile -__version__ = "0.13.0" +__version__ = "0.14.0" diff --git a/src/viresclient/_client.py b/src/viresclient/_client.py index 4f5946d..5306979 100644 --- a/src/viresclient/_client.py +++ b/src/viresclient/_client.py @@ -462,6 +462,8 @@ def _get( message=None, show_progress=True, leave_progress_bar=True, + content_type=None, + headers=None, ): """Make a request and handle response according to response_handler @@ -484,13 +486,23 @@ def _get( request, handler=response_handler, status_handler=progressbar.update, + content_type=content_type, + headers=headers, ) else: return self._wps_service.retrieve_async( - request, handler=response_handler + request, + handler=response_handler, + content_type=content_type, + headers=headers, ) else: - return self._wps_service.retrieve(request, handler=response_handler) + return self._wps_service.retrieve( + request, + handler=response_handler, + content_type=content_type, + headers=headers, + ) except WPSError: raise RuntimeError( "Server error. Or perhaps the request is invalid? " diff --git a/src/viresclient/_client_swarm.py b/src/viresclient/_client_swarm.py index dd082ad..7cc6e4e 100644 --- a/src/viresclient/_client_swarm.py +++ b/src/viresclient/_client_swarm.py @@ -3,12 +3,16 @@ import datetime import json import os +import shutil import sys +import uuid from collections import OrderedDict from io import StringIO from textwrap import dedent from warnings import warn +import h5py +from numpy import asarray from pandas import read_csv from tqdm import tqdm @@ -16,6 +20,7 @@ from ._data import CONFIG_SWARM from ._data_handling import ReturnedDataFile from ._wps.environment import JINJA2_ENVIRONMENT +from ._wps.multipart import generate_multipart_request from ._wps.time_util import parse_datetime TEMPLATE_FILES = { @@ -27,6 +32,7 @@ "get_observatories": "vires_get_observatories.xml", "get_conjunctions": "vires_get_conjunctions.xml", "get_collection_info": "vires_get_collection_info.xml", + "eval_model_mp": "model_eval_multipart_payload.xml", } REFERENCES = { @@ -60,6 +66,10 @@ "CHAOS-8 Secondary (internal) magnetospheric field", " http://www.spacecenter.dk/files/magnetic-models/CHAOS-8/ ", ), + "CHAOS-MIO": ( + "CHAOS-8 Ionospheric field", + " http://www.spacecenter.dk/files/magnetic-models/CHAOS-8/ ", + ), "MF7": ( "MF7 crustal field model, derived from CHAMP satellite observations", " http://geomag.org/models/MF7.html", @@ -262,6 +272,7 @@ class SwarmWPSInputs(WPSInputs): "response_type", "custom_shc", "ignore_cached_models", + "do_not_interpolate_models", ] def __init__( @@ -276,6 +287,7 @@ def __init__( response_type=None, custom_shc=None, ignore_cached_models=False, + do_not_interpolate_models=False, ): # Set up default values # Obligatory - these must be replaced before the request is made @@ -291,6 +303,7 @@ def __init__( self.sampling_step = None if sampling_step is None else sampling_step self.custom_shc = None if custom_shc is None else custom_shc self.ignore_cached_models = ignore_cached_models + self.do_not_interpolate_models = do_not_interpolate_models @property def collection_ids(self): @@ -355,6 +368,17 @@ def ignore_cached_models(self, value): else: raise TypeError + @property + def do_not_interpolate_models(self): + return self._do_not_interpolate_models + + @do_not_interpolate_models.setter + def do_not_interpolate_models(self, value): + if isinstance(value, bool): + self._do_not_interpolate_models = value + else: + raise TypeError + @property def begin_time(self): return self._begin_time @@ -1549,6 +1573,7 @@ class SwarmRequest(ClientRequest): "CHAOS-Static", "CHAOS-MMA-Primary", "CHAOS-MMA-Secondary", + "CHAOS-MIO", "MCO_SHA_2C", "MCO_SHA_2D", "MLI_SHA_2C", @@ -1932,6 +1957,7 @@ def set_products( residuals=False, sampling_step=None, ignore_cached_models=False, + do_not_interpolate_models=False, ): """Set the combination of products to retrieve. @@ -1946,6 +1972,7 @@ def set_products( residuals (bool): True if only returning measurement-model residual sampling_step (str): ISO_8601 duration, e.g. 10 seconds: PT10S, 1 minute: PT1M ignore_cached_models (bool): True if cached models should be ignored and calculated on-the-fly + do_not_interpolate_models (bool): True if the models for HR collection should not be interpolated from the LR collection """ if self._collection_list is None: @@ -2045,6 +2072,7 @@ def set_products( self._request_inputs.sampling_step = sampling_step self._request_inputs.custom_shc = custom_shc self._request_inputs.ignore_cached_models = ignore_cached_models + self._request_inputs.do_not_interpolate_models = do_not_interpolate_models return self @@ -2620,3 +2648,253 @@ def get_conjunctions( ) return response + + def eval_model( + self, + models, + time, + latitude, + longitude, + radius, + time_precision="ns", + show_progress=True, + temp_dir=".", + input_prefix="_model_eval_input_", + output_prefix="_model_eval_output_", + ): + """Evaluate models for the given times and locations. + + Args: + models (list(str)/dict): from .available_models() or defineable with custom expressions + time (datetime64) array of times + latitude (float64) array of geocentric latitudes (deg) + longitude (float64) array of geocentric longitudes (deg) + radius (float64) array of radii (m) + time_precision (str) optional time precision: ns* | us | ms | s + show_progress (bool) show download progress True + + Returns: + dictionary of arrays with the model values + """ + # FIXME show download progress + + def _write_hdf5_file(filename, data): + with h5py.File(filename, "w") as hdf: + for key, array in data.items(): + options = ( + {} + if array.ndim == 0 + else { + "compression": "gzip", + "compression_opts": 9, + } + ) + hdf.create_dataset(key, data=array, **options) + + def _read_hdf5_file(filename): + with h5py.File(filename, "r") as hdf: + data = {key: hdf[key][...] for key in hdf} + sources = hdf.attrs["sources"].tolist() + if "Timestamp" in data: + data["Timestamp"] = data["Timestamp"].astype(time_type) + return data, sources + + def _response_handler(filename, chunksize=1024 * 1024): + def _handler(file_obj): + # save received received HDF5 file + with open(filename, "wb") as file: + shutil.copyfileobj(file_obj, file, chunksize) + # read results from the HDF5 file + return _read_hdf5_file(filename) + + return _handler + + # FIXME: temp. file handling + request_id = uuid.uuid4() + input_filename = os.path.join(temp_dir, f"{input_prefix}{request_id}.hdf5") + output_filename = os.path.join(temp_dir, f"{output_prefix}{request_id}.hdf5") + + _, model_expression_string = self._parse_models_input(models) + + time_type = f"datetime64[{time_precision}]" + + time = asarray(time, time_type) + latitude = asarray(latitude, "float64") + longitude = asarray(longitude, "float64") + radius = asarray(radius, "float64") + + # the XML request and binary data are sent as multipart/related request + # see https://en.wikipedia.org/wiki/MIME#Multipart_messages + multipart_boundary = "part-delimiter" + + # build XML request + templatefile = TEMPLATE_FILES["eval_model_mp"] + template = JINJA2_ENVIRONMENT.get_template(templatefile) + request = template.render( + model_expression=model_expression_string, + input_content_id=request_id, + input_time_format=time_type, + input_mime_type="application/x-hdf5", + output_time_format=time_type, + output_mime_type="application/x-hdf5", + ).encode("UTF-8") + + try: + # write input HDF5 file + _write_hdf5_file( + input_filename, + { + "Timestamp": time.astype("int64"), + "Latitude": latitude, + "Longitude": longitude, + "Radius": radius, + }, + ) + + # streaming request from the input HDF5 file + with open(input_filename, "rb") as input_file: + parts = [ + ( + request, + { + "Content-Type": "application/xml; charset=utf-8", + }, + ), + ( + input_file, + { + "Content-Id": request_id, + "Content-Type": "application/x-hdf5", + }, + ), + ] + + # Due to the Django limitations we must aggregate the request + # chunks in one block. + # payload_size = get_multipart_request_size(parts, multipart_boundary) + # payload = generate_multipart_request(parts, multipart_boundary) + payload = b"".join( + generate_multipart_request(parts, multipart_boundary) + ) + + result, sources = self._get( + payload, + response_handler=_response_handler(output_filename), + asynchronous=False, + show_progress=show_progress, + content_type=(f"multipart/related; boundary={multipart_boundary}"), + headers={ + "MIME-Version": "1.0", + # "Content-Length": payload_size, + }, + ) + + finally: + for filename in [input_filename, output_filename]: + if os.path.exists(filename): + os.remove(filename) + + return result, sources + + def eval_model_for_cdf_file( + self, + models, + input_cdf_filename, + output_cdf_filename, + show_progress=True, + ): + """Evaluate models for the coordinates given in a Swarm-like CDF file. + + Args: + models (list(str)/dict): from .available_models() or defineable with custom expressions + input_cdf_filename, (str) input CDF file. + output_cdf_filename, (str) output CDF file. + show_progress (bool) show download progress True + + Returns: + copy of output_cdf_filename + + """ + # FIXME show download progress + + def _response_handler(filename, chunksize=1024 * 1024): + def _handler(file_obj): + # save received received file + with open(filename, "wb") as file: + shutil.copyfileobj(file_obj, file, chunksize) + return filename + + return _handler + + request_id = uuid.uuid4() + + _, model_expression_string = self._parse_models_input(models) + + # the XML request and binary data are sent as multipart/related request + # see https://en.wikipedia.org/wiki/MIME#Multipart_messages + multipart_boundary = "part-delimiter" + + # build XML request + templatefile = TEMPLATE_FILES["eval_model_mp"] + template = JINJA2_ENVIRONMENT.get_template(templatefile) + request = template.render( + model_expression=model_expression_string, + input_content_id=request_id, + input_time_format="format specific default", + input_mime_type="application/x-cdf", + output_time_format="input time format", + output_mime_type="application/x-cdf", + ).encode("UTF-8") + + temp_cdf_filename = ".{output_cdf_filename}.tmp.cdf" + + if os.path.exists(temp_cdf_filename): + os.remove(temp_cdf_filename) + + try: + # streaming request from the input HDF5 file + with open(input_cdf_filename, "rb") as input_file: + parts = [ + ( + request, + { + "Content-Type": "application/xml; charset=utf-8", + }, + ), + ( + input_file, + { + "Content-Id": request_id, + "Content-Type": "application/x-cdf", + }, + ), + ] + + # Due to the Django limitations we must aggregate the request + # chunks in one block. + + # payload_size = get_multipart_request_size(parts, multipart_boundary) + # payload = generate_multipart_request(parts, multipart_boundary) + payload = b"".join( + generate_multipart_request(parts, multipart_boundary) + ) + + self._get( + payload, + response_handler=_response_handler(temp_cdf_filename), + asynchronous=False, + show_progress=show_progress, + content_type=(f"multipart/related; boundary={multipart_boundary}"), + headers={ + "MIME-Version": "1.0", + # "Content-Length": payload_size, + }, + ) + + os.rename(temp_cdf_filename, output_cdf_filename) + + finally: + if os.path.exists(temp_cdf_filename): + os.remove(temp_cdf_filename) + + return output_cdf_filename diff --git a/src/viresclient/_wps/multipart.py b/src/viresclient/_wps/multipart.py new file mode 100644 index 0000000..7d4ac8f --- /dev/null +++ b/src/viresclient/_wps/multipart.py @@ -0,0 +1,104 @@ +# ------------------------------------------------------------------------------- +# +# multi-part request handling +# +# Author: Martin Paces +# +# ------------------------------------------------------------------------------- +# Copyright (C) 2025 EOX IT Services GmbH +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies of this Software or works derived from this Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# ------------------------------------------------------------------------------- + +from io import BytesIO + +CHUNK_SIZE = 64 * 1024 # 64kB chunk-size +SEEK_SET = 0 +SEEK_END = 2 +CRLF = b"\r\n" + + +def generate_multipart_request(parts, boundary, chunksize=CHUNK_SIZE): + """Generate multi-part payload from the given parts (pairs of the part + payload and header dictionaries) and boundary string. + """ + for part, headers in parts: + yield _get_part_head(boundary, part, headers) + yield from _generate_part(part, chunksize=chunksize) + yield _get_multipart_tail(boundary) + + +def get_multipart_request_size(parts, boundary): + """Get byte-size of the multi-part payload for the given parts + (pairs of the part payload and header dictionaries) and boundary string. + """ + size = 0 + for part, headers in parts: + size += len(_get_part_head(boundary, part, headers)) + size += _get_part_byte_size(part) + size += len(_get_multipart_tail(boundary)) + return size + + +def _get_part_head(boundary, part, headers): + headers = { + **headers, + "Content-Length": _get_part_byte_size(part), + } + + def _generate_part_head(): + yield "" + yield f"--{boundary}" + for key, value in headers.items(): + yield f"{key}: {value}" + yield "" + yield "" + + return CRLF.join(s.encode("ascii") for s in _generate_part_head()) + + +def _get_multipart_tail(boundary): + def _generate_multipart_tail(): + yield "" + yield f"--{boundary}--" + yield "" + + return CRLF.join(s.encode("ascii") for s in _generate_multipart_tail()) + + +def _get_part_byte_size(part): + if isinstance(part, bytes): + return len(part) + # assuming seekable binary file-like object + part.seek(0, SEEK_END) + size = part.tell() + part.seek(0, SEEK_SET) + return size + + +def _generate_part(part, chunksize=CHUNK_SIZE): + if isinstance(part, bytes): + part = BytesIO(part) + # assuming seekable binary file-like object + part.seek(0, SEEK_SET) + while True: + chunk = part.read(chunksize) + if not chunk: + break + yield chunk diff --git a/src/viresclient/_wps/templates/model_eval_multipart_payload.xml b/src/viresclient/_wps/templates/model_eval_multipart_payload.xml new file mode 100644 index 0000000..6ee3756 --- /dev/null +++ b/src/viresclient/_wps/templates/model_eval_multipart_payload.xml @@ -0,0 +1,37 @@ + + + vires:eval_model_at_time_and_location + + + model_ids + + {{model_expression | cdata}} + + + + input + + +{% if input_time_format -%} + + input_time_format + + {{input_time_format}} + + +{% endif -%} +{% if output_time_format -%} + + output_time_format + + {{output_time_format}} + + +{% endif -%} + + + + output + + + diff --git a/src/viresclient/_wps/templates/vires_fetch_filtered_data.xml b/src/viresclient/_wps/templates/vires_fetch_filtered_data.xml index b7eadc6..12ebfb1 100644 --- a/src/viresclient/_wps/templates/vires_fetch_filtered_data.xml +++ b/src/viresclient/_wps/templates/vires_fetch_filtered_data.xml @@ -23,6 +23,14 @@ {% endif -%} + {% if do_not_interpolate_models -%} + + do_not_interpolate_models + + true + + + {% endif -%} {% endif -%} {% if custom_shc -%} diff --git a/src/viresclient/_wps/templates/vires_fetch_filtered_data_async.xml b/src/viresclient/_wps/templates/vires_fetch_filtered_data_async.xml index aebd23d..e820d69 100644 --- a/src/viresclient/_wps/templates/vires_fetch_filtered_data_async.xml +++ b/src/viresclient/_wps/templates/vires_fetch_filtered_data_async.xml @@ -23,6 +23,14 @@ {% endif -%} + {% if do_not_interpolate_models -%} + + do_not_interpolate_models + + true + + + {% endif -%} {% endif -%} {% if custom_shc -%} diff --git a/src/viresclient/_wps/wps.py b/src/viresclient/_wps/wps.py index 6cf7464..81c4be6 100644 --- a/src/viresclient/_wps/wps.py +++ b/src/viresclient/_wps/wps.py @@ -89,6 +89,50 @@ def __init__(self, text=None): ) +def retry(n_retries, retry_time_seconds, label): + """Request re-try decorator.""" + + def _retry(method): + + def _retry_wrapper(self, *args, **kwargs): + + for index in range(n_retries + 1): + if index == 0: + self.logger.debug("sending %s.", label) + else: + self.logger.debug("sending %s. Retry attempt #%s.", label, index) + + try: + return method(self, *args, **kwargs) + + except WPSError: + raise + + except Exception as error: + if index < n_retries: + self.logger.error( + "%s failed. Retrying in %s seconds. %s: %s", + label, + self.RETRY_TIME, + error.__class__.__name__, + error, + ) + else: + self.logger.error( + "%s failed. No more retries. %s: %s", + label, + error.__class__.__name__, + error, + ) + raise + + sleep(retry_time_seconds) + + return _retry_wrapper + + return _retry + + class WPS10Service: """WPS 1.0 service proxy class. @@ -101,8 +145,8 @@ class WPS10Service: """ DEFAULT_CONTENT_TYPE = "application/xml; charset=utf-8" - RETRY_TIME = 20 # seconds - STATUS_POLL_RETRIES = 3 # re-try attempts + RETRY_TIME = 20 # re-try wait period in seconds + REQUEST_RETRIES = 3 # re-try attempts STATUS = { "{http://www.opengis.net/wps/1.0.0}ProcessAccepted": "ACCEPTED", @@ -120,15 +164,15 @@ def __init__(self, url, headers=None, logger=None): self.headers = headers or {} self.logger = self._LoggerAdapter(logger or getLogger(__name__), {}) - def retrieve(self, request, handler=None, content_type=None): + def retrieve(self, request, handler=None, content_type=None, headers=None): """Send a synchronous POST WPS request to a server and retrieve the output. """ headers = { **self.headers, + **self._headers_to_bytes(headers or {}), "Content-Type": content_type or self.DEFAULT_CONTENT_TYPE, } - return self._retrieve( Request(self.url, request, headers), handler, self.error_handler ) @@ -142,6 +186,7 @@ def retrieve_async( polling_interval=1, output_name="output", content_type=None, + headers=None, ): """Send an asynchronous POST WPS request to a server and retrieve the output. @@ -150,6 +195,7 @@ def retrieve_async( status, percentCompleted, status_url, execute_response = self.submit_async( request, content_type=content_type, + headers=headers, ) wpsstatus = WPSStatus() wpsstatus.update( @@ -202,6 +248,7 @@ def log_wpsstatus_percentCompleted(wpsstatus): return output + @retry(REQUEST_RETRIES, RETRY_TIME, "asynchronous output request") def retrieve_async_output(self, status_url, output_name, handler=None): """Retrieve asynchronous job output reference.""" self.logger.debug("Retrieving asynchronous job output '%s'.", output_name) @@ -226,13 +273,14 @@ def parse_output_reference(xml, identifier): or elm_reference.attrib["href"] ) - def submit_async(self, request, content_type=None): + def submit_async(self, request, content_type=None, headers=None): """Send a POST WPS asynchronous request to a server and retrieve the status URL. """ self.logger.debug("Submitting asynchronous job.") headers = { **self.headers, + **self._headers_to_bytes(headers or {}), "Content-Type": content_type or self.DEFAULT_CONTENT_TYPE, } return self._retrieve( @@ -241,40 +289,12 @@ def submit_async(self, request, content_type=None): self.error_handler, ) + @retry(REQUEST_RETRIES, RETRY_TIME, "status poll request") def poll_status(self, status_url): """Poll status of an asynchronous WPS job.""" - self.logger.debug("Polling asynchronous job status.") - - for index in range(self.STATUS_POLL_RETRIES + 1): - - if index == 0: - self.logger.debug("Polling asynchronous job status.") - else: - self.logger.debug( - "Polling asynchronous job status. Retry attempt #%s.", index - ) - - try: - return self._retrieve( - Request(status_url, None, self.headers), self.parse_status - ) - except Exception as error: - if index < self.STATUS_POLL_RETRIES: - self.logger.error( - "Status poll failed. Retrying in %s seconds. %s: %s", - self.RETRY_TIME, - error.__class__.__name__, - error, - ) - else: - self.logger.error( - "Status poll failed. No more retries. %s: %s", - error.__class__.__name__, - error, - ) - raise - - sleep(self.RETRY_TIME) + return self._retrieve( + Request(status_url, None, self.headers), self.parse_status + ) @classmethod def parse_status(cls, response): @@ -331,7 +351,7 @@ def error_handler(cls, http_error): xml = ElementTree.parse(http_error) ows_exception, namespace = cls.find_exception(xml) except ElementTree.ParseError: - raise http_error + raise http_error from None raise cls.parse_ows_exception(ows_exception, namespace) @classmethod @@ -378,3 +398,12 @@ def _default_cleanup_handler(self, status_url): @staticmethod def _default_handler(file_obj): return file_obj.read() + + @staticmethod + def _headers_to_bytes(headers, encoding="ascii"): + def _to_bytes(value): + if isinstance(value, bytes): + return value + return str(value).encode(encoding) + + return {_to_bytes(key): _to_bytes(value) for key, value in headers.items()}