diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6c40bfa0..72be78d9 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -37,4 +37,4 @@ jobs: run: | python -m pip install .[tests] python -m birdnet_analyzer.utils - python -m pytest + python -m pytest -s tests/test_utils.py tests/train/ diff --git a/birdnet_analyzer/audio.py b/birdnet_analyzer/audio.py index 45ccb361..decc97d0 100644 --- a/birdnet_analyzer/audio.py +++ b/birdnet_analyzer/audio.py @@ -27,6 +27,7 @@ def open_audio_file(path: str, sample_rate=48000, offset=0.0, duration=None, fmi Returns: Returns the audio time series and the sampling rate. """ + # Open file with librosa (uses ffmpeg or libav) if speed == 1.0: sig, rate = librosa.load( diff --git a/birdnet_analyzer/model.py b/birdnet_analyzer/model.py index 063ecd3f..e11d3162 100644 --- a/birdnet_analyzer/model.py +++ b/birdnet_analyzer/model.py @@ -1197,7 +1197,6 @@ def embeddings(sample): Returns: The embeddings. """ - load_model(False) sample = np.array(sample, dtype="float32") diff --git a/birdnet_analyzer/train/utils.py b/birdnet_analyzer/train/utils.py index 10d56d13..ca88a2d3 100644 --- a/birdnet_analyzer/train/utils.py +++ b/birdnet_analyzer/train/utils.py @@ -88,6 +88,7 @@ def _load_audio_file(f, label_vector, config): else: sig_splits = audio.split_signal(sig, rate, cfg.SIG_LENGTH, cfg.SIG_OVERLAP, cfg.SIG_MINLEN) + # Get feature embeddings batch_size = 1 # turns out that batch size 1 is the fastest, probably because of having to resize the model input when the number of samples in a batch changes for i in range(0, len(sig_splits), batch_size): diff --git a/tests/data b/tests/data index b43d4283..d6871b77 160000 --- a/tests/data +++ b/tests/data @@ -1 +1 @@ -Subproject commit b43d4283fe0d24f63d5d460584e608094ead2879 +Subproject commit d6871b77a0a1e8396d96cbbb29c4dc2a2292f2ab diff --git a/tests/embeddings/test_embeddings.py b/tests/embeddings/test_embeddings.py index 9b8a4e5b..7ad8db7f 100644 --- a/tests/embeddings/test_embeddings.py +++ b/tests/embeddings/test_embeddings.py @@ -4,9 +4,11 @@ import tempfile from unittest.mock import MagicMock, patch +import numpy as np import pytest import birdnet_analyzer.config as cfg +from birdnet_analyzer import model from birdnet_analyzer.cli import embeddings_parser from birdnet_analyzer.embeddings.core import embeddings @@ -53,3 +55,16 @@ def test_embeddings_cli(mock_run_embeddings: MagicMock, mock_ensure_model: Magic mock_ensure_model.assert_called_once() threads = min(8, max(1, multiprocessing.cpu_count() // 2)) mock_run_embeddings.assert_called_once_with(env["input_dir"], env["output_dir"], 0, 1.0, 0, 15000, threads, 1, None) + + +def test_model_embeddings_function_returns_expected_shape(): + # Create a dummy sample (e.g., 1D numpy array of audio data) + sample = np.zeros(144000).astype(np.float32) + # Reshape the sample to (1, 144000) as expected by the model + sample = sample.reshape(1, 144000) + # Call the embeddings function + result = model.embeddings(sample) + + # Check that result is a numpy array and has expected shape (depends on model, e.g., (1, embedding_dim)) + assert isinstance(result, np.ndarray) + assert result.ndim == 2 diff --git a/tests/train/test_train.py b/tests/train/test_train.py index b270c2a1..bbc89fb8 100644 --- a/tests/train/test_train.py +++ b/tests/train/test_train.py @@ -6,6 +6,7 @@ import pytest import birdnet_analyzer.config as cfg +from birdnet_analyzer.analyze.core import analyze from birdnet_analyzer.cli import train_parser from birdnet_analyzer.train.core import train @@ -17,10 +18,11 @@ def setup_test_environment(): input_dir = os.path.join(test_dir, "input") output_dir = os.path.join(test_dir, "output") - os.makedirs(input_dir, exist_ok=True) - os.makedirs(output_dir, exist_ok=True) + # Directory should not exist, so no exist_ok=True + os.makedirs(input_dir) + os.makedirs(output_dir) - classifier_output = os.path.join(output_dir, "classifier_output") + classifier_output = os.path.join(output_dir, "classifier_output", "custom_classifier.tflite") # Store original config values original_config = { @@ -55,3 +57,32 @@ def test_train_cli(mock_train_model, mock_ensure_model, setup_test_environment): mock_ensure_model.assert_called_once() mock_train_model.assert_called_once_with() + +@pytest.mark.timeout(400) # Increase timeout for training, 400s should be sufficient, win is by far the slowest +def test_training(setup_test_environment): + """Test the training process and prediction with dummy data.""" + env = setup_test_environment + training_data_input = "tests/data/training" + + # Read class names from subfolders in the input directory, filtering out background classes + dummy_classes = [ + d for d in os.listdir(training_data_input) + if os.path.isdir(os.path.join(training_data_input, d)) and d.lower() not in cfg.NON_EVENT_CLASSES + ] + + train(training_data_input, env["classifier_output"]) + + assert os.path.isfile(env["classifier_output"]), "Classifier output file was not created." + assert os.path.exists(env["classifier_output"].replace(".tflite", "_Labels.txt")), "Labels file was not created." + assert os.path.exists(env["classifier_output"].replace(".tflite", "_Params.csv")), "Params file was not created." + assert os.path.exists(env["classifier_output"].replace(".tflite", ".tflite_sample_counts.csv")), "Params file was not created." + + soundscape_path = "birdnet_analyzer/example/soundscape.wav" + analyze(soundscape_path, env["output_dir"], top_n=1, classifier=env["classifier_output"]) + + output_file = os.path.join(env["output_dir"], "soundscape.BirdNET.selection.table.txt") + with open(output_file) as f: + lines = f.readlines()[1:] + for line in lines: + parts = line.strip().split("\t") + assert parts[7] in dummy_classes, f"Detected class {parts[7]} not in expected classes {dummy_classes}"