diff --git a/src/main/java/com/hubspot/jinjava/JinjavaConfig.java b/src/main/java/com/hubspot/jinjava/JinjavaConfig.java index 3cff4787c..343df63a8 100644 --- a/src/main/java/com/hubspot/jinjava/JinjavaConfig.java +++ b/src/main/java/com/hubspot/jinjava/JinjavaConfig.java @@ -88,6 +88,7 @@ public class JinjavaConfig { private final ExecutionMode executionMode; private final LegacyOverrides legacyOverrides; private final boolean enablePreciseDivideFilter; + private final boolean enableFilterChainOptimization; private final ObjectMapper objectMapper; private final Features features; @@ -151,6 +152,7 @@ private JinjavaConfig(Builder builder) { legacyOverrides = builder.legacyOverrides; dateTimeProvider = builder.dateTimeProvider; enablePreciseDivideFilter = builder.enablePreciseDivideFilter; + enableFilterChainOptimization = builder.enableFilterChainOptimization; objectMapper = setupObjectMapper(builder.objectMapper); objectUnwrapper = builder.objectUnwrapper; processors = builder.processors; @@ -307,6 +309,10 @@ public boolean getEnablePreciseDivideFilter() { return enablePreciseDivideFilter; } + public boolean isEnableFilterChainOptimization() { + return enableFilterChainOptimization; + } + public DateTimeProvider getDateTimeProvider() { return dateTimeProvider; } @@ -349,6 +355,7 @@ public static class Builder { private ExecutionMode executionMode = DefaultExecutionMode.instance(); private LegacyOverrides legacyOverrides = LegacyOverrides.NONE; private boolean enablePreciseDivideFilter = false; + private boolean enableFilterChainOptimization = false; private ObjectMapper objectMapper = null; private ObjectUnwrapper objectUnwrapper = new JinjavaObjectUnwrapper(); @@ -520,6 +527,13 @@ public Builder withEnablePreciseDivideFilter(boolean enablePreciseDivideFilter) return this; } + public Builder withEnableFilterChainOptimization( + boolean enableFilterChainOptimization + ) { + this.enableFilterChainOptimization = enableFilterChainOptimization; + return this; + } + public Builder withObjectMapper(ObjectMapper objectMapper) { this.objectMapper = objectMapper; return this; diff --git a/src/main/java/com/hubspot/jinjava/el/ext/AstFilterChain.java b/src/main/java/com/hubspot/jinjava/el/ext/AstFilterChain.java new file mode 100644 index 000000000..30dfc4435 --- /dev/null +++ b/src/main/java/com/hubspot/jinjava/el/ext/AstFilterChain.java @@ -0,0 +1,205 @@ +package com.hubspot.jinjava.el.ext; + +import com.hubspot.jinjava.interpret.DisabledException; +import com.hubspot.jinjava.interpret.JinjavaInterpreter; +import com.hubspot.jinjava.interpret.TemplateError; +import com.hubspot.jinjava.interpret.TemplateError.ErrorItem; +import com.hubspot.jinjava.interpret.TemplateError.ErrorReason; +import com.hubspot.jinjava.interpret.TemplateError.ErrorType; +import com.hubspot.jinjava.lib.filter.Filter; +import com.hubspot.jinjava.objects.SafeString; +import de.odysseus.el.tree.Bindings; +import de.odysseus.el.tree.impl.ast.AstNode; +import de.odysseus.el.tree.impl.ast.AstParameters; +import de.odysseus.el.tree.impl.ast.AstRightValue; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import javax.el.ELContext; +import javax.el.ELException; + +/** + * AST node for a chain of filters applied to an input expression. + * Instead of creating nested AstMethod calls for each filter in a chain like: + * filter:length.filter(filter:lower.filter(filter:trim.filter(input))) + * + * This node represents the entire chain as a single evaluation unit: + * input|trim|lower|length + * + * This optimization reduces: + * - Filter lookups (done once per filter instead of per AST node traversal) + * - Method invocation overhead + * - Object wrapping/unwrapping between filters + * - Context operations + */ +public class AstFilterChain extends AstRightValue { + + protected final AstNode input; + protected final List filterSpecs; + + public AstFilterChain(AstNode input, List filterSpecs) { + this.input = Objects.requireNonNull(input, "Input node cannot be null"); + this.filterSpecs = Objects.requireNonNull(filterSpecs, "Filter specs cannot be null"); + if (filterSpecs.isEmpty()) { + throw new IllegalArgumentException("Filter chain must have at least one filter"); + } + } + + public AstNode getInput() { + return input; + } + + public List getFilterSpecs() { + return filterSpecs; + } + + @Override + public Object eval(Bindings bindings, ELContext context) { + JinjavaInterpreter interpreter = getInterpreter(context); + + if (interpreter.getContext().isValidationMode()) { + return ""; + } + + Object value = input.eval(bindings, context); + + for (FilterSpec spec : filterSpecs) { + String filterKey = ExtendedParser.FILTER_PREFIX + spec.getName(); + interpreter.getContext().addResolvedValue(filterKey); + + Filter filter; + try { + filter = interpreter.getContext().getFilter(spec.getName()); + } catch (DisabledException e) { + interpreter.addError( + new TemplateError( + ErrorType.FATAL, + ErrorReason.DISABLED, + ErrorItem.FILTER, + e.getMessage(), + spec.getName(), + interpreter.getLineNumber(), + -1, + e + ) + ); + return null; + } + if (filter == null) { + return null; + } + + Object[] args = evaluateFilterArgs(spec, bindings, context); + Map kwargs = extractNamedParams(args); + Object[] positionalArgs = extractPositionalArgs(args); + + boolean wasSafeString = value instanceof SafeString; + if (wasSafeString) { + value = value.toString(); + } + + try { + value = filter.filter(value, interpreter, positionalArgs, kwargs); + } catch (ELException e) { + throw e; + } catch (RuntimeException e) { + throw new ELException( + String.format("Error in filter '%s': %s", spec.getName(), e.getMessage()), + e + ); + } + + if (wasSafeString && filter.preserveSafeString() && value instanceof String) { + value = new SafeString((String) value); + } + } + + return value; + } + + protected JinjavaInterpreter getInterpreter(ELContext context) { + return (JinjavaInterpreter) context + .getELResolver() + .getValue(context, null, ExtendedParser.INTERPRETER); + } + + protected Object[] evaluateFilterArgs( + FilterSpec spec, + Bindings bindings, + ELContext context + ) { + AstParameters params = spec.getParams(); + if (params == null || params.getCardinality() == 0) { + return new Object[0]; + } + + Object[] args = new Object[params.getCardinality()]; + for (int i = 0; i < params.getCardinality(); i++) { + args[i] = params.getChild(i).eval(bindings, context); + } + return args; + } + + private Map extractNamedParams(Object[] args) { + Map kwargs = new LinkedHashMap<>(); + for (Object arg : args) { + if (arg instanceof NamedParameter) { + NamedParameter namedParam = (NamedParameter) arg; + kwargs.put(namedParam.getName(), namedParam.getValue()); + } + } + return kwargs; + } + + private Object[] extractPositionalArgs(Object[] args) { + List positional = new ArrayList<>(); + for (Object arg : args) { + if (!(arg instanceof NamedParameter)) { + positional.add(arg); + } + } + return positional.toArray(); + } + + @Override + public void appendStructure(StringBuilder builder, Bindings bindings) { + input.appendStructure(builder, bindings); + for (FilterSpec spec : filterSpecs) { + builder.append('|').append(spec.getName()); + AstParameters params = spec.getParams(); + if (params != null && params.getCardinality() > 0) { + params.appendStructure(builder, bindings); + } + } + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(input.toString()); + for (FilterSpec spec : filterSpecs) { + sb.append('|').append(spec.toString()); + } + return sb.toString(); + } + + @Override + public int getCardinality() { + return 1 + filterSpecs.size(); + } + + @Override + public AstNode getChild(int i) { + if (i == 0) { + return input; + } + int filterIndex = i - 1; + if (filterIndex < filterSpecs.size()) { + FilterSpec spec = filterSpecs.get(filterIndex); + return spec.getParams(); + } + return null; + } +} diff --git a/src/main/java/com/hubspot/jinjava/el/ext/ExtendedParser.java b/src/main/java/com/hubspot/jinjava/el/ext/ExtendedParser.java index 4f7f741a0..543726a7b 100644 --- a/src/main/java/com/hubspot/jinjava/el/ext/ExtendedParser.java +++ b/src/main/java/com/hubspot/jinjava/el/ext/ExtendedParser.java @@ -531,30 +531,11 @@ protected AstNode value() throws ScanException, ParseException { private AstNode parseOperators(AstNode left) throws ScanException, ParseException { if ("|".equals(getToken().getImage()) && lookahead(0).getSymbol() == IDENTIFIER) { - AstNode v = left; - - do { - consumeToken(); // '|' - String filterName = consumeToken().getImage(); - List filterParams = Lists.newArrayList(v, interpreter()); - - // optional filter args - if (getToken().getSymbol() == Symbol.LPAREN) { - AstParameters astParameters = params(); - for (int i = 0; i < astParameters.getCardinality(); i++) { - filterParams.add(astParameters.getChild(i)); - } - } - - AstProperty filterProperty = createAstDot( - identifier(FILTER_PREFIX + filterName), - "filter", - true - ); - v = createAstMethod(filterProperty, createAstParameters(filterParams)); // function("filter:" + filterName, new AstParameters(filterParams)); - } while ("|".equals(getToken().getImage())); - - return v; + if (shouldUseFilterChainOptimization()) { + return parseFiltersAsChain(left); + } else { + return parseFiltersAsNestedMethods(left); + } } else if ( "is".equals(getToken().getImage()) && "not".equals(lookahead(0).getImage()) && @@ -577,6 +558,68 @@ protected AstParameters createAstParameters(List nodes) { return new AstParameters(nodes); } + protected AstFilterChain createAstFilterChain( + AstNode input, + List filterSpecs + ) { + return new AstFilterChain(input, filterSpecs); + } + + private AstNode parseFiltersAsChain(AstNode left) throws ScanException, ParseException { + List filterSpecs = new ArrayList<>(); + + do { + consumeToken(); // '|' + String filterName = consumeToken().getImage(); + AstParameters filterParams = null; + + // optional filter args + if (getToken().getSymbol() == Symbol.LPAREN) { + filterParams = params(); + } + + filterSpecs.add(new FilterSpec(filterName, filterParams)); + } while ("|".equals(getToken().getImage())); + + return createAstFilterChain(left, filterSpecs); + } + + protected AstNode parseFiltersAsNestedMethods(AstNode left) + throws ScanException, ParseException { + AstNode v = left; + + do { + consumeToken(); // '|' + String filterName = consumeToken().getImage(); + List filterParams = Lists.newArrayList(v, interpreter()); + + // optional filter args + if (getToken().getSymbol() == Symbol.LPAREN) { + AstParameters astParameters = params(); + for (int i = 0; i < astParameters.getCardinality(); i++) { + filterParams.add(astParameters.getChild(i)); + } + } + + AstProperty filterProperty = createAstDot( + identifier(FILTER_PREFIX + filterName), + "filter", + true + ); + v = createAstMethod(filterProperty, createAstParameters(filterParams)); + } while ("|".equals(getToken().getImage())); + + return v; + } + + protected boolean shouldUseFilterChainOptimization() { + return JinjavaInterpreter + .getCurrentMaybe() + .map(JinjavaInterpreter::getConfig) + .map(JinjavaConfig::isEnableFilterChainOptimization) + .orElse(false); + } + private boolean isPossibleExpTest(Symbol symbol) { return VALID_SYMBOLS_FOR_EXP_TEST.contains(symbol); } diff --git a/src/main/java/com/hubspot/jinjava/el/ext/FilterSpec.java b/src/main/java/com/hubspot/jinjava/el/ext/FilterSpec.java new file mode 100644 index 000000000..175016913 --- /dev/null +++ b/src/main/java/com/hubspot/jinjava/el/ext/FilterSpec.java @@ -0,0 +1,48 @@ +package com.hubspot.jinjava.el.ext; + +import de.odysseus.el.tree.impl.ast.AstParameters; +import java.util.Objects; + +/** + * Specification for a filter in a filter chain. + * Holds the filter name and optional parameters. + */ +public class FilterSpec { + + private final String name; + private final AstParameters params; + + public FilterSpec(String name, AstParameters params) { + this.name = Objects.requireNonNull(name, "Filter name cannot be null"); + this.params = params; + } + + public String getName() { + return name; + } + + public AstParameters getParams() { + return params; + } + + public boolean hasParams() { + return params != null && params.getCardinality() > 0; + } + + @Override + public String toString() { + if (hasParams()) { + StringBuilder sb = new StringBuilder(name); + sb.append('('); + for (int i = 0; i < params.getCardinality(); i++) { + if (i > 0) { + sb.append(", "); + } + sb.append(params.getChild(i)); + } + sb.append(')'); + return sb.toString(); + } + return name; + } +} diff --git a/src/main/java/com/hubspot/jinjava/el/ext/eager/EagerExtendedParser.java b/src/main/java/com/hubspot/jinjava/el/ext/eager/EagerExtendedParser.java index 5ca383afd..29f53e2b7 100644 --- a/src/main/java/com/hubspot/jinjava/el/ext/eager/EagerExtendedParser.java +++ b/src/main/java/com/hubspot/jinjava/el/ext/eager/EagerExtendedParser.java @@ -198,4 +198,9 @@ protected AstList createAstList(AstParameters parameters) protected AstParameters createAstParameters(List nodes) { return new EagerAstParameters(nodes); } + + @Override + protected boolean shouldUseFilterChainOptimization() { + return false; + } } diff --git a/src/test/java/com/hubspot/jinjava/el/ext/AstFilterChainParityTest.java b/src/test/java/com/hubspot/jinjava/el/ext/AstFilterChainParityTest.java new file mode 100644 index 000000000..e95e9ab74 --- /dev/null +++ b/src/test/java/com/hubspot/jinjava/el/ext/AstFilterChainParityTest.java @@ -0,0 +1,531 @@ +package com.hubspot.jinjava.el.ext; + +import static org.assertj.core.api.Assertions.assertThat; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.hubspot.jinjava.Jinjava; +import com.hubspot.jinjava.JinjavaConfig; +import com.hubspot.jinjava.LegacyOverrides; +import com.hubspot.jinjava.interpret.JinjavaInterpreter; +import com.hubspot.jinjava.interpret.RenderResult; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import org.junit.Before; +import org.junit.Test; + +public class AstFilterChainParityTest { + + private Jinjava jinjavaOptimized; + private Jinjava jinjavaUnoptimized; + private Map context; + + @Before + public void setup() { + LegacyOverrides legacyOverrides = LegacyOverrides + .newBuilder() + .withUsePyishObjectMapper(true) + .withKeepNullableLoopValues(true) + .build(); + + jinjavaOptimized = + new Jinjava( + JinjavaConfig + .newBuilder() + .withEnableFilterChainOptimization(true) + .withLegacyOverrides(legacyOverrides) + .build() + ); + + jinjavaUnoptimized = + new Jinjava( + JinjavaConfig + .newBuilder() + .withEnableFilterChainOptimization(false) + .withLegacyOverrides(legacyOverrides) + .build() + ); + + context = new HashMap<>(); + context.put("name", " Hello World "); + context.put("text", "the quick brown fox jumps over the lazy dog"); + context.put("number", 12345); + context.put("float_num", 3.14159); + context.put("negative", -42); + context.put("items", Arrays.asList("apple", "banana", "cherry")); + context.put("empty_list", ImmutableList.of()); + context.put("numbers", Arrays.asList(3, 1, 4, 1, 5, 9, 2, 6)); + context.put("html", ""); + context.put("null_value", null); + context.put( + "nested", + ImmutableMap.of("key", "value", "num", 100, "list", Arrays.asList(1, 2, 3)) + ); + context.put( + "objects", + Arrays.asList( + ImmutableMap.of("name", "Alice", "age", 30), + ImmutableMap.of("name", "Bob", "age", 25), + ImmutableMap.of("name", "Charlie", "age", 35) + ) + ); + context.put("mixed_case", "HeLLo WoRLd"); + context.put("whitespace", " lots of spaces "); + context.put("unicode", "héllo wörld 你好"); + context.put("special_chars", "a&bd\"e'f"); + context.put("json_string", "{\"key\": \"value\", \"num\": 42}"); + context.put("long_text", "word ".repeat(100)); + context.put("arg_value", 10); + } + + @Test + public void itProducesSameResultsForSingleFilters() { + List templates = ImmutableList.of( + "{{ name|trim }}", + "{{ name|lower }}", + "{{ name|upper }}", + "{{ name|length }}", + "{{ number|string }}", + "{{ number|abs }}", + "{{ float_num|round }}", + "{{ float_num|int }}", + "{{ items|first }}", + "{{ items|last }}", + "{{ items|length }}", + "{{ items|reverse }}", + "{{ items|sort }}", + "{{ html|escape }}", + "{{ html|e }}", + "{{ text|capitalize }}", + "{{ text|title }}", + "{{ text|wordcount }}", + "{{ negative|abs }}", + "{{ mixed_case|lower }}", + "{{ mixed_case|upper }}", + "{{ whitespace|trim }}", + "{{ unicode|upper }}", + "{{ unicode|lower }}" + ); + + assertParityForTemplates(templates); + } + + @Test + public void itProducesSameResultsForChainedFilters() { + List templates = ImmutableList.of( + "{{ name|trim|lower }}", + "{{ name|trim|upper }}", + "{{ name|trim|lower|capitalize }}", + "{{ name|trim|lower|upper }}", + "{{ text|upper|lower|capitalize }}", + "{{ text|capitalize|lower|upper }}", + "{{ number|string|length }}", + "{{ number|string|upper }}", + "{{ items|first|upper }}", + "{{ items|last|lower }}", + "{{ items|reverse|first }}", + "{{ items|sort|last }}", + "{{ items|sort|reverse|first }}", + "{{ html|escape|upper }}", + "{{ float_num|round|string|length }}", + "{{ whitespace|trim|lower|capitalize }}" + ); + + assertParityForTemplates(templates); + } + + @Test + public void itProducesSameResultsForFiltersWithPositionalArgs() { + List templates = ImmutableList.of( + "{{ text|truncate(20) }}", + "{{ text|truncate(20, True) }}", + "{{ text|truncate(20, True, '...') }}", + "{{ text|truncate(10, False) }}", + "{{ items|join(', ') }}", + "{{ items|join(' - ') }}", + "{{ items|join('') }}", + "{{ text|replace('the', 'a') }}", + "{{ text|replace('o', '0') }}", + "{{ text|split(' ') }}", + "{{ text|split(' ', 3) }}", + "{{ number|default(0) }}", + "{{ null_value|default('fallback') }}", + "{{ null_value|default(42) }}", + "{{ float_num|round(2) }}", + "{{ float_num|round(0) }}", + "{{ text|center(50) }}", + "{{ text|center(50, '-') }}", + "{{ numbers|batch(3) }}", + "{{ numbers|slice(3) }}" + ); + + assertParityForTemplates(templates); + } + + @Test + public void itProducesSameResultsForFiltersWithNamedParams() { + List templates = ImmutableList.of( + "{{ text|truncate(length=20) }}", + "{{ text|truncate(length=20, killwords=True) }}", + "{{ text|truncate(length=20, end='!!!') }}", + "{{ text|truncate(length=15, killwords=False, end='...') }}" + ); + + assertParityForTemplates(templates); + } + + @Test + public void itProducesSameResultsForMixedPositionalAndNamedParams() { + List templates = ImmutableList.of( + "{{ text|truncate(20, killwords=True) }}", + "{{ text|truncate(20, end='!') }}", + "{{ items|join(', ')|truncate(length=15) }}" + ); + + assertParityForTemplates(templates); + } + + @Test + public void itProducesSameResultsForChainedFiltersWithArgs() { + List templates = ImmutableList.of( + "{{ text|truncate(20)|upper }}", + "{{ text|upper|truncate(20) }}", + "{{ text|replace('the', 'a')|upper }}", + "{{ text|upper|replace('THE', 'a') }}", + "{{ text|truncate(30)|replace('...', '!')|upper }}", + "{{ items|join(', ')|upper }}", + "{{ items|join(', ')|truncate(10) }}", + "{{ items|sort|join(' - ')|upper }}", + "{{ items|reverse|join(', ')|lower }}", + "{{ numbers|sort|join('-') }}", + "{{ numbers|reverse|join(', ')|length }}" + ); + + assertParityForTemplates(templates); + } + + @Test + public void itProducesSameResultsForFilterArgsWithExpressions() { + List templates = ImmutableList.of( + "{{ text|truncate(arg_value) }}", + "{{ text|truncate(arg_value + 5) }}", + "{{ text|truncate(arg_value * 2) }}", + "{{ items|join(name|trim) }}", + "{{ text|replace(items|first, items|last) }}" + ); + + assertParityForTemplates(templates); + } + + @Test + public void itProducesSameResultsForNullAndUndefinedHandling() { + List templates = ImmutableList.of( + "{{ null_value|default('fallback') }}", + "{{ null_value|default('fallback')|upper }}", + "{{ undefined_var|default('missing') }}", + "{{ undefined_var|default('missing')|lower }}", + "{{ null_value|string }}", + "{{ null_value|e }}", + "{{ nested.missing|default('not found') }}", + "{{ nested.missing|default('')|length }}" + ); + + assertParityForTemplates(templates); + } + + @Test + public void itProducesSameResultsForSafeStringHandling() { + context.put("safe_html", "Bold"); + + List templates = ImmutableList.of( + "{{ safe_html|safe }}", + "{{ safe_html|safe|upper }}", + "{{ safe_html|upper|safe }}", + "{{ safe_html|safe|length }}", + "{{ safe_html|safe|trim }}", + "{{ safe_html|safe|lower|capitalize }}" + ); + + assertParityForTemplates(templates); + } + + @Test + public void itProducesSameResultsForCollectionFilters() { + List templates = ImmutableList.of( + "{{ items|list }}", + "{{ items|unique }}", + "{{ numbers|sum }}", + "{{ numbers|sort }}", + "{{ numbers|sort|reverse }}", + "{{ objects|map(attribute='name') }}", + "{{ objects|map(attribute='name')|join(', ') }}", + "{{ objects|selectattr('age', '>', 28) }}", + "{{ objects|rejectattr('age', '<', 30) }}", + "{{ numbers|select('>', 3) }}", + "{{ numbers|reject('==', 1) }}", + "{{ items|batch(2)|list }}", + "{{ numbers|slice(3)|list }}" + ); + + assertParityForTemplates(templates); + } + + @Test + public void itProducesSameResultsForStringManipulationFilters() { + List templates = ImmutableList.of( + "{{ text|format }}", + "{{ text|striptags }}", + "{{ html|striptags }}", + "{{ text|urlize }}", + "{{ special_chars|escape }}", + "{{ special_chars|urlencode }}", + "{{ text|regex_replace('\\\\s+', '_') }}", + "{{ text|replace(' ', '_') }}", + "{{ name|trim|replace(' ', '-')|lower }}" + ); + + assertParityForTemplates(templates); + } + + @Test + public void itProducesSameResultsForNumericFilters() { + List templates = ImmutableList.of( + "{{ number|filesizeformat }}", + "{{ float_num|round }}", + "{{ float_num|round(2) }}", + "{{ float_num|round(2, 'floor') }}", + "{{ float_num|round(2, 'ceil') }}", + "{{ negative|abs }}", + "{{ number|float }}", + "{{ float_num|int }}", + "{{ number|divide(100) }}", + "{{ number|multiply(2) }}", + "{{ float_num|log }}", + "{{ number|root }}" + ); + + assertParityForTemplates(templates); + } + + @Test + public void itProducesSameResultsForDateTimeFilters() { + context.put("timestamp", 1609459200000L); + context.put("date_string", "2021-01-01"); + + List templates = ImmutableList.of( + "{{ timestamp|datetimeformat }}", + "{{ timestamp|unixtimestamp }}" + ); + + assertParityForTemplates(templates); + } + + @Test + public void itProducesSameResultsForJsonFilters() { + List templates = ImmutableList.of( + "{{ nested|tojson }}", + "{{ items|tojson }}", + "{{ json_string|fromjson }}", + "{{ json_string|fromjson|tojson }}" + ); + + assertParityForTemplates(templates); + } + + @Test + public void itProducesSameResultsForMultipleFilterChainsInTemplate() { + List templates = ImmutableList.of( + "{{ name|trim|lower }} and {{ text|upper|truncate(10) }}", + "Hello {{ name|trim }}, you have {{ items|length }} items", + "{{ items|first|upper }} - {{ items|last|lower }}", + "{{ number|string }} is {{ number|string|length }} digits", + "Name: {{ name|trim|lower|capitalize }}, Count: {{ items|length }}" + ); + + assertParityForTemplates(templates); + } + + @Test + public void itProducesSameResultsForNestedPropertyAccess() { + List templates = ImmutableList.of( + "{{ nested.key|upper }}", + "{{ nested.num|string }}", + "{{ nested.list|first }}", + "{{ nested.list|join('-') }}", + "{{ nested.key|upper|lower|capitalize }}", + "{{ objects[0].name|upper }}", + "{{ objects[0].name|upper|truncate(3) }}" + ); + + assertParityForTemplates(templates); + } + + @Test + public void itProducesSameResultsForFilterChainInConditions() { + List templates = ImmutableList.of( + "{% if name|trim|length > 5 %}long{% else %}short{% endif %}", + "{% if items|length > 2 %}many{% else %}few{% endif %}", + "{% if name|trim|lower == 'hello world' %}match{% else %}no match{% endif %}" + ); + + assertParityForTemplates(templates); + } + + @Test + public void itProducesSameResultsForFilterChainInLoops() { + List templates = ImmutableList.of( + "{% for item in items|sort %}{{ item|upper }}{% endfor %}", + "{% for item in items|reverse %}{{ item|capitalize }}{% endfor %}", + "{% for n in numbers|sort|unique %}{{ n }}{% endfor %}" + ); + + assertParityForTemplates(templates); + } + + @Test + public void itProducesSameResultsForLongFilterChains() { + List templates = ImmutableList.of( + "{{ text|upper|lower|capitalize|trim }}", + "{{ text|trim|lower|upper|lower|capitalize }}", + "{{ name|trim|lower|upper|lower|upper|lower }}", + "{{ text|replace('the', 'a')|upper|lower|capitalize|trim }}", + "{{ items|sort|reverse|join(', ')|upper|truncate(20) }}" + ); + + assertParityForTemplates(templates); + } + + @Test + public void itTracksResolvedValuesConsistently() { + String template = "{{ name|trim|lower|upper }}"; + + RenderResult optimizedResult = jinjavaOptimized.renderForResult(template, context); + RenderResult unoptimizedResult = jinjavaUnoptimized.renderForResult( + template, + context + ); + + assertThat(optimizedResult.getOutput()) + .as("Output should match") + .isEqualTo(unoptimizedResult.getOutput()); + + Set optimizedResolved = optimizedResult.getContext().getResolvedValues(); + Set unoptimizedResolved = unoptimizedResult.getContext().getResolvedValues(); + + assertThat(optimizedResolved).as("Resolved filter:trim").contains("filter:trim"); + assertThat(optimizedResolved).as("Resolved filter:lower").contains("filter:lower"); + assertThat(optimizedResolved).as("Resolved filter:upper").contains("filter:upper"); + + assertThat(unoptimizedResolved) + .as("Unoptimized resolved filter:trim") + .contains("filter:trim"); + assertThat(unoptimizedResolved) + .as("Unoptimized resolved filter:lower") + .contains("filter:lower"); + assertThat(unoptimizedResolved) + .as("Unoptimized resolved filter:upper") + .contains("filter:upper"); + } + + @Test + public void itHandlesUnknownFiltersConsistently() { + String template = "{{ name|unknownfilter }}"; + + RenderResult optimizedResult = jinjavaOptimized.renderForResult(template, context); + RenderResult unoptimizedResult = jinjavaUnoptimized.renderForResult( + template, + context + ); + + assertThat(optimizedResult.getOutput()) + .as("Both paths should return empty for unknown filter") + .isEqualTo(unoptimizedResult.getOutput()); + } + + @Test + public void itProducesSameResultsForEmptyInputs() { + context.put("empty_string", ""); + + List templates = ImmutableList.of( + "{{ empty_string|upper }}", + "{{ empty_string|trim }}", + "{{ empty_string|default('fallback') }}", + "{{ empty_string|length }}", + "{{ empty_list|join(', ') }}", + "{{ empty_list|first }}", + "{{ empty_list|last }}", + "{{ empty_list|length }}" + ); + + assertParityForTemplates(templates); + } + + @Test + public void itProducesSameResultsForSpecialCharacters() { + List templates = ImmutableList.of( + "{{ special_chars|escape }}", + "{{ special_chars|escape|upper }}", + "{{ special_chars|urlencode }}", + "{{ special_chars|replace('&', 'and') }}", + "{{ unicode|upper }}", + "{{ unicode|lower }}", + "{{ unicode|length }}" + ); + + assertParityForTemplates(templates); + } + + @Test + public void itProducesSameResultsForBase64Filters() { + context.put("plain_text", "Hello, World!"); + context.put("base64_text", "SGVsbG8sIFdvcmxkIQ=="); + + List templates = ImmutableList.of( + "{{ plain_text|b64encode }}", + "{{ base64_text|b64decode }}", + "{{ plain_text|b64encode|b64decode }}" + ); + + assertParityForTemplates(templates); + } + + @Test + public void itProducesSameResultsForSelectAndRejectFilters() { + List templates = ImmutableList.of( + "{{ numbers|select('even')|list }}", + "{{ numbers|select('odd')|list }}", + "{{ numbers|reject('even')|list }}", + "{{ numbers|select('>', 3)|list }}", + "{{ numbers|select('>=', 4)|list }}", + "{{ numbers|reject('>', 5)|list }}" + ); + + assertParityForTemplates(templates); + } + + @Test + public void itProducesSameResultsForAttrFilters() { + List templates = ImmutableList.of( + "{{ objects|map(attribute='name')|list }}", + "{{ objects|map(attribute='age')|list }}", + "{{ objects|selectattr('age', '>', 28)|map(attribute='name')|list }}", + "{{ objects|rejectattr('age', '<', 30)|map(attribute='name')|list }}", + "{{ objects|groupby('age') }}" + ); + + assertParityForTemplates(templates); + } + + private void assertParityForTemplates(List templates) { + for (String template : templates) { + String optimizedResult = jinjavaOptimized.render(template, context); + String unoptimizedResult = jinjavaUnoptimized.render(template, context); + assertThat(optimizedResult) + .as("Template: %s", template) + .isEqualTo(unoptimizedResult); + } + } +} diff --git a/src/test/java/com/hubspot/jinjava/el/ext/AstFilterChainPerformanceTest.java b/src/test/java/com/hubspot/jinjava/el/ext/AstFilterChainPerformanceTest.java new file mode 100644 index 000000000..b5cc4b8b7 --- /dev/null +++ b/src/test/java/com/hubspot/jinjava/el/ext/AstFilterChainPerformanceTest.java @@ -0,0 +1,170 @@ +package com.hubspot.jinjava.el.ext; + +import static org.assertj.core.api.Assertions.assertThat; + +import com.hubspot.jinjava.Jinjava; +import com.hubspot.jinjava.JinjavaConfig; +import java.util.HashMap; +import java.util.Map; +import org.junit.Before; +import org.junit.Ignore; +import org.junit.Test; + +/** + * Performance tests for the filter chain optimization. + * + * Run manually with: mvn test -Dtest=AstFilterChainPerformanceTest + * Or run the main() method directly for detailed output. + */ +public class AstFilterChainPerformanceTest { + + private Jinjava jinjavaOptimized; + private Jinjava jinjavaUnoptimized; + private Map context; + + @Before + public void setup() { + jinjavaOptimized = + new Jinjava( + JinjavaConfig.newBuilder().withEnableFilterChainOptimization(true).build() + ); + + jinjavaUnoptimized = + new Jinjava( + JinjavaConfig.newBuilder().withEnableFilterChainOptimization(false).build() + ); + + context = new HashMap<>(); + context.put("name", " Hello World "); + context.put("text", "the quick brown fox jumps over the lazy dog"); + context.put("number", 12345); + context.put("items", new String[] { "apple", "banana", "cherry" }); + context.put("content", Map.of("text", "the quick brown fox jumps over the lazy dog")); + } + + public static void main(String[] args) { + AstFilterChainPerformanceTest test = new AstFilterChainPerformanceTest(); + test.setup(); + test.runPerformanceComparison(); + } + + /** + * Run this test manually to see detailed performance comparison. + * Use main() method or run with -Dtest=AstFilterChainPerformanceTest#runPerformanceComparison + */ + @Test + @Ignore("Manual performance test - run explicitly when needed") + public void runPerformanceComparison() { + int warmupIterations = 10000; + int testIterations = 100000; + + System.out.println("=== Filter Chain Performance Test ===\n"); + System.out.println("Warming up..."); + + runFilterTests(jinjavaOptimized, warmupIterations); + runFilterTests(jinjavaUnoptimized, warmupIterations); + + System.out.println( + "Running performance tests with " + testIterations + " iterations each\n" + ); + + comparePerformance("Single filter: {{ name|trim }}", testIterations); + comparePerformance("Two filters: {{ name|trim|lower }}", testIterations); + comparePerformance("Three filters: {{ name|trim|lower|capitalize }}", testIterations); + comparePerformance( + "Five filters: {{ text|upper|replace('THE', 'a')|trim|lower|title }}", + testIterations + ); + comparePerformance( + "Filters with args: {{ text|truncate(20)|upper }}", + testIterations + ); + comparePerformance( + "Multiple chains: {{ name|trim|lower }} and {{ text|upper|truncate(10) }}", + testIterations + ); + } + + @Test + public void optimizedVersionShouldBeFaster() { + int warmupIterations = 100; + int testIterations = 1000; + String template = "{{ content.text|upper|replace('THE', 'a')|trim|lower|title }}"; + + for (int i = 0; i < warmupIterations; i++) { + jinjavaOptimized.render(template, context); + jinjavaUnoptimized.render(template, context); + } + + long totalOptimizedTime = 0; + long totalUnoptimizedTime = 0; + int rounds = 3; + + for (int round = 0; round < rounds; round++) { + totalUnoptimizedTime += timeExecution(jinjavaUnoptimized, template, testIterations); + totalOptimizedTime += timeExecution(jinjavaOptimized, template, testIterations); + } + + long avgUnoptimizedTime = totalUnoptimizedTime / rounds; + long avgOptimizedTime = totalOptimizedTime / rounds; + + System.out.printf( + "Performance test: Optimized=%d ms, Unoptimized=%d ms, Speedup=%.2fx%n", + avgOptimizedTime, + avgUnoptimizedTime, + (1.0 * avgUnoptimizedTime) / avgOptimizedTime + ); + + assertThat(avgOptimizedTime) + .as( + "Optimized (%d ms) should be faster than unoptimized (%d ms)", + avgOptimizedTime, + avgUnoptimizedTime + ) + .isLessThan((avgUnoptimizedTime * 95) / 100); + } + + private void comparePerformance(String description, int iterations) { + String template = description.substring(description.indexOf("{{")); + if (description.contains(":")) { + template = description.substring(description.indexOf(":") + 2); + } + + System.out.println(description); + + long optimizedTime = timeExecution(jinjavaOptimized, template, iterations); + long unoptimizedTime = timeExecution(jinjavaUnoptimized, template, iterations); + + double speedup = (1.0 * unoptimizedTime) / optimizedTime; + System.out.printf( + " Optimized: %d ms, Unoptimized: %d ms, Speedup: %.2fx%n%n", + optimizedTime, + unoptimizedTime, + speedup + ); + } + + private long timeExecution(Jinjava jinjava, String template, int iterations) { + long startTime = System.currentTimeMillis(); + for (int i = 0; i < iterations; i++) { + jinjava.render(template, context); + } + return System.currentTimeMillis() - startTime; + } + + private void runFilterTests(Jinjava jinjava, int iterations) { + String[] templates = { + "{{ name|trim }}", + "{{ name|trim|lower }}", + "{{ name|trim|lower|capitalize }}", + "{{ text|upper|replace('THE', 'a')|trim|lower|title }}", + "{{ text|truncate(20)|upper }}", + }; + + for (String template : templates) { + for (int i = 0; i < iterations; i++) { + jinjava.render(template, context); + } + } + } +} diff --git a/src/test/java/com/hubspot/jinjava/el/ext/AstFilterChainTest.java b/src/test/java/com/hubspot/jinjava/el/ext/AstFilterChainTest.java new file mode 100644 index 000000000..f0a0931a3 --- /dev/null +++ b/src/test/java/com/hubspot/jinjava/el/ext/AstFilterChainTest.java @@ -0,0 +1,70 @@ +package com.hubspot.jinjava.el.ext; + +import static org.assertj.core.api.Assertions.assertThat; + +import com.hubspot.jinjava.Jinjava; +import com.hubspot.jinjava.JinjavaConfig; +import java.util.HashMap; +import java.util.Map; +import org.junit.Before; +import org.junit.Test; + +public class AstFilterChainTest { + + private Jinjava jinjava; + private Map context; + + @Before + public void setup() { + jinjava = + new Jinjava( + JinjavaConfig.newBuilder().withEnableFilterChainOptimization(true).build() + ); + + context = new HashMap<>(); + context.put("name", " Hello World "); + context.put("text", "the quick brown fox jumps over the lazy dog"); + context.put("number", 12345); + context.put("items", new String[] { "apple", "banana", "cherry" }); + } + + @Test + public void itHandlesSingleFilter() { + String result = jinjava.render("{{ name|trim }}", context); + assertThat(result).isEqualTo("Hello World"); + } + + @Test + public void itHandlesChainedFilters() { + String result = jinjava.render("{{ name|trim|lower }}", context); + assertThat(result).isEqualTo("hello world"); + } + + @Test + public void itHandlesFiltersWithArguments() { + String result = jinjava.render("{{ text|truncate(20)|upper }}", context); + assertThat(result).isNotEmpty(); + assertThat(result).isUpperCase(); + } + + @Test + public void itHandlesComplexFilterChain() { + String result = jinjava.render( + "{{ text|upper|replace('THE', 'a')|trim|lower|capitalize }}", + context + ); + assertThat(result).isNotEmpty(); + } + + @Test + public void itHandlesFilterWithJoin() { + String result = jinjava.render("{{ items|join(', ')|upper }}", context); + assertThat(result).isEqualTo("APPLE, BANANA, CHERRY"); + } + + @Test + public void itHandlesFilterWithStringConversion() { + String result = jinjava.render("{{ number|string|length }}", context); + assertThat(result).isEqualTo("5"); + } +}