Skip to content

Commit cdca52e

Browse files
authored
[Feat] Low overhead (#135)
* [WIP] refactor: move proxy_wrapper to within the instrumentor * WIP: strengthen instrumentation logic * remove unused proxy config; rename ML_DAIKON to TRAINCHECK * subclass selective instrumentation impl * fix instrumentation logic to get the parent class of a method definition * fix: only use positional arguments for function_wrapper * fix: respect configured tracker type during selective instrumentation * fix: unify registry implementation for proxy and subclass * fix: step incrementing logic * add: richer error msg for unchanged var check in contain relation * fix: subclass registry updating process * fix: remove unproxy scanning for subclass to further reduce overhead * add: refined logging for observer and registry * fix: selective dumping for the proxy class * add: monkey patch __setattr__ at the module level when using subclass, to ensure submodule assignments are captured * add: support sampling and warmup instrumentation policies
1 parent 10033f3 commit cdca52e

35 files changed

Lines changed: 619 additions & 583 deletions

.github/workflows/eval-overhead-e2e.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,11 @@ on:
66
paths:
77
- '.github/workflows/**'
88
- 'traincheck/instrumentor/**'
9-
- 'traincheck/proxy_wrapper/**'
109
- 'traincheck/collect_trace.py'
1110
pull_request:
1211
paths:
1312
- '.github/workflows/**'
1413
- 'traincheck/instrumentor/**'
15-
- 'traincheck/proxy_wrapper/**'
1614
- 'traincheck/collect_trace.py'
1715

1816

docs/5-min-tutorial.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ For example, the "`optimizer.zero_grad` did **not** reset `.grad` from non-zero
246246
"var_type": NaN,
247247
"mode": NaN,
248248
"dump_loc": NaN,
249-
"attributes._ML_DAIKON_data_ID": NaN,
249+
"attributes._TRAINCHECK_data_ID": NaN,
250250
"attributes.data": NaN,
251251
"attributes.dtype": NaN,
252252
"attributes.grad": NaN,
@@ -274,7 +274,7 @@ For example, the "`optimizer.zero_grad` did **not** reset `.grad` from non-zero
274274
"attributes.requires_grad": NaN,
275275
"attributes.retains_grad": NaN,
276276
"attributes.shape": NaN,
277-
"attributes._ML_DAIKON_grad_ID": NaN,
277+
"attributes._TRAINCHECK_grad_ID": NaN,
278278
"exception": NaN,
279279
"exception_msg": NaN,
280280
"proxy_obj_names": NaN

docs/ae-eval-s5.1-silent-issue-detection.md

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,9 @@ diff --color -r checker_output/trace_pytorch-104336/failed.log reference_checker
145145
> "process_id": 9591,
146146
> "thread_id": 140324043503424,
147147
86c86
148-
< "attributes._ML_DAIKON_data_ID": 140704882109040,
148+
< "attributes._TRAINCHECK_data_ID": 140704882109040,
149149
---
150-
> "attributes._ML_DAIKON_data_ID": 140317529048544,
150+
> "attributes._TRAINCHECK_data_ID": 140317529048544,
151151
116,117c116,117
152152
< "time": 2437523672783,
153153
< "meta_vars._DATA_PARALLEL_RANK": 4.0,
@@ -161,9 +161,9 @@ diff --color -r checker_output/trace_pytorch-104336/failed.log reference_checker
161161
> "process_id": 9747,
162162
> "thread_id": 140028492969792,
163163
128c128
164-
< "attributes._ML_DAIKON_data_ID": 140043703504144,
164+
< "attributes._TRAINCHECK_data_ID": 140043703504144,
165165
---
166-
> "attributes._ML_DAIKON_data_ID": 140021978318304,
166+
> "attributes._TRAINCHECK_data_ID": 140021978318304,
167167
158,159c158,159
168168
< "time": 2437502499438,
169169
< "meta_vars._DATA_PARALLEL_RANK": 2.0,
@@ -182,9 +182,9 @@ diff --color -r checker_output/trace_pytorch-115607/failed.log reference_checker
182182
< "exception_msg": NaN,
183183
< "proxy_obj_names": NaN,
184184
113c110,113
185-
< "attributes._ML_DAIKON_grad_ID": NaN
185+
< "attributes._TRAINCHECK_grad_ID": NaN
186186
---
187-
> "attributes._ML_DAIKON_grad_ID": NaN,
187+
> "attributes._TRAINCHECK_grad_ID": NaN,
188188
> "exception": NaN,
189189
> "exception_msg": NaN,
190190
> "proxy_obj_names": NaN
@@ -193,9 +193,9 @@ diff --color -r checker_output/trace_pytorch-115607/failed.log reference_checker
193193
< "exception_msg": NaN,
194194
< "proxy_obj_names": NaN,
195195
215c212,215
196-
< "attributes._ML_DAIKON_grad_ID": NaN
196+
< "attributes._TRAINCHECK_grad_ID": NaN
197197
---
198-
> "attributes._ML_DAIKON_grad_ID": NaN,
198+
> "attributes._TRAINCHECK_grad_ID": NaN,
199199
> "exception": NaN,
200200
> "exception_msg": NaN,
201201
> "proxy_obj_names": NaN
@@ -210,9 +210,9 @@ diff --color -r checker_output/trace_pytorch-115607/failed.log reference_checker
210210
< "exception_msg": NaN,
211211
< "proxy_obj_names": NaN,
212212
331c328,331
213-
< "attributes._ML_DAIKON_grad_ID": NaN
213+
< "attributes._TRAINCHECK_grad_ID": NaN
214214
---
215-
> "attributes._ML_DAIKON_grad_ID": NaN,
215+
> "attributes._TRAINCHECK_grad_ID": NaN,
216216
> "exception": NaN,
217217
> "exception_msg": NaN,
218218
> "proxy_obj_names": NaN
@@ -247,10 +247,10 @@ diff --color -r checker_output/trace_pytorch-51800/failed.log reference_checker_
247247
> "time": 19876858668088743,
248248
> "meta_vars.step": 0,
249249
89c70,89
250-
< "attributes._ML_DAIKON_grad_ID": NaN
250+
< "attributes._TRAINCHECK_grad_ID": NaN
251251
---
252252
> "type": "function_call (pre)",
253-
> "attributes._ML_DAIKON_grad_ID": NaN,
253+
> "attributes._TRAINCHECK_grad_ID": NaN,
254254
> "func_call_id": "b39a4a81b2c24473ba916ab1832fbf12_19876858668012869",
255255
> "function": "torch.nn.modules.module.Module.eval",
256256
> "is_bound_method": true,
@@ -290,9 +290,9 @@ diff --color -r checker_output/trace_x-jxmnop-ddp-out-of-sync/failed.log referen
290290
---
291291
> "meta_vars._DATA_PARALLEL_RANK": "1",
292292
87c87
293-
< "attributes._ML_DAIKON_data_ID": 140656561409856,
293+
< "attributes._TRAINCHECK_data_ID": 140656561409856,
294294
---
295-
> "attributes._ML_DAIKON_data_ID": 140621279056480,
295+
> "attributes._TRAINCHECK_data_ID": 140621279056480,
296296
117c117
297297
< "time": 123297988837864,
298298
---
@@ -308,9 +308,9 @@ diff --color -r checker_output/trace_x-jxmnop-ddp-out-of-sync/failed.log referen
308308
---
309309
> "meta_vars._DATA_PARALLEL_RANK": "0",
310310
129c129
311-
< "attributes._ML_DAIKON_data_ID": 140621279058160,
311+
< "attributes._TRAINCHECK_data_ID": 140621279058160,
312312
---
313-
> "attributes._ML_DAIKON_data_ID": 140656561411776,
313+
> "attributes._TRAINCHECK_data_ID": 140656561411776,
314314
159c159
315315
< "time": 123299970638648,
316316
---

docs/assets/code/mnist.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
from torchvision import datasets, transforms
99

1010
from traincheck import annotate_stage
11-
from traincheck.instrumentor import meta_vars
11+
from traincheck.instrumentor import META_VARS
1212

13-
meta_vars["step"] = -1
13+
META_VARS["step"] = -1
1414

1515

1616
class Net(nn.Module):
@@ -40,10 +40,10 @@ def forward(self, x):
4040

4141

4242
def train(args, model, device, train_loader, optimizer, epoch):
43-
annotate_stage("training") # ML_DAIKON: stage annotation
43+
annotate_stage("training") # TRAINCHECK: stage annotation
4444
model.train()
4545
for batch_idx, (data, target) in enumerate(train_loader):
46-
meta_vars["step"] += 1
46+
META_VARS["step"] += 1
4747
data, target = data.to(device), target.to(device)
4848
optimizer.zero_grad()
4949
output = model(data)
@@ -63,13 +63,13 @@ def train(args, model, device, train_loader, optimizer, epoch):
6363
if args.dry_run:
6464
break
6565

66-
# ML_DAIKON: break after 100 batches
66+
# TRAINCHECK: break after 100 batches
6767
if batch_idx == 50:
6868
break
6969

7070

7171
def test(model, device, test_loader):
72-
annotate_stage("testing") # ML_DAIKON: stage annotation
72+
annotate_stage("testing") # TRAINCHECK: stage annotation
7373
model.eval()
7474
test_loss = 0
7575
correct = 0
@@ -87,7 +87,7 @@ def test(model, device, test_loader):
8787
correct += pred.eq(target.view_as(pred)).sum().item()
8888

8989
data_idx += 1
90-
# ML_DAIKON: break after 10 batches
90+
# TRAINCHECK: break after 10 batches
9191
if data_idx == 10:
9292
break
9393

@@ -174,7 +174,7 @@ def main():
174174
)
175175
args = parser.parse_args()
176176

