Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
260 changes: 186 additions & 74 deletions src/test/java/org/apache/sysds/test/AutomatedTestBase.java
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,15 @@ public abstract class AutomatedTestBase {
public static boolean TEST_GPU = false;
public static final double GPU_TOLERANCE = 1e-9;

// ms wait time
public static final int FED_WORKER_WAIT = 3000;
public static final int FED_MONITOR_WAIT = 10000;
/**
* Default upper bound (ms) passed to federated worker readiness waits. The wait returns as soon
* as the worker's TCP port accepts a connection, so this value only affects the deadline used
* when a worker never becomes ready. {@link FederatedWorkerUtils} clamps caller values below its
* enforced floor up to that floor, so the effective ceiling is at least that floor regardless
* of this constant.
*/
public static final int FED_WORKER_WAIT = 3000;
public static final int FED_MONITOR_WAIT = 10000;
public static final int FED_WORKER_WAIT_S = 50;


Expand Down Expand Up @@ -1642,13 +1648,14 @@ protected Process startLocalFedWorker(int port){

/**
* Start a new JVM for a federated worker at the port.
*
* @param port Port to use for the JVM
* @param sleep The sleep time to wait for the worker to start
*
* @param port Port to use for the JVM
* @param timeoutMs Upper bound on the wait for the worker to become ready, in ms; raised to a
* minimum value enforced inside {@link FederatedWorkerUtils}.
* @return The process containing the worker
*/
protected Process startLocalFedWorker(int port, int sleep){
return startLocalFedWorker(port, null, sleep);
protected Process startLocalFedWorker(int port, int timeoutMs){
return startLocalFedWorker(port, null, timeoutMs);
}

/**
Expand All @@ -1665,18 +1672,64 @@ protected Process startLocalFedWorker(int port, String[] addArgs) {

/**
* Start new JVM for a federated worker at the port.
*
* @param port Port to use for the JVM
* @param addArgs The arguments to add
* @param sleep The time to wait for the process to start
*
* <p>Returns once the worker's TCP port accepts connections (the worker opens the port after
* Netty's bind completes), or throws a {@link RuntimeException} after {@code timeoutMs} elapses.
*
* @param port Port to use for the JVM
* @param addArgs The arguments to add
* @param timeoutMs Upper bound on the wait for the worker to become ready, in ms; raised to a
* minimum value enforced inside {@link FederatedWorkerUtils}.
* @return the process associated with the worker.
*/
protected static Process startLocalFedWorker(int port, String[] addArgs, int sleep) {
Process process = null;
protected static Process startLocalFedWorker(int port, String[] addArgs, int timeoutMs) {
Process process = spawnLocalFedWorker(port, addArgs);
FederatedWorkerUtils.waitForWorker(process, port, timeoutMs);
return process;
}

/**
* Start N federated worker JVMs back to back, then wait for all of them to become ready in one
* shared poll loop. The wall-clock wait scales with the slowest worker rather than the sum of the
* per-worker waits.
*
* @param ports Ports to use, one per worker
* @return The process per port, in the same order as {@code ports}.
*/
protected static Process[] startLocalFedWorkers(int[] ports) {
return startLocalFedWorkers(ports, null, FED_WORKER_WAIT);
}

/** @see #startLocalFedWorkers(int[], String[], int) */
protected static Process[] startLocalFedWorkers(int[] ports, String[] addArgs) {
return startLocalFedWorkers(ports, addArgs, FED_WORKER_WAIT);
}

/**
* Start N federated worker JVMs back to back, then wait for all of them to become ready in one
* shared poll loop.
*
* @param ports Ports to use, one per worker
* @param addArgs Extra worker CLI args (applied to every worker), or null
* @param timeoutMs Upper bound on the wait, in ms; raised to a minimum value enforced inside
* {@link FederatedWorkerUtils}.
* @return The process per port, in the same order as {@code ports}.
*/
protected static Process[] startLocalFedWorkers(int[] ports, String[] addArgs, int timeoutMs) {
Process[] processes = new Process[ports.length];
for(int i = 0; i < ports.length; i++) {
processes[i] = spawnLocalFedWorker(ports[i], addArgs);
}
FederatedWorkerUtils.waitForWorkers(processes, ports, timeoutMs);
return processes;
}

/** Spawn a federated worker JVM and return without waiting for the port to bind. */
private static Process spawnLocalFedWorker(int port, String[] addArgs) {
String separator = System.getProperty("file.separator");
String classpath = System.getProperty("java.class.path");
String path = System.getProperty("java.home") + separator + "bin" + separator + "java";
String[] args = new String[] {path, "-Xmx1000m", "-Xms1000m", "-Xmn100m",
String[] args = new String[] {path, "-Xmx1000m", "-Xms1000m", "-Xmn100m",
"--add-opens=java.base/java.nio=ALL-UNNAMED" ,
"--add-opens=java.base/java.io=ALL-UNNAMED" ,
"--add-opens=java.base/java.util=ALL-UNNAMED" ,
Expand All @@ -1701,19 +1754,14 @@ protected static Process startLocalFedWorker(int port, String[] addArgs, int sle
DMLScript.class.getName(), "-w", Integer.toString(port), "-stats"});
if(addArgs != null)
args = ArrayUtils.addAll(args, addArgs);

ProcessBuilder processBuilder = new ProcessBuilder(args).inheritIO();

ProcessBuilder processBuilder = new ProcessBuilder(args).inheritIO();
try {
process = processBuilder.start();
// Give some time to startup the worker.
sleep(sleep);
return processBuilder.start();
}
catch(IOException | InterruptedException e) {
e.printStackTrace();
catch(IOException e) {
throw new RuntimeException("Failed to launch federated worker process on port " + port, e);
}
isAlive(process);
return process;
}

/**
Expand Down Expand Up @@ -1743,7 +1791,7 @@ protected Process startLocalFedMonitoring(int port, String[] addArgs) {
}

/**
* Start a thread for a worker. This will share the same JVM, so all static variables will be shared.!
* Start a thread for a worker. This will share the same JVM, so all static variables will be shared.
*
* Also when using the local Fed Worker thread the statistics printing, and clearing from the worker is disabled.
*
Expand All @@ -1769,63 +1817,112 @@ public static Thread startLocalFedWorkerThread(int port, String[] otherArgs) {
}

/**
* Start a thread for a worker. This will share the same JVM, so all static variables will be shared.!
* Start a thread for a worker. This will share the same JVM, so all static variables will be shared.
*
* Also when using the local Fed Worker thread the statistics printing, and clearing from the worker is disabled.
*
* @param port Port to use
* @param sleep The amount of time to wait for the worker startup. in Milliseconds
* @param port Port to use
* @param timeoutMs Upper bound on the wait for the worker to become ready, in ms; raised to a
* minimum value enforced inside {@link FederatedWorkerUtils}.
* @return The thread associated with the worker.
*/
public static Thread startLocalFedWorkerThread(int port, int sleep) {
return startLocalFedWorkerThread(port, null, sleep);
public static Thread startLocalFedWorkerThread(int port, int timeoutMs) {
return startLocalFedWorkerThread(port, null, timeoutMs);
}

/**
* Start a thread for a worker. This will share the same JVM, so all static variables will be shared.!
*
* Also when using the local Fed Worker thread the statistics printing, and clearing from the worker is disabled.
*
* Start a thread for a worker. This will share the same JVM, so all static variables will be shared.
*
* <p>Also when using the local Fed Worker thread the statistics printing, and clearing from the worker is
* disabled.
*
* <p>Returns once the worker's TCP port accepts connections (the worker opens the port after Netty's bind
* completes), or throws a {@link RuntimeException} after {@code timeoutMs} elapses.
*
* @param port Port to use
* @param otherArgs The command line arguments to start the worker with
* @param sleep The amount of time to wait for the worker startup. in Milliseconds
* @param timeoutMs Upper bound on the wait for the worker to become ready, in ms; raised to a
* minimum value enforced inside {@link FederatedWorkerUtils}.
* @return The thread associated with the worker.
*/
public static Thread startLocalFedWorkerThread(int port, String[] otherArgs, int sleep) {
public static Thread startLocalFedWorkerThread(int port, String[] otherArgs, int timeoutMs) {
Thread t = spawnLocalFedWorkerThread(port, otherArgs);
FederatedWorkerUtils.waitForWorker(t, port, timeoutMs);
return t;
}

/**
* Start N federated worker threads in the same JVM back to back, then wait for all of them to
* become ready in one shared poll loop. The wall-clock wait scales with the slowest worker rather
* than the sum of the per-worker waits.
*
* @param ports Ports to use, one per worker
* @return The thread per port, in the same order as {@code ports}.
*/
public static Thread[] startLocalFedWorkerThreads(int[] ports) {
return startLocalFedWorkerThreads(ports, null, FED_WORKER_WAIT);
}

/** @see #startLocalFedWorkerThreads(int[], String[], int) */
public static Thread[] startLocalFedWorkerThreads(int[] ports, String[] otherArgs) {
return startLocalFedWorkerThreads(ports, otherArgs, FED_WORKER_WAIT);
}

/**
* Start N federated worker threads in the same JVM back to back, then wait for all of them to
* become ready in one shared poll loop.
*
* @param ports Ports to use, one per worker
* @param otherArgs Extra worker CLI args (applied to every worker), or null
* @param timeoutMs Upper bound on the wait, in ms; raised to a minimum value enforced inside
* {@link FederatedWorkerUtils}.
* @return The thread per port, in the same order as {@code ports}.
*/
public static Thread[] startLocalFedWorkerThreads(int[] ports, String[] otherArgs, int timeoutMs) {
Thread[] threads = new Thread[ports.length];
for(int i = 0; i < ports.length; i++) {
threads[i] = spawnLocalFedWorkerThread(ports[i], otherArgs);
// Sleep THREAD_SPAWN_STAGGER_MS between in-JVM thread spawns to reduce contention on
// shared static initialization in DMLScript / FederatedWorker (e.g. LineageCacheConfig
// setters) when multiple worker threads enter main() concurrently.
if(i + 1 < ports.length) {
try {
java.util.concurrent.TimeUnit.MILLISECONDS.sleep(THREAD_SPAWN_STAGGER_MS);
}
catch(InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException("Interrupted while spawning federated worker threads", e);
}
}
}
FederatedWorkerUtils.waitForWorkers(threads, ports, timeoutMs);
return threads;
}

private static final int THREAD_SPAWN_STAGGER_MS = 25;

/** Spawn a federated worker thread in this JVM and return without waiting for the port to bind. */
private static Thread spawnLocalFedWorkerThread(int port, String[] otherArgs) {
ArrayList<String> args = new ArrayList<>();

args.add("-w");
args.add(Integer.toString(port));

if(otherArgs != null)
for( String s : otherArgs)
for(String s : otherArgs)
args.add(s);

String[] finalArguments = args.toArray(new String[args.size()]);
Statistics.allowWorkerStatistics = false;

try {
Thread t = new Thread(() -> {
try {
main(finalArguments);
}
catch(Exception e) {
LOG.error("Exception in startup of federated worker", e);
}
});
t.start();
java.util.concurrent.TimeUnit.MILLISECONDS.sleep(sleep);
if(!t.isAlive())
throw new RuntimeException("Failed starting federated worker");
return t;
}
catch(InterruptedException e) {
e.printStackTrace();
fail("Failed to start federated worker : " + e);
// should never happen
return null;
}
Thread t = new Thread(() -> {
try {
main(finalArguments);
}
catch(Exception e) {
LOG.error("Exception in startup of federated worker", e);
}
});
t.start();
return t;
}

public static boolean isAlive(Thread... threads){
Expand All @@ -1846,28 +1943,43 @@ public static boolean isAlive(Process... processes) {

/**
* Start java worker in same JVM.
*
*
* <p>Returns once the worker's TCP port accepts connections (the worker opens the port after
* Netty's bind completes), or throws a {@link RuntimeException} after the default federated worker
* timeout elapses. The port is extracted from {@code args}, which must contain {@code "-w" <port>}.
*
* @param args the command line arguments
* @return the thread associated with the process.s
* @return the thread associated with the worker.
*/
public static Thread startLocalFedWorkerWithArgs(String[] args) {
Thread t = null;
final int port = extractWorkerPort(args);
Thread t = new Thread(() -> {
try {
main(args);
}
catch(IOException e) {
LOG.error("Exception in startup of federated worker on port " + port, e);
}
});
t.start();
FederatedWorkerUtils.waitForWorker(t, port, FED_WORKER_WAIT);
return t;
}

try {
t = new Thread(() -> {
private static int extractWorkerPort(String[] args) {
for(int i = 0; i < args.length - 1; i++) {
if("-w".equals(args[i])) {
try {
main(args);
return Integer.parseInt(args[i + 1]);
}
catch(IOException e) {
catch(NumberFormatException e) {
throw new IllegalArgumentException(
"Federated worker args contain non-numeric port after -w: " + args[i + 1], e);
}
});
t.start();
java.util.concurrent.TimeUnit.MILLISECONDS.sleep(FED_WORKER_WAIT);
}
catch(InterruptedException e) {
// Should happen at closing of the worker so don't print
}
}
return t;
throw new IllegalArgumentException("Federated worker args must contain '-w <port>': "
+ Arrays.toString(args));
}

private boolean rCompareException(boolean exceptionExpected, String errMessage, Throwable e, boolean result) {
Expand Down
Loading
Loading