diff --git a/sacred/experiment.py b/sacred/experiment.py index ef68ec81..fb311fbb 100755 --- a/sacred/experiment.py +++ b/sacred/experiment.py @@ -384,6 +384,7 @@ def add_artifact( self, filename: PathType, name: Optional[str] = None, + recursive: bool = False, metadata: Optional[dict] = None, content_type: Optional[str] = None, ): @@ -411,7 +412,8 @@ def add_artifact( This only has an effect when using the MongoObserver. """ assert self.current_run is not None, "Can only be called during a run." - self.current_run.add_artifact(filename, name, metadata, content_type) + self.current_run.add_artifact(filename, name, recursive, + metadata, content_type) @property def info(self): diff --git a/sacred/observers/file_storage.py b/sacred/observers/file_storage.py index 1b27c72d..99a73f22 100644 --- a/sacred/observers/file_storage.py +++ b/sacred/observers/file_storage.py @@ -8,7 +8,7 @@ from typing import Optional import warnings -from shutil import copyfile +from shutil import copyfile, copytree from sacred.commandline_options import CommandLineOption from sacred.dependencies import get_digest @@ -197,6 +197,10 @@ def save_file(self, filename, target_name=None): target_name = target_name or os.path.basename(filename) copyfile(filename, os.path.join(self.dir, target_name)) + def save_dir(self, filename, target_name=None): + target_name = target_name or os.path.basename(filename) + copytree(filename, os.path.join(self.dir, target_name)) + def save_cout(self): with open(os.path.join(self.dir, "cout.txt"), "ab") as f: f.write(self.cout[self.cout_write_cursor :].encode("utf-8")) @@ -259,6 +263,11 @@ def artifact_event(self, name, filename, metadata=None, content_type=None): self.run_entry["artifacts"].append(name) self.save_json(self.run_entry, "run.json") + def artifact_directory_event(self, name, filename): + self.save_dir(filename, name) + self.run_entry['artifacts'].append(name + "/") + self.save_json(self.run_entry, 'run.json') + def log_metrics(self, metrics_by_name, info): """Store new measurements into metrics.json. """ diff --git a/sacred/run.py b/sacred/run.py index 7a9fa578..766f4b13 100755 --- a/sacred/run.py +++ b/sacred/run.py @@ -166,7 +166,15 @@ def add_resource(self, filename): filename = os.path.abspath(filename) self._emit_resource_added(filename) - def add_artifact(self, filename, name=None, metadata=None, content_type=None): + def add_artifact( + self, + filename, + name=None, + recursive=False, + metadata=None, + content_type=None, + ): + """Add a file as an artifact. In Sacred terminology an artifact is a file produced by the experiment @@ -191,7 +199,8 @@ def add_artifact(self, filename, name=None, metadata=None, content_type=None): """ filename = os.path.abspath(filename) name = os.path.basename(filename) if name is None else name - self._emit_artifact_added(name, filename, metadata, content_type) + self._emit_artifact_added(name, filename, recursive, + metadata, content_type) def __call__(self, *args): r"""Start this run. @@ -403,16 +412,32 @@ def _emit_resource_added(self, filename): for observer in self.observers: self._safe_call(observer, "resource_event", filename=filename) - def _emit_artifact_added(self, name, filename, metadata, content_type): + def _emit_artifact_added(self, name, filename, recursive, + metadata, content_type): for observer in self.observers: - self._safe_call( - observer, - "artifact_event", - name=name, - filename=filename, - metadata=metadata, - content_type=content_type, - ) + + if recursive: + if hasattr(observer, 'artifact_directory_event'): + self._safe_call( + observer, + 'artifact_directory_event', + name=name, + filename=filename) + else: + self.run_logger.warning("Observer of type {} " + " doesn't support" + " recursive artifacts".format( + type(observer) + )) + else: + self._safe_call( + observer, + 'artifact_event', + name=name, + filename=filename, + metadata=metadata, + content_type=content_type, + ) def _safe_call(self, obs, method, **kwargs): if obs not in self._failed_observers: