|
| 1 | +#!/usr/bin/env python3 |
| 2 | +""" |
| 3 | +Example: register_function_cpp() with TTree::Draw-like functionality |
| 4 | +
|
| 5 | +Phase 13.5.B — C++ Function Registration |
| 6 | +
|
| 7 | +This example demonstrates: |
| 8 | +1. Registering custom C++ functions |
| 9 | +2. Using them in RDataFrame expressions |
| 10 | +3. TTree::Draw-like plotting with dsl.draw() |
| 11 | +""" |
| 12 | + |
| 13 | +import ROOT |
| 14 | + |
| 15 | +# Suppress ROOT startup messages |
| 16 | +ROOT.gROOT.SetBatch(True) |
| 17 | + |
| 18 | +from RDataFrameDSL import DSLCompiler |
| 19 | + |
| 20 | + |
| 21 | +def example_basic(): |
| 22 | + """Basic example: register and use a simple function.""" |
| 23 | + print("=" * 60) |
| 24 | + print("Example 1: Basic Function Registration") |
| 25 | + print("=" * 60) |
| 26 | + |
| 27 | + # Create synthetic data |
| 28 | + rdf = ROOT.RDataFrame(1000) |
| 29 | + rdf = rdf.Define("px", "gRandom->Gaus(0, 10)") |
| 30 | + rdf = rdf.Define("py", "gRandom->Gaus(0, 10)") |
| 31 | + rdf = rdf.Define("pz", "gRandom->Gaus(0, 50)") |
| 32 | + |
| 33 | + # Create DSL compiler with schema |
| 34 | + dsl = DSLCompiler({ |
| 35 | + "px": "double", |
| 36 | + "py": "double", |
| 37 | + "pz": "double", |
| 38 | + }) |
| 39 | + |
| 40 | + # Register custom C++ function |
| 41 | + dsl.register_function_cpp(''' |
| 42 | + double pt(double px, double py) { |
| 43 | + return sqrt(px*px + py*py); |
| 44 | + } |
| 45 | + ''') |
| 46 | + |
| 47 | + # Check registration |
| 48 | + func = dsl.get_registered_function("pt") |
| 49 | + print(f"✅ Registered: {func.name}") |
| 50 | + print(f" C++ name: {func.cpp_name}") |
| 51 | + print(f" Hash: {func.hash}") |
| 52 | + print(f" Declared: {func.declared}") |
| 53 | + print(f" Headers: {func.headers}") |
| 54 | + |
| 55 | + # Apply to RDataFrame |
| 56 | + rdf = dsl.apply(rdf) |
| 57 | + |
| 58 | + # Use registered function in Define |
| 59 | + rdf = rdf.Define("track_pt", f"{func.cpp_name}(px, py)") |
| 60 | + |
| 61 | + # Verify it works |
| 62 | + mean_pt = rdf.Mean("track_pt").GetValue() |
| 63 | + print(f"\n Mean pt: {mean_pt:.2f}") |
| 64 | + |
| 65 | + return dsl, rdf |
| 66 | + |
| 67 | + |
| 68 | +def example_multiple_functions(): |
| 69 | + """Register multiple functions and chain them.""" |
| 70 | + print("\n" + "=" * 60) |
| 71 | + print("Example 2: Multiple Functions") |
| 72 | + print("=" * 60) |
| 73 | + |
| 74 | + # Create synthetic data |
| 75 | + rdf = ROOT.RDataFrame(1000) |
| 76 | + rdf = rdf.Define("px", "gRandom->Gaus(0, 10)") |
| 77 | + rdf = rdf.Define("py", "gRandom->Gaus(0, 10)") |
| 78 | + rdf = rdf.Define("pz", "gRandom->Gaus(0, 50)") |
| 79 | + rdf = rdf.Define("energy", "gRandom->Gaus(100, 20)") |
| 80 | + |
| 81 | + dsl = DSLCompiler({ |
| 82 | + "px": "double", |
| 83 | + "py": "double", |
| 84 | + "pz": "double", |
| 85 | + "energy": "double", |
| 86 | + }) |
| 87 | + |
| 88 | + # Register multiple functions with chaining |
| 89 | + dsl.register_function_cpp(''' |
| 90 | + double pt(double px, double py) { |
| 91 | + return sqrt(px*px + py*py); |
| 92 | + } |
| 93 | + ''').register_function_cpp(''' |
| 94 | + double p(double px, double py, double pz) { |
| 95 | + return sqrt(px*px + py*py + pz*pz); |
| 96 | + } |
| 97 | + ''').register_function_cpp(''' |
| 98 | + double eta(double px, double py, double pz) { |
| 99 | + double pmag = sqrt(px*px + py*py + pz*pz); |
| 100 | + if (pmag == pz) return 0; |
| 101 | + return 0.5 * log((pmag + pz) / (pmag - pz)); |
| 102 | + } |
| 103 | + ''').register_function_cpp(''' |
| 104 | + double phi(double px, double py) { |
| 105 | + return atan2(py, px); |
| 106 | + } |
| 107 | + ''') |
| 108 | + |
| 109 | + # List registered functions |
| 110 | + print(f"✅ Registered functions: {dsl.list_registered_functions()}") |
| 111 | + |
| 112 | + # Apply and use |
| 113 | + rdf = dsl.apply(rdf) |
| 114 | + |
| 115 | + pt_func = dsl.get_registered_function("pt") |
| 116 | + eta_func = dsl.get_registered_function("eta") |
| 117 | + phi_func = dsl.get_registered_function("phi") |
| 118 | + |
| 119 | + rdf = rdf.Define("track_pt", f"{pt_func.cpp_name}(px, py)") |
| 120 | + rdf = rdf.Define("track_eta", f"{eta_func.cpp_name}(px, py, pz)") |
| 121 | + rdf = rdf.Define("track_phi", f"{phi_func.cpp_name}(px, py)") |
| 122 | + |
| 123 | + # Stats |
| 124 | + print(f"\n Mean pt: {rdf.Mean('track_pt').GetValue():.2f}") |
| 125 | + print(f" Mean eta: {rdf.Mean('track_eta').GetValue():.2f}") |
| 126 | + print(f" Mean phi: {rdf.Mean('track_phi').GetValue():.2f}") |
| 127 | + |
| 128 | + return dsl, rdf |
| 129 | + |
| 130 | + |
| 131 | +def example_rvec(): |
| 132 | + """Register function working with RVec.""" |
| 133 | + print("\n" + "=" * 60) |
| 134 | + print("Example 3: RVec Functions") |
| 135 | + print("=" * 60) |
| 136 | + |
| 137 | + # Create synthetic data with RVec columns |
| 138 | + rdf = ROOT.RDataFrame(100) |
| 139 | + rdf = rdf.Define("nTracks", "(int)(gRandom->Uniform(1, 10))") |
| 140 | + rdf = rdf.Define("trackPt", |
| 141 | + "ROOT::VecOps::RVec<double> v(nTracks); " |
| 142 | + "for(int i=0; i<nTracks; i++) v[i] = gRandom->Gaus(10, 3); " |
| 143 | + "return v;") |
| 144 | + |
| 145 | + dsl = DSLCompiler({ |
| 146 | + "nTracks": "int", |
| 147 | + "trackPt": "RVec<double>", |
| 148 | + }) |
| 149 | + |
| 150 | + # Register RVec function |
| 151 | + dsl.register_function_cpp(''' |
| 152 | + double sum_pt(const RVec<double>& pts) { |
| 153 | + return Sum(pts); |
| 154 | + } |
| 155 | + ''') |
| 156 | + |
| 157 | + dsl.register_function_cpp(''' |
| 158 | + double max_pt(const RVec<double>& pts) { |
| 159 | + return pts.size() > 0 ? Max(pts) : 0.0; |
| 160 | + } |
| 161 | + ''') |
| 162 | + |
| 163 | + func_sum = dsl.get_registered_function("sum_pt") |
| 164 | + func_max = dsl.get_registered_function("max_pt") |
| 165 | + |
| 166 | + print(f"✅ Registered: sum_pt → {func_sum.cpp_name}") |
| 167 | + print(f"✅ Registered: max_pt → {func_max.cpp_name}") |
| 168 | + |
| 169 | + rdf = dsl.apply(rdf) |
| 170 | + rdf = rdf.Define("total_pt", f"{func_sum.cpp_name}(trackPt)") |
| 171 | + rdf = rdf.Define("leading_pt", f"{func_max.cpp_name}(trackPt)") |
| 172 | + |
| 173 | + print(f"\n Mean total pt: {rdf.Mean('total_pt').GetValue():.2f}") |
| 174 | + print(f" Mean leading pt: {rdf.Mean('leading_pt').GetValue():.2f}") |
| 175 | + |
| 176 | + return dsl, rdf |
| 177 | + |
| 178 | + |
| 179 | +def example_draw_like(): |
| 180 | + """TTree::Draw-like functionality.""" |
| 181 | + print("\n" + "=" * 60) |
| 182 | + print("Example 4: TTree::Draw-like Usage") |
| 183 | + print("=" * 60) |
| 184 | + |
| 185 | + # Create synthetic data |
| 186 | + rdf = ROOT.RDataFrame(1000) |
| 187 | + rdf = rdf.Define("px", "gRandom->Gaus(0, 10)") |
| 188 | + rdf = rdf.Define("py", "gRandom->Gaus(0, 10)") |
| 189 | + rdf = rdf.Define("eta", "gRandom->Uniform(-2.5, 2.5)") |
| 190 | + rdf = rdf.Define("isGood", "abs(eta) < 1.0") |
| 191 | + |
| 192 | + dsl = DSLCompiler({ |
| 193 | + "px": "double", |
| 194 | + "py": "double", |
| 195 | + "eta": "double", |
| 196 | + "isGood": "bool", |
| 197 | + }) |
| 198 | + |
| 199 | + # Register pt function |
| 200 | + dsl.register_function_cpp(''' |
| 201 | + double pt(double px, double py) { |
| 202 | + return sqrt(px*px + py*py); |
| 203 | + } |
| 204 | + ''') |
| 205 | + |
| 206 | + # Define computed column using registered function |
| 207 | + func = dsl.get_registered_function("pt") |
| 208 | + dsl.define("track_pt", f"{func.cpp_name}(px, py)") |
| 209 | + |
| 210 | + # Apply |
| 211 | + rdf = dsl.apply(rdf) |
| 212 | + |
| 213 | + print("✅ Ready for TTree::Draw-like plotting") |
| 214 | + print("\n Available expressions:") |
| 215 | + print(" - dsl.draw('track_pt', rdf)") |
| 216 | + print(" - dsl.draw('track_pt:eta', rdf, type='hist2d')") |
| 217 | + print(" - dsl.draw('track_pt', rdf, selection='isGood')") |
| 218 | + |
| 219 | + # Test if dfdraw is available |
| 220 | + try: |
| 221 | + import dfdraw |
| 222 | + print("\n dfdraw available - can run draw()") |
| 223 | + |
| 224 | + # 1D histogram |
| 225 | + fig, ax, stats = dsl.draw("track_pt", rdf) |
| 226 | + print(f" ✅ 1D histogram: mean={stats.get('mean', 'N/A'):.2f}") |
| 227 | + |
| 228 | + except ImportError: |
| 229 | + print("\n ⚠️ dfdraw not installed - skipping draw() demo") |
| 230 | + print(" Install with: pip install dfdraw") |
| 231 | + |
| 232 | + return dsl, rdf |
| 233 | + |
| 234 | + |
| 235 | +def example_lambda_rejection(): |
| 236 | + """Demonstrate FROZEN RULE #1: Lambda rejection.""" |
| 237 | + print("\n" + "=" * 60) |
| 238 | + print("Example 5: FROZEN RULE #1 - Lambda Rejection") |
| 239 | + print("=" * 60) |
| 240 | + |
| 241 | + dsl = DSLCompiler({"x": "double"}) |
| 242 | + |
| 243 | + # Try to register a lambda (should fail) |
| 244 | + print("Attempting to register lambda expression...") |
| 245 | + try: |
| 246 | + dsl.register_function_cpp("[](double x) { return x * 2; }") |
| 247 | + print("❌ Lambda was accepted (BUG!)") |
| 248 | + except ValueError as e: |
| 249 | + print("✅ Lambda correctly rejected:") |
| 250 | + print(f" {str(e)[:60]}...") |
| 251 | + |
| 252 | + # Named function works |
| 253 | + print("\nRegistering named function instead...") |
| 254 | + dsl.register_function_cpp("double double_it(double x) { return x * 2; }") |
| 255 | + func = dsl.get_registered_function("double_it") |
| 256 | + print(f"✅ Named function accepted: {func.cpp_name}") |
| 257 | + |
| 258 | + |
| 259 | +def main(): |
| 260 | + """Run all examples.""" |
| 261 | + print("\n" + "=" * 60) |
| 262 | + print("RDataFrameDSL — register_function_cpp() Examples") |
| 263 | + print("Phase 13.5.B Implementation") |
| 264 | + print("=" * 60) |
| 265 | + |
| 266 | + example_basic() |
| 267 | + example_multiple_functions() |
| 268 | + example_rvec() |
| 269 | + example_draw_like() |
| 270 | + example_lambda_rejection() |
| 271 | + |
| 272 | + print("\n" + "=" * 60) |
| 273 | + print("✅ All examples completed successfully!") |
| 274 | + print("=" * 60 + "\n") |
| 275 | + |
| 276 | + |
| 277 | +if __name__ == "__main__": |
| 278 | + main() |
0 commit comments