Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion client/src/types/fsm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

export type StateMachine = {
className: string;
initial: string;
initialStates: string[];
states: string[];
transitions: { from: string; to: string; label: string }[];
};
20 changes: 16 additions & 4 deletions client/src/webview/mermaid.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,24 @@ export function createMermaidDiagram(sm: StateMachine): string {
lines.push('---');
lines.push('stateDiagram-v2');

// initial state
lines.push(` [*] --> ${sm.initial}`);
// initial states
sm.initialStates.forEach(state => {
lines.push(` [*] --> ${state}`);
});

// transitions
// group transitions by from/to states and merge labels
const transitionMap = new Map<string, string[]>();
sm.transitions.forEach(transition => {
lines.push(` ${transition.from} --> ${transition.to} : ${transition.label}`);
const key = `${transition.from}|${transition.to}`;
if (!transitionMap.has(key)) transitionMap.set(key, []);
transitionMap.get(key).push(transition.label);
});

// add transitions
transitionMap.forEach((labels, key) => {
const [from, to] = key.split('|');
const mergedLabel = labels.join(', ');
lines.push(` ${from} --> ${to} : ${mergedLabel}`);
});

return lines.join('\n');
Expand Down
2 changes: 1 addition & 1 deletion client/src/webview/views/diagram.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ export function renderStateMachineView(sm: StateMachine, diagram: string, select
</div>
<div>
<p><strong>States:</strong> ${sm.states.join(', ')}</p>
<p><strong>Initial state:</strong> ${sm.initial}</p>
<p><strong>Initial state${sm.initialStates.length > 1 ? 's' : ''}:</strong> ${sm.initialStates.join(', ')}</p>
<p><strong>Number of states:</strong> ${sm.states.length}</p>
<p><strong>Number of transitions:</strong> ${sm.transitions.length + 1}</p>
</div>
Expand Down
2 changes: 1 addition & 1 deletion server/src/main/java/fsm/StateMachine.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
*/
public record StateMachine(
String className,
String initial,
List<String> initialStates,
List<String> states,
List<StateMachineTransition> transitions
) { }
79 changes: 43 additions & 36 deletions server/src/main/java/fsm/StateMachineParser.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import java.net.URI;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

import liquidjava.rj_language.ast.*;
import liquidjava.rj_language.parsing.RefinementsParser;
Expand Down Expand Up @@ -52,20 +54,20 @@ public static StateMachine parse(String uri) {
String className = getClassName(ctType);

// extract initial state and transitions
String initial;
List<String> initialStates;
List<StateMachineTransition> transitions;
if (ctType instanceof CtClass<?> ctClass) {
initial = getInitialStateFromClass(ctClass, states);
initialStates = getInitialStatesFromClass(ctClass, states);
transitions = getTransitionsFromClass(ctClass, states);
} else if (ctType instanceof CtInterface<?> ctInterface) {
initial = getInitialStateFromInterface(ctInterface, className, states);
initialStates = getInitialStatesFromInterface(ctInterface, className, states);
transitions = getTransitionsFromInterface(ctInterface, className, states);
} else {
return null;
}
if (transitions.isEmpty()) return null; // no transitions found

return new StateMachine(className, initial, states, transitions);
return new StateMachine(className, initialStates, states, transitions);

} catch (Exception e) {
e.printStackTrace();
Expand Down Expand Up @@ -119,49 +121,47 @@ private static List<String> getStates(CtType<?> ctType) {
}

/**
* Gets the initial state from a class
* Gets the initial states from a class
* If not explicitely defined, uses the first state in the state set
* @param ctClass the CtClass
* @param states the list of states
* @return initial state
* @return initial states
*/
private static String getInitialStateFromClass(CtClass<?> ctClass, List<String> states) {
private static List<String> getInitialStatesFromClass(CtClass<?> ctClass, List<String> states) {
Set<String> initialStates = new HashSet<>();
for (CtConstructor<?> constructor : ctClass.getConstructors()) {
for (CtAnnotation<?> annotation : constructor.getAnnotations()) {
if (annotation.getAnnotationType().getSimpleName().equals(STATE_REFINEMENT_ANNOTATION)) {
String to = annotation.getValueAsString("to");
List<String> parsedStates = parseStateExpression(to, states);
if (!parsedStates.isEmpty()) {
return parsedStates.getFirst();
}
initialStates.addAll(parsedStates);
}
}
}
return states.getFirst();
return initialStates.isEmpty() ? List.of(states.get(0)) : initialStates.stream().toList();
}

/**
* Gets the initial state from an interface
* If not explicitely defined, uses the first state in the state set
* @param ctInterface the CtInterface
* @param className the class name
* @return initial state
* @return initial states
*/
private static String getInitialStateFromInterface(CtInterface<?> ctInterface, String className, List<String> states) {
private static List<String> getInitialStatesFromInterface(CtInterface<?> ctInterface, String className, List<String> states) {
Set<String> initialStates = new HashSet<>();
for (CtMethod<?> method : ctInterface.getMethods()) {
if (method.getSimpleName().equals(className)) {
for (CtAnnotation<?> annotation : method.getAnnotations()) {
if (annotation.getAnnotationType().getSimpleName().equals(STATE_REFINEMENT_ANNOTATION)) {
String to = annotation.getValueAsString("to");
List<String> parsedStates = parseStateExpression(to, states);
if (!parsedStates.isEmpty()) {
return parsedStates.getFirst();
}
initialStates.addAll(parsedStates);
}
}
}
}
return states.isEmpty() ? null : states.getFirst();
return initialStates.isEmpty() ? List.of(states.get(0)) : initialStates.stream().toList();
}

/**
Expand Down Expand Up @@ -261,29 +261,36 @@ private static List<String> parseStateExpression(String expr, List<String> state
*/
private static List<String> getStateExpressions(Expression expr, List<String> states) {
List<String> stateExpressions = new ArrayList<>();
if (expr instanceof Var var) {
stateExpressions.add(var.getName());
} else if (expr instanceof FunctionInvocation func) {
stateExpressions.add(func.getName());
} else if (expr instanceof GroupExpression group) {
stateExpressions.addAll(getStateExpressions(group.getExpression(), states));
} else if (expr instanceof BinaryExpression bin) {
String op = bin.getOperator();
if (op.equals("||")) {
// combine states from both operands
stateExpressions.addAll(getStateExpressions(bin.getFirstOperand(), states));
stateExpressions.addAll(getStateExpressions(bin.getSecondOperand(), states));
switch (expr) {
case Var var -> stateExpressions.add(var.getName());
case FunctionInvocation func -> stateExpressions.add(func.getName());
case GroupExpression group -> stateExpressions.addAll(getStateExpressions(group.getExpression(), states));
case BinaryExpression bin -> {
String op = bin.getOperator();
if (op.equals("||")) {
// combine states from both operands
stateExpressions.addAll(getStateExpressions(bin.getFirstOperand(), states));
stateExpressions.addAll(getStateExpressions(bin.getSecondOperand(), states));
}
}
} else if (expr instanceof UnaryExpression unary) {
if (unary.getOp().equals("!")) {
// all except those in the expression
List<String> negatedStates = getStateExpressions(unary.getExpression(), states);
for (String state : states) {
if (!negatedStates.contains(state)) {
stateExpressions.add(state);
case UnaryExpression unary -> {
if (unary.getOp().equals("!")) {
// all except those in the expression
List<String> negatedStates = getStateExpressions(unary.getExpression(), states);
for (String state : states) {
if (!negatedStates.contains(state)) {
stateExpressions.add(state);
}
}
}
}
case Ite ite -> {
// combine states from then and else branches
// TODO: handle conditional transitions
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets add this as an issue so we don't forget about it, one way to do it is adding the expression to the label like its done in symbolic automata

stateExpressions.addAll(getStateExpressions(ite.getThen(), states));
stateExpressions.addAll(getStateExpressions(ite.getElse(), states));
}
default -> {}
}
return stateExpressions;
}
Expand Down