Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added JEPA/.DS_Store
Binary file not shown.
Binary file added JEPA/configs/.DS_Store
Binary file not shown.
118 changes: 118 additions & 0 deletions JEPA/configs/base.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# configs/base.yaml

seed: 42

run:
save_dir: /mimer/NOBACKUP/groups/naiss2025-5-243/youya/CodeRepair_JEPA/e2_checkpoints_jepa
run_name: "" # Override with --set run.run_name=... in sbatch/array jobs.


data:
train_subset_size: 0
# train_subset_seed: 42
# train_subset_indices_path: /mimer/NOBACKUP/groups/naiss2025-5-243/Embeddings_RBR/Decoder_RBR/JEPA/sampled_train_idx.npy
# val_subset_indices_path: /mimer/NOBACKUP/groups/naiss2025-5-243/Embeddings_RBR/Decoder_RBR/JEPA/diagnosis/sample_global_indices.npy
exp1_emb_dir: "/mimer/NOBACKUP/groups/naiss2025-5-243/Embeddings_RBR/buggy_fixed_embeddings"
exp1_chunk_pattern: "buggy_fixed_embeddings_chunk_{:04d}.pkl"
exp1_fixed_keys: ["fixed_embeddings", "fixed_emb", "fixed", "target_embeddings"]

source: hf # hf | jsonl
hf:
dataset_id: "ASSERT-KTH/RunBugRun-Final"
split: "train"
fields:
buggy: "buggy_code"
fixed: "fixed_code"
language: "language"
problem_id: "problem_id"
buggy_submission_id: "buggy_submission_id"
fixed_submission_id: "fixed_submission_id"

indices:
dir: "/mimer/NOBACKUP/groups/naiss2025-5-243/Embeddings_RBR/Decoder_RBR/saved_indices"
global_target: "global_target_indices.npy"
train: "train_idx.npy"
val: "val_idx.npy"
test: "test_idx.npy"

num_workers: 4


encoder:
name: answerdotai/ModernBERT-large
max_len: 512
train_mode: frozen # frozen | full | lora
lora:
enabled: false
r: 16
alpha: 32
dropout: 0.05
bias: none
target_modules: [] # Required for exp2. See the exp2 config.


predictor:
name: vit1d
vit:
patch: 16
model_dim: 256
layers: 4
heads: 8
mlp_ratio: 4.0
dropout: 0.1
activation: relu
norm_first: false
mlp:
hidden_sizes: [4096, 2048, 1024]
activation: relu
dropout: 0.0
use_layernorm: false
residual: false
out_layernorm: false

ema:
tau: 0.996

loss:
w_align: 1.0
w_var: 1.0
w_mse: 0.1
w_tgt_exp1: 0.1
align_eps: 1e-8
var_target_std: 1.0
var_eps: 1e-4

train:
epochs: 3
batch_size: 16
grad_accum: 1
fp16: true
lr: 2e-5
lr_encoder: 1e-5
lr_predictor: 1e-4
weight_decay: 0.01
log_every: 100
save_every_epoch: false
eval_every_steps: 800 # Run validation and save the best checkpoint every 500 steps.
save_every_steps: 3000
save_ckpt: true
resume_strict: true
# resume_from: /mimer/NOBACKUP/groups/naiss2025-5-243/youya/CodeRepair_JEPA/e2_checkpoints_jepa/r2_wmse0.01_lrE8e-5_lrP2e-4_tau0.999_r16_20260315_203434/checkpoints/ckpt_step8000.pt

optim:
betas: [0.9, 0.999]
eps: 1e-8

ddp:
enabled: true
backend: nccl
find_unused_parameters: false


wandb:
enabled: true
resume: never
entity: assert-kth
project: CodeRepair_JEPA
group: e2_encoder_vit_diagnosis
run_name: e2_lora_vit
27 changes: 27 additions & 0 deletions JEPA/configs/exp2_lora_encoder_predictor.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# configs/exp2_lora_encoder_predictor.yaml

run:
run_name: exp2_lora_vit1d

encoder:
train_mode: lora
lora:
enabled: true
r: 16
alpha: 32
dropout: 0.05

# Set the LoRA target linear layers for this encoder.
# Confirm the exact names with print(model) or named_modules().
# Common examples:
# - BERT/RoBERTa style: ["query", "key", "value", "dense"]
# - LLaMA style: ["q_proj", "k_proj", "v_proj", "o_proj"]
target_modules: ["Wqkv", "Wo"]

predictor:
name: vit1d

# train:
# epochs: 2
# batch_size: 16
# lr: 2e-5
Binary file added JEPA/scripts/.DS_Store
Binary file not shown.
12 changes: 12 additions & 0 deletions JEPA/scripts/_bootstrap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from __future__ import annotations

import sys
from pathlib import Path


def bootstrap() -> None:
repo_root = Path(__file__).resolve().parents[1]
src_dir = repo_root / "src"
src_str = str(src_dir)
if src_str not in sys.path:
sys.path.insert(0, src_str)
14 changes: 14 additions & 0 deletions JEPA/scripts/decoder/test_lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).resolve().parents[1]))

from _bootstrap import bootstrap

bootstrap()

from jepa.tasks.decoder.test_lora import main


if __name__ == "__main__":
main()
14 changes: 14 additions & 0 deletions JEPA/scripts/decoder/test_projector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).resolve().parents[1]))

from _bootstrap import bootstrap

