diff --git a/src/struct.zig b/src/struct.zig index d1b9f29..2e9e1c6 100644 --- a/src/struct.zig +++ b/src/struct.zig @@ -144,7 +144,7 @@ pub fn packStruct(writer: *std.Io.Writer, comptime T: type, value_or_maybe_null: if (has_custom_write_fn) { return try value.msgpackWrite(packer(writer)); } else { - const format = comptime if (std.meta.hasFn(Type, "msgpackFormat")) T.msgpackFormat() else default_struct_format; + const format = comptime if (std.meta.hasFn(Type, "msgpackFormat")) Type.msgpackFormat() else default_struct_format; switch (format) { .as_map => |opts| { return packStructAsMap(writer, Type, value, opts); @@ -160,7 +160,7 @@ pub fn unpackStructFromMapBody(reader: *std.Io.Reader, allocator: std.mem.Alloca const Type = NonOptional(T); const type_info = @typeInfo(Type); const fields = type_info.@"struct".fields; - const FieldEnum = std.meta.FieldEnum(T); + const FieldEnum = std.meta.FieldEnum(Type); var fields_seen = std.bit_set.StaticBitSet(fields.len).initEmpty(); @@ -208,10 +208,10 @@ pub fn unpackStructFromMapBody(reader: *std.Io.Reader, allocator: std.mem.Alloca } }, .custom => { - const KeyType = comptime @typeInfo(@TypeOf(T.msgpackFieldKey)).@"fn".return_type.?; + const KeyType = comptime @typeInfo(@TypeOf(Type.msgpackFieldKey)).@"fn".return_type.?; const key = try unpackAny(reader, allocator, KeyType); inline for (fields, 0..) |field, i| { - if (T.msgpackFieldKey(@field(FieldEnum, field.name)) == key) { + if (Type.msgpackFieldKey(@field(FieldEnum, field.name)) == key) { fields_seen.set(i); @field(result, field.name) = try unpackAny(reader, allocator, field.type); break; @@ -280,9 +280,9 @@ pub fn unpackStruct(reader: *std.Io.Reader, allocator: std.mem.Allocator, compti const has_custom_read_fn = std.meta.hasFn(Type, "msgpackRead"); if (has_custom_read_fn) { - return try T.msgpackRead(unpacker(reader, allocator)); + return try Type.msgpackRead(unpacker(reader, allocator)); } else { - const format = comptime if (std.meta.hasFn(Type, "msgpackFormat")) T.msgpackFormat() else default_struct_format; + const format = comptime if (std.meta.hasFn(Type, "msgpackFormat")) Type.msgpackFormat() else default_struct_format; switch (format) { .as_map => |opts| { return try unpackStructAsMap(reader, allocator, T, opts); @@ -318,6 +318,68 @@ test "writeStruct: map_by_index" { }, writer.buffered()); } +test "writeStruct: optional custom format" { + const Msg = struct { + a: u32, + + pub fn msgpackFormat() StructFormat { + return .{ .as_map = .{ .key = .field_index } }; + } + }; + + var buffer: [100]u8 = undefined; + var writer = std.Io.Writer.fixed(&buffer); + try packStruct(&writer, ?Msg, Msg{ .a = 1 }); + + try std.testing.expectEqualSlices(u8, &.{ + 0x81, + 0x00, + 0x01, + }, writer.buffered()); +} + +test "readStruct: optional custom field key" { + const Msg = struct { + a: u32, + + pub fn msgpackFormat() StructFormat { + return .{ .as_map = .{ .key = .custom } }; + } + + pub fn msgpackFieldKey(field: std.meta.FieldEnum(@This())) u8 { + return switch (field) { + .a => 1, + }; + } + }; + + const packed_msg = [_]u8{ + 0x81, + 0x01, + 0x2a, + }; + var reader = std.Io.Reader.fixed(&packed_msg); + const decoded = try unpackStruct(&reader, std.testing.allocator, ?Msg); + + try std.testing.expectEqual(@as(?Msg, Msg{ .a = 42 }), decoded); +} + +test "readStruct: optional custom reader" { + const Msg = struct { + a: u32, + + pub fn msgpackRead(unpacker_value: anytype) !@This() { + return .{ .a = try unpacker_value.readInt(u32) }; + } + }; + + const packed_msg = [_]u8{0x2a}; + var reader = std.Io.Reader.fixed(&packed_msg); + const decoded = try unpackStruct(&reader, std.testing.allocator, ?Msg); + + try std.testing.expectEqual(@as(?Msg, Msg{ .a = 42 }), decoded); +} + test "writeStruct: map by field_name" { const Msg = struct { a: u32, diff --git a/src/union.zig b/src/union.zig index 1758529..d473614 100644 --- a/src/union.zig +++ b/src/union.zig @@ -139,7 +139,7 @@ pub fn packUnion(writer: *std.Io.Writer, comptime T: type, value_or_maybe_null: @compileError("Expected union type"); } - const format = if (std.meta.hasFn(Type, "msgpackFormat")) T.msgpackFormat() else default_union_format; + const format = if (std.meta.hasFn(Type, "msgpackFormat")) Type.msgpackFormat() else default_union_format; switch (format) { .as_map => |opts| { return packUnionAsMap(writer, Type, value, opts); @@ -281,7 +281,7 @@ pub fn unpackUnionAsTagged(reader: *std.Io.Reader, allocator: std.mem.Allocator, pub fn unpackUnion(reader: *std.Io.Reader, allocator: std.mem.Allocator, comptime T: type) !T { const Type = NonOptional(T); - const format = if (std.meta.hasFn(Type, "msgpackFormat")) T.msgpackFormat() else default_union_format; + const format = if (std.meta.hasFn(Type, "msgpackFormat")) Type.msgpackFormat() else default_union_format; switch (format) { .as_map => |opts| { return try unpackUnionAsMap(reader, allocator, T, opts); @@ -329,6 +329,48 @@ test "writeUnion: int field" { try std.testing.expectEqualSlices(u8, &msg1_packed, writer.buffered()); } +test "writeUnion: optional custom format" { + const Msg = union(enum) { + a: u32, + b: u64, + + pub fn msgpackFormat() UnionFormat { + return .{ .as_map = .{ .key = .field_index } }; + } + }; + + var buffer: [100]u8 = undefined; + var writer = std.Io.Writer.fixed(&buffer); + try packUnion(&writer, ?Msg, Msg{ .a = 1 }); + + try std.testing.expectEqualSlices(u8, &.{ + 0x81, + 0x00, + 0x01, + }, writer.buffered()); +} + +test "readUnion: optional custom format" { + const Msg = union(enum) { + a: u32, + b: u64, + + pub fn msgpackFormat() UnionFormat { + return .{ .as_map = .{ .key = .field_index } }; + } + }; + + const packed_msg = [_]u8{ + 0x81, + 0x00, + 0x2a, + }; + var reader = std.Io.Reader.fixed(&packed_msg); + const decoded = try unpackUnion(&reader, std.testing.allocator, ?Msg); + + try std.testing.expectEqualDeep(@as(?Msg, Msg{ .a = 42 }), decoded); +} + test "writeUnion: void field" { var buffer: [100]u8 = undefined; var writer = std.Io.Writer.fixed(&buffer);