Skip to content

Commit 335770c

Browse files
authored
State Machine Diagram Fixes (#49)
1 parent d28024a commit 335770c

File tree

5 files changed

+62
-43
lines changed

5 files changed

+62
-43
lines changed

client/src/types/fsm.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
export type StateMachine = {
44
className: string;
5-
initial: string;
5+
initialStates: string[];
66
states: string[];
77
transitions: { from: string; to: string; label: string }[];
88
};

client/src/webview/mermaid.ts

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,24 @@ export function createMermaidDiagram(sm: StateMachine, orientation: "LR" | "TB")
1717
lines.push('stateDiagram-v2');
1818
lines.push(` direction ${orientation}`);
1919

20-
// initial state
21-
lines.push(` [*] --> ${sm.initial}`);
20+
// initial states
21+
sm.initialStates.forEach(state => {
22+
lines.push(` [*] --> ${state}`);
23+
});
2224

23-
// transitions
25+
// group transitions by from/to states and merge labels
26+
const transitionMap = new Map<string, string[]>();
2427
sm.transitions.forEach(transition => {
25-
lines.push(` ${transition.from} --> ${transition.to} : ${transition.label}`);
28+
const key = `${transition.from}|${transition.to}`;
29+
if (!transitionMap.has(key)) transitionMap.set(key, []);
30+
transitionMap.get(key).push(transition.label);
31+
});
32+
33+
// add transitions
34+
transitionMap.forEach((labels, key) => {
35+
const [from, to] = key.split('|');
36+
const mergedLabel = labels.join(', ');
37+
lines.push(` ${from} --> ${to} : ${mergedLabel}`);
2638
});
2739

2840
return lines.join('\n');

client/src/webview/views/diagram.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ export function renderStateMachineView(sm: StateMachine, diagram: string, select
1515
</div>
1616
<div>
1717
<p><strong>States:</strong> ${sm.states.join(', ')}</p>
18-
<p><strong>Initial state:</strong> ${sm.initial}</p>
18+
<p><strong>Initial state${sm.initialStates.length > 1 ? 's' : ''}:</strong> ${sm.initialStates.join(', ')}</p>
1919
<p><strong>Number of states:</strong> ${sm.states.length}</p>
2020
<p><strong>Number of transitions:</strong> ${sm.transitions.length + 1}</p>
2121
</div>

server/src/main/java/fsm/StateMachine.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
*/
88
public record StateMachine(
99
String className,
10-
String initial,
10+
List<String> initialStates,
1111
List<String> states,
1212
List<StateMachineTransition> transitions
1313
) { }

server/src/main/java/fsm/StateMachineParser.java

Lines changed: 43 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33
import java.net.URI;
44
import java.util.ArrayList;
5+
import java.util.HashSet;
56
import java.util.List;
7+
import java.util.Set;
68

79
import liquidjava.rj_language.ast.*;
810
import liquidjava.rj_language.parsing.RefinementsParser;
@@ -52,20 +54,20 @@ public static StateMachine parse(String uri) {
5254
String className = getClassName(ctType);
5355

5456
// extract initial state and transitions
55-
String initial;
57+
List<String> initialStates;
5658
List<StateMachineTransition> transitions;
5759
if (ctType instanceof CtClass<?> ctClass) {
58-
initial = getInitialStateFromClass(ctClass, states);
60+
initialStates = getInitialStatesFromClass(ctClass, states);
5961
transitions = getTransitionsFromClass(ctClass, states);
6062
} else if (ctType instanceof CtInterface<?> ctInterface) {
61-
initial = getInitialStateFromInterface(ctInterface, className, states);
63+
initialStates = getInitialStatesFromInterface(ctInterface, className, states);
6264
transitions = getTransitionsFromInterface(ctInterface, className, states);
6365
} else {
6466
return null;
6567
}
6668
if (transitions.isEmpty()) return null; // no transitions found
6769

68-
return new StateMachine(className, initial, states, transitions);
70+
return new StateMachine(className, initialStates, states, transitions);
6971

7072
} catch (Exception e) {
7173
e.printStackTrace();
@@ -119,49 +121,47 @@ private static List<String> getStates(CtType<?> ctType) {
119121
}
120122

121123
/**
122-
* Gets the initial state from a class
124+
* Gets the initial states from a class
123125
* If not explicitely defined, uses the first state in the state set
124126
* @param ctClass the CtClass
125127
* @param states the list of states
126-
* @return initial state
128+
* @return initial states
127129
*/
128-
private static String getInitialStateFromClass(CtClass<?> ctClass, List<String> states) {
130+
private static List<String> getInitialStatesFromClass(CtClass<?> ctClass, List<String> states) {
131+
Set<String> initialStates = new HashSet<>();
129132
for (CtConstructor<?> constructor : ctClass.getConstructors()) {
130133
for (CtAnnotation<?> annotation : constructor.getAnnotations()) {
131134
if (annotation.getAnnotationType().getSimpleName().equals(STATE_REFINEMENT_ANNOTATION)) {
132135
String to = annotation.getValueAsString("to");
133136
List<String> parsedStates = parseStateExpression(to, states);
134-
if (!parsedStates.isEmpty()) {
135-
return parsedStates.getFirst();
136-
}
137+
initialStates.addAll(parsedStates);
137138
}
138139
}
139140
}
140-
return states.getFirst();
141+
return initialStates.isEmpty() ? List.of(states.get(0)) : initialStates.stream().toList();
141142
}
142143

143144
/**
144145
* Gets the initial state from an interface
145146
* If not explicitely defined, uses the first state in the state set
146147
* @param ctInterface the CtInterface
147148
* @param className the class name
148-
* @return initial state
149+
* @return initial states
149150
*/
150-
private static String getInitialStateFromInterface(CtInterface<?> ctInterface, String className, List<String> states) {
151+
private static List<String> getInitialStatesFromInterface(CtInterface<?> ctInterface, String className, List<String> states) {
152+
Set<String> initialStates = new HashSet<>();
151153
for (CtMethod<?> method : ctInterface.getMethods()) {
152154
if (method.getSimpleName().equals(className)) {
153155
for (CtAnnotation<?> annotation : method.getAnnotations()) {
154156
if (annotation.getAnnotationType().getSimpleName().equals(STATE_REFINEMENT_ANNOTATION)) {
155157
String to = annotation.getValueAsString("to");
156158
List<String> parsedStates = parseStateExpression(to, states);
157-
if (!parsedStates.isEmpty()) {
158-
return parsedStates.getFirst();
159-
}
159+
initialStates.addAll(parsedStates);
160160
}
161161
}
162162
}
163163
}
164-
return states.isEmpty() ? null : states.getFirst();
164+
return initialStates.isEmpty() ? List.of(states.get(0)) : initialStates.stream().toList();
165165
}
166166

167167
/**
@@ -261,29 +261,36 @@ private static List<String> parseStateExpression(String expr, List<String> state
261261
*/
262262
private static List<String> getStateExpressions(Expression expr, List<String> states) {
263263
List<String> stateExpressions = new ArrayList<>();
264-
if (expr instanceof Var var) {
265-
stateExpressions.add(var.getName());
266-
} else if (expr instanceof FunctionInvocation func) {
267-
stateExpressions.add(func.getName());
268-
} else if (expr instanceof GroupExpression group) {
269-
stateExpressions.addAll(getStateExpressions(group.getExpression(), states));
270-
} else if (expr instanceof BinaryExpression bin) {
271-
String op = bin.getOperator();
272-
if (op.equals("||")) {
273-
// combine states from both operands
274-
stateExpressions.addAll(getStateExpressions(bin.getFirstOperand(), states));
275-
stateExpressions.addAll(getStateExpressions(bin.getSecondOperand(), states));
264+
switch (expr) {
265+
case Var var -> stateExpressions.add(var.getName());
266+
case FunctionInvocation func -> stateExpressions.add(func.getName());
267+
case GroupExpression group -> stateExpressions.addAll(getStateExpressions(group.getExpression(), states));
268+
case BinaryExpression bin -> {
269+
String op = bin.getOperator();
270+
if (op.equals("||")) {
271+
// combine states from both operands
272+
stateExpressions.addAll(getStateExpressions(bin.getFirstOperand(), states));
273+
stateExpressions.addAll(getStateExpressions(bin.getSecondOperand(), states));
274+
}
276275
}
277-
} else if (expr instanceof UnaryExpression unary) {
278-
if (unary.getOp().equals("!")) {
279-
// all except those in the expression
280-
List<String> negatedStates = getStateExpressions(unary.getExpression(), states);
281-
for (String state : states) {
282-
if (!negatedStates.contains(state)) {
283-
stateExpressions.add(state);
276+
case UnaryExpression unary -> {
277+
if (unary.getOp().equals("!")) {
278+
// all except those in the expression
279+
List<String> negatedStates = getStateExpressions(unary.getExpression(), states);
280+
for (String state : states) {
281+
if (!negatedStates.contains(state)) {
282+
stateExpressions.add(state);
283+
}
284284
}
285285
}
286286
}
287+
case Ite ite -> {
288+
// combine states from then and else branches
289+
// TODO: handle conditional transitions
290+
stateExpressions.addAll(getStateExpressions(ite.getThen(), states));
291+
stateExpressions.addAll(getStateExpressions(ite.getElse(), states));
292+
}
293+
default -> {}
287294
}
288295
return stateExpressions;
289296
}

0 commit comments

Comments
 (0)