177-
annotate_stage("init") # ML_DAIKON: stage annotation
177+
annotate_stage("init") # TRAINCHECK: stage annotation
178178
use_cuda = not args.no_cuda and torch.cuda.is_available()
179179
use_mps = not args.no_mps and torch.backends.mps.is_available()
180180

@@ -191,7 +191,7 @@ def main():
191191
test_kwargs = {"batch_size": args.test_batch_size}
192192
if use_cuda:
193193
cuda_kwargs = {"num_workers": 2, "pin_memory": True, "shuffle": True}
194-
# ML_DAIKON: set num_workers to 0 to avoid dataloader related invariants
194+
# TRAINCHECK: set num_workers to 0 to avoid dataloader related invariants
195195
# cuda_kwargs = {'num_workers': 0, 'pin_memory': True, 'shuffle': True}
196196
train_kwargs.update(cuda_kwargs)
197197
test_kwargs.update(cuda_kwargs)
@@ -212,11 +212,11 @@ def main():
212212
train(args, model, device, train_loader, optimizer, epoch)
213213
test(model, device, test_loader)
214214

215-
annotate_stage("training") # ML_DAIKON: stage annotation
215+
annotate_stage("training") # TRAINCHECK: stage annotation
216216
scheduler.step()
217217

218218
if args.save_model:
219-
annotate_stage("checkpointing") # ML_DAIKON: stage annotation
219+
annotate_stage("checkpointing") # TRAINCHECK: stage annotation
220220
torch.save(model.state_dict(), "mnist_cnn.pt")
221221

