diff --git a/src/main/java/net/fabricmc/tinyremapper/TinyRemapper.java b/src/main/java/net/fabricmc/tinyremapper/TinyRemapper.java index a7f0c090..0d0a240b 100644 --- a/src/main/java/net/fabricmc/tinyremapper/TinyRemapper.java +++ b/src/main/java/net/fabricmc/tinyremapper/TinyRemapper.java @@ -27,6 +27,7 @@ import java.nio.file.attribute.BasicFileAttributes; import java.util.ArrayDeque; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.HashMap; @@ -43,10 +44,12 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.ForkJoinPool; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiConsumer; +import java.util.function.Consumer; import java.util.function.Supplier; import java.util.regex.Pattern; import java.util.stream.Collectors; @@ -864,92 +867,124 @@ private void handleConflicts(MrjState state) { throw new RuntimeException("Unfixable conflicts"); } } - + + public interface InputConsumer { + void accept(InputTag[] tags, String internalName, byte[] bytecode); + } + public void apply(final BiConsumer outputConsumer) { apply(outputConsumer, (InputTag[]) null); } public void apply(final BiConsumer outputConsumer, InputTag... inputTags) { - // We expect apply() to be invoked only once if the user didn't request any input tags. Invoking it multiple - // times still works with keepInputData=true, but wastes some time by redoing most processing. - // With input tags the first apply invocation computes the entire output, but yields only what matches the given - // input tags. The output data is being kept for eventual further apply() outputs, only finish() clears it. - boolean hasInputTags = !singleInputTags.get().isEmpty(); - - synchronized (this) { // guard against concurrent apply invocations + this.apply((tags, internalName, bytecode) -> outputConsumer.accept(internalName, bytecode), inputTags); + } + + public void apply(InputConsumer inputConsumer) { + this.apply(inputConsumer, null); + } + + public void apply(InputConsumer inputConsumer, InputTag[] inputs) { + Set tags = singleInputTags.get().keySet(); + boolean isAll = inputs == null || (tags.containsAll(Arrays.asList(inputs)) && tags.size() == inputs.length); + boolean hasInputTags = !tags.isEmpty(); + boolean consumedAll = false; + BiConsumer consumer = (cls, data) -> inputConsumer.accept(cls.getInputTags(), ClassInstance.getMrjName(cls.getContext().remapper.map(cls.getName()), cls.getMrjVersion()), data); + + synchronized(this) { refresh(); - - if (outputBuffer == null) { // first (inputTags present) or full (no input tags) output invocation, process everything but don't output if input tags are present - BiConsumer immediateOutputConsumer; - - if (fixPackageAccess || hasInputTags) { // need re-processing or output buffering for repeated applies + + if(outputBuffer == null) { + BiConsumer immediateConsumer = null; + boolean hasPostProcess = fixPackageAccess; + if (hasInputTags || hasPostProcess) { // need re-processing or output buffering for repeated applies outputBuffer = new ConcurrentHashMap<>(); - immediateOutputConsumer = outputBuffer::put; - } else { - immediateOutputConsumer = (cls, data) -> outputConsumer.accept(ClassInstance.getMrjName(cls.getContext().remapper.map(cls.getName()), cls.getMrjVersion()), data); + immediateConsumer = outputBuffer::put; } - - List> futures = new ArrayList<>(); - + + if(!hasPostProcess && (isAll || !hasInputTags)) { + if(immediateConsumer != null) { + immediateConsumer = immediateConsumer.andThen(consumer); + } else { + immediateConsumer = consumer; + } + consumedAll = true; + } + for (MrjState state : mrjStates.values()) { mrjRefresh(state); - - for (final ClassInstance cls : state.classes.values()) { - if (!cls.isInput) continue; - + + BiConsumer finalImmediateConsumer = immediateConsumer; + this.executeThreaded(state.classes.values(), cls -> { + if (!cls.isInput) return; + if (cls.data == null) { if (!hasInputTags && !keepInputData) throw new IllegalStateException("invoking apply multiple times without input tags or hasInputData"); throw new IllegalStateException("data for input class " + cls + " is missing?!"); } - - futures.add(threadPool.submit(() -> immediateOutputConsumer.accept(cls, apply(cls)))); - } + + finalImmediateConsumer.accept(cls, apply(cls)); + }); } - - waitForAll(futures); - + boolean needsFixes = !classesToMakePublic.isEmpty() || !membersToMakePublic.isEmpty(); - + if (fixPackageAccess) { if (needsFixes) { System.out.printf("Fixing access for %d classes and %d members.%n", classesToMakePublic.size(), membersToMakePublic.size()); } - - for (Map.Entry entry : outputBuffer.entrySet()) { + + this.executeThreaded(this.outputBuffer.entrySet(), entry -> { ClassInstance cls = entry.getKey(); byte[] data = entry.getValue(); - + if (needsFixes) { data = fixClass(cls, data); } - + if (hasInputTags) { entry.setValue(data); - } else { - outputConsumer.accept(ClassInstance.getMrjName(cls.getContext().remapper.map(cls.getName()), cls.getMrjVersion()), data); } + if(isAll) { + consumer.accept(cls, data); + } + }); + if(isAll) { + consumedAll = true; } - + if (!hasInputTags) outputBuffer = null; // don't expect repeat invocations - + classesToMakePublic.clear(); membersToMakePublic.clear(); } else if (needsFixes) { throw new RuntimeException(String.format("%d classes and %d members need access fixes", classesToMakePublic.size(), membersToMakePublic.size())); } } - - assert hasInputTags == (outputBuffer != null); - - if (outputBuffer != null) { // partial output selected by input tags - for (Map.Entry entry : outputBuffer.entrySet()) { - ClassInstance cls = entry.getKey(); - - if (inputTags == null || cls.hasAnyInputTag(inputTags)) { - outputConsumer.accept(ClassInstance.getMrjName(cls.getContext().remapper.map(cls.getName()), cls.getMrjVersion()), entry.getValue()); - } + } + + // this can be done outside synchronize + if (!consumedAll) { // partial output selected by input tags + this.executeThreaded(this.outputBuffer.entrySet(), entry -> { + ClassInstance key = entry.getKey(); + if(inputs == null || key.hasAnyInputTag(inputs)) { + consumer.accept(key, entry.getValue()); } + }); + } + } + + private void executeThreaded(Collection list, Consumer consumer) { + // the pool used by Stream#parallel, it's best to leave the threading to java + // made in preparation for custom thread pools + if(this.threadPool == ForkJoinPool.commonPool()) { + list.parallelStream().forEach(consumer); + } else { + List> futures = new ArrayList<>(outputBuffer.size()); + for(T entry : list) { + futures.add(this.threadPool.submit(() -> consumer.accept(entry))); } + waitForAll(futures); } }