Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ Cargo.lock
**/.env
.DS_Store

# Log outputs
*.log

.cache/
rustc-*

Expand Down
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion crates/sdk/src/keygen/dummy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ use crate::{
/// - trace heights ordered by AIR ID
///
/// All trace heights are rounded to the next power of two (or 0 -> 0).
pub(super) fn compute_root_proof_heights(
pub fn compute_root_proof_heights(
root_vm: &mut VirtualMachine<BabyBearPoseidon2RootEngine, NativeCpuBuilder>,
root_committed_exe: &VmCommittedExe<BabyBearPoseidon2RootConfig>,
dummy_internal_proof: &Proof<SC>,
Expand Down
2 changes: 1 addition & 1 deletion crates/sdk/src/keygen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ use crate::{
};

pub mod asm;
pub(crate) mod dummy;
pub mod dummy;
pub mod perm;
#[cfg(feature = "evm-prove")]
pub mod static_verifier;
Expand Down
4 changes: 2 additions & 2 deletions crates/sdk/src/keygen/perm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::cmp::Reverse;
use openvm_continuations::verifier::common::types::SpecialAirIds;

/// Permutation of the AIR IDs to order them by forced trace heights.
pub(crate) struct AirIdPermutation {
pub struct AirIdPermutation {
pub perm: Vec<usize>,
}

Expand Down Expand Up @@ -47,7 +47,7 @@ impl AirIdPermutation {
ret
}
/// arr[i] <- arr[perm[i]]
pub(crate) fn permute<T>(&self, arr: &mut [T]) {
pub fn permute<T>(&self, arr: &mut [T]) {
debug_assert_eq!(arr.len(), self.perm.len());
let mut perm = self.perm.clone();
for i in 0..perm.len() {
Expand Down
6 changes: 3 additions & 3 deletions crates/sdk/src/prover/agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ where
E: StarkFriEngine<SC = SC>,
NativeBuilder: VmBuilder<E, VmConfig = NativeConfig>,
{
leaf_prover: VmInstance<E, NativeBuilder>,
leaf_controller: LeafProvingController,
pub leaf_prover: VmInstance<E, NativeBuilder>,
pub leaf_controller: LeafProvingController,

pub internal_prover: VmInstance<E, NativeBuilder>,
#[cfg(feature = "evm-prove")]
root_prover: RootVerifierLocalProver,
pub root_prover: RootVerifierLocalProver,
pub num_children_internal: usize,
pub max_internal_wrapper_layers: usize,
}
Expand Down
8 changes: 7 additions & 1 deletion crates/vm/src/arch/vm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -910,7 +910,7 @@ where
#[getset(get = "pub")]
exe: Arc<VmExe<Val<E::SC>>>,
#[getset(get = "pub", get_mut = "pub")]
state: Option<VmState<Val<E::SC>, GuestMemory>>,
pub state: Option<VmState<Val<E::SC>, GuestMemory>>,
}

impl<E, VB> VmInstance<E, VB>
Expand Down Expand Up @@ -1056,6 +1056,8 @@ where
let mut trace_heights = trace_heights.to_vec();
trace_heights[PUBLIC_VALUES_AIR_ID] = vm.config().as_ref().num_public_values as u32;
let state = self.state.take().expect("State should always be present");
#[cfg(feature = "metrics")]
let debug_infos = state.metrics.debug_infos.clone();
let num_custom_pvs = state.custom_pvs.len();
let (proof, final_memory) = vm.prove(&mut self.interpreter, state, None, &trace_heights)?;
let final_memory = final_memory.ok_or(ExecutionError::DidNotTerminate)?;
Expand All @@ -1068,6 +1070,10 @@ where
DEFAULT_RNG_SEED,
num_custom_pvs,
));
#[cfg(feature = "metrics")]
{
self.state.as_mut().unwrap().metrics.debug_infos = debug_infos;
}
Ok(proof)
}
}
Expand Down
44 changes: 36 additions & 8 deletions crates/vm/src/metrics/cycle_tracker/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
/// Stats for a nested span in the execution segment that is tracked by the [`CycleTracker`].
#[derive(Clone, Debug, Default)]
pub struct SpanInfo {
/// The name of the span.
pub tag: String,
/// The cycle count at which the span starts.
pub start: usize,
}

#[derive(Clone, Debug, Default)]
pub struct CycleTracker {
/// Stack of span names, with most recent at the end
stack: Vec<String>,
stack: Vec<SpanInfo>,
/// Depth of the stack.
depth: usize,
}

impl CycleTracker {
Expand All @@ -10,29 +21,42 @@ impl CycleTracker {
}

pub fn top(&self) -> Option<&String> {
self.stack.last()
match self.stack.last() {
Some(span) => Some(&span.tag),
_ => None,
}
}

/// Starts a new cycle tracker span for the given name.
/// If a span already exists for the given name, it ends the existing span and pushes a new one
/// to the vec.
pub fn start(&mut self, mut name: String) {
pub fn start(&mut self, mut name: String, cycles_count: usize) {
// hack to remove "CT-" prefix
if name.starts_with("CT-") {
name = name.split_off(3);
}
self.stack.push(name);
self.stack.push(SpanInfo {
tag: name.clone(),
start: cycles_count,
});
let padding = "│ ".repeat(self.depth);
tracing::info!("{}┌╴{}", padding, name);
self.depth += 1;
}

/// Ends the cycle tracker span for the given name.
/// If no span exists for the given name, it panics.
pub fn end(&mut self, mut name: String) {
pub fn end(&mut self, mut name: String, cycles_count: usize) {
// hack to remove "CT-" prefix
if name.starts_with("CT-") {
name = name.split_off(3);
}
let stack_top = self.stack.pop();
assert_eq!(stack_top.unwrap(), name, "Stack top does not match name");
let SpanInfo { tag, start } = self.stack.pop().unwrap();
assert_eq!(tag, name, "Stack top does not match name");
self.depth -= 1;
let padding = "│ ".repeat(self.depth);
let span_cycles = cycles_count - start;
tracing::info!("{}└╴{} cycles", padding, span_cycles);
}

/// Ends the current cycle tracker span.
Expand All @@ -42,7 +66,11 @@ impl CycleTracker {

/// Get full name of span with all parent names separated by ";" in flamegraph format
pub fn get_full_name(&self) -> String {
self.stack.join(";")
self.stack
.iter()
.map(|span_info| span_info.tag.clone())
.collect::<Vec<String>>()
.join(";")
}
}

Expand Down
2 changes: 1 addition & 1 deletion crates/vm/src/metrics/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ impl VmMetrics {
.map(|(_, func)| (*func).clone())
.unwrap();
if pc == self.current_fn.start {
self.cycle_tracker.start(self.current_fn.name.clone());
self.cycle_tracker.start(self.current_fn.name.clone(), 0);
} else {
while let Some(name) = self.cycle_tracker.top() {
if name == &self.current_fn.name {
Expand Down
2 changes: 1 addition & 1 deletion crates/vm/src/system/memory/offline_checker/columns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ impl<const N: usize, T> MemoryWriteAuxCols<T, N> {
#[repr(C)]
#[derive(Clone, Copy, Debug, AlignedBorrow)]
pub struct MemoryReadAuxCols<T> {
pub(in crate::system::memory) base: MemoryBaseAuxCols<T>,
pub base: MemoryBaseAuxCols<T>,
}

impl<F: PrimeField32> MemoryReadAuxCols<F> {
Expand Down
32 changes: 31 additions & 1 deletion extensions/native/circuit/cuda/include/native/poseidon2.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,42 @@ template <typename T> struct SimplePoseidonSpecificCols {
MemoryWriteAuxCols<T, CHUNK> write_data_2;
};

template <typename T> struct MultiObserveCols {
T pc;
T final_timestamp_increment;
T state_ptr;
T input_ptr;
T init_pos;
T len;
T input_register_1;
T input_register_2;
T input_register_3;
T output_register;
T is_first;
T is_last;
T curr_len;
T start_idx;
T end_idx;
T aux_after_start[CHUNK];
T aux_before_end[CHUNK];
T aux_read_enabled[CHUNK];
MemoryReadAuxCols<T> read_data[CHUNK];
MemoryWriteAuxCols<T, 1> write_data[CHUNK];
T data[CHUNK];
T should_permute;
MemoryWriteAuxCols<T, CHUNK * 2> write_sponge_state;
MemoryWriteAuxCols<T, 1> write_final_idx;
};

template <typename T> constexpr T constexpr_max(T a, T b) { return a > b ? a : b; }

constexpr size_t COL_SPECIFIC_WIDTH = constexpr_max(
sizeof(TopLevelSpecificCols<uint8_t>),
constexpr_max(
sizeof(InsideRowSpecificCols<uint8_t>),
sizeof(SimplePoseidonSpecificCols<uint8_t>)
constexpr_max(
sizeof(SimplePoseidonSpecificCols<uint8_t>),
sizeof(MultiObserveCols<uint8_t>)
)
)
);
89 changes: 89 additions & 0 deletions extensions/native/circuit/cuda/include/native/sumcheck.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
#pragma once

#include "primitives/constants.h"
#include "system/memory/offline_checker.cuh"

using namespace native;

template <typename T> struct HeaderSpecificCols {
T pc;
T registers[5];
T prod_evals_id;
T logup_evals_id;
MemoryReadAuxCols<T> read_records[8];
MemoryWriteAuxCols<T, EXT_DEG> write_records;
};

template <typename T> struct ProdSpecificCols {
T data_ptr;
T p[EXT_DEG * 2];
T p_evals[EXT_DEG];
MemoryWriteAuxCols<T, EXT_DEG> write_record;
MemoryWriteAuxCols<T, EXT_DEG * 2> ps_record;
T eval_rlc[EXT_DEG];
};

template <typename T> struct LogupSpecificCols {
T data_ptr;
T pq[EXT_DEG * 4];
T p_evals[EXT_DEG];
T q_evals[EXT_DEG];
MemoryWriteAuxCols<T, EXT_DEG * 4> pqs_record;
MemoryWriteAuxCols<T, EXT_DEG> write_records[2];
T eval_rlc[EXT_DEG];
};

template <typename T> constexpr T constexpr_max(T a, T b) {
return a > b ? a : b;
}

constexpr size_t COL_SPECIFIC_WIDTH = constexpr_max(
sizeof(HeaderSpecificCols<uint8_t>),
constexpr_max(sizeof(ProdSpecificCols<uint8_t>), sizeof(LogupSpecificCols<uint8_t>))
);

template <typename T> struct NativeSumcheckCols {
T header_row;
T prod_row;
T logup_row;
T is_end;

T prod_continued;
T logup_continued;

T prod_in_round_evaluation;
T prod_next_round_evaluation;
T logup_in_round_evaluation;
T logup_next_round_evaluation;

T prod_acc;
T logup_acc;

T first_timestamp;
T start_timestamp;
T last_timestamp;

T register_ptrs[5];

T ctx[EXT_DEG * 2];

T prod_nested_len;
T logup_nested_len;

T curr_prod_n;
T curr_logup_n;

T alpha[EXT_DEG];
T challenges[EXT_DEG * 4];

T max_round;
T within_round_limit;
T should_acc;

T eval_acc[EXT_DEG];

T is_hint_src_id;

T specific[COL_SPECIFIC_WIDTH];
};

13 changes: 13 additions & 0 deletions extensions/native/circuit/cuda/include/native/utils.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#pragma once

#include "primitives/trace_access.h"
#include "system/memory/controller.cuh"

__device__ __forceinline__ void mem_fill_base(
MemoryAuxColsFactory &mem_helper,
uint32_t timestamp,
RowSlice base_aux
) {
uint32_t prev = base_aux[COL_INDEX(MemoryBaseAuxCols, prev_timestamp)].asUInt32();
mem_helper.fill(base_aux, prev, timestamp);
}
Loading
Loading