Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions point_transformer_v3/fvdb_extensions/models/fvdb_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,28 +62,30 @@ def jagged_cumulative_argsort(unsorted_jt: fvdb.JaggedTensor) -> fvdb.JaggedTens


def morton_from_jagged_ijk(jagged_ijk: fvdb.JaggedTensor) -> fvdb.JaggedTensor:
ijk_j = jagged_ijk.jdata
morton_j = fvdb.morton(ijk_j)
ijk_j: torch.Tensor = jagged_ijk.jdata
kji_j = ijk_j[:, [2, 1, 0]].contiguous()
morton_j = fvdb.morton(kji_j)
return jagged_ijk.jagged_like(morton_j)


def morton_flipped_from_jagged_ijk(jagged_ijk: fvdb.JaggedTensor) -> fvdb.JaggedTensor:
ijk_j: torch.Tensor = jagged_ijk.jdata
kji_j = ijk_j.flip(dims=[-1])
morton_j = fvdb.morton(kji_j)
kij_j = ijk_j[:, [2, 0, 1]].contiguous()
morton_j = fvdb.morton(kij_j)
return jagged_ijk.jagged_like(morton_j)


def hilbert_from_jagged_ijk(jagged_ijk: fvdb.JaggedTensor) -> fvdb.JaggedTensor:
ijk_j = jagged_ijk.jdata
hilbert_j = fvdb.hilbert(ijk_j)
ijk_j: torch.Tensor = jagged_ijk.jdata
jki_j = ijk_j[:, [1, 2, 0]].contiguous()
hilbert_j = fvdb.hilbert(jki_j)
return jagged_ijk.jagged_like(hilbert_j)


def hilbert_flipped_from_jagged_ijk(jagged_ijk: fvdb.JaggedTensor) -> fvdb.JaggedTensor:
ijk_j: torch.Tensor = jagged_ijk.jdata
kji_j = ijk_j.flip(dims=[-1])
hilbert_j = fvdb.hilbert(kji_j)
ikj_j = ijk_j[:, [0, 2, 1]].contiguous()
hilbert_j = fvdb.hilbert(ikj_j)
return jagged_ijk.jagged_like(hilbert_j)


