Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion sacred/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,7 @@ def add_artifact(
self,
filename: PathType,
name: Optional[str] = None,
recursive: bool = False,
metadata: Optional[dict] = None,
content_type: Optional[str] = None,
):
Expand Down Expand Up @@ -411,7 +412,8 @@ def add_artifact(
This only has an effect when using the MongoObserver.
"""
assert self.current_run is not None, "Can only be called during a run."
self.current_run.add_artifact(filename, name, metadata, content_type)
self.current_run.add_artifact(filename, name, recursive,
metadata, content_type)

@property
def info(self):
Expand Down
11 changes: 10 additions & 1 deletion sacred/observers/file_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import Optional
import warnings

from shutil import copyfile
from shutil import copyfile, copytree

from sacred.commandline_options import CommandLineOption
from sacred.dependencies import get_digest
Expand Down Expand Up @@ -197,6 +197,10 @@ def save_file(self, filename, target_name=None):
target_name = target_name or os.path.basename(filename)
copyfile(filename, os.path.join(self.dir, target_name))

def save_dir(self, filename, target_name=None):
target_name = target_name or os.path.basename(filename)
copytree(filename, os.path.join(self.dir, target_name))

def save_cout(self):
with open(os.path.join(self.dir, "cout.txt"), "ab") as f:
f.write(self.cout[self.cout_write_cursor :].encode("utf-8"))
Expand Down Expand Up @@ -259,6 +263,11 @@ def artifact_event(self, name, filename, metadata=None, content_type=None):
self.run_entry["artifacts"].append(name)
self.save_json(self.run_entry, "run.json")

def artifact_directory_event(self, name, filename):
self.save_dir(filename, name)
self.run_entry['artifacts'].append(name + "/")
self.save_json(self.run_entry, 'run.json')

def log_metrics(self, metrics_by_name, info):
"""Store new measurements into metrics.json.
"""
Expand Down
47 changes: 36 additions & 11 deletions sacred/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,15 @@ def add_resource(self, filename):
filename = os.path.abspath(filename)
self._emit_resource_added(filename)

def add_artifact(self, filename, name=None, metadata=None, content_type=None):
def add_artifact(
self,
filename,
name=None,
recursive=False,
metadata=None,
content_type=None,
):

"""Add a file as an artifact.

In Sacred terminology an artifact is a file produced by the experiment
Expand All @@ -191,7 +199,8 @@ def add_artifact(self, filename, name=None, metadata=None, content_type=None):
"""
filename = os.path.abspath(filename)
name = os.path.basename(filename) if name is None else name
self._emit_artifact_added(name, filename, metadata, content_type)
self._emit_artifact_added(name, filename, recursive,
metadata, content_type)

def __call__(self, *args):
r"""Start this run.
Expand Down Expand Up @@ -403,16 +412,32 @@ def _emit_resource_added(self, filename):
for observer in self.observers:
self._safe_call(observer, "resource_event", filename=filename)

def _emit_artifact_added(self, name, filename, metadata, content_type):
def _emit_artifact_added(self, name, filename, recursive,
metadata, content_type):
for observer in self.observers:
self._safe_call(
observer,
"artifact_event",
name=name,
filename=filename,
metadata=metadata,
content_type=content_type,
)

if recursive:
if hasattr(observer, 'artifact_directory_event'):
self._safe_call(
observer,
'artifact_directory_event',
name=name,
filename=filename)
else:
self.run_logger.warning("Observer of type {} "
" doesn't support"
" recursive artifacts".format(
type(observer)
))
else:
self._safe_call(
observer,
'artifact_event',
name=name,
filename=filename,
metadata=metadata,
content_type=content_type,
)

def _safe_call(self, obs, method, **kwargs):
if obs not in self._failed_observers:
Expand Down