|
4 | 4 |
|
5 | 5 | from consts import MODAL_CUDA_INCLUDE_DIRS, MODAL_PATH |
6 | 6 | from modal import App, Image, Mount |
7 | | -from run_eval import run_cuda_script, run_pytorch_script |
| 7 | +from run_eval import run_cuda_script, run_pytorch_script, CompileResult, RunResult, FullResult |
8 | 8 |
|
9 | 9 | # Create a stub for the Modal app |
10 | 10 | # IMPORTANT: This has to stay in separate file or modal breaks |
@@ -96,57 +96,20 @@ def modal_run_cuda_script( # # noqa: C901 |
96 | 96 | submission_content: str = None, |
97 | 97 | timeout_seconds: int = 600, |
98 | 98 | arch: int = None, |
99 | | -) -> tuple[str, float]: |
| 99 | +) -> FullResult: |
100 | 100 | """Modal version of run_cuda_script, handling timeouts""" |
101 | 101 | try: |
102 | 102 | with timeout(timeout_seconds): |
103 | | - compile_result, run_result = run_cuda_script( |
| 103 | + comp, run = run_cuda_script( |
104 | 104 | script_content, |
105 | 105 | reference_content=reference_content, |
106 | 106 | submission_content=submission_content, |
107 | 107 | arch=arch, |
108 | 108 | include_dirs=MODAL_CUDA_INCLUDE_DIRS, |
109 | 109 | ) |
110 | | - |
111 | | - if not compile_result.success: |
112 | | - if not compile_result.nvcc_found: |
113 | | - return ( |
114 | | - "Error executing script: NVCC not found:\n" |
115 | | - + f"command `{compile_result.command}` failed with exit code {compile_result.exit_code}:\n" |
116 | | - + compile_result.stderr, |
117 | | - 0.0, |
118 | | - ) |
119 | | - return ( |
120 | | - "Error executing script: CUDA compilation failed with return code " |
121 | | - + f"{compile_result.exit_code}:\n{compile_result.stderr}\n" |
122 | | - + f"compile command: `{compile_result.command}`", |
123 | | - 0.0, |
124 | | - ) |
125 | | - |
126 | | - if not run_result.success: |
127 | | - # exit code 1 encodes failed tests |
128 | | - if run_result.exit_code == 1: |
129 | | - return f"check_implementation failed:\n{run_result.stderr}", 0.0 |
130 | | - else: |
131 | | - return ( |
132 | | - f"Script failed with exit code ({run_result.exit_code}):\n{run_result.stderr}", |
133 | | - 0.0, |
134 | | - ) |
135 | | - |
136 | | - print("run process stdout:", run_result.stdout) |
137 | | - print("run process stderr:", run_result.stderr) |
138 | | - |
139 | | - score = float(run_result.result.get("duration.mean", "0.0")) / 1e9 |
140 | | - passed = run_result.result.get("check", "") == "pass" |
141 | | - if not passed: |
142 | | - return "check_implementation failed", 0.0 |
143 | | - |
144 | | - if score is None: |
145 | | - return run_result.stdout, run_result.duration |
146 | | - |
147 | | - return run_result.stdout, score |
148 | | - |
| 110 | + return FullResult(success=True, error="", compile=comp, run=run) |
| 111 | + # TODO fixup error handling! |
149 | 112 | except TimeoutException as e: |
150 | | - return f"Timeout Error: {str(e)}", 0.0 |
| 113 | + return FullResult(success=False, error=f"Timeout Error: {str(e)}", compile=None, run=None) |
151 | 114 | except Exception as e: |
152 | | - return f"Error executing script: {str(e)}", 0.0 |
| 115 | + return FullResult(success=False, error=f"Error executing script: {str(e)}", compile=None, run=None) |
0 commit comments