-
Notifications
You must be signed in to change notification settings - Fork 12
Expand file tree
/
Copy pathhandler.py
More file actions
105 lines (86 loc) · 3.5 KB
/
handler.py
File metadata and controls
105 lines (86 loc) · 3.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from collections import OrderedDict
from pathlib import Path
import pandas as pd
import yaml
class ExpHandler:
"""
en_wandb, debug, no_resume, output_dir
"""
def __init__(self, args):
self._save_dir = Path(args.output_dir)
self._run_name = self._save_dir.name
# self.csv_path = self._save_dir / f'{self._run_name}.csv'
# self.cfg_path = self._save_dir / f'{self._run_name}.yaml'
self.csv_path = self._save_dir / 'log.csv'
self.cfg_path = self._save_dir / 'args.yaml'
if args.en_wandb:
import wandb
wandb_kwargs = dict(project=self._save_dir.parents[1].name,
group=self._save_dir.parents[0].name,
name=self._run_name,
save_code=True,
resume='allow')
args.commit = self._get_commit_hash()
if not args.no_resume and self.cfg_path.exists():
with open(self.cfg_path, 'r') as f:
config = yaml.safe_load(f)
if args.en_wandb:
if 'wandb_id' not in config:
wandb_kwargs['id'] = wandb.util.generate_id()
else:
wandb_kwargs['id'] = config['wandb_id']
self.log_data = pd.read_csv(self.csv_path).to_dict('records') if self.csv_path.exists() else []
else: # new run
if args.en_wandb:
wandb_kwargs['id'] = wandb.util.generate_id()
self._save_config(args)
self.log_data = []
if args.en_wandb:
self.wandb_run = wandb.init(**wandb_kwargs)
@staticmethod
def _get_commit_hash():
import git
try:
repo = git.Repo(search_parent_directories=True)
commit = repo.head.object.hexsha
except git.InvalidGitRepositoryError:
commit = 'not_set'
return commit
@property
def save_dir(self):
return self._save_dir
def _save_config(self, args):
conf = vars(args)
if hasattr(self, 'wandb_run'):
conf['wandb_id'] = self.wandb_run.id
print('=' * 40)
for k, v in conf.items():
print(f'{k}: {v}')
print('=' * 40)
with open(self.cfg_path, 'w') as f:
yaml.safe_dump(conf, f, sort_keys=False)
def write(self, eval_metrics=None, train_metrics=None, **kwargs):
rowd = OrderedDict([(f'{k}', v) for k, v in kwargs.items()])
if train_metrics:
rowd.update([(f'train/' + k, v) for k, v in train_metrics.items()])
if eval_metrics:
rowd.update([(f'eval/' + k, v) for k, v in eval_metrics.items()])
self.log_data.append(rowd)
pd.DataFrame(self.log_data).to_csv(self.csv_path,
index=False)
# initial = not os.path.exists(self.csv_path)
# with open(self.csv_path, mode='a') as cf:
# dw = csv.DictWriter(cf, fieldnames=rowd.keys())
# if initial:
# dw.writeheader()
# dw.writerow(rowd)
if hasattr(self, 'wandb_run'):
self.wandb_run.log(rowd)
def finish(self):
(self._save_dir / 'finished').touch()
if hasattr(self, 'wandb_run'):
self.wandb_run.finish()
@property
def wandb_obj(self):
if hasattr(self, 'wandb_run'):
return self.wandb_run