Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Properties;
import java.util.concurrent.TimeUnit;
import java.util.stream.Stream;

import static org.apache.iotdb.it.env.cluster.ClusterConstant.AI_NODE_NAME;
Expand All @@ -52,7 +53,8 @@ public class AINodeWrapper extends AbstractNodeWrapper {
private final String seedConfigNode;
private final int clusterIngressPort;

private static final String SCRIPT_FILE = "start-ainode.sh";
private static final String START_SCRIPT_FILE = "start-ainode.sh";
private static final String STOP_SCRIPT_FILE = "stop-ainode.sh";

private static final String SHELL_COMMAND = "bash";

Expand Down Expand Up @@ -165,8 +167,8 @@ public void start() {
// start AINode
List<String> startCommand = new ArrayList<>();
startCommand.add(SHELL_COMMAND);
startCommand.add(filePrefix + File.separator + SCRIPT_PATH + File.separator + SCRIPT_FILE);
startCommand.add("-r");
startCommand.add(
filePrefix + File.separator + SCRIPT_PATH + File.separator + START_SCRIPT_FILE);

ProcessBuilder processBuilder =
new ProcessBuilder(startCommand)
Expand All @@ -179,6 +181,48 @@ public void start() {
}
}

@Override
public void stop() {
if (this.instance == null) {
return;
}
try {
// stop AINode
File stdoutFile = new File(getLogPath());
String filePrefix = getNodePath();
List<String> stopCommand = new ArrayList<>();
stopCommand.add(SHELL_COMMAND);
stopCommand.add(
filePrefix + File.separator + SCRIPT_PATH + File.separator + STOP_SCRIPT_FILE);
ProcessBuilder processBuilder =
new ProcessBuilder(stopCommand)
.redirectOutput(ProcessBuilder.Redirect.appendTo(stdoutFile))
.redirectError(ProcessBuilder.Redirect.appendTo(stdoutFile));
Process stopProcess = processBuilder.inheritIO().start();
if (!stopProcess.waitFor(20, TimeUnit.SECONDS)) {
logger.warn("Node {} does not exit within 20s, killing it", getId());
if (!this.instance.destroyForcibly().waitFor(10, TimeUnit.SECONDS)) {
logger.error("Cannot forcibly stop node {}", getId());
}
}
int exitCode = stopProcess.exitValue();
if (exitCode != 0) {
logger.warn("Node {}'s stop script exited with code {}", getId(), exitCode);
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
logger.error("Waiting node to shutdown error.", e);
} catch (IOException e) {
logger.error("Waiting node to shutdown error.", e);
}
logger.info("In test {} {} stopped.", getTestLogDirName(), getId());
}

@Override
public void stopForcibly() {
this.stop();
}

@Override
public int getMetricPort() {
// no metric currently
Expand Down
20 changes: 20 additions & 0 deletions iotdb-core/ainode/iotdb/ainode/core/device/backend/cuda_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
#
import time

import torch

Expand All @@ -24,6 +25,9 @@
class CUDABackend(BackendAdapter):
type = BackendType.CUDA

def __init__(self) -> None:
self._safe_cuda_init()

def is_available(self) -> bool:
return torch.cuda.is_available()

Expand All @@ -37,3 +41,19 @@ def make_device(self, index: int | None) -> torch.device:

def set_device(self, index: int) -> None:
torch.cuda.set_device(index)

def _safe_cuda_init(self) -> None:
# Safe CUDA initialization to avoid potential deadlocks
# This is a workaround for certain PyTorch versions where the first CUDA call can cause a long delay
# By calling a simple CUDA operation at startup, we can ensure that the CUDA context is initialized early
# and avoid unexpected delays during actual model loading or inference.
attempt_cnt = 3
for attempt in range(attempt_cnt):
try:
if self.is_available():
return
raise RuntimeError("CUDA not available")
except Exception as e:
print(f"CUDA init attempt {attempt + 1} failed: {e}")
if attempt < attempt_cnt:
time.sleep(1.5)
27 changes: 15 additions & 12 deletions iotdb-core/ainode/iotdb/ainode/core/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,26 @@
# specific language governing permissions and limitations
# under the License.
#

import multiprocessing
import sys

# PyInstaller multiprocessing support
# freeze_support() is essential for PyInstaller frozen executables on all platforms
# It detects if the current process is a multiprocessing child process
# If it is, it executes the child process target function and exits
# If it's not, it returns immediately and continues with main() execution
# This prevents child processes from executing the main application logic
if getattr(sys, "frozen", False):
# Call freeze_support() for both standard multiprocessing and torch.multiprocessing
multiprocessing.freeze_support()
multiprocessing.set_start_method("spawn", force=True)

import torch.multiprocessing as mp

mp.freeze_support()
mp.set_start_method("spawn", force=True)

from iotdb.ainode.core.ai_node import AINode
from iotdb.ainode.core.log import Logger

Expand All @@ -42,7 +57,6 @@ def main():
command = arguments[1]
if command == "start":
try:
mp.set_start_method("spawn", force=True)
logger.info(f"Current multiprocess start method: {mp.get_start_method()}")
logger.info("IoTDB-AINode is starting...")
ai_node = AINode()
Expand All @@ -55,15 +69,4 @@ def main():


if __name__ == "__main__":
# PyInstaller multiprocessing support
# freeze_support() is essential for PyInstaller frozen executables on all platforms
# It detects if the current process is a multiprocessing child process
# If it is, it executes the child process target function and exits
# If it's not, it returns immediately and continues with main() execution
# This prevents child processes from executing the main application logic
if getattr(sys, "frozen", False):
# Call freeze_support() for both standard multiprocessing and torch.multiprocessing
multiprocessing.freeze_support()
mp.freeze_support()

main()
2 changes: 1 addition & 1 deletion iotdb-core/ainode/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ exclude = [
python = ">=3.11.0,<3.12.0"

# ---- DL / HF stack ----
torch = "^2.8.0,<2.9.0"
torch = "^2.9.0,<2.10.0"
torchmetrics = "^1.8.0"
transformers = "==4.56.2"
tokenizers = ">=0.22.0,<=0.23.0"
Expand Down