diff --git a/train.py b/train.py index 7c12912..8ae9a91 100644 --- a/train.py +++ b/train.py @@ -194,7 +194,7 @@ def train( else: loss = model.loss(ypred, label, adj, batch_num_nodes) loss.backward() - nn.utils.clip_grad_norm(model.parameters(), args.clip) + nn.utils.clip_grad_norm_(model.parameters(), args.clip) optimizer.step() iter += 1 avg_loss += loss @@ -294,7 +294,7 @@ def train_node_classifier(G, labels, model, args, writer=None): else: loss = model.loss(ypred_train, labels_train) loss.backward() - nn.utils.clip_grad_norm(model.parameters(), args.clip) + nn.utils.clip_grad_norm_(model.parameters(), args.clip) optimizer.step() #for param_group in optimizer.param_groups: @@ -424,7 +424,7 @@ def train_node_classifier_multigraph(G_list, labels, model, args, writer=None): else: loss = model.loss(ypred_train_cmp, labels_train) loss.backward() - nn.utils.clip_grad_norm(model.parameters(), args.clip) + nn.utils.clip_grad_norm_(model.parameters(), args.clip) optimizer.step() #for param_group in optimizer.param_groups: @@ -516,7 +516,7 @@ def evaluate(dataset, model, args, name="Validation", max_num_examples=None): preds = np.hstack(preds) result = { - "prec": metrics.precision_score(labels, preds, average="macro"), + "prec": metrics.precision_score(labels, preds, average="macro", zero_division=0), "recall": metrics.recall_score(labels, preds, average="macro"), "acc": metrics.accuracy_score(labels, preds), } @@ -534,13 +534,13 @@ def evaluate_node(ypred, labels, train_idx, test_idx): labels_test = np.ravel(labels[:, test_idx]) result_train = { - "prec": metrics.precision_score(labels_train, pred_train, average="macro"), + "prec": metrics.precision_score(labels_train, pred_train, average="macro", zero_division=0), "recall": metrics.recall_score(labels_train, pred_train, average="macro"), "acc": metrics.accuracy_score(labels_train, pred_train), "conf_mat": metrics.confusion_matrix(labels_train, pred_train), } result_test = { - "prec": metrics.precision_score(labels_test, pred_test, average="macro"), + "prec": metrics.precision_score(labels_test, pred_test, average="macro", zero_division=0), "recall": metrics.recall_score(labels_test, pred_test, average="macro"), "acc": metrics.accuracy_score(labels_test, pred_test), "conf_mat": metrics.confusion_matrix(labels_test, pred_test),