From f4f57f433641a1b4fccda4e8bb0faeb25daa6876 Mon Sep 17 00:00:00 2001 From: naglepuff Date: Thu, 19 Mar 2026 12:02:52 -0400 Subject: [PATCH 1/2] Modify ingest script to check a local directory --- .../commands/load_public_dataset.py | 52 +++++++++++++++---- 1 file changed, 42 insertions(+), 10 deletions(-) diff --git a/bats_ai/core/management/commands/load_public_dataset.py b/bats_ai/core/management/commands/load_public_dataset.py index 7b9ee72f..8fcf252e 100644 --- a/bats_ai/core/management/commands/load_public_dataset.py +++ b/bats_ai/core/management/commands/load_public_dataset.py @@ -1,11 +1,11 @@ from __future__ import annotations import contextlib -from csv import DictReader -from datetime import date import hashlib import logging import os +from csv import DictReader +from datetime import date from pathlib import Path from typing import Any @@ -116,6 +116,7 @@ def _ingest_files_from_manifest( offset: int | None, file_key: str = "file_key", tag_keys: list[str] | None = None, + data_dir: Path | None = None, ): if tag_keys is None: tag_keys = [] @@ -137,6 +138,7 @@ def _ingest_files_from_manifest( filename = None try: + local = False s3_key = line[file_key] existing_recording = Recording.objects.filter(name=s3_key).first() if existing_recording: @@ -146,12 +148,28 @@ def _ingest_files_from_manifest( logger.info("Ingesting %s...", s3_key) object_exists = _try_head_s3_object(s3_client, bucket, s3_key) if not object_exists: - logger.warning("Could not HEAD object with key %s. Skipping...", s3_key) - continue - filename = _create_filename(s3_key) - logger.info("Downloading to temporary file %s...", filename) - s3_client.download_file(bucket, s3_key, filename) - logger.info("Creating recording for %s", s3_key) + if not data_dir: + logger.warning("Could not HEAD object with key %s. Skipping...", s3_key) + else: + logger.info( + "Could not HEAD object with key %s. Checking local directory %s", + s3_key, + data_dir, + ) + local = True + if not local: + filename = _create_filename(s3_key) + logger.info("Downloading to temporary file %s...", filename) + s3_client.download_file(bucket, s3_key, filename) + logger.info("Creating recording for %s", s3_key) + else: + assert data_dir + filename = str(data_dir / s3_key) + if Path(filename).exists(): + logger.info("Found local file match for %s.", s3_key) + else: + logger.warning("Could not find a local match for %s, skipping...", s3_key) + continue metadata = _get_metadata(filename, line) with open(filename, "rb") as f: recording = Recording.objects.create( @@ -188,7 +206,7 @@ def _ingest_files_from_manifest( ) recording_compute_spectrogram.delay(recording.pk) finally: - if filename: + if not local and filename: # Delete the file (this may run on a machine with limited resources) try: logger.info("Cleaning up by removing temporary file %s...", filename) @@ -198,7 +216,7 @@ def _ingest_files_from_manifest( class Command(BaseCommand): - help = "Create recordings and spectrograms from WAV files in a public s3 bucket" + help = "Ingest recordings from local filesystem and public s3 according to a manifest file." def add_arguments(self, parser): parser.add_argument( @@ -212,6 +230,9 @@ def add_arguments(self, parser): # Assume columns "Key" and "Tags" help="Manifest CSV file with file keys and tags", ) + parser.add_argument( + "--data-dir", type=str, help="The directory where local recordings are located" + ) parser.add_argument( "--owner", type=str, @@ -253,6 +274,16 @@ def handle(self, *args, **options): except ClientError: self.stdout.write(self.style.ERROR(f"Could not access bucket {bucket}")) return + + data_dir = options.get("data_dir") + if data_dir: + data_dir = Path(data_dir) + if not data_dir.exists(): + self.stdout.write( + self.style.ERROR(f"Specified data directory {data_dir} does not exist") + ) + return + manifest = Path(options["manifest"]) if not manifest.exists(): self.stdout.write(self.style.ERROR(f"Could not find manifest file {manifest}")) @@ -290,4 +321,5 @@ def handle(self, *args, **options): offset=offset, file_key=file_key, tag_keys=tag_keys, + data_dir=data_dir, ) From 2605300f0a0ebb7e66b9d0503fc1e3f5120a9812 Mon Sep 17 00:00:00 2001 From: naglepuff Date: Fri, 3 Apr 2026 12:57:55 -0400 Subject: [PATCH 2/2] Check data_dir existence instead of assert --- bats_ai/core/management/commands/load_public_dataset.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/bats_ai/core/management/commands/load_public_dataset.py b/bats_ai/core/management/commands/load_public_dataset.py index 8fcf252e..22ebd867 100644 --- a/bats_ai/core/management/commands/load_public_dataset.py +++ b/bats_ai/core/management/commands/load_public_dataset.py @@ -1,11 +1,11 @@ from __future__ import annotations import contextlib +from csv import DictReader +from datetime import date import hashlib import logging import os -from csv import DictReader -from datetime import date from pathlib import Path from typing import Any @@ -150,6 +150,7 @@ def _ingest_files_from_manifest( if not object_exists: if not data_dir: logger.warning("Could not HEAD object with key %s. Skipping...", s3_key) + continue else: logger.info( "Could not HEAD object with key %s. Checking local directory %s", @@ -163,7 +164,9 @@ def _ingest_files_from_manifest( s3_client.download_file(bucket, s3_key, filename) logger.info("Creating recording for %s", s3_key) else: - assert data_dir + if not data_dir: + logger.warning("No local data directory specified. Skipping...") + continue filename = str(data_dir / s3_key) if Path(filename).exists(): logger.info("Found local file match for %s.", s3_key)