From 987e2f242b01d0c0e8c889619a59f1498bb71a3a Mon Sep 17 00:00:00 2001 From: Cody Wild Date: Thu, 25 Jul 2019 14:47:53 -0700 Subject: [PATCH 1/5] add recursive option to artifact --- sacred/observers/file_storage.py | 13 ++++++++++--- sacred/run.py | 6 ++++-- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/sacred/observers/file_storage.py b/sacred/observers/file_storage.py index febacbf3..55623fcb 100644 --- a/sacred/observers/file_storage.py +++ b/sacred/observers/file_storage.py @@ -5,7 +5,7 @@ import os import os.path -from shutil import copyfile +from shutil import copyfile, copytree from sacred.commandline_options import CommandLineOption from sacred.dependencies import get_digest @@ -160,6 +160,10 @@ def save_file(self, filename, target_name=None): target_name = target_name or os.path.basename(filename) copyfile(filename, os.path.join(self.dir, target_name)) + def save_dir(self, filename, target_name=None): + target_name = target_name or os.path.basename(filename) + copytree(filename, os.path.join(self.dir, target_name)) + def save_cout(self): with open(os.path.join(self.dir, 'cout.txt'), 'ab') as f: f.write(self.cout[self.cout_write_cursor:].encode("utf-8")) @@ -214,8 +218,11 @@ def resource_event(self, filename): self.run_entry['resources'].append([filename, store_path]) self.save_json(self.run_entry, 'run.json') - def artifact_event(self, name, filename, metadata=None, content_type=None): - self.save_file(filename, name) + def artifact_event(self, name, filename, metadata=None, content_type=None, recursive=False): + if recursive: + self.save_dir(filename, name) + else: + self.save_file(filename, name) self.run_entry['artifacts'].append(name) self.save_json(self.run_entry, 'run.json') diff --git a/sacred/run.py b/sacred/run.py index 10cc6d25..14db4b73 100755 --- a/sacred/run.py +++ b/sacred/run.py @@ -160,6 +160,7 @@ def add_artifact( self, filename, name=None, + recursive=False, metadata=None, content_type=None, ): @@ -187,7 +188,7 @@ def add_artifact( """ filename = os.path.abspath(filename) name = os.path.basename(filename) if name is None else name - self._emit_artifact_added(name, filename, metadata, content_type) + self._emit_artifact_added(name, filename, recursive, metadata, content_type) def __call__(self, *args): r"""Start this run. @@ -385,11 +386,12 @@ def _emit_resource_added(self, filename): for observer in self.observers: self._safe_call(observer, 'resource_event', filename=filename) - def _emit_artifact_added(self, name, filename, metadata, content_type): + def _emit_artifact_added(self, name, filename, recursive, metadata, content_type): for observer in self.observers: self._safe_call(observer, 'artifact_event', name=name, filename=filename, + recursive=recursive, metadata=metadata, content_type=content_type) From 04327a1d1ed08212a3d86e1f7747de44ca5f2247 Mon Sep 17 00:00:00 2001 From: Cody Wild Date: Thu, 25 Jul 2019 14:54:04 -0700 Subject: [PATCH 2/5] working version of recursive artifact saving in FileStorageObserver --- sacred/experiment.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sacred/experiment.py b/sacred/experiment.py index aa1014ef..5682f8b0 100755 --- a/sacred/experiment.py +++ b/sacred/experiment.py @@ -332,6 +332,7 @@ def add_artifact( self, filename, name=None, + recursive=False, metadata=None, content_type=None, ): @@ -359,7 +360,7 @@ def add_artifact( This only has an effect when using the MongoObserver. """ assert self.current_run is not None, "Can only be called during a run." - self.current_run.add_artifact(filename, name, metadata, content_type) + self.current_run.add_artifact(filename, name, recursive, metadata, content_type) @property def info(self): From 0a2ac54dfee4696e066381fcc66663720d04d37b Mon Sep 17 00:00:00 2001 From: Cody Wild Date: Thu, 25 Jul 2019 15:07:18 -0700 Subject: [PATCH 3/5] make sure other observers add_artifacts aren't impacted by a different artifact_event signature --- sacred/observers/file_storage.py | 12 +++++++----- sacred/run.py | 18 ++++++++++++------ 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/sacred/observers/file_storage.py b/sacred/observers/file_storage.py index 55623fcb..d3f4964b 100644 --- a/sacred/observers/file_storage.py +++ b/sacred/observers/file_storage.py @@ -218,14 +218,16 @@ def resource_event(self, filename): self.run_entry['resources'].append([filename, store_path]) self.save_json(self.run_entry, 'run.json') - def artifact_event(self, name, filename, metadata=None, content_type=None, recursive=False): - if recursive: - self.save_dir(filename, name) - else: - self.save_file(filename, name) + def artifact_event(self, name, filename, metadata=None, content_type=None): + self.save_file(filename, name) self.run_entry['artifacts'].append(name) self.save_json(self.run_entry, 'run.json') + def artifact_directory_event(self, name, filename, metadata=None, content_type=None): + self.save_dir(filename, name) + self.run_entry['artifacts'].append(name + "/") + self.save_json(self.run_entry, 'run.json') + def log_metrics(self, metrics_by_name, info): """Store new measurements into metrics.json. """ diff --git a/sacred/run.py b/sacred/run.py index 14db4b73..698f4ef4 100755 --- a/sacred/run.py +++ b/sacred/run.py @@ -388,12 +388,18 @@ def _emit_resource_added(self, filename): def _emit_artifact_added(self, name, filename, recursive, metadata, content_type): for observer in self.observers: - self._safe_call(observer, 'artifact_event', - name=name, - filename=filename, - recursive=recursive, - metadata=metadata, - content_type=content_type) + if recursive: + self._safe_call(observer, 'artifact_directory_event', + name=name, + filename=filename, + metadata=metadata, + content_type=content_type) + else: + self._safe_call(observer, 'artifact_event', + name=name, + filename=filename, + metadata=metadata, + content_type=content_type) def _safe_call(self, obs, method, **kwargs): if obs not in self._failed_observers and hasattr(obs, method): From e2b2d36e43e96ab69ace526da2b6c3f0eb2f2f92 Mon Sep 17 00:00:00 2001 From: Cody Wild Date: Thu, 25 Jul 2019 15:42:37 -0700 Subject: [PATCH 4/5] Clean up line lengths --- sacred/experiment.py | 3 ++- sacred/observers/file_storage.py | 2 +- sacred/run.py | 6 ++++-- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/sacred/experiment.py b/sacred/experiment.py index 5682f8b0..c9188dc6 100755 --- a/sacred/experiment.py +++ b/sacred/experiment.py @@ -360,7 +360,8 @@ def add_artifact( This only has an effect when using the MongoObserver. """ assert self.current_run is not None, "Can only be called during a run." - self.current_run.add_artifact(filename, name, recursive, metadata, content_type) + self.current_run.add_artifact(filename, name, recursive, + metadata, content_type) @property def info(self): diff --git a/sacred/observers/file_storage.py b/sacred/observers/file_storage.py index d3f4964b..133fd40f 100644 --- a/sacred/observers/file_storage.py +++ b/sacred/observers/file_storage.py @@ -223,7 +223,7 @@ def artifact_event(self, name, filename, metadata=None, content_type=None): self.run_entry['artifacts'].append(name) self.save_json(self.run_entry, 'run.json') - def artifact_directory_event(self, name, filename, metadata=None, content_type=None): + def artifact_directory_event(self, name, filename): self.save_dir(filename, name) self.run_entry['artifacts'].append(name + "/") self.save_json(self.run_entry, 'run.json') diff --git a/sacred/run.py b/sacred/run.py index 698f4ef4..6dc07f89 100755 --- a/sacred/run.py +++ b/sacred/run.py @@ -188,7 +188,8 @@ def add_artifact( """ filename = os.path.abspath(filename) name = os.path.basename(filename) if name is None else name - self._emit_artifact_added(name, filename, recursive, metadata, content_type) + self._emit_artifact_added(name, filename, recursive, + metadata, content_type) def __call__(self, *args): r"""Start this run. @@ -386,7 +387,8 @@ def _emit_resource_added(self, filename): for observer in self.observers: self._safe_call(observer, 'resource_event', filename=filename) - def _emit_artifact_added(self, name, filename, recursive, metadata, content_type): + def _emit_artifact_added(self, name, filename, recursive, + metadata, content_type): for observer in self.observers: if recursive: self._safe_call(observer, 'artifact_directory_event', From f4257992d01bc68e4815369b2687c9adbb69ac13 Mon Sep 17 00:00:00 2001 From: Cody Wild Date: Tue, 30 Jul 2019 14:13:26 -0700 Subject: [PATCH 5/5] informative logging on observers that don't implement recursive artifact saving --- sacred/run.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/sacred/run.py b/sacred/run.py index 6dc07f89..ed9b4a7d 100755 --- a/sacred/run.py +++ b/sacred/run.py @@ -391,11 +391,16 @@ def _emit_artifact_added(self, name, filename, recursive, metadata, content_type): for observer in self.observers: if recursive: - self._safe_call(observer, 'artifact_directory_event', - name=name, - filename=filename, - metadata=metadata, - content_type=content_type) + if hasattr(observer, 'artifact_directory_event'): + self._safe_call(observer, 'artifact_directory_event', + name=name, + filename=filename) + else: + self.run_logger.warning("Observer of type {} " + " doesn't support" + " recursive artifacts".format( + type(observer) + )) else: self._safe_call(observer, 'artifact_event', name=name,