diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f38d5d27..b9a249ed 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -84,11 +84,14 @@ jobs: - name: Set up slurm if: ${{ matrix.cluster_type == 'slurm' }} + # docker build can lead to race condition -> image "docker.io/library/ipp-cluster:slurm": already exists + # see https://github.com/mlflow/mlflow/pull/20779 + # work-a-round fix: docker compose again if first call failed run: | export DOCKER_BUILDKIT=1 export COMPOSE_DOCKER_CLI_BUILD=1 cd ci/slurm - docker compose up -d --build + docker compose up -d --build || docker compose up -d --build - name: Install Python (conda) ${{ matrix.python }} if: ${{ matrix.cluster_type == 'mpi' }} @@ -128,6 +131,23 @@ jobs: pip install distributed joblib pip install --only-binary :all: matplotlib + - name: Start MongoDB + if: ${{ (! matrix.runs_on) && (! matrix.cluster_type) }} # only under linux with no cluster + uses: supercharge/mongodb-github-action@1.12.1 # uses latest mongodb per default + + - name: Install pymongo package + if: ${{ (! matrix.runs_on) && (! matrix.cluster_type) }} # only under linux with no cluster + run: pip install pymongo + + - name: Try to connect to mongodb + if: ${{ (! matrix.runs_on) && (! matrix.cluster_type) }} # only under linux with no cluster + run: | + python3 <= 4: - # mimic the old API 3.x - self._records.insert = self._records.insert_one - self._records.update = self._records.update_one - self._records.ensure_index = self._records.create_index - self._records.remove = self._records.delete_many - - self._records.ensure_index('msg_id', unique=True) - self._records.ensure_index('submitted') # for sorting history + self._records = self._db.get_collection("task_records", options) + # self._records = self._db['task_records'] + self._records.create_index('msg_id', unique=True) + self._records.create_index('submitted') # for sorting history # for rec in self._records.find def _binary_buffers(self, rec): @@ -82,7 +83,7 @@ def add_record(self, msg_id, rec): """Add a new Task Record, by msg_id.""" # print rec rec = self._binary_buffers(rec) - self._records.insert(rec) + self._records.insert_one(rec) def get_record(self, msg_id): """Get a specific Task Record, by msg_id.""" @@ -96,15 +97,15 @@ def update_record(self, msg_id, rec): """Update the data in an existing record.""" rec = self._binary_buffers(rec) - self._records.update({'msg_id': msg_id}, {'$set': rec}) + self._records.update_one({'msg_id': msg_id}, {'$set': rec}) def drop_matching_records(self, check): """Remove a record from the DB.""" - self._records.remove(check) + self._records.delete_many(check) def drop_record(self, msg_id): """Remove a record from the DB.""" - self._records.remove({'msg_id': msg_id}) + self._records.delete_many({'msg_id': msg_id}) def find_records(self, check, keys=None): """Find records matching a query dict, optionally extracting subset of keys. diff --git a/ipyparallel/tests/test_label.py b/ipyparallel/tests/test_label.py new file mode 100644 index 00000000..8820b863 --- /dev/null +++ b/ipyparallel/tests/test_label.py @@ -0,0 +1,108 @@ +"""Tests for task label functionality""" + +# Copyright (c) IPython Development Team. +# Distributed under the terms of the Modified BSD License. +import logging +import os +from unittest import TestCase + +import pytest + +import ipyparallel as ipp +from ipyparallel.cluster.launcher import LocalControllerLauncher + + +def speudo_wait(t): + import time + + tic = time.time() + print(f"waiting for {t}s...") + # time.sleep(t) # do NOT wait for t seconds to speed up tests + print("done") + return time.time() - tic + + +class TaskLabelTest: + def setUp(self): + self.cluster = ipp.Cluster( + n=2, log_level=10, controller=self.get_controller_launcher() + ) + self.cluster.start_cluster_sync() + + self.rc = self.cluster.connect_client_sync() + self.rc.wait_for_engines(n=2) + + def get_controller_launcher(self): + raise NotImplementedError + + def tearDown(self): + self.cluster.stop_engines() + self.cluster.stop_controller() + # self.cluster.close() + + def run_tasks(self, view): + ar_list = [] + # use context to set label + with view.temp_flags(label="mylabel_map"): + ar_list.append(view.map_async(speudo_wait, [1.1, 1.2, 1.3, 1.4, 1.5])) + # use set_flags to set label + ar_list.extend( + [ + view.set_flags(label=f"mylabel_apply_{i:02}").apply_async( + speudo_wait, 2 + i / 10 + ) + for i in range(5) + ] + ) + view.wait(ar_list) + + # build list of used labels + map_labels = ["mylabel_map"] + apply_labels = [] + for i in range(5): + apply_labels.append(f"mylabel_apply_{i:02}") + return map_labels, apply_labels + + def check_labels(self, labels): + # query database + data = self.rc.db_query({'label': {"$nin": ""}}, keys=['msg_id', 'label']) + for d in data: + msg_id = d['msg_id'] + label = d['label'] + assert label in labels + labels.remove(label) + + assert len(labels) == 0 + + def clear_db(self): + self.rc.purge_everything() + + def test_balanced_view(self): + bview = self.rc.load_balanced_view() + map_labels, apply_labels = self.run_tasks(bview) + labels = map_labels * 5 + apply_labels + self.check_labels(labels) + self.clear_db() + + def test_direct_view(self): + dview = self.rc[:] + map_labels, apply_labels = self.run_tasks(dview) + labels = map_labels * 2 + apply_labels * 2 + self.check_labels(labels) + self.clear_db() + + +class TestLabelDictDB(TaskLabelTest, TestCase): + def get_controller_launcher(self): + class dictDB(LocalControllerLauncher): + controller_args = ["--dictdb"] + + return dictDB + + +class TestLabelSqliteDB(TaskLabelTest, TestCase): + def get_controller_launcher(self): + class sqliteDB(LocalControllerLauncher): + controller_args = ["--sqlitedb"] + + return sqliteDB