diff --git a/examples/Makefile b/examples/Makefile index 3f0dd36f..71b0c703 100644 --- a/examples/Makefile +++ b/examples/Makefile @@ -1,4 +1,4 @@ -PYTHON=uv run python3 +PYTHON=uv run PY_FILES=$(wildcard *.py) IPYNB_FILES=$(addprefix ../ipynb-examples/, $(PY_FILES:.py=.ipynb)) diff --git a/examples/example3-statemachine.py b/examples/example3-statemachine.py index 78a78215..1b535671 100644 --- a/examples/example3-statemachine.py +++ b/examples/example3-statemachine.py @@ -15,10 +15,8 @@ dispense = pyrtl.Output(1, "dispense") refund = pyrtl.Output(1, "refund") -state = pyrtl.Register(3, "state") - -# First new step, let's enumerate a set of constants to serve as our states +# First new step, let's enumerate a set of constants for all possible states. class State(enum.IntEnum): WAIT = 0 # Waiting for first token. TOK1 = 1 # Received first token, waiting for second token. @@ -28,6 +26,11 @@ class State(enum.IntEnum): RFND = 5 # Issue refund. +# Define a `Register`, that calculates its bitwidth from the largest possible `State`. +# By default, `State` names like `WAIT` will display in traces, instead of state numbers +# like `0`. +state = pyrtl.Register(name="state", States=State) + # Now we could build a state machine using just the `Registers` and logic discussed in # prior examples, but doing operations **conditionally** on some input is a pretty # fundamental operation in hardware design. PyRTL provides `conditional_assignment` to @@ -114,11 +117,9 @@ class State(enum.IntEnum): sim.step_multiple(sim_inputs) # Also, to make our input/output easy to reason about let's specify an order to the -# traces with `trace_list`. We also use `enum_name` to display the state names (`WAIT`, -# `TOK1`, ...) rather than their numbers (0, 1, ...). +# traces with `trace_list`. sim.tracer.render_trace( - trace_list=["token_in", "req_refund", "state", "dispense", "refund"], - repr_per_name={"state": pyrtl.enum_name(State)}, + trace_list=["token_in", "req_refund", "state", "dispense", "refund"] ) # Finally, suppose you want to simulate your design and verify its output matches your diff --git a/ipynb-examples/example3-statemachine.ipynb b/ipynb-examples/example3-statemachine.ipynb index b2cd00c2..a5003693 100644 --- a/ipynb-examples/example3-statemachine.ipynb +++ b/ipynb-examples/example3-statemachine.ipynb @@ -48,16 +48,14 @@ "req_refund = pyrtl.Input(1, \"req_refund\")\n", "\n", "dispense = pyrtl.Output(1, \"dispense\")\n", - "refund = pyrtl.Output(1, \"refund\")\n", - "\n", - "state = pyrtl.Register(3, \"state\")\n" + "refund = pyrtl.Output(1, \"refund\")\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - " First new step, let's enumerate a set of constants to serve as our states\n" + " First new step, let's enumerate a set of constants for all possible states.\n" ] }, { @@ -77,6 +75,26 @@ " RFND = 5 # Issue refund.\n" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " Define a [Register](https://pyrtl.readthedocs.io/en/latest/basic.html#pyrtl.Register), that calculates its bitwidth from the largest possible `State`.\n", + " By default, `State` names like `WAIT` will display in traces, instead of state numbers\n", + " like `0`.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "state = pyrtl.Register(name=\"state\", States=State)\n" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -216,8 +234,7 @@ "metadata": {}, "source": [ " Also, to make our input/output easy to reason about let's specify an order to the\n", - " traces with `trace_list`. We also use `enum_name` to display the state names (`WAIT`,\n", - " `TOK1`, ...) rather than their numbers (0, 1, ...).\n" + " traces with `trace_list`.\n" ] }, { @@ -229,8 +246,7 @@ "outputs": [], "source": [ "sim.tracer.render_trace(\n", - " trace_list=[\"token_in\", \"req_refund\", \"state\", \"dispense\", \"refund\"],\n", - " repr_per_name={\"state\": pyrtl.enum_name(State)},\n", + " trace_list=[\"token_in\", \"req_refund\", \"state\", \"dispense\", \"refund\"]\n", ")\n" ] }, diff --git a/pyrtl/simulation.py b/pyrtl/simulation.py index 045aa3b8..9a5346e8 100644 --- a/pyrtl/simulation.py +++ b/pyrtl/simulation.py @@ -1106,6 +1106,8 @@ def invoke_f(f, value): if f is not None: return invoke_f(f, value) + if isinstance(wire, Register) and wire.States is not None: + return invoke_f(enum_name(wire.States), value) return invoke_f(repr_func, value) def render_val( @@ -1141,20 +1143,19 @@ def render_val( _prev_line* fields in RendererConstants. :param is_last: If True, current_val is in the last cycle. """ - if len(w) > 1 or w.name in repr_per_name: + is_state_register = isinstance(w, Register) and w.States is not None + if len(w) > 1 or w.name in repr_per_name or is_state_register: # Render values in boxes for multi-bit wires ("bus"), or single-bit wires # with a specific representation. # # We display multi-wire zero values as a centered horizontal line when a # specific `repr_per_name` is not requested for this trace, and a standard # numeric format is requested. - flat_zero = w.name not in repr_per_name and ( - repr_func is hex - or repr_func is oct - or repr_func is int - or repr_func is str - or repr_func is bin - or repr_func is val_to_signed_integer + numeric_formats = [hex, oct, int, str, bin, val_to_signed_integer] + flat_zero = ( + w.name not in repr_per_name + and not is_state_register + and repr_func in numeric_formats ) if prev_line: # Bus wires are currently never rendered across multiple lines. @@ -1956,7 +1957,7 @@ def print_perf_counters(self, *trace_names: str, file=sys.stdout): def enum_name(EnumClass: type) -> Callable[[int], str]: - """Returns a function that returns the name of an :class:`enum.IntEnum` value. + """Returns a function that returns the name of an :class:`~enum.IntEnum` value. .. doctest only:: @@ -1965,32 +1966,44 @@ def enum_name(EnumClass: type) -> Callable[[int], str]: >>> pyrtl.reset_working_block() Use ``enum_name`` as a ``repr_func`` or ``repr_per_name`` for - :meth:`SimulationTrace.render_trace` to display :class:`enum.IntEnum` names in + :meth:`~SimulationTrace.render_trace` to display :class:`~enum.IntEnum` names in traces, instead of their numeric value. Example:: - >>> class State(enum.IntEnum): + >>> class Option(enum.IntEnum): ... FOO = 0 ... BAR = 1 - >>> state = pyrtl.Input(name="state", bitwidth=1) + >>> pyrtl.enum_name(Option)(1) + 'BAR' + + >>> option = pyrtl.Input(name="option", bitwidth=1) >>> sim = pyrtl.Simulation() - >>> sim.step_multiple({"state": [State.FOO, State.BAR]}) - >>> sim.tracer.render_trace(repr_per_name={"state": pyrtl.enum_name(State)}) + >>> sim.step_multiple({"option": [Option.FOO, Option.BAR]}) + >>> sim.tracer.render_trace(repr_per_name={"option": pyrtl.enum_name(Option)}) Which prints:: - │0 │1 + │0 │1 + + option FOO│BAR + + .. note:: - state FOO│BAR + When using ``enum_name`` with a :class:`.Register`, consider constructing + :class:`.Register` with a ``State`` instead. See :meth:`.Register.__init__`. - :param EnumClass: ``enum`` to convert. This is the enum class, like ``State``, not - an enum value, like ``State.FOO`` or ``1``. + :param EnumClass: ``enum`` to convert. This is the enum class, like ``Option``, not + an enum value, like ``Option.FOO`` or ``1``. - :return: A function that accepts an enum value, like ``State.FOO`` or ``1``, and - returns the value's name as a string, like ``"FOO"``. + :return: A function that accepts an enum value, like ``Option.FOO`` or ``1``, and + returns the value's name as a string, like ``"FOO"``. Unknown values will + be converted to string with :class:`hex`. """ def value_to_name(value: int) -> str: - return EnumClass(value).name + try: + return EnumClass(value).name + except ValueError: + return hex(value) return value_to_name diff --git a/pyrtl/wire.py b/pyrtl/wire.py index 8c93e297..43af56b2 100644 --- a/pyrtl/wire.py +++ b/pyrtl/wire.py @@ -16,6 +16,7 @@ from __future__ import annotations +import enum import numbers import re import traceback @@ -1661,15 +1662,15 @@ def __ior__(self, _): class Register(WireVector): - """A WireVector with an embedded register state element. + """A :class:`WireVector` with an embedded register state element. - Registers only update their outputs on the rising edges of an implicit clock signal. - The "value" in the current cycle can be accessed by referencing the Register itself. - To set the value for the next cycle (after the next rising clock edge), set the - :attr:`Register.next` property with the ``<<=`` (:meth:`~WireVector.__ilshift__`) - operator. + ``Registers`` only update their outputs on the rising edges of an implicit clock + signal. The "value" in the current cycle can be accessed by referencing the + ``Register`` itself. To set the value for the next cycle (after the next rising + clock edge), set the :attr:`Register.next` property with the ``<<=`` + (:meth:`~WireVector.__ilshift__`) operator. - Registers reset to zero by default, and reside in the same clock domain. + ``Registers`` reset to zero by default, and reside in the same clock domain. .. doctest only:: @@ -1806,28 +1807,81 @@ def __bool__(self): def __init__( self, - bitwidth: int, + bitwidth: int | None = None, name: str = "", reset_value: int | None = None, - block: Block = None, + block: Block | None = None, + States: type[enum.IntEnum] | None = None, ): - """Construct a register. + """Construct a ``Register``. + + .. doctest only:: + + >>> import pyrtl + >>> pyrtl.reset_working_block() + + Example with ``States``:: + + >>> import enum + >>> class MyStates(enum.IntEnum): + ... ZERO = 0 + ... ONE = 1 + ... TWO = 2 + ... THREE = 3 + + >>> state = pyrtl.Register( + ... name="state", States=MyStates, reset_value=MyStates.ONE + ... ) + >>> state.bitwidth + 2 - It is an error if the ``reset_value`` cannot fit into the specified ``bitwidth`` - for this register. + >>> state.next <<= state + 1 - :param bitwidth: Number of bits to represent this register. - :param name: The name of the register's current value (``reg``, not + >>> sim = pyrtl.Simulation() + >>> sim.step_multiple(nsteps=4) + >>> sim.tracer.render_trace() + + Which prints:: + + │0 │1 │2 │3 + state ONE │TWO │THREE│ZERO + + :param bitwidth: Number of bits to represent this ``Register``. + :param name: The name of the ``Register``'s current value (``reg``, not ``reg.next``). Must be unique. If none is provided, one will be autogenerated. - :param reset_value: Value to initialize this register to during simulation and - in any code (e.g. Verilog) that is exported. Defaults to 0. Can be + :param reset_value: Value to initialize this ``Register`` to during simulation + and in any code (e.g. Verilog) that is exported. Defaults to 0. Can be overridden at simulation time. - :param block: The block under which the wire should be placed. Defaults to the - :ref:`working_block`. + :param block: The :class:`Block` under which the wire should be placed. Defaults + to the :ref:`working_block`. + :param States: An :class:`~enum.IntEnum` defining all possible states for the + ``Register``. This should be an :class:`~enum.IntEnum` class, like + ``MyStates`` in the example above. If ``bitwidth`` is ``None``, the largest + value in the :class:`~enum.IntEnum` determines the ``Register``'s + ``bitwidth``. When ``States`` is not ``None``, + :meth:`~.SimulationTrace.render_trace` defaults to displaying enumeration + names rather than hex values. + + :raises PyrtlError: If the ``reset_value`` or ``States`` cannot fit into the + specified ``bitwidth`` for this register. """ from pyrtl.helperfuncs import infer_val_and_bitwidth + self.States = States + if States is not None: + largest_state = max(States) + inferred_bitwidth = infer_val_and_bitwidth(largest_state).bitwidth + if bitwidth is None: + bitwidth = inferred_bitwidth + + if bitwidth < inferred_bitwidth: + msg = ( + f"The largest State {largest_state.name} ({largest_state}) cannot " + f"fit in the specified {bitwidth} bits for this register" + ) + raise PyrtlError(msg) + super().__init__(bitwidth=bitwidth, name=name, block=block) self.reg_in = None # wire vector setting self.next if reset_value is not None: diff --git a/tests/test_simulation.py b/tests/test_simulation.py index 63e1e183..35e57a0a 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -264,6 +264,11 @@ class State(enum.IntEnum): FOO = 0 BAR = 1 + state_name = pyrtl.enum_name(State) + self.assertEqual(state_name(0), "FOO") + self.assertEqual(state_name(1), "BAR") + self.assertEqual(state_name(2), "0x2") + state = pyrtl.Input(name="state", bitwidth=1) sim = pyrtl.Simulation() sim.step_multiple({state.name: [State.FOO, State.BAR]}) @@ -271,9 +276,33 @@ class State(enum.IntEnum): sim.tracer.render_trace( file=buff, renderer=self.renderer, - repr_per_name={state.name: pyrtl.enum_name(State)}, + repr_per_name={state.name: state_name}, ) - expected = " |0 |1 \n \nstate FOO|BAR\n" + expected = ( + " |0 |1 \n" + " \n" + "state FOO|BAR\n" + ) # fmt: skip + self.assertEqual(buff.getvalue(), expected) + + def test_state_register(self): + class State(enum.IntEnum): + A = 0 + B = 1 + C = 2 + D = 3 + + state = pyrtl.Register(name="state", States=State, reset_value=State.B) + state.next <<= state + 1 + sim = pyrtl.Simulation() + sim.step_multiple(nsteps=4) + buff = io.StringIO() + sim.tracer.render_trace(file=buff, renderer=self.renderer) + expected = ( + " |0|1|2|3\n" + " \n" + "state B|C|D|A\n" + ) # fmt: skip self.assertEqual(buff.getvalue(), expected) def test_val_to_signed_integer(self): @@ -286,7 +315,11 @@ def test_val_to_signed_integer(self): sim.tracer.render_trace( file=buff, renderer=self.renderer, repr_func=pyrtl.val_to_signed_integer ) - expected = " |0 |1 |2 |3 \n \ncounter --|1 |-2|-1\n" + expected = ( + " |0 |1 |2 |3 \n" + " \n" + "counter --|1 |-2|-1\n" + ) # fmt: skip self.assertEqual(buff.getvalue(), expected) def test_custom_repr_per_wire(self): diff --git a/tests/test_wire.py b/tests/test_wire.py index a076596d..52511afd 100644 --- a/tests/test_wire.py +++ b/tests/test_wire.py @@ -1,4 +1,5 @@ import doctest +import enum import unittest import pyrtl @@ -298,6 +299,35 @@ def test_invalid_reset_value_not_an_integer(self): pyrtl.Register(4, reset_value="hello") +class TestStateRegister(unittest.TestCase): + def setUp(self): + pyrtl.reset_working_block() + + def test_bitwidth(self): + class PowerOfTwoStates(enum.IntEnum): + ZERO = 0 + ONE = 1 + TWO = 2 + THREE = 3 + + state = pyrtl.Register(States=PowerOfTwoStates) + self.assertEqual(state.bitwidth, 2) + + class NotPowerOfTwoStates(enum.IntEnum): + ZERO = 0 + ONE = 1 + TWO = 2 + FOUR = 4 + THREE = 3 + + state = pyrtl.Register(States=NotPowerOfTwoStates) + self.assertEqual(state.bitwidth, 3) + + with self.assertRaises(pyrtl.PyrtlError): + # Bitwidth 1 is too small to fit PowerOfTwoStates.THREE. + state = pyrtl.Register(bitwidth=1, States=PowerOfTwoStates) + + class TestConst(unittest.TestCase): def setUp(self): pyrtl.reset_working_block()