Skip to content

Commit f3a0917

Browse files
Add files via upload
1 parent 5fd4776 commit f3a0917

File tree

3 files changed

+11265
-0
lines changed

3 files changed

+11265
-0
lines changed

DGP.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import numpy as np
2+
import pandas as pd
3+
4+
5+
def expit(x):
6+
return 1.0 / (1.0 + np.exp(-x))
7+
8+
9+
def g_func(X, active_idx, beta_coeffs):
10+
X_sub = X[:, active_idx]
11+
beta = np.asarray(beta_coeffs)
12+
base = np.sum(np.tanh(X_sub) * beta, axis=1)
13+
poly = 0.2 * np.sum((X_sub ** 2) * beta, axis=1)
14+
inter = 0.3 * X_sub[:, 0] * X_sub[:, 1]
15+
return base + poly + inter
16+
17+
18+
def dgp_linear_cross_section(n=10000, p=20, seed=42):
19+
rng = np.random.default_rng(seed)
20+
21+
X = rng.normal(0.0, 1.0, size=(n, p))
22+
23+
all_indices = np.arange(p)
24+
active_d = all_indices[: 4 * p // 5]
25+
active_y = all_indices[p // 10 : 9 * p // 10]
26+
27+
beta_t = rng.normal(0.0, 0.3, size=len(active_d))
28+
beta_y0 = rng.normal(0.0, 0.5, size=len(active_y))
29+
30+
half = len(active_y) // 2
31+
beta_y1 = np.concatenate([
32+
np.repeat(0.15, half),
33+
np.repeat(-0.15, len(active_y) - half),
34+
])
35+
beta_y1 = rng.permutation(beta_y1)
36+
37+
eta = g_func(X, active_d, beta_t) + rng.normal(0.0, 0.3, size=n)
38+
ps = expit(eta)
39+
t = rng.binomial(1, ps, size=n)
40+
41+
raw_tau = 2.0 + np.sum(X[:, active_y] * beta_y1, axis=1)
42+
ite = raw_tau.copy()
43+
44+
g_x_y0 = g_func(X, active_y, beta_y0)
45+
y = g_x_y0 + t * raw_tau + rng.normal(0.0, 0.2, size=n)
46+
47+
actual_ate = np.mean(ite)
48+
calibration_offset = actual_ate - 2.0
49+
50+
ite = ite - calibration_offset
51+
y = y - t * calibration_offset
52+
53+
columns = [f"X{i}" for i in range(1, p + 1)]
54+
df = pd.DataFrame(X, columns=columns)
55+
df.insert(0, "ite", ite)
56+
df.insert(0, "t", t)
57+
df.insert(0, "y", y)
58+
return df
59+
60+
61+
if __name__ == "__main__":
62+
data = dgp_linear_cross_section(n=10000, p=20, seed=42)
63+
data.to_csv("data.csv", index=False)

0 commit comments

Comments
 (0)