bootstrap()

from jepa.tasks.decoder.test_projector import main


if __name__ == "__main__":
main()
14 changes: 14 additions & 0 deletions JEPA/scripts/decoder/train_lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).resolve().parents[1]))

from _bootstrap import bootstrap

bootstrap()

from jepa.tasks.decoder.train_lora import parse_args, train


if __name__ == "__main__":
train(parse_args())
14 changes: 14 additions & 0 deletions JEPA/scripts/decoder/train_projector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).resolve().parents[1]))

from _bootstrap import bootstrap

bootstrap()

from jepa.tasks.decoder.train_projector import parse_args, train


if __name__ == "__main__":
train(parse_args())
14 changes: 14 additions & 0 deletions JEPA/scripts/encoder/embed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).resolve().parents[1]))

from _bootstrap import bootstrap

bootstrap()

from jepa.tasks.encoder.embed import main


if __name__ == "__main__":
main()
14 changes: 14 additions & 0 deletions JEPA/scripts/encoder/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).resolve().parents[1]))

from _bootstrap import bootstrap

bootstrap()

from jepa.tasks.encoder.train import main


if __name__ == "__main__":
main()
Binary file added JEPA/src/.DS_Store
Binary file not shown.
1 change: 1 addition & 0 deletions JEPA/src/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""JEPA research codebase."""
123 changes: 123 additions & 0 deletions JEPA/src/losses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# losses.py
# -*- coding: utf-8 -*-
"""
Losses for JEPA-style training.

Exp2 minimal set:
- cosine alignment: align(z_pred, z_tgt)
- variance regularizer (VICReg variance term): avoid collapse without negatives

Also provide simple training-time metrics:
- retrieval_top1_acc: batch retrieval acc z_pred -> z_tgt
- emb_std_mean: mean std across embedding dimensions (collapse indicator)
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Dict, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F


# -------------------------
# Core losses
# -------------------------
class CosineAlignLoss(nn.Module):
"""1 - cosine similarity (mean over batch)."""

def __init__(self, eps: float = 1e-8):
super().__init__()
self.eps = float(eps)

def forward(self, z_pred: torch.Tensor, z_tgt: torch.Tensor) -> torch.Tensor:
z_pred = F.normalize(z_pred, dim=-1, eps=self.eps)
z_tgt = F.normalize(z_tgt, dim=-1, eps=self.eps)
return 1.0 - (z_pred * z_tgt).sum(dim=-1).mean()


class VarianceLoss(nn.Module):
"""
VICReg variance term:
std = sqrt(var + eps)
loss = mean( relu(target_std - std) )
"""

def __init__(self, target_std: float = 1.0, eps: float = 1e-4):
super().__init__()
self.target_std = float(target_std)
self.eps = float(eps)

def forward(self, z: torch.Tensor) -> torch.Tensor:
# z: [B, D]
std = torch.sqrt(z.var(dim=0, unbiased=False) + self.eps)
return F.relu(self.target_std - std).mean()


class EmaPredictiveLoss(nn.Module):
"""
Minimal JEPA-style loss for exp2 (EMA target + predictor):

L = w_align * align(z_pred, z_tgt) + w_var * (var(z_ctx) + var(z_tgt))
"""
def __init__(
self,
w_align: float = 1.0,
w_var: float = 1.0,
align_eps: float = 1e-8,
var_target_std: float = 1.0,
var_eps: float = 1e-4,
):
super().__init__()
self.w_align = float(w_align)
self.w_var = float(w_var)
self.align = CosineAlignLoss(eps=align_eps)
self.var = VarianceLoss(target_std=var_target_std, eps=var_eps)

def forward(self, z_ctx, z_pred, z_tgt):
l_align = self.align(z_pred, z_tgt)
l_var = self.var(z_ctx) + self.var(z_tgt)
loss = self.w_align * l_align + self.w_var * l_var
return {"loss": loss, "align": l_align, "var": l_var}


def build_loss(cfg: Dict[str, Any]) -> EmaPredictiveLoss:
lcfg = cfg.get("loss", {})
return EmaPredictiveLoss(
w_align=float(lcfg.get("w_align", 1.0)),
w_var=float(lcfg.get("w_var", 1.0)),
align_eps=float(lcfg.get("align_eps", 1e-8)),
var_target_std=float(lcfg.get("var_target_std", 1.0)),
var_eps=float(lcfg.get("var_eps", 1e-4)),
)


# -------------------------
# Metrics (no grad)
# -------------------------
@torch.no_grad()
def retrieval_top1_acc(z_pred: torch.Tensor, z_tgt: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
"""
Batch retrieval accuracy:
For each i, find argmax_j cos(z_pred[i], z_tgt[j])
count how many i match j=i
AMP-safe: cast to float32 for stable metric computation.
"""
z_pred = z_pred.float()
z_tgt = z_tgt.float()

z_pred = F.normalize(z_pred, dim=-1, eps=eps)
z_tgt = F.normalize(z_tgt, dim=-1, eps=eps)

sim = z_pred @ z_tgt.t() # [B, B]
pred_idx = sim.argmax(dim=1)
gt_idx = torch.arange(sim.size(0), device=sim.device)
return (pred_idx == gt_idx).float().mean()


@torch.no_grad()
def emb_std_mean(z: torch.Tensor) -> torch.Tensor:
z = z.float()
return z.std(dim=0, unbiased=False).mean()
Loading