Expand Down Expand Up @@ -259,7 +261,10 @@ def jagged_attention(
out_b = cast(
Any,
flash_attn.flash_attn_qkvpacked_func(
qkv_b.half(), dropout_p=0.0, softmax_scale=scale, window_size=window_size
qkv_b.half(),
dropout_p=0.0,
softmax_scale=scale,
window_size=window_size,
),
).reshape(
Li, hidden_size
Expand Down Expand Up @@ -303,15 +308,15 @@ def jagged_attention(
feats_out_j = cast(
Any,
flash_attn.flash_attn_varlen_qkvpacked_func(
qkv_j.half(),
qkv_j.to(dtype=torch.bfloat16),
cu_seqlens,
max_seqlen=patch_size,
dropout_p=0.0, # TODO: implement attention dropout in the future. By default, it is 0.
softmax_scale=scale,
),
).reshape(
num_voxels, hidden_size
) # dtype: float16
) # dtype: bfloat16

feats_out_j = feats_out_j.to(feats_j.dtype)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,7 @@ def create_grid_from_points(

coords_jagged = fvdb.JaggedTensor(coords_list)

grid = fvdb.GridBatch.from_ijk(
coords_jagged,
voxel_sizes=[[voxel_size, voxel_size, voxel_size]] * len(coords_list),
origins=[0.0] * 3,
)
grid = fvdb.GridBatch.from_ijk(coords_jagged)

feats_jagged = fvdb.JaggedTensor(feats_list)
feats_vdb_order = grid.inject_from_ijk(coords_jagged, feats_jagged) #
Expand Down Expand Up @@ -195,7 +191,6 @@ def forward(self, data_dict: dict) -> torch.Tensor:
grid, jfeats, original_coord_to_voxel_idx = create_grid_from_points(
grid_coord, feat, offset, voxel_size=0.02
)
# import pdb; pdb.set_trace()
# TODO: check the downsampling behavior is the same or not?
assert (
grid_coord.shape == grid.ijk.jdata.shape
Expand Down
47 changes: 31 additions & 16 deletions point_transformer_v3/fvdb_extensions/models/ptv3_fvdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,26 +161,34 @@ def __init__(
self.out_channels = out_channels

self.proj = FJTM(torch.nn.Linear(in_channels, out_channels))
self.norm = FJTM(norm_layer_module(out_channels))
self.act_layer = FJTM(torch.nn.GELU())
self.proj_skip = FJTM(torch.nn.Linear(skip_channels, out_channels))
self.norm = FJTM(norm_layer_module(out_channels))
self.norm_skip = FJTM(norm_layer_module(out_channels))
self.act_layer = FJTM(torch.nn.GELU())
self.act_layer_skip = FJTM(torch.nn.GELU())

def __call__(
self, feats: fvdb.JaggedTensor, grid: fvdb.GridBatch, last_feats: fvdb.JaggedTensor, last_grid: fvdb.GridBatch
self,
feats: fvdb.JaggedTensor,
grid: fvdb.GridBatch,
last_feats: fvdb.JaggedTensor,
last_grid: fvdb.GridBatch,
) -> tuple[fvdb.JaggedTensor, fvdb.GridBatch]:
"""Override __call__ to preserve type hints from forward."""
return super().__call__(feats, grid, last_feats, last_grid)

def forward(
self, feats: fvdb.JaggedTensor, grid: fvdb.GridBatch, last_feats: fvdb.JaggedTensor, last_grid: fvdb.GridBatch
self,
feats: fvdb.JaggedTensor,
grid: fvdb.GridBatch,
last_feats: fvdb.JaggedTensor,
last_grid: fvdb.GridBatch,
) -> tuple[fvdb.JaggedTensor, fvdb.GridBatch]:
with NVTXRange("PTV3_Unpooling"):
# The conversion is to avoid the bug when enabled AMP,
# despite both feats.jdata and linear.weights are float32,
# the output becomes float16 which causes the subsequent convolution operation to fail.
feats = self.proj(feats).to(torch.float32)
feats = self.proj(feats) # .to(torch.float32)
feats = self.norm(feats)
feats = self.act_layer(feats)

Expand Down Expand Up @@ -296,7 +304,12 @@ def forward(self, feats: fvdb.JaggedTensor, grid: fvdb.GridBatch) -> fvdb.Jagged


class PTV3_CPE(FVDBGridModule):
def __init__(self, hidden_size: int, no_conv_in_cpe: bool = False, shared_plan_cache: dict | None = None):
def __init__(
self,
hidden_size: int,
no_conv_in_cpe: bool = False,
shared_plan_cache: dict | None = None,
):
"""
Args:
hidden_size (int): Number of channels in the input features.
Expand Down Expand Up @@ -387,7 +400,7 @@ def __init__(
sliding_window_attention,
order_index,
order_types,
)
) # temporary disable attention
self.norm2 = FJTM(torch.nn.LayerNorm(hidden_size))
self.order_index = order_index
self.mlp = PTV3_MLP(hidden_size, proj_drop)
Expand All @@ -397,16 +410,18 @@ def forward(self, feats: fvdb.JaggedTensor, grid: fvdb.GridBatch) -> fvdb.Jagged
assert isinstance(feats, fvdb.JaggedTensor), "Input feats must be a JaggedTensor"
assert isinstance(grid, fvdb.GridBatch), "Input grid must be a GridBatch"
with NVTXRange("PTV3_Block"):
short_cut = feats
feats = self.cpe(feats, grid)
feats = fvdb.add(short_cut, feats)
short_cut = feats

feats = self.norm1(feats)
feats = self.attn(feats, grid)
feats = self.attn(feats, grid) # temporary disable attention
feats = self.drop_path(feats) # temporary disable attention
# The drop_path is applied to each point independently.
feats = self.drop_path(feats)
feats = fvdb.add(short_cut, feats)
short_cut = feats

short_cut = feats
feats = self.norm2(feats)
feats = self.mlp(feats)
feats = self.drop_path(feats)
Expand Down Expand Up @@ -634,23 +649,23 @@ def __init__(
)
)

def _shuffle_order(self):
def _shuffle_order(self, shuffled_order):
"""
Randomly shuffle the order tuple to create variation across forward passes.
Returns a new shuffled tuple of order types.
"""
if self.shuffle_orders:
indices = torch.randperm(len(self.order_type))
return tuple(self.order_type[i] for i in indices)
indices = torch.randperm(len(shuffled_order))
return tuple(shuffled_order[i] for i in indices)
else:
return self.order_type
return shuffled_order

def forward(self, feats: fvdb.JaggedTensor, grid: fvdb.GridBatch) -> fvdb.JaggedTensor:
original_grid = grid
with NVTXRange("PTV3_Forward"):

# Shuffle order at the beginning of forward pass (matching reference implementation)
shuffled_order = self._shuffle_order()
shuffled_order = self._shuffle_order(self.order_type)

# Store shuffled order in grid metadata so all blocks can access it
grid._shuffled_order = shuffled_order # type: ignore
Expand All @@ -670,7 +685,7 @@ def forward(self, feats: fvdb.JaggedTensor, grid: fvdb.GridBatch) -> fvdb.Jagged
feats, grid = pooler(feats, grid)

# Shuffle order after pooling for the next (downsampled) stage
shuffled_order = self._shuffle_order()
shuffled_order = self._shuffle_order(shuffled_order)
grid._shuffled_order = shuffled_order # type: ignore
layer_id += 1
with NVTXRange(f"PTV3_Encoder_{layer_id}"):
Expand Down
3 changes: 2 additions & 1 deletion point_transformer_v3/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# fvdb requirements
# # fvdb requirements
--extra-index-url https://download.pytorch.org/whl/cu129
--extra-index-url https://d36m13axqqhiit.cloudfront.net/simple
fvdb-core==0.3.0+pt28.cu129
Expand All @@ -11,6 +11,7 @@ peft
wandb
tensorboard
tensorboardx
yapf

# flash-attn is only needed when patch_size > 0 (default config uses patch_size=1024)
# While PyTorch 2.8+ has built-in flash attention, flash-attn provides optimized varlen functions
Expand Down