Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions ext/MicrofloatsExt.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module MicrofloatsExt

import cuTile as ct
import Microfloats

using Microfloats: Float8_E4M3FN, Float8_E5M2, Float8_E8M0FNU, Float4_E2M1FN

Expand All @@ -14,6 +15,10 @@ ct.julia_to_tile_dtype!(table::ct.TypeTable, ::Type{Float8_E5M2}) = ct.F8E5M2(
ct.julia_to_tile_dtype!(table::ct.TypeTable, ::Type{Float8_E8M0FNU}) = ct.F8E8M0FNU(table)
ct.julia_to_tile_dtype!(table::ct.TypeTable, ::Type{Float4_E2M1FN}) = ct.F4E2M1FN(table)

# Microfloats are byte-storage primitives, so cuTile's default
# `bitwidth` (8 * sizeof) over-counts the sub-byte formats.
ct.bitwidth(::Type{T}) where {T<:Microfloats.Microfloat} = Microfloats.bitwidth(T)

# E8M0FNU has no sign bit and represents a power of two; tileiras rejects
# nearest-even on f32→E8M0FNU (only `zero` and `positive_inf` are valid).
ct.ftof_rounding_mode(::Type{Float8_E8M0FNU}) = ct.RoundingMode.Zero
Expand Down
32 changes: 32 additions & 0 deletions src/bytecode/encodings.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ module Opcode
const XOrIOp = 108
const YieldOp = 109
const Atan2Op = 110 # since 13.2
const PackOp = 111 # since 13.3
const UnpackOp = 112 # since 13.3
end

# Enums for operation attributes
Expand Down Expand Up @@ -1796,6 +1798,36 @@ function encode_BitcastOp!(cb::CodeBuilder, result_type::TypeId, source::Value)
return new_op!(cb)
end

"""
encode_PackOp!(cb, result_type, source) -> Value

Pack a rank-1 numeric tile into a rank-1 `tile<i8>`. Unlike `bitcast`, this is
not element-wise: the whole tile is reinterpreted as a byte array, so the result
length is the input's total byte count. The source must not be an 8-bit type
(use `bitcast`). Since 13.3. Opcode: 111
"""
function encode_PackOp!(cb::CodeBuilder, result_type::TypeId, source::Value)
encode_varint!(cb.buf, Opcode.PackOp)
encode_typeid!(cb.buf, result_type)
encode_operand!(cb.buf, source)
return new_op!(cb)
end

"""
encode_UnpackOp!(cb, result_type, source) -> Value

Unpack a rank-1 `tile<i8>` into a rank-1 numeric tile (the inverse of
[`encode_PackOp!`](@ref)). The input byte count must equal the output's total
byte count. The result must not be an 8-bit type (use `bitcast`). Since 13.3.
Opcode: 112
"""
function encode_UnpackOp!(cb::CodeBuilder, result_type::TypeId, source::Value)
encode_varint!(cb.buf, Opcode.UnpackOp)
encode_typeid!(cb.buf, result_type)
encode_operand!(cb.buf, source)
return new_op!(cb)
end

"""
encode_BroadcastOp!(cb, result_type, source) -> Value

Expand Down
99 changes: 99 additions & 0 deletions src/compiler/intrinsics/conversions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,105 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.bitcast), args)
CGVal(result_v, result_type_id, result_jltype, source.shape)
end

@inline lookup_bitwidth(@nospecialize(T::Type)) =
Base.invokelatest(bitwidth, T)::Int

"""
Intrinsics.pack(x::Tile{S,Tuple{N}}) -> Tile{UInt8,Tuple{N*bitwidth(S)÷8}}

Pack a rank-1 numeric tile into a rank-1 `UInt8` tile (the tile's bits viewed as
a byte array); lowers to `cuda_tile.pack`. `S` must not be 8-bit (use `bitcast`).
Requires Tile IR bytecode v13.3+.
"""
@intrinsic pack(x)
function tfunc(𝕃, ::typeof(Intrinsics.pack), @nospecialize(x))
src = CC.widenconst(x)
src <: Tile || return nothing
S = src.parameters[1]
Shape = src.parameters[2]
(S isa Type && Shape isa Type) || return nothing
dims = Shape.parameters
length(dims) == 1 || return nothing
n = dims[1]::Int
bs = lookup_bitwidth(S)
return Tile{UInt8, Tuple{fld(n * bs, 8)}}
end
function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.pack), args)
cb = ctx.cb
tt = ctx.tt

source = @something emit_value!(ctx, args[1]) throw(IRError("pack: cannot resolve source"))
tt.version >= v"13.3" ||
throw(IRError("cuda_tile.pack requires Tile IR bytecode v13.3+, got v$(tt.version)"))
length(source.shape) == 1 ||
throw(IRError("pack: requires a rank-1 tile, got a $(length(source.shape))-D tile"))

src_type = CC.widenconst(source.jltype)
S = eltype(src_type)
sbits = lookup_bitwidth(S)
sbits == 8 &&
throw(IRError("pack: 8-bit element type $S should be reinterpreted via bitcast, not packed"))
n = source.shape[1]
(n * sbits) % 8 == 0 ||
throw(IRError("pack: a $n-element $S tile ($(n * sbits) bits) is not a whole number of bytes"))
new_n = (n * sbits) ÷ 8

new_shape = RowMajorShape([new_n])
result_type_id = tile_type!(tt, lookup_dtype!(tt, UInt8), new_shape)
result_v = encode_PackOp!(cb, result_type_id, source.v)
CGVal(result_v, result_type_id, Tile{UInt8, Tuple{new_n}}, new_shape)
end

"""
Intrinsics.unpack(x::Tile{UInt8,Tuple{N}}, ::Type{T}) -> Tile{T,Tuple{N*8÷bitwidth(T)}}

Unpack a rank-1 `UInt8` tile into a rank-1 numeric tile of element type `T` (the
inverse of [`pack`](@ref Intrinsics.pack)); lowers to `cuda_tile.unpack`. `T`
must be a compile-time constant and must not be 8-bit (use `bitcast`). Requires
Tile IR bytecode v13.3+.
"""
@intrinsic unpack(x, ::Type{T}) where {T}
function tfunc(𝕃, ::typeof(Intrinsics.unpack), @nospecialize(x), @nospecialize(target_type))
T = instanceof_tfunc(target_type)
T === nothing && return nothing
src = CC.widenconst(x)
src <: Tile || return nothing
Shape = src.parameters[2]
Shape isa Type || return nothing
dims = Shape.parameters
length(dims) == 1 || return nothing
n = dims[1]::Int
bt = lookup_bitwidth(T)
return Tile{T, Tuple{fld(n * 8, bt)}}
end
function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.unpack), args)
cb = ctx.cb
tt = ctx.tt

source = @something emit_value!(ctx, args[1]) throw(IRError("unpack: cannot resolve source"))
target_type = @something get_constant(ctx, args[2]) throw(IRError("unpack: requires compile-time target type"))
tt.version >= v"13.3" ||
throw(IRError("cuda_tile.unpack requires Tile IR bytecode v13.3+, got v$(tt.version)"))
length(source.shape) == 1 ||
throw(IRError("unpack: requires a rank-1 tile, got a $(length(source.shape))-D tile"))

src_type = CC.widenconst(source.jltype)
eltype(src_type) === UInt8 ||
throw(IRError("unpack: requires a UInt8 tile, got $(eltype(src_type))"))
tbits = lookup_bitwidth(target_type)
tbits == 8 &&
throw(IRError("unpack: 8-bit target $target_type should be reinterpreted via bitcast, not unpacked"))
n = source.shape[1]
(n * 8) % tbits == 0 ||
throw(IRError("unpack: $n bytes ($(n * 8) bits) do not evenly divide into $target_type ($tbits-bit) elements"))
new_n = (n * 8) ÷ tbits

new_shape = RowMajorShape([new_n])
result_type_id = tile_type!(tt, lookup_dtype!(tt, target_type), new_shape)
result_v = encode_UnpackOp!(cb, result_type_id, source.v)
CGVal(result_v, result_type_id, Tile{target_type, Tuple{new_n}}, new_shape)
end

"""
Intrinsics.exti(x::Tile{<:Integer}, ::Type{T}, s::Signedness.T) -> Tile{T} where {T<:Integer}

Expand Down
76 changes: 76 additions & 0 deletions src/language/operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -932,6 +932,82 @@ Equivalent to single-arg `permutedims`.
end
end

# Width-convert a rank-1 tile to element type `T` (rank-1 in, rank-1 out).
@inline function reinterpret_width(::Type{T}, flat::Tile{S}) where {T, S}
bs = bitwidth(S)
bt = bitwidth(T)
if bs == bt
return Intrinsics.bitcast(flat, T) # same width
elseif bt == 8
return Intrinsics.bitcast(Intrinsics.pack(flat), T) # S → bytes → T8
elseif bs == 8
return Intrinsics.unpack(Intrinsics.bitcast(flat, UInt8), T) # S8 → bytes → T
else
return Intrinsics.unpack(Intrinsics.pack(flat), T) # S → bytes → T
end
end

# Result shape for `reinterpret(T, x)`: rescale the leading (column-major)
# dimension by the element-width ratio, like `reinterpret(T, ::AbstractArray)`.
@inline function reinterpret_scaled_shape(::Type{T}, ::Type{S}, sz::NTuple{N, Int}) where {T, S, N}
bs = bitwidth(S)
bt = bitwidth(T)
N == 0 && return () # 0-D: only equal-width is valid; cross-width caught at emit
return (fld(sz[1] * bs, bt), Base.tail(sz)...)
end

# Result shape for `reinterpret(reshape, T, x)`: drop the leading dim on widening
# (it must equal the ratio), prepend one on narrowing, like the array version.
@inline function reinterpret_reshape_shape(::Type{T}, ::Type{S}, sz::NTuple{N, Int}) where {T, S, N}
bs = bitwidth(S)
bt = bitwidth(T)
bs == bt && return sz
N == 0 && return () # cross-width on a 0-D tile is invalid; caught at emit
return bt > bs ? Base.tail(sz) : (div(bs, bt), sz...)
end

"""
Base.reinterpret(::Type{T}, x::Tile) -> Tile{T}

Reinterpret the *whole tile* `x` as a tile of element type `T`, like
`reinterpret(T, ::AbstractArray)`: the underlying bits are viewed as a contiguous
(column-major) block and the leading dimension is rescaled by the ratio of
element widths. Lowers to `cuda_tile.bitcast` for equal widths and to
`cuda_tile.pack`/`unpack` (via `reshape` to rank-1) when widths differ.

This is how sub-byte formats move through global memory: a `Tile{UInt8,(N,)}`
reinterprets to a `Tile{Float4_E2M1FN,(2N,)}` and back, so FP4 data can be stored
in a `UInt8` array. The total bit-width is preserved, so it must divide evenly.

Note `reinterpret.(T, x)` (with a dot) is the unrelated *element-wise* broadcast,
which keeps the shape and requires `T` to be the same width as `eltype(x)`.

```julia
bytes = ct.load(a, pid, (8,)) # Tile{UInt8,(8,)}
fp4 = reinterpret(Float4_E2M1FN, bytes) # Tile{Float4_E2M1FN,(16,)}
vals = convert(ct.Tile{Float32}, fp4) # widen for compute
```
"""
@inline function Base.reinterpret(::Type{T}, x::Tile) where {T}
rshape = reinterpret_scaled_shape(T, eltype(x), size(x))
flat = Intrinsics.reshape(x, (prod(size(x)),))
return Intrinsics.reshape(reinterpret_width(T, flat), rshape)
end

"""
Base.reinterpret(reshape, ::Type{T}, x::Tile) -> Tile{T}

The `reshape`-form whole-tile reinterpret, mirroring
`reinterpret(reshape, T, ::AbstractArray)`: instead of rescaling the leading
dimension it *removes* it when widening (the leading dim must equal
`bitwidth(T) ÷ bitwidth(eltype(x))`) and *prepends* one when narrowing.
"""
@inline function Base.reinterpret(::typeof(reshape), ::Type{T}, x::Tile) where {T}
rshape = reinterpret_reshape_shape(T, eltype(x), size(x))
flat = Intrinsics.reshape(x, (prod(size(x)),))
return Intrinsics.reshape(reinterpret_width(T, flat), rshape)
end

@inline Base.convert(::Type{Tile{T}}, tile::Tile{T}) where {T} = tile
@inline Base.convert(::Type{Tile{T2}}, tile::Tile{T1, Shape}) where {T1, T2, Shape} =
map(T2, tile)
Expand Down
17 changes: 17 additions & 0 deletions src/language/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,23 @@ similar_type(::Type{Tile{T, Shape}}, ::Type{U}, new_shape::Tuple) where {T, Shap
similar_type(::Type{<:Tile{T}}, ::Type{U}) where {T, U} = Tile{U}
similar_type(::Type, ::Type{T}) where {T} = T # fallback for non-Tile types

"""
bitwidth(::Type{T}) -> Int

Number of bits a single element of `T` occupies in a Tile IR tile. Used by the
whole-tile [`reinterpret`](@ref Base.reinterpret(::Type, ::Tile)) to scale the
tile shape across a change of element width (e.g. `UInt8` ↔ `Float4_E2M1FN`,
8 bits ↔ 4 bits).

The default is `8 * sizeof(T)`, which is correct for the standard integer and
floating-point types and for the byte-wide `Float8_*` formats. Sub-byte formats
whose `sizeof` rounds up to a whole byte (e.g. `Float4_E2M1FN`, 4 bits but
`sizeof == 1`) override this; the `Microfloats` extension forwards to
`Microfloats.bitwidth`, which derives the true width from the format's bit
fields. Matches the `bitwidth` convention used by `Microfloats`/`Narrow`.
"""
bitwidth(::Type{T}) where {T} = 8 * sizeof(T)


"""
TFloat32 <: AbstractFloat
Expand Down
74 changes: 74 additions & 0 deletions test/codegen/operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1140,6 +1140,80 @@ end
end
end

@testset "reinterpret (whole-tile)" begin
# Equal width, different Tile IR dtype (Float16 -> Int16): whole-tile
# reinterpret is a plain bitcast — no pack/unpack, shape preserved.
@test @filecheck begin
@check_label "entry"
code_tiled(Tuple{ct.TileArray{Float16,1,spec1d}, ct.TileArray{Int16,1,spec1d}}) do a, b
pid = ct.bid(1)
tile = ct.load(a, pid, (16,))
@check "bitcast"
@check_not "pack"
ct.store(b, pid, reinterpret(Int16, tile))
return
end
end

# Widen UInt8 -> UInt16 (1D): lowers to a single unpack, identity reshapes
# folded away.
@test @filecheck begin
@check_label "entry"
code_tiled(Tuple{ct.TileArray{UInt8,1,spec1d}, ct.TileArray{UInt16,1,spec1d}}) do a, b
pid = ct.bid(1)
tile = ct.load(a, pid, (16,))
@check "unpack"
ct.store(b, pid, reinterpret(UInt16, tile))
return
end
end

# Narrow UInt16 -> UInt8 (1D): lowers to a single pack.
@test @filecheck begin
@check_label "entry"
code_tiled(Tuple{ct.TileArray{UInt16,1,spec1d}, ct.TileArray{UInt8,1,spec1d}}) do a, b
pid = ct.bid(1)
tile = ct.load(a, pid, (8,))
@check "pack"
ct.store(b, pid, reinterpret(UInt8, tile))
return
end
end

# pack/unpack require v13.3 — older bytecode rejects with a clear error.
# (`literal` since the `+` in the message is a regex metachar to FileCheck.)
@test @filecheck throws=ct.IRError begin
@check literal=true "v13.3+"
code_tiled(Tuple{ct.TileArray{UInt8,1,spec1d}, ct.TileArray{UInt16,1,spec1d}};
bytecode_version=v"13.2") do a, b
pid = ct.bid(1)
tile = ct.load(a, pid, (16,))
ct.store(b, pid, reinterpret(UInt16, tile))
return
end
end

# Rank-1 scaled: one UInt8 (8 bits) can't fill a UInt16; caught by unpack.
@test @filecheck throws=ct.IRError begin
@check "do not evenly divide"
code_tiled(Tuple{ct.TileArray{UInt8,1,spec1d}, ct.TileArray{UInt16,1,spec1d}}) do a, b
pid = ct.bid(1)
ct.store(b, pid, reinterpret(UInt16, ct.load(a, pid, (1,))))
return
end
end

# reshape-widen: leading dim must equal the ratio (2); 1 fails the final reshape.
@test @filecheck throws=ct.IRError begin
@check "same number of elements"
code_tiled(Tuple{ct.TileArray{UInt8,2,spec2d}, ct.TileArray{UInt16,2,spec2d}}) do a, b
pid = ct.bid(1)
ct.store(b, pid, reinterpret(reshape, UInt16, ct.load(a, pid, (1, 4))))
return
end
end
end

# TODO: exti - sign/zero extend integer
# TODO: ftoi - float to integer
# TODO: itof - integer to float
Expand Down
Loading