|
2 | 2 |
|
3 | 3 | import java.net.URI; |
4 | 4 | import java.util.ArrayList; |
| 5 | +import java.util.HashSet; |
5 | 6 | import java.util.List; |
| 7 | +import java.util.Set; |
6 | 8 |
|
7 | 9 | import liquidjava.rj_language.ast.*; |
8 | 10 | import liquidjava.rj_language.parsing.RefinementsParser; |
@@ -52,20 +54,20 @@ public static StateMachine parse(String uri) { |
52 | 54 | String className = getClassName(ctType); |
53 | 55 |
|
54 | 56 | // extract initial state and transitions |
55 | | - String initial; |
| 57 | + List<String> initialStates; |
56 | 58 | List<StateMachineTransition> transitions; |
57 | 59 | if (ctType instanceof CtClass<?> ctClass) { |
58 | | - initial = getInitialStateFromClass(ctClass, states); |
| 60 | + initialStates = getInitialStatesFromClass(ctClass, states); |
59 | 61 | transitions = getTransitionsFromClass(ctClass, states); |
60 | 62 | } else if (ctType instanceof CtInterface<?> ctInterface) { |
61 | | - initial = getInitialStateFromInterface(ctInterface, className, states); |
| 63 | + initialStates = getInitialStatesFromInterface(ctInterface, className, states); |
62 | 64 | transitions = getTransitionsFromInterface(ctInterface, className, states); |
63 | 65 | } else { |
64 | 66 | return null; |
65 | 67 | } |
66 | 68 | if (transitions.isEmpty()) return null; // no transitions found |
67 | 69 |
|
68 | | - return new StateMachine(className, initial, states, transitions); |
| 70 | + return new StateMachine(className, initialStates, states, transitions); |
69 | 71 |
|
70 | 72 | } catch (Exception e) { |
71 | 73 | e.printStackTrace(); |
@@ -119,49 +121,47 @@ private static List<String> getStates(CtType<?> ctType) { |
119 | 121 | } |
120 | 122 |
|
121 | 123 | /** |
122 | | - * Gets the initial state from a class |
| 124 | + * Gets the initial states from a class |
123 | 125 | * If not explicitely defined, uses the first state in the state set |
124 | 126 | * @param ctClass the CtClass |
125 | 127 | * @param states the list of states |
126 | | - * @return initial state |
| 128 | + * @return initial states |
127 | 129 | */ |
128 | | - private static String getInitialStateFromClass(CtClass<?> ctClass, List<String> states) { |
| 130 | + private static List<String> getInitialStatesFromClass(CtClass<?> ctClass, List<String> states) { |
| 131 | + Set<String> initialStates = new HashSet<>(); |
129 | 132 | for (CtConstructor<?> constructor : ctClass.getConstructors()) { |
130 | 133 | for (CtAnnotation<?> annotation : constructor.getAnnotations()) { |
131 | 134 | if (annotation.getAnnotationType().getSimpleName().equals(STATE_REFINEMENT_ANNOTATION)) { |
132 | 135 | String to = annotation.getValueAsString("to"); |
133 | 136 | List<String> parsedStates = parseStateExpression(to, states); |
134 | | - if (!parsedStates.isEmpty()) { |
135 | | - return parsedStates.getFirst(); |
136 | | - } |
| 137 | + initialStates.addAll(parsedStates); |
137 | 138 | } |
138 | 139 | } |
139 | 140 | } |
140 | | - return states.getFirst(); |
| 141 | + return initialStates.isEmpty() ? List.of(states.get(0)) : initialStates.stream().toList(); |
141 | 142 | } |
142 | 143 |
|
143 | 144 | /** |
144 | 145 | * Gets the initial state from an interface |
145 | 146 | * If not explicitely defined, uses the first state in the state set |
146 | 147 | * @param ctInterface the CtInterface |
147 | 148 | * @param className the class name |
148 | | - * @return initial state |
| 149 | + * @return initial states |
149 | 150 | */ |
150 | | - private static String getInitialStateFromInterface(CtInterface<?> ctInterface, String className, List<String> states) { |
| 151 | + private static List<String> getInitialStatesFromInterface(CtInterface<?> ctInterface, String className, List<String> states) { |
| 152 | + Set<String> initialStates = new HashSet<>(); |
151 | 153 | for (CtMethod<?> method : ctInterface.getMethods()) { |
152 | 154 | if (method.getSimpleName().equals(className)) { |
153 | 155 | for (CtAnnotation<?> annotation : method.getAnnotations()) { |
154 | 156 | if (annotation.getAnnotationType().getSimpleName().equals(STATE_REFINEMENT_ANNOTATION)) { |
155 | 157 | String to = annotation.getValueAsString("to"); |
156 | 158 | List<String> parsedStates = parseStateExpression(to, states); |
157 | | - if (!parsedStates.isEmpty()) { |
158 | | - return parsedStates.getFirst(); |
159 | | - } |
| 159 | + initialStates.addAll(parsedStates); |
160 | 160 | } |
161 | 161 | } |
162 | 162 | } |
163 | 163 | } |
164 | | - return states.isEmpty() ? null : states.getFirst(); |
| 164 | + return initialStates.isEmpty() ? List.of(states.get(0)) : initialStates.stream().toList(); |
165 | 165 | } |
166 | 166 |
|
167 | 167 | /** |
@@ -261,29 +261,36 @@ private static List<String> parseStateExpression(String expr, List<String> state |
261 | 261 | */ |
262 | 262 | private static List<String> getStateExpressions(Expression expr, List<String> states) { |
263 | 263 | List<String> stateExpressions = new ArrayList<>(); |
264 | | - if (expr instanceof Var var) { |
265 | | - stateExpressions.add(var.getName()); |
266 | | - } else if (expr instanceof FunctionInvocation func) { |
267 | | - stateExpressions.add(func.getName()); |
268 | | - } else if (expr instanceof GroupExpression group) { |
269 | | - stateExpressions.addAll(getStateExpressions(group.getExpression(), states)); |
270 | | - } else if (expr instanceof BinaryExpression bin) { |
271 | | - String op = bin.getOperator(); |
272 | | - if (op.equals("||")) { |
273 | | - // combine states from both operands |
274 | | - stateExpressions.addAll(getStateExpressions(bin.getFirstOperand(), states)); |
275 | | - stateExpressions.addAll(getStateExpressions(bin.getSecondOperand(), states)); |
| 264 | + switch (expr) { |
| 265 | + case Var var -> stateExpressions.add(var.getName()); |
| 266 | + case FunctionInvocation func -> stateExpressions.add(func.getName()); |
| 267 | + case GroupExpression group -> stateExpressions.addAll(getStateExpressions(group.getExpression(), states)); |
| 268 | + case BinaryExpression bin -> { |
| 269 | + String op = bin.getOperator(); |
| 270 | + if (op.equals("||")) { |
| 271 | + // combine states from both operands |
| 272 | + stateExpressions.addAll(getStateExpressions(bin.getFirstOperand(), states)); |
| 273 | + stateExpressions.addAll(getStateExpressions(bin.getSecondOperand(), states)); |
| 274 | + } |
276 | 275 | } |
277 | | - } else if (expr instanceof UnaryExpression unary) { |
278 | | - if (unary.getOp().equals("!")) { |
279 | | - // all except those in the expression |
280 | | - List<String> negatedStates = getStateExpressions(unary.getExpression(), states); |
281 | | - for (String state : states) { |
282 | | - if (!negatedStates.contains(state)) { |
283 | | - stateExpressions.add(state); |
| 276 | + case UnaryExpression unary -> { |
| 277 | + if (unary.getOp().equals("!")) { |
| 278 | + // all except those in the expression |
| 279 | + List<String> negatedStates = getStateExpressions(unary.getExpression(), states); |
| 280 | + for (String state : states) { |
| 281 | + if (!negatedStates.contains(state)) { |
| 282 | + stateExpressions.add(state); |
| 283 | + } |
284 | 284 | } |
285 | 285 | } |
286 | 286 | } |
| 287 | + case Ite ite -> { |
| 288 | + // combine states from then and else branches |
| 289 | + // TODO: handle conditional transitions |
| 290 | + stateExpressions.addAll(getStateExpressions(ite.getThen(), states)); |
| 291 | + stateExpressions.addAll(getStateExpressions(ite.getElse(), states)); |
| 292 | + } |
| 293 | + default -> {} |
287 | 294 | } |
288 | 295 | return stateExpressions; |
289 | 296 | } |
|
0 commit comments