222222

docs/assets/examples/traincheck-collect/mnist-config/mnist.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
from torchvision import datasets, transforms
99

1010
from traincheck import annotate_stage
11-
from traincheck.instrumentor import meta_vars
11+
from traincheck.instrumentor import META_VARS
1212

13-
meta_vars["step"] = -1
13+
META_VARS["step"] = -1
1414

1515

1616
class Net(nn.Module):
@@ -40,10 +40,10 @@ def forward(self, x):
4040

4141

4242
def train(args, model, device, train_loader, optimizer, epoch):
43-
annotate_stage("training") # ML_DAIKON: stage annotation
43+
annotate_stage("training") # TRAINCHECK: stage annotation
4444
model.train()
4545
for batch_idx, (data, target) in enumerate(train_loader):
46-
meta_vars["step"] += 1
46+
META_VARS["step"] += 1
4747
data, target = data.to(device), target.to(device)
4848
optimizer.zero_grad()
4949
output = model(data)
@@ -63,13 +63,13 @@ def train(args, model, device, train_loader, optimizer, epoch):
6363
if args.dry_run:
6464
break
6565

66-
# ML_DAIKON: break after 100 batches
66+
# TRAINCHECK: break after 100 batches
6767
if batch_idx == 50:
6868
break
6969

7070

7171
def test(model, device, test_loader):
72-
annotate_stage("testing") # ML_DAIKON: stage annotation
72+
annotate_stage("testing") # TRAINCHECK: stage annotation
7373
model.eval()
7474
test_loss = 0
7575
correct = 0
@@ -87,7 +87,7 @@ def test(model, device, test_loader):
8787
correct += pred.eq(target.view_as(pred)).sum().item()
8888

8989
data_idx += 1
90-
# ML_DAIKON: break after 10 batches
90+
# TRAINCHECK: break after 10 batches
9191
if data_idx == 10:
9292
break
9393

@@ -174,7 +174,7 @@ def main():
174174
)
175175
args = parser.parse_args()
176176

177-
annotate_stage("init") # ML_DAIKON: stage annotation
177+
annotate_stage("init") # TRAINCHECK: stage annotation
178178
use_cuda = not args.no_cuda and torch.cuda.is_available()
179179
use_mps = not args.no_mps and torch.backends.mps.is_available()
180180

@@ -191,7 +191,7 @@ def main():
191191
test_kwargs = {"batch_size": args.test_batch_size}
192192
if use_cuda:
193193
cuda_kwargs = {"num_workers": 2, "pin_memory": True, "shuffle": True}
194-
# ML_DAIKON: set num_workers to 0 to avoid dataloader related invariants
194+
# TRAINCHECK: set num_workers to 0 to avoid dataloader related invariants
195195
# cuda_kwargs = {'num_workers': 0, 'pin_memory': True, 'shuffle': True}
196196
train_kwargs.update(cuda_kwargs)
197197
test_kwargs.update(cuda_kwargs)
@@ -212,11 +212,11 @@ def main():
212212
train(args, model, device, train_loader, optimizer, epoch)
213213
test(model, device, test_loader)
214214

215-
annotate_stage("training") # ML_DAIKON: stage annotation
215+
annotate_stage("training") # TRAINCHECK: stage annotation
216216
scheduler.step()
217217

218218
if args.save_model:
219-
annotate_stage("checkpointing") # ML_DAIKON: stage annotation
219+
annotate_stage("checkpointing") # TRAINCHECK: stage annotation
220220
torch.save(model.state_dict(), "mnist_cnn.pt")
221221

222222

0 commit comments

Comments
 (0)