diff --git a/bats_ai/core/management/commands/load_public_dataset.py b/bats_ai/core/management/commands/load_public_dataset.py index 7b9ee72f..22ebd867 100644 --- a/bats_ai/core/management/commands/load_public_dataset.py +++ b/bats_ai/core/management/commands/load_public_dataset.py @@ -116,6 +116,7 @@ def _ingest_files_from_manifest( offset: int | None, file_key: str = "file_key", tag_keys: list[str] | None = None, + data_dir: Path | None = None, ): if tag_keys is None: tag_keys = [] @@ -137,6 +138,7 @@ def _ingest_files_from_manifest( filename = None try: + local = False s3_key = line[file_key] existing_recording = Recording.objects.filter(name=s3_key).first() if existing_recording: @@ -146,12 +148,31 @@ def _ingest_files_from_manifest( logger.info("Ingesting %s...", s3_key) object_exists = _try_head_s3_object(s3_client, bucket, s3_key) if not object_exists: - logger.warning("Could not HEAD object with key %s. Skipping...", s3_key) - continue - filename = _create_filename(s3_key) - logger.info("Downloading to temporary file %s...", filename) - s3_client.download_file(bucket, s3_key, filename) - logger.info("Creating recording for %s", s3_key) + if not data_dir: + logger.warning("Could not HEAD object with key %s. Skipping...", s3_key) + continue + else: + logger.info( + "Could not HEAD object with key %s. Checking local directory %s", + s3_key, + data_dir, + ) + local = True + if not local: + filename = _create_filename(s3_key) + logger.info("Downloading to temporary file %s...", filename) + s3_client.download_file(bucket, s3_key, filename) + logger.info("Creating recording for %s", s3_key) + else: + if not data_dir: + logger.warning("No local data directory specified. Skipping...") + continue + filename = str(data_dir / s3_key) + if Path(filename).exists(): + logger.info("Found local file match for %s.", s3_key) + else: + logger.warning("Could not find a local match for %s, skipping...", s3_key) + continue metadata = _get_metadata(filename, line) with open(filename, "rb") as f: recording = Recording.objects.create( @@ -188,7 +209,7 @@ def _ingest_files_from_manifest( ) recording_compute_spectrogram.delay(recording.pk) finally: - if filename: + if not local and filename: # Delete the file (this may run on a machine with limited resources) try: logger.info("Cleaning up by removing temporary file %s...", filename) @@ -198,7 +219,7 @@ def _ingest_files_from_manifest( class Command(BaseCommand): - help = "Create recordings and spectrograms from WAV files in a public s3 bucket" + help = "Ingest recordings from local filesystem and public s3 according to a manifest file." def add_arguments(self, parser): parser.add_argument( @@ -212,6 +233,9 @@ def add_arguments(self, parser): # Assume columns "Key" and "Tags" help="Manifest CSV file with file keys and tags", ) + parser.add_argument( + "--data-dir", type=str, help="The directory where local recordings are located" + ) parser.add_argument( "--owner", type=str, @@ -253,6 +277,16 @@ def handle(self, *args, **options): except ClientError: self.stdout.write(self.style.ERROR(f"Could not access bucket {bucket}")) return + + data_dir = options.get("data_dir") + if data_dir: + data_dir = Path(data_dir) + if not data_dir.exists(): + self.stdout.write( + self.style.ERROR(f"Specified data directory {data_dir} does not exist") + ) + return + manifest = Path(options["manifest"]) if not manifest.exists(): self.stdout.write(self.style.ERROR(f"Could not find manifest file {manifest}")) @@ -290,4 +324,5 @@ def handle(self, *args, **options): offset=offset, file_key=file_key, tag_keys=tag_keys, + data_dir=data_dir, )