diff --git a/client/src/types/fsm.ts b/client/src/types/fsm.ts index 943938b..f7c726c 100644 --- a/client/src/types/fsm.ts +++ b/client/src/types/fsm.ts @@ -2,7 +2,7 @@ export type StateMachine = { className: string; - initial: string; + initialStates: string[]; states: string[]; transitions: { from: string; to: string; label: string }[]; }; diff --git a/client/src/webview/mermaid.ts b/client/src/webview/mermaid.ts index 63114f0..d0dfd5e 100644 --- a/client/src/webview/mermaid.ts +++ b/client/src/webview/mermaid.ts @@ -16,12 +16,24 @@ export function createMermaidDiagram(sm: StateMachine): string { lines.push('---'); lines.push('stateDiagram-v2'); - // initial state - lines.push(` [*] --> ${sm.initial}`); + // initial states + sm.initialStates.forEach(state => { + lines.push(` [*] --> ${state}`); + }); - // transitions + // group transitions by from/to states and merge labels + const transitionMap = new Map(); sm.transitions.forEach(transition => { - lines.push(` ${transition.from} --> ${transition.to} : ${transition.label}`); + const key = `${transition.from}|${transition.to}`; + if (!transitionMap.has(key)) transitionMap.set(key, []); + transitionMap.get(key).push(transition.label); + }); + + // add transitions + transitionMap.forEach((labels, key) => { + const [from, to] = key.split('|'); + const mergedLabel = labels.join(', '); + lines.push(` ${from} --> ${to} : ${mergedLabel}`); }); return lines.join('\n'); diff --git a/client/src/webview/views/diagram.ts b/client/src/webview/views/diagram.ts index 190034e..b8d3fa2 100644 --- a/client/src/webview/views/diagram.ts +++ b/client/src/webview/views/diagram.ts @@ -12,7 +12,7 @@ export function renderStateMachineView(sm: StateMachine, diagram: string, select

States: ${sm.states.join(', ')}

-

Initial state: ${sm.initial}

+

Initial state${sm.initialStates.length > 1 ? 's' : ''}: ${sm.initialStates.join(', ')}

Number of states: ${sm.states.length}

Number of transitions: ${sm.transitions.length + 1}

diff --git a/server/src/main/java/fsm/StateMachine.java b/server/src/main/java/fsm/StateMachine.java index 8e988e8..c26d07c 100644 --- a/server/src/main/java/fsm/StateMachine.java +++ b/server/src/main/java/fsm/StateMachine.java @@ -7,7 +7,7 @@ */ public record StateMachine( String className, - String initial, + List initialStates, List states, List transitions ) { } diff --git a/server/src/main/java/fsm/StateMachineParser.java b/server/src/main/java/fsm/StateMachineParser.java index e33b842..adb78f0 100644 --- a/server/src/main/java/fsm/StateMachineParser.java +++ b/server/src/main/java/fsm/StateMachineParser.java @@ -2,7 +2,9 @@ import java.net.URI; import java.util.ArrayList; +import java.util.HashSet; import java.util.List; +import java.util.Set; import liquidjava.rj_language.ast.*; import liquidjava.rj_language.parsing.RefinementsParser; @@ -52,20 +54,20 @@ public static StateMachine parse(String uri) { String className = getClassName(ctType); // extract initial state and transitions - String initial; + List initialStates; List transitions; if (ctType instanceof CtClass ctClass) { - initial = getInitialStateFromClass(ctClass, states); + initialStates = getInitialStatesFromClass(ctClass, states); transitions = getTransitionsFromClass(ctClass, states); } else if (ctType instanceof CtInterface ctInterface) { - initial = getInitialStateFromInterface(ctInterface, className, states); + initialStates = getInitialStatesFromInterface(ctInterface, className, states); transitions = getTransitionsFromInterface(ctInterface, className, states); } else { return null; } if (transitions.isEmpty()) return null; // no transitions found - return new StateMachine(className, initial, states, transitions); + return new StateMachine(className, initialStates, states, transitions); } catch (Exception e) { e.printStackTrace(); @@ -119,25 +121,24 @@ private static List getStates(CtType ctType) { } /** - * Gets the initial state from a class + * Gets the initial states from a class * If not explicitely defined, uses the first state in the state set * @param ctClass the CtClass * @param states the list of states - * @return initial state + * @return initial states */ - private static String getInitialStateFromClass(CtClass ctClass, List states) { + private static List getInitialStatesFromClass(CtClass ctClass, List states) { + Set initialStates = new HashSet<>(); for (CtConstructor constructor : ctClass.getConstructors()) { for (CtAnnotation annotation : constructor.getAnnotations()) { if (annotation.getAnnotationType().getSimpleName().equals(STATE_REFINEMENT_ANNOTATION)) { String to = annotation.getValueAsString("to"); List parsedStates = parseStateExpression(to, states); - if (!parsedStates.isEmpty()) { - return parsedStates.getFirst(); - } + initialStates.addAll(parsedStates); } } } - return states.getFirst(); + return initialStates.isEmpty() ? List.of(states.get(0)) : initialStates.stream().toList(); } /** @@ -145,23 +146,22 @@ private static String getInitialStateFromClass(CtClass ctClass, List * If not explicitely defined, uses the first state in the state set * @param ctInterface the CtInterface * @param className the class name - * @return initial state + * @return initial states */ - private static String getInitialStateFromInterface(CtInterface ctInterface, String className, List states) { + private static List getInitialStatesFromInterface(CtInterface ctInterface, String className, List states) { + Set initialStates = new HashSet<>(); for (CtMethod method : ctInterface.getMethods()) { if (method.getSimpleName().equals(className)) { for (CtAnnotation annotation : method.getAnnotations()) { if (annotation.getAnnotationType().getSimpleName().equals(STATE_REFINEMENT_ANNOTATION)) { String to = annotation.getValueAsString("to"); List parsedStates = parseStateExpression(to, states); - if (!parsedStates.isEmpty()) { - return parsedStates.getFirst(); - } + initialStates.addAll(parsedStates); } } } } - return states.isEmpty() ? null : states.getFirst(); + return initialStates.isEmpty() ? List.of(states.get(0)) : initialStates.stream().toList(); } /** @@ -261,29 +261,36 @@ private static List parseStateExpression(String expr, List state */ private static List getStateExpressions(Expression expr, List states) { List stateExpressions = new ArrayList<>(); - if (expr instanceof Var var) { - stateExpressions.add(var.getName()); - } else if (expr instanceof FunctionInvocation func) { - stateExpressions.add(func.getName()); - } else if (expr instanceof GroupExpression group) { - stateExpressions.addAll(getStateExpressions(group.getExpression(), states)); - } else if (expr instanceof BinaryExpression bin) { - String op = bin.getOperator(); - if (op.equals("||")) { - // combine states from both operands - stateExpressions.addAll(getStateExpressions(bin.getFirstOperand(), states)); - stateExpressions.addAll(getStateExpressions(bin.getSecondOperand(), states)); + switch (expr) { + case Var var -> stateExpressions.add(var.getName()); + case FunctionInvocation func -> stateExpressions.add(func.getName()); + case GroupExpression group -> stateExpressions.addAll(getStateExpressions(group.getExpression(), states)); + case BinaryExpression bin -> { + String op = bin.getOperator(); + if (op.equals("||")) { + // combine states from both operands + stateExpressions.addAll(getStateExpressions(bin.getFirstOperand(), states)); + stateExpressions.addAll(getStateExpressions(bin.getSecondOperand(), states)); + } } - } else if (expr instanceof UnaryExpression unary) { - if (unary.getOp().equals("!")) { - // all except those in the expression - List negatedStates = getStateExpressions(unary.getExpression(), states); - for (String state : states) { - if (!negatedStates.contains(state)) { - stateExpressions.add(state); + case UnaryExpression unary -> { + if (unary.getOp().equals("!")) { + // all except those in the expression + List negatedStates = getStateExpressions(unary.getExpression(), states); + for (String state : states) { + if (!negatedStates.contains(state)) { + stateExpressions.add(state); + } } } } + case Ite ite -> { + // combine states from then and else branches + // TODO: handle conditional transitions + stateExpressions.addAll(getStateExpressions(ite.getThen(), states)); + stateExpressions.addAll(getStateExpressions(ite.getElse(), states)); + } + default -> {} } return stateExpressions; }