diff --git a/Benchmarks/IxVM.lean b/Benchmarks/IxVM.lean new file mode 100644 index 00000000..bae83fed --- /dev/null +++ b/Benchmarks/IxVM.lean @@ -0,0 +1,45 @@ +import Ix.Meta +import Ix.IxVM +import Ix.Aiur.Simple +import Ix.Aiur.Compile +import Ix.Aiur.Protocol +import Ix.Benchmark.Bench + +def commitmentParameters : Aiur.CommitmentParameters := { + logBlowup := 1 +} + +def friParameters : Aiur.FriParameters := { + logFinalPolyLen := 0 + numQueries := 100 + commitProofOfWorkBits := 20 + queryProofOfWorkBits := 0 +} + +def main : IO Unit := do + let .ok toplevel := IxVM.ixVM + | throw (IO.userError "Merging failed") + let some funIdx := toplevel.getFuncIdx `ixon_serde_blake3_bench + | throw (IO.userError "Aiur function not found") + let .ok decls := toplevel.checkAndSimplify + | throw (IO.userError "Simplification failed") + let .ok bytecode := decls.compile + | throw (IO.userError "Compilation failed") + let aiurSystem := Aiur.AiurSystem.build bytecode commitmentParameters + + let env ← get_env! + let natAddCommName := ``Nat.add_comm + let constList := Lean.collectDependencies natAddCommName env.constants + let rawEnv ← Ix.CompileM.rsCompileEnvFFI constList + let ixonEnv := rawEnv.toEnv + let ixonConsts := ixonEnv.consts.valuesIter + let (ioBuffer, n) := ixonConsts.fold (init := (default, 0)) fun (ioBuffer, i) c => + let (_, bytes) := Ixon.Serialize.put c |>.run default + (ioBuffer.extend #[.ofNat i] (bytes.data.map .ofUInt8), i + 1) + + let _report ← oneShotBench "IxVM benchmarks" + (bench "serde/blake3 Nat.add_comm" + (aiurSystem.prove friParameters funIdx #[.ofNat n]) + ioBuffer) + { oneShot := true } + return diff --git a/Ix/Aiur/Meta.lean b/Ix/Aiur/Meta.lean index 6281075a..40e8cb94 100644 --- a/Ix/Aiur/Meta.lean +++ b/Ix/Aiur/Meta.lean @@ -97,6 +97,7 @@ declare_syntax_cat trm syntax ("." noWs)? ident : trm -- syntax "cast" "(" trm ", " typ ")" : trm syntax num : trm +syntax "(" ")" : trm syntax "(" trm (", " trm)* ")" : trm syntax "[" trm (", " trm)* "]" : trm syntax "[" trm "; " num "]" : trm @@ -151,6 +152,7 @@ partial def elabTrm : ElabStxCat `trm | `(trm| $n:num) => do let data ← mkAppM ``Data.field #[← elabG n] mkAppM ``Term.data #[data] + | `(trm| ()) => pure $ mkConst ``Term.unit | `(trm| ($t:trm $[, $ts:trm]*)) => do if ts.isEmpty then elabTrm t else diff --git a/Ix/Aiur/Protocol.lean b/Ix/Aiur/Protocol.lean index 36482f90..aee0b44d 100644 --- a/Ix/Aiur/Protocol.lean +++ b/Ix/Aiur/Protocol.lean @@ -43,8 +43,18 @@ structure IOBuffer where map : Std.HashMap (Array G) IOKeyInfo deriving Inhabited +def IOBuffer.extend (ioBuffer : IOBuffer) (key data : Array G) : IOBuffer := + let idx := ioBuffer.data.size + let len := data.size + { ioBuffer with + data := ioBuffer.data ++ data + map := ioBuffer.map.insert key { idx, len } } + instance : BEq IOBuffer where - beq x y := x.data == y.data && x.map.toArray == y.map.toArray + beq x y := + x.data == y.data && + x.map.size == y.map.size && + x.map.all fun k v => y.map.get? k == some v namespace AiurSystem diff --git a/Ix/Common.lean b/Ix/Common.lean index 2aedc497..ea202220 100644 --- a/Ix/Common.lean +++ b/Ix/Common.lean @@ -307,6 +307,62 @@ def runFrontend (input : String) (filePath : FilePath) : IO Environment := do (← msgs.toList.mapM (·.toString)).map (String.trimAscii · |>.toString) else return s.commandState.env +abbrev ConstList := List (Lean.Name × Lean.ConstantInfo) +private abbrev CollectM := StateM Lean.NameHashSet + +private partial def collectDependenciesAux (const : Lean.ConstantInfo) + (consts : Lean.ConstMap) (acc : ConstList) : CollectM ConstList := do + modify (·.insert const.name) + match const with + | .ctorInfo val => + let acc ← collectNames [val.induct] acc + goExpr consts acc val.type + | .axiomInfo val | .quotInfo val => goExpr consts acc val.type + | .inductInfo val => + let acc ← collectNames val.all acc + let acc ← collectNames val.ctors acc + goExpr consts acc val.type + | .defnInfo val | .thmInfo val | .opaqueInfo val => + let acc ← collectNames val.all acc + let acc ← goExpr consts acc val.type + goExpr consts acc val.value + | .recInfo val => + let acc ← collectNames val.all acc + let acc ← goExpr consts acc val.type + val.rules.foldlM (init := acc) fun acc rule => goExpr consts acc rule.rhs +where + collectNames all acc := do + let visited ← get + all.foldlM (init := acc) fun acc name => + if visited.contains name then pure acc + else + let const := consts.find! name + collectDependenciesAux const consts $ (name, const) :: acc + goExpr (consts : Lean.ConstMap) (acc : ConstList) : Lean.Expr → CollectM ConstList + | .bvar _ | .fvar _ | .mvar _ | .sort _ | .lit _ => pure acc + | .const name _ => do + let visited ← get + if visited.contains name then pure acc + else + let const := consts.find! name + collectDependenciesAux const consts $ (name, const) :: acc + | .app f a => do + let acc ← goExpr consts acc f + goExpr consts acc a + | .lam _ t b _ | .forallE _ t b _ => do + let acc ← goExpr consts acc t + goExpr consts acc b + | .letE _ t v b _ => do + let acc ← goExpr consts acc t + let acc ← goExpr consts acc v + goExpr consts acc b + | .mdata _ e | .proj _ _ e => goExpr consts acc e + +def collectDependencies (name : Lean.Name) (consts : Lean.ConstMap) : ConstList := + let const := consts.find! name + let (constList, _) := collectDependenciesAux const consts [(name, const)] default + constList + end Lean /-- Format a duration in milliseconds with appropriate unit suffix. diff --git a/Ix/Environment.lean b/Ix/Environment.lean index 6e0beeb7..85c22d6c 100644 --- a/Ix/Environment.lean +++ b/Ix/Environment.lean @@ -110,6 +110,11 @@ partial def toStringAux : Name → String instance : ToString Name where toString := toStringAux +def fromLeanName : Lean.Name → Name + | .anonymous => .mkAnon + | .str pre s => .mkStr (.fromLeanName pre) s + | .num pre n => .mkNat (.fromLeanName pre) n + end Name /-- Compare Ix.Name by hash for ordered collections. -/ diff --git a/Ix/IxVM.lean b/Ix/IxVM.lean index 005838ff..4841cb4e 100644 --- a/Ix/IxVM.lean +++ b/Ix/IxVM.lean @@ -13,26 +13,37 @@ namespace IxVM def entrypoints := ⟦ /- # Test entrypoints -/ - -- fn ixon_blake3_test(h: [[G; 4]; 8]) { - -- let key = [ - -- h[0][0], h[0][1], h[0][2], h[0][3], - -- h[1][0], h[1][1], h[1][2], h[1][3], - -- h[2][0], h[2][1], h[2][2], h[2][3], - -- h[3][0], h[3][1], h[3][2], h[3][3], - -- h[4][0], h[4][1], h[4][2], h[4][3], - -- h[5][0], h[5][1], h[5][2], h[5][3], - -- h[6][0], h[6][1], h[6][2], h[6][3], - -- h[7][0], h[7][1], h[7][2], h[7][3] - -- ]; - -- let (idx, len) = io_get_info(key); - -- let bytes_unconstrained = read_byte_stream(idx, len); - -- let ixon_unconstrained = deserialize(bytes_unconstrained); - -- let bytes = serialize(ixon_unconstrained); - -- let bytes_hash = blake3(bytes); - -- assert_eq!(h, bytes_hash); - -- } + fn ixon_serde_test(n: G) { + match n { + 0 => (), + _ => + let n_minus_1 = n - 1; + let (idx, len) = io_get_info([n_minus_1]); + let bytes = read_byte_stream(idx, len); + let (const, rest) = get_constant(bytes); + assert_eq!(rest, ByteStream.Nil); + let bytes2 = put_constant(const, ByteStream.Nil); + assert_eq!(bytes, bytes2); + ixon_serde_test(n_minus_1), + } + } /- # Benchmark entrypoints -/ + + fn ixon_serde_blake3_bench(n: G) { + match n { + 0 => (), + _ => + let n_minus_1 = n - 1; + let (idx, len) = io_get_info([n_minus_1]); + let bytes = read_byte_stream(idx, len); + let (const, rest) = get_constant(bytes); + assert_eq!(rest, ByteStream.Nil); + let bytes2 = put_constant(const, ByteStream.Nil); + assert_eq!(blake3(bytes), blake3(bytes2)); + ixon_serde_blake3_bench(n_minus_1), + } + } ⟧ def ixVM : Except Aiur.Global Aiur.Toplevel := do diff --git a/Ix/IxVM/ByteStream.lean b/Ix/IxVM/ByteStream.lean index c63e27d6..a227471f 100644 --- a/Ix/IxVM/ByteStream.lean +++ b/Ix/IxVM/ByteStream.lean @@ -45,7 +45,9 @@ def byteStream := ⟦ } } - -- Count bytes needed to represent a u64 (0-8) + -- Count bytes needed to represent a u64. + -- Important: this implementation differs from the Lean and Rust ones, returning + -- 1 for [0; 8] instead of 0. fn u64_byte_count(x: [G; 8]) -> G { match x { [_, 0, 0, 0, 0, 0, 0, 0] => 1, @@ -189,6 +191,37 @@ def byteStream := ⟦ Nil } + -- Computes the predecessor of an `u64` assumed to be properly represented in + -- little-endian bytes. If that's not the case, this implementation has UB. + fn relaxed_u64_pred(bytes: [G; 8]) -> [G; 8] { + let [b0, b1, b2, b3, b4, b5, b6, b7] = bytes; + match b0 { + 0 => match b1 { + 0 => match b2 { + 0 => match b3 { + 0 => match b4 { + 0 => match b5 { + 0 => match b6 { + 0 => match b7 { + 0 => [0, 0, 0, 0, 0, 0, 0, 0], + _ => [255, 255, 255, 255, 255, 255, 255, b7 - 1], + }, + _ => [255, 255, 255, 255, 255, 255, b6 - 1, b7], + }, + _ => [255, 255, 255, 255, 255, b5 - 1, b6, b7], + }, + _ => [255, 255, 255, 255, b4 - 1, b5, b6, b7], + }, + _ => [255, 255, 255, b3 - 1, b4, b5, b6, b7], + }, + _ => [255, 255, b2 - 1, b3, b4, b5, b6, b7], + }, + _ => [255, b1 - 1, b2, b3, b4, b5, b6, b7], + }, + _ => [b0 - 1, b1, b2, b3, b4, b5, b6, b7], + } + } + fn u64_list_length(xs: U64List) -> [G; 8] { match xs { U64List.Nil => [0; 8], diff --git a/Ix/IxVM/IxonDeserialize.lean b/Ix/IxVM/IxonDeserialize.lean index 618e9b01..e1dc75bd 100644 --- a/Ix/IxVM/IxonDeserialize.lean +++ b/Ix/IxVM/IxonDeserialize.lean @@ -6,6 +6,603 @@ public section namespace IxVM def ixonDeserialize := ⟦ + -- ============================================================================ + -- Byte reading primitives + -- ============================================================================ + + fn read_byte(stream: ByteStream) -> (G, ByteStream) { + match stream { + ByteStream.Cons(byte, &rest) => (byte, rest), + ByteStream.Nil => (0, ByteStream.Nil), + } + } + + -- Read num_bytes little-endian bytes into a u64 + fn get_u64_le(stream: ByteStream, num_bytes: G) -> ([G; 8], ByteStream) { + match num_bytes { + 0 => ([0; 8], stream), + _ => + let (byte, s) = read_byte(stream); + let (rest_bytes, s2) = get_u64_le(s, num_bytes - 1); + let [r0, r1, r2, r3, r4, r5, r6, _] = rest_bytes; + ([byte, r0, r1, r2, r3, r4, r5, r6], s2), + } + } + + -- ============================================================================ + -- Tag parsing + -- ============================================================================ + + -- Tag0: [large:1][size:7] + fn get_tag0(stream: ByteStream) -> ([G; 8], ByteStream) { + let (byte, s) = read_byte(stream); + let bits = u8_bit_decomposition(byte); + let [b0, b1, b2, b3, b4, b5, b6, b7] = bits; + let small_size = b0 + 2 * b1 + 4 * b2 + 8 * b3 + 16 * b4 + 32 * b5 + 64 * b6; + match b7 { + 0 => + ([small_size, 0, 0, 0, 0, 0, 0, 0], s), + _ => + let num_bytes = small_size + 1; + get_u64_le(s, num_bytes), + } + } + + -- Tag2: [flag:2][large:1][size:5] + fn get_tag2(stream: ByteStream) -> ((G, [G; 8]), ByteStream) { + let (byte, s) = read_byte(stream); + let bits = u8_bit_decomposition(byte); + let [b0, b1, b2, b3, b4, b5, b6, b7] = bits; + let flag = b6 + 2 * b7; + let small_size = b0 + 2 * b1 + 4 * b2 + 8 * b3 + 16 * b4; + match b5 { + 0 => + ((flag, [small_size, 0, 0, 0, 0, 0, 0, 0]), s), + _ => + let num_bytes = small_size + 1; + let (size, s2) = get_u64_le(s, num_bytes); + ((flag, size), s2), + } + } + + -- Tag4: [flag:4][large:1][size:3] + fn get_tag4(stream: ByteStream) -> ((G, [G; 8]), ByteStream) { + let (byte, s) = read_byte(stream); + let bits = u8_bit_decomposition(byte); + let [b0, b1, b2, b3, b4, b5, b6, b7] = bits; + let flag = b4 + 2 * b5 + 4 * b6 + 8 * b7; + let small_size = b0 + 2 * b1 + 4 * b2; + match b3 { + 0 => + ((flag, [small_size, 0, 0, 0, 0, 0, 0, 0]), s), + _ => + let num_bytes = small_size + 1; + let (size, s2) = get_u64_le(s, num_bytes); + ((flag, size), s2), + } + } + + -- ============================================================================ + -- U64 list deserialization + -- ============================================================================ + + fn get_u64_list(stream: ByteStream, count: [G; 8]) -> (U64List, ByteStream) { + let is_zero = u64_is_zero(count); + match is_zero { + 1 => (U64List.Nil, stream), + 0 => + let (val, s) = get_tag0(stream); + let (rest, s2) = get_u64_list(s, relaxed_u64_pred(count)); + (U64List.Cons(val, store(rest)), s2), + } + } + + -- ============================================================================ + -- Expression deserialization + -- ============================================================================ + + -- App telescope: read count args, wrapping func in App nodes + fn get_app_telescope(func: Expr, stream: ByteStream, count: [G; 8]) -> (Expr, ByteStream) { + let is_zero = u64_is_zero(count); + match is_zero { + 1 => (func, stream), + 0 => + let (arg, s) = get_expr(stream); + let app = Expr.App(store(func), store(arg)); + get_app_telescope(app, s, relaxed_u64_pred(count)), + } + } + + -- Lam telescope: read count types then body, wrap as nested Lams + fn get_lam_telescope(stream: ByteStream, count: [G; 8]) -> (Expr, ByteStream) { + let is_zero = u64_is_zero(count); + match is_zero { + 1 => + -- No more types, read the body + get_expr(stream), + 0 => + -- Read one type, recurse for remaining types + body + let (ty, s) = get_expr(stream); + let (inner, s2) = get_lam_telescope(s, relaxed_u64_pred(count)); + (Expr.Lam(store(ty), store(inner)), s2), + } + } + + -- All telescope: read count types then body, wrap as nested Alls + fn get_all_telescope(stream: ByteStream, count: [G; 8]) -> (Expr, ByteStream) { + let is_zero = u64_is_zero(count); + match is_zero { + 1 => + -- No more types, read the body + get_expr(stream), + 0 => + -- Read one type, recurse for remaining types + body + let (ty, s) = get_expr(stream); + let (inner, s2) = get_all_telescope(s, relaxed_u64_pred(count)); + (Expr.All(store(ty), store(inner)), s2), + } + } + + fn get_expr(stream: ByteStream) -> (Expr, ByteStream) { + let (tag, s) = get_tag4(stream); + let (flag, size) = tag; + match flag { + -- Srt: Tag4(0x0, univ_idx) + 0x0 => (Expr.Srt(size), s), + + -- Var: Tag4(0x1, idx) + 0x1 => (Expr.Var(size), s), + + -- Ref: Tag4(0x2, len) + Tag0(ref_idx) + univ_list + 0x2 => + let (ref_idx, s2) = get_tag0(s); + let (univ_list, s3) = get_u64_list(s2, size); + (Expr.Ref(ref_idx, store(univ_list)), s3), + + -- Rec: Tag4(0x3, len) + Tag0(rec_idx) + univ_list + 0x3 => + let (rec_idx, s2) = get_tag0(s); + let (univ_list, s3) = get_u64_list(s2, size); + (Expr.Rec(rec_idx, store(univ_list)), s3), + + -- Prj: Tag4(0x4, field_idx) + Tag0(type_ref_idx) + expr(val) + 0x4 => + let (type_ref_idx, s2) = get_tag0(s); + let (val, s3) = get_expr(s2); + (Expr.Prj(type_ref_idx, size, store(val)), s3), + + -- Str: Tag4(0x5, ref_idx) + 0x5 => (Expr.Str(size), s), + + -- Nat: Tag4(0x6, ref_idx) + 0x6 => (Expr.Nat(size), s), + + -- App: Tag4(0x7, count) + func + args... + 0x7 => + let (func, s2) = get_expr(s); + get_app_telescope(func, s2, size), + + -- Lam: Tag4(0x8, count) + types... + body + 0x8 => get_lam_telescope(s, size), + + -- All: Tag4(0x9, count) + types... + body + 0x9 => get_all_telescope(s, size), + + -- Let: Tag4(0xA, non_dep) + expr(ty) + expr(val) + expr(body) + 0xA => + let (ty, s2) = get_expr(s); + let (val, s3) = get_expr(s2); + let (body, s4) = get_expr(s3); + (Expr.Let(size, store(ty), store(val), store(body)), s4), + + -- Share: Tag4(0xB, idx) + 0xB => (Expr.Share(size), s), + } + } + + -- ============================================================================ + -- Universe deserialization + -- ============================================================================ + + -- Build a chain of Succ constructors around a base universe + fn build_succ_chain(base: Univ, count: [G; 8]) -> Univ { + let is_zero = u64_is_zero(count); + match is_zero { + 1 => base, + 0 => + let inner = build_succ_chain(base, relaxed_u64_pred(count)); + Univ.Succ(store(inner)), + } + } + + fn get_univ(stream: ByteStream) -> (Univ, ByteStream) { + let (tag, s) = get_tag2(stream); + let (flag, size) = tag; + match flag { + -- Zero/Succ: Tag2(0, count) + 0 => + let is_zero = u64_is_zero(size); + match is_zero { + 1 => (Univ.Zero, s), + 0 => + let (base, s2) = get_univ(s); + (build_succ_chain(base, size), s2), + }, + + -- Max: Tag2(1, 0) + univ(a) + univ(b) + 1 => + let (a, s2) = get_univ(s); + let (b, s3) = get_univ(s2); + (Univ.Max(store(a), store(b)), s3), + + -- IMax: Tag2(2, 0) + univ(a) + univ(b) + 2 => + let (a, s2) = get_univ(s); + let (b, s3) = get_univ(s2); + (Univ.IMax(store(a), store(b)), s3), + + -- Var: Tag2(3, idx) + 3 => (Univ.Var(size), s), + } + } + + -- ============================================================================ + -- Address deserialization (32 bytes) + -- ============================================================================ + + fn get_address(stream: ByteStream) -> ([G; 32], ByteStream) { + let (b0, s) = read_byte(stream); + let (b1, s) = read_byte(s); + let (b2, s) = read_byte(s); + let (b3, s) = read_byte(s); + let (b4, s) = read_byte(s); + let (b5, s) = read_byte(s); + let (b6, s) = read_byte(s); + let (b7, s) = read_byte(s); + let (b8, s) = read_byte(s); + let (b9, s) = read_byte(s); + let (b10, s) = read_byte(s); + let (b11, s) = read_byte(s); + let (b12, s) = read_byte(s); + let (b13, s) = read_byte(s); + let (b14, s) = read_byte(s); + let (b15, s) = read_byte(s); + let (b16, s) = read_byte(s); + let (b17, s) = read_byte(s); + let (b18, s) = read_byte(s); + let (b19, s) = read_byte(s); + let (b20, s) = read_byte(s); + let (b21, s) = read_byte(s); + let (b22, s) = read_byte(s); + let (b23, s) = read_byte(s); + let (b24, s) = read_byte(s); + let (b25, s) = read_byte(s); + let (b26, s) = read_byte(s); + let (b27, s) = read_byte(s); + let (b28, s) = read_byte(s); + let (b29, s) = read_byte(s); + let (b30, s) = read_byte(s); + let (b31, s) = read_byte(s); + ([b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15, + b16, b17, b18, b19, b20, b21, b22, b23, b24, b25, b26, b27, b28, b29, b30, b31], s) + } + + -- ============================================================================ + -- List deserialization + -- ============================================================================ + + fn get_expr_list(stream: ByteStream, count: [G; 8]) -> (ExprList, ByteStream) { + let is_zero = u64_is_zero(count); + match is_zero { + 1 => (ExprList.Nil, stream), + 0 => + let (expr, s) = get_expr(stream); + let (rest, s2) = get_expr_list(s, relaxed_u64_pred(count)); + (ExprList.Cons(store(expr), store(rest)), s2), + } + } + + fn get_univ_list(stream: ByteStream, count: [G; 8]) -> (UnivList, ByteStream) { + let is_zero = u64_is_zero(count); + match is_zero { + 1 => (UnivList.Nil, stream), + 0 => + let (u, s) = get_univ(stream); + let (rest, s2) = get_univ_list(s, relaxed_u64_pred(count)); + (UnivList.Cons(store(u), store(rest)), s2), + } + } + + fn get_address_list(stream: ByteStream, count: [G; 8]) -> (AddressList, ByteStream) { + let is_zero = u64_is_zero(count); + match is_zero { + 1 => (AddressList.Nil, stream), + 0 => + let (addr, s) = get_address(stream); + let (rest, s2) = get_address_list(s, relaxed_u64_pred(count)); + (AddressList.Cons(addr, store(rest)), s2), + } + } + + -- ============================================================================ + -- Sharing, refs, univs table deserialization + -- ============================================================================ + + fn get_sharing(stream: ByteStream) -> (ExprList, ByteStream) { + let (len, s) = get_tag0(stream); + get_expr_list(s, len) + } + + fn get_refs(stream: ByteStream) -> (AddressList, ByteStream) { + let (len, s) = get_tag0(stream); + get_address_list(s, len) + } + + fn get_univs(stream: ByteStream) -> (UnivList, ByteStream) { + let (len, s) = get_tag0(stream); + get_univ_list(s, len) + } + + -- ============================================================================ + -- Constant structure deserialization + -- ============================================================================ + + -- Unpack DefKind and DefinitionSafety from packed byte + -- Encoding: kind * 4 + safety + -- kind: Definition=0, Opaque=1, Theorem=2 + -- safety: Unsafe=0, Safe=1, Partial=2 + fn unpack_def_kind_safety(byte: G) -> (DefKind, DefinitionSafety) { + match byte { + 0 => (DefKind.Definition, DefinitionSafety.Unsafe), + 1 => (DefKind.Definition, DefinitionSafety.Safe), + 2 => (DefKind.Definition, DefinitionSafety.Partial), + 4 => (DefKind.Opaque, DefinitionSafety.Unsafe), + 5 => (DefKind.Opaque, DefinitionSafety.Safe), + 6 => (DefKind.Opaque, DefinitionSafety.Partial), + 8 => (DefKind.Theorem, DefinitionSafety.Unsafe), + 9 => (DefKind.Theorem, DefinitionSafety.Safe), + 10 => (DefKind.Theorem, DefinitionSafety.Partial), + } + } + + -- Definition: byte(packed_kind_safety) + Tag0(lvls) + expr(typ) + expr(value) + fn get_definition(stream: ByteStream) -> (Definition, ByteStream) { + let (packed, s) = read_byte(stream); + let (kind, safety) = unpack_def_kind_safety(packed); + let (lvls, s2) = get_tag0(s); + let (typ, s3) = get_expr(s2); + let (value, s4) = get_expr(s3); + (Definition.Mk(kind, safety, lvls, store(typ), store(value)), s4) + } + + -- RecursorRule: Tag0(fields) + expr(rhs) + fn get_recursor_rule(stream: ByteStream) -> (RecursorRule, ByteStream) { + let (fields, s) = get_tag0(stream); + let (rhs, s2) = get_expr(s); + (RecursorRule.Mk(fields, store(rhs)), s2) + } + + fn get_recursor_rule_list(stream: ByteStream, count: [G; 8]) -> (RecursorRuleList, ByteStream) { + let is_zero = u64_is_zero(count); + match is_zero { + 1 => (RecursorRuleList.Nil, stream), + 0 => + let (rule, s) = get_recursor_rule(stream); + let (rest, s2) = get_recursor_rule_list(s, relaxed_u64_pred(count)); + (RecursorRuleList.Cons(rule, store(rest)), s2), + } + } + + -- Recursor: byte(bools) + Tag0(lvls) + Tag0(params) + Tag0(indices) + + -- Tag0(motives) + Tag0(minors) + expr(typ) + Tag0(rules_len) + rules... + fn get_recursor(stream: ByteStream) -> (Recursor, ByteStream) { + let (bools_byte, s) = read_byte(stream); + let bits = u8_bit_decomposition(bools_byte); + let k = bits[0]; + let is_unsafe = bits[1]; + let (lvls, s2) = get_tag0(s); + let (params, s3) = get_tag0(s2); + let (indices, s4) = get_tag0(s3); + let (motives, s5) = get_tag0(s4); + let (minors, s6) = get_tag0(s5); + let (typ, s7) = get_expr(s6); + let (rules_len, s8) = get_tag0(s7); + let (rules, s9) = get_recursor_rule_list(s8, rules_len); + (Recursor.Mk(k, is_unsafe, lvls, params, indices, motives, minors, store(typ), store(rules)), s9) + } + + -- Axiom: byte(is_unsafe) + Tag0(lvls) + expr(typ) + fn get_axiom(stream: ByteStream) -> (Axiom, ByteStream) { + let (is_unsafe, s) = read_byte(stream); + let (lvls, s2) = get_tag0(s); + let (typ, s3) = get_expr(s2); + (Axiom.Mk(is_unsafe, lvls, store(typ)), s3) + } + + -- QuotKind: byte(0=Typ, 1=Ctor, 2=Lift, 3=Ind) + fn get_quot_kind(byte: G) -> QuotKind { + match byte { + 0 => QuotKind.Typ, + 1 => QuotKind.Ctor, + 2 => QuotKind.Lift, + 3 => QuotKind.Ind, + } + } + + -- Quotient: byte(kind) + Tag0(lvls) + expr(typ) + fn get_quotient(stream: ByteStream) -> (Quotient, ByteStream) { + let (kind_byte, s) = read_byte(stream); + let kind = get_quot_kind(kind_byte); + let (lvls, s2) = get_tag0(s); + let (typ, s3) = get_expr(s2); + (Quotient.Mk(kind, lvls, store(typ)), s3) + } + + -- Constructor: byte(is_unsafe) + Tag0(lvls) + Tag0(cidx) + Tag0(params) + + -- Tag0(fields) + expr(typ) + fn get_constructor(stream: ByteStream) -> (Constructor, ByteStream) { + let (is_unsafe, s) = read_byte(stream); + let (lvls, s2) = get_tag0(s); + let (cidx, s3) = get_tag0(s2); + let (params, s4) = get_tag0(s3); + let (fields, s5) = get_tag0(s4); + let (typ, s6) = get_expr(s5); + (Constructor.Mk(is_unsafe, lvls, cidx, params, fields, store(typ)), s6) + } + + fn get_constructor_list(stream: ByteStream, count: [G; 8]) -> (ConstructorList, ByteStream) { + let is_zero = u64_is_zero(count); + match is_zero { + 1 => (ConstructorList.Nil, stream), + 0 => + let (ctor, s) = get_constructor(stream); + let (rest, s2) = get_constructor_list(s, relaxed_u64_pred(count)); + (ConstructorList.Cons(ctor, store(rest)), s2), + } + } + + -- Inductive: byte(bools) + Tag0(lvls) + Tag0(params) + Tag0(indices) + + -- Tag0(nested) + expr(typ) + Tag0(ctors_len) + ctors... + fn get_inductive(stream: ByteStream) -> (Inductive, ByteStream) { + let (bools_byte, s) = read_byte(stream); + let bits = u8_bit_decomposition(bools_byte); + let recr = bits[0]; + let refl = bits[1]; + let is_unsafe = bits[2]; + let (lvls, s2) = get_tag0(s); + let (params, s3) = get_tag0(s2); + let (indices, s4) = get_tag0(s3); + let (nested, s5) = get_tag0(s4); + let (typ, s6) = get_expr(s5); + let (ctors_len, s7) = get_tag0(s6); + let (ctors, s8) = get_constructor_list(s7, ctors_len); + (Inductive.Mk(recr, refl, is_unsafe, lvls, params, indices, nested, store(typ), store(ctors)), s8) + } + + -- ============================================================================ + -- Projection deserialization + -- ============================================================================ + + -- InductiveProj: Tag0(idx) + address(block) + fn get_inductive_proj(stream: ByteStream) -> (InductiveProj, ByteStream) { + let (idx, s) = get_tag0(stream); + let (block, s2) = get_address(s); + (InductiveProj.Mk(idx, block), s2) + } + + -- ConstructorProj: Tag0(idx) + Tag0(cidx) + address(block) + fn get_constructor_proj(stream: ByteStream) -> (ConstructorProj, ByteStream) { + let (idx, s) = get_tag0(stream); + let (cidx, s2) = get_tag0(s); + let (block, s3) = get_address(s2); + (ConstructorProj.Mk(idx, cidx, block), s3) + } + + -- RecursorProj: Tag0(idx) + address(block) + fn get_recursor_proj(stream: ByteStream) -> (RecursorProj, ByteStream) { + let (idx, s) = get_tag0(stream); + let (block, s2) = get_address(s); + (RecursorProj.Mk(idx, block), s2) + } + + -- DefinitionProj: Tag0(idx) + address(block) + fn get_definition_proj(stream: ByteStream) -> (DefinitionProj, ByteStream) { + let (idx, s) = get_tag0(stream); + let (block, s2) = get_address(s); + (DefinitionProj.Mk(idx, block), s2) + } + + -- ============================================================================ + -- Mutual constant deserialization + -- ============================================================================ + + -- MutConst: byte(tag) + payload + fn get_mut_const(stream: ByteStream) -> (MutConst, ByteStream) { + let (tag, s) = read_byte(stream); + match tag { + 0 => + let (defn, s2) = get_definition(s); + (MutConst.Defn(defn), s2), + 1 => + let (indc, s2) = get_inductive(s); + (MutConst.Indc(indc), s2), + 2 => + let (recr, s2) = get_recursor(s); + (MutConst.Recr(recr), s2), + } + } + + fn get_mut_const_list(stream: ByteStream, count: [G; 8]) -> (MutConstList, ByteStream) { + let is_zero = u64_is_zero(count); + match is_zero { + 1 => (MutConstList.Nil, stream), + 0 => + let (mc, s) = get_mut_const(stream); + let (rest, s2) = get_mut_const_list(s, relaxed_u64_pred(count)); + (MutConstList.Cons(mc, store(rest)), s2), + } + } + + -- ============================================================================ + -- Constant info deserialization + -- ============================================================================ + + -- Dispatch on variant number (0-7) to deserialize the appropriate ConstantInfo + fn get_constant_info_by_variant(variant: G, stream: ByteStream) -> (ConstantInfo, ByteStream) { + match variant { + 0 => + let (defn, s) = get_definition(stream); + (ConstantInfo.Defn(defn), s), + 1 => + let (recr, s) = get_recursor(stream); + (ConstantInfo.Recr(recr), s), + 2 => + let (axim, s) = get_axiom(stream); + (ConstantInfo.Axio(axim), s), + 3 => + let (quot, s) = get_quotient(stream); + (ConstantInfo.Quot(quot), s), + 4 => + let (prj, s) = get_constructor_proj(stream); + (ConstantInfo.CPrj(prj), s), + 5 => + let (prj, s) = get_recursor_proj(stream); + (ConstantInfo.RPrj(prj), s), + 6 => + let (prj, s) = get_inductive_proj(stream); + (ConstantInfo.IPrj(prj), s), + 7 => + let (prj, s) = get_definition_proj(stream); + (ConstantInfo.DPrj(prj), s), + } + } + + -- Parse ConstantInfo from flag (0xC for Muts, 0xD for non-Muts) and size + fn get_constant_info(flag: G, size: [G; 8], stream: ByteStream) -> (ConstantInfo, ByteStream) { + match flag { + -- Muts: flag=0xC, size is the entry count + 0xC => + let (mutuals, s) = get_mut_const_list(stream, size); + (ConstantInfo.Muts(store(mutuals)), s), + -- Non-Muts: flag=0xD, size[0] is the variant number + 0xD => + get_constant_info_by_variant(size[0], stream), + } + } + + -- ============================================================================ + -- Top-level constant deserialization + -- ============================================================================ + + fn get_constant(stream: ByteStream) -> (Constant, ByteStream) { + let (tag, s) = get_tag4(stream); + let (flag, size) = tag; + let (info, s2) = get_constant_info(flag, size, s); + let (sharing, s3) = get_sharing(s2); + let (refs, s4) = get_refs(s3); + let (univs, s5) = get_univs(s4); + (Constant.Mk(info, store(sharing), store(refs), store(univs)), s5) + } ⟧ end IxVM diff --git a/Ix/IxVM/IxonSerialize.lean b/Ix/IxVM/IxonSerialize.lean index 583844e8..f7827c9a 100644 --- a/Ix/IxVM/IxonSerialize.lean +++ b/Ix/IxVM/IxonSerialize.lean @@ -80,6 +80,23 @@ def ixonSerialize := ⟦ } } + -- Tag2: 2-bit flag, variable size + -- Format: [flag:2][large:1][size:5] or [flag:2][large:1][size_bytes...] + fn put_tag2(flag: G, size: [G; 8], rest: ByteStream) -> ByteStream { + let byte_count = u64_byte_count(size); + let small = u8_less_than(size[0], 32); + match (byte_count, small) { + (1, 1) => + -- Single byte: flag in bits 6-7, size in bits 0-4 + let head = flag * 64 + size[0]; + ByteStream.Cons(head, store(rest)), + _ => + -- Multi-byte: flag in bits 6-7, large=1 in bit 5, size_bytes-1 in bits 0-4 + let head = flag * 64 + 32 + (byte_count - 1); + ByteStream.Cons(head, store(put_u64_le(size, byte_count, rest))), + } + } + fn put_tag4(flag: G, bs: [G; 8], rest: ByteStream) -> ByteStream { let byte_count = u64_byte_count(bs); let small = u8_less_than(bs[0], 8); @@ -193,13 +210,13 @@ def ixonSerialize := ⟦ fn pack_def_kind_safety(kind: DefKind, safety: DefinitionSafety) -> G { match (kind, safety) { (DefKind.Definition, DefinitionSafety.Unsafe) => 0, - (DefKind.Opaque, DefinitionSafety.Unsafe) => 4, - (DefKind.Theorem, DefinitionSafety.Unsafe) => 8, (DefKind.Definition, DefinitionSafety.Safe) => 1, - (DefKind.Opaque, DefinitionSafety.Safe) => 5, - (DefKind.Theorem, DefinitionSafety.Safe) => 9, (DefKind.Definition, DefinitionSafety.Partial) => 2, + (DefKind.Opaque, DefinitionSafety.Unsafe) => 4, + (DefKind.Opaque, DefinitionSafety.Safe) => 5, (DefKind.Opaque, DefinitionSafety.Partial) => 6, + (DefKind.Theorem, DefinitionSafety.Unsafe) => 8, + (DefKind.Theorem, DefinitionSafety.Safe) => 9, (DefKind.Theorem, DefinitionSafety.Partial) => 10, } } @@ -252,23 +269,6 @@ def ixonSerialize := ⟦ } } - -- Tag2: 2-bit flag, variable size - -- Format: [flag:2][large:1][size:5] or [flag:2][large:1][size_bytes...] - fn put_tag2(flag: G, size: [G; 8], rest: ByteStream) -> ByteStream { - let byte_count = u64_byte_count(size); - let small = u8_less_than(size[0], 32); - match (byte_count, small) { - (1, 1) => - -- Single byte: flag in bits 6-7, size in bits 0-4 - let head = flag * 64 + size[0]; - ByteStream.Cons(head, store(rest)), - _ => - -- Multi-byte: flag in bits 6-7, large=1 in bit 5, size_bytes-1 in bits 0-4 - let head = flag * 64 + 32 + (byte_count - 1); - ByteStream.Cons(head, store(put_u64_le(size, byte_count, rest))), - } - } - -- ============================================================================ -- List serialization -- ============================================================================ @@ -499,22 +499,6 @@ def ixonSerialize := ⟦ ConstantInfo.RPrj(prj) => put_recursor_proj(prj, rest), ConstantInfo.IPrj(prj) => put_inductive_proj(prj, rest), ConstantInfo.DPrj(prj) => put_definition_proj(prj, rest), - -- Muts is never called here - handled separately in put_constant - ConstantInfo.Muts(_) => rest, - } - } - - fn constant_info_variant(info: ConstantInfo) -> [G; 8] { - match info { - ConstantInfo.Defn(_) => [0; 8], -- CONST_DEFN - ConstantInfo.Recr(_) => [1; 8], -- CONST_RECR - ConstantInfo.Axio(_) => [2; 8], -- CONST_AXIO - ConstantInfo.Quot(_) => [3; 8], -- CONST_QUOT - ConstantInfo.CPrj(_) => [4; 8], -- CONST_CPRJ - ConstantInfo.RPrj(_) => [5; 8], -- CONST_RPRJ - ConstantInfo.IPrj(_) => [6; 8], -- CONST_IPRJ - ConstantInfo.DPrj(_) => [7; 8], -- CONST_DPRJ - ConstantInfo.Muts(_) => [0; 8], -- Not used (handled separately) } } @@ -536,23 +520,29 @@ def ixonSerialize := ⟦ fn put_constant(cnst: Constant, rest: ByteStream) -> ByteStream { match cnst { Constant.Mk(info, &sharing, &refs, &univs) => + let up_to_sharing = put_sharing(sharing, put_refs(refs, put_univs(univs, rest))); match info { ConstantInfo.Muts(&mutuals) => -- Use FLAG_MUTS (0xC) with entry count in size field let count = mut_const_list_length(mutuals); - put_tag4(0xC, count, - put_mut_const_list(mutuals, - put_sharing(sharing, - put_refs(refs, - put_univs(univs, rest))))), - _ => - -- Use FLAG (0xD) with variant in size field - let variant = constant_info_variant(info); - put_tag4(0xD, variant, - put_constant_info(info, - put_sharing(sharing, - put_refs(refs, - put_univs(univs, rest))))), + put_tag4(0xC, count, put_mut_const_list(mutuals, up_to_sharing)), + -- Use FLAG (0xD) with variant in size field + ConstantInfo.Defn(_) => + put_tag4(0xD, [0; 8], put_constant_info(info, up_to_sharing)), + ConstantInfo.Recr(_) => + put_tag4(0xD, [1, 0, 0, 0, 0, 0, 0, 0], put_constant_info(info, up_to_sharing)), + ConstantInfo.Axio(_) => + put_tag4(0xD, [2, 0, 0, 0, 0, 0, 0, 0], put_constant_info(info, up_to_sharing)), + ConstantInfo.Quot(_) => + put_tag4(0xD, [3, 0, 0, 0, 0, 0, 0, 0], put_constant_info(info, up_to_sharing)), + ConstantInfo.CPrj(_) => + put_tag4(0xD, [4, 0, 0, 0, 0, 0, 0, 0], put_constant_info(info, up_to_sharing)), + ConstantInfo.RPrj(_) => + put_tag4(0xD, [5, 0, 0, 0, 0, 0, 0, 0], put_constant_info(info, up_to_sharing)), + ConstantInfo.IPrj(_) => + put_tag4(0xD, [6, 0, 0, 0, 0, 0, 0, 0], put_constant_info(info, up_to_sharing)), + ConstantInfo.DPrj(_) => + put_tag4(0xD, [7, 0, 0, 0, 0, 0, 0, 0], put_constant_info(info, up_to_sharing)), }, } } diff --git a/Tests/Ix/IxVM.lean b/Tests/Ix/IxVM.lean new file mode 100644 index 00000000..7a7b49b8 --- /dev/null +++ b/Tests/Ix/IxVM.lean @@ -0,0 +1,16 @@ +module + +public import Ix.Meta +public import Tests.Aiur.Common + +public def serdeNatAddComm : IO AiurTestCase := do + let env ← get_env! + let natAddCommName := ``Nat.add_comm + let constList := Lean.collectDependencies natAddCommName env.constants + let rawEnv ← Ix.CompileM.rsCompileEnvFFI constList + let ixonEnv := rawEnv.toEnv + let ixonConsts := ixonEnv.consts.valuesIter + let (ioBuffer, n) := ixonConsts.fold (init := (default, 0)) fun (ioBuffer, i) c => + let (_, bytes) := Ixon.Serialize.put c |>.run default + (ioBuffer.extend #[.ofNat i] (bytes.data.map .ofUInt8), i + 1) + pure ⟨`ixon_serde_test, "Ixon serde test", #[.ofNat n], #[], ioBuffer, ioBuffer⟩ diff --git a/Tests/Main.lean b/Tests/Main.lean index 1f282294..e80952e5 100644 --- a/Tests/Main.lean +++ b/Tests/Main.lean @@ -1,6 +1,7 @@ import Tests.Aiur import Tests.ByteArray import Tests.Ix.Ixon +import Tests.Ix.IxVM import Tests.Ix.Claim import Tests.Ix.Commit import Tests.Ix.Compile @@ -67,7 +68,8 @@ def ignoredRunners : List (String × IO UInt32) := [ | IO.eprintln "SHA256 setup failed"; return 1 let r2 ← LSpec.lspecEachIO sha256TestCases fun tc => pure (sha256Env.runTestCase tc) return if r1 == 0 && r2 == 0 then 0 else 1), - ("ixvm", do LSpec.lspecIO (.ofList [("ixvm", [mkAiurTests IxVM.ixVM []])]) []), + ("ixvm", do + LSpec.lspecIO (.ofList [("ixvm", [mkAiurTests IxVM.ixVM [← serdeNatAddComm]])]) []), ] def main (args : List String) : IO UInt32 := do diff --git a/lakefile.lean b/lakefile.lean index f6d1fc85..87cb7a6a 100644 --- a/lakefile.lean +++ b/lakefile.lean @@ -45,6 +45,10 @@ lean_exe «bench-blake3» where lean_exe «bench-sha256» where root := `Benchmarks.Sha256 +lean_exe «bench-ixvm» where + root := `Benchmarks.IxVM + supportInterpreter := true + lean_exe «bench-shardmap» where root := `Benchmarks.ShardMap