Skip to content

Conversation

@nfeybesse
Copy link
Contributor

  1. Initial Problem

You wanted to register multiple custom gradients in Java using
TensorFlow.registerCustomGradient(...).

Observed symptom:

After registering a few gradients (≈ 5–10),

TFJ_RegisterCustomGradient(opType, adapter) received adapter_ptr = 0 on the C++ side,

which resulted in:

either a refusal to register the gradient,

or a SIGSEGV later during backpropagation.

Key observation:

If the “important” gradient was registered first, it worked.

Subsequent ones failed → this was a cumulative issue, not related to the specific op.

  1. Actual Root Cause

It was not:

a JNI signature bug,

an InfoMap issue,

nor a casting or ABI problem.

👉 The real cause was a limitation in JavaCPP FunctionPointer callbacks:

each TFJ_GradFuncAdapter allocates a native thunk,

after a certain number of such allocations, JavaCPP silently passes a null pointer (0),

the TensorFlow C++ runtime then receives an invalid callback pointer.

👉 Conclusion:
Creating one native callback per gradient is not scalable.

  1. Principle of the Definitive Fix

Instead of:

1 gradient = 1 native callback

We switched to:

1 single native callback

with dispatching in Java based on opType

This is exactly how TensorFlow does it in Python on the C++ side.

  1. Final Architecture
    A. A Single Native Callback (Singleton)

A single TFJ_GradFuncAdapter instance

Registered with TensorFlow C++ for all ops

As a result:

no more adapter_ptr = 0

no practical limit on the number of custom gradients

B. Java-side Dispatch by opType

A Java dispatcher selects the correct gradient during backpropagation:

TensorFlow C++

CustomGradFunc (C++)

TFJ_GradFuncAdapter.call(...)

DispatchingGradientAdapter.apply(...)

CustomGradient / RawCustomGradient for the corresponding op

  1. Proper Handling of Visibility Constraints
    Problem

NativeScope and Ops have package-private constructors

They are only accessible from org.tensorflow.op

Solution

DispatchingGradientAdapter is package-private and lives in org.tensorflow.op

A public GradientDispatch class acts as a bridge

TensorFlow.java only sees the public TFJ_GradFuncAdapter type

➡️ This strictly respects TensorFlow Java’s internal design, with no hacks.

  1. Correct Support for “NoGradient”
    Problem

Returning null on the Java side caused a NullPointerException

The native code did not correctly support TF_Output.oper == nullptr

Fixes

Java side (AbstractGradientAdapter):

null is now translated into:

TF_Output { oper = nullptr, index = 0 }

C++ side (CustomGradFunc):

out.oper == nullptr is interpreted as NoGradient

No dangerous dereference

No crashes / no SIGSEGV

  1. Cleanup of the C++ Bridge (CustomGradFunc)

Applied corrections:

Removed a double loop that was adding gradients twice

Consistent handling of NoGradient

Single, safe memory deallocation (free(outputs))

Preserved defensive hardening:

checks on num_outputs

outputs == nullptr

etc.

  1. Final State
    What now works

✔ Registering dozens (or hundreds) of custom gradients

✔ Registration order no longer matters

✔ No more adapter_ptr = 0

✔ No JNI crashes / no SIGSEGV

✔ Proper support for partial gradients (NoGradient)

✔ Architecture aligned with native TensorFlow

What was avoided

❌ Fragile JavaCPP patches

❌ Dependency on internal allocation details

❌ Workarounds based on registration order

  1. In One Sentence

We replaced a non-scalable architecture (“N gradients = N native callbacks”) with a scalable one (“1 native callback + Java dispatch”), while properly fixing NoGradient handling and strictly respecting TensorFlow Java’s internal constraints.

unordered_map<string, TFJ_GradFuncAdapter> g_grad_func_adapters;

// Cast helper (inspired by TF C-API)
template <typename T, typename U>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you fix this diff to remove all the formatting changes so we can see just the functional changes to CustomGradFunc?

return false;
}

bool registered = GradOpRegistry::Global()->Register(op_type, CustomGradFunc);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please fix the formatting to reduce the diff.

@Craigacp
Copy link
Collaborator

This looks like a fairly complicated fix to work around a bug in JavaCPP? Is it not better to fix it there?

@nfeybesse
Copy link
Contributor Author

Thanks for the question — it’s a fair concern.

This change is indeed a workaround for a limitation in JavaCPP (bytedeco/javacpp#1205), where multiple native callbacks of the same kind cannot be reliably registered and invoked. In practice, only the last registered gradient adapter survives, which makes it impossible to support more than one Java custom gradient per process.

Fixing this directly in JavaCPP would be ideal in theory, but in practice it is not a viable short- or medium-term option for TensorFlow Java:

The issue is deep in JavaCPP’s native callback and lifetime management.

TensorFlow Java depends on JavaCPP as an external project, and cannot reasonably block feature development or correctness fixes on changes there.

Even with a JavaCPP fix, TensorFlow Java would still need a stable, deterministic way to manage gradient dispatch per op type.

For these reasons, this PR follows the same architectural pattern already used by TensorFlow itself.

TensorFlow Python does not register one native callback per op.
Instead, it registers a single C++ gradient hook and performs runtime dispatch based on the op type (via the gradient registry). In other words, Python also uses a centralized dispatcher rather than relying on multiple independent native callbacks.

This PR mirrors that design on the Java side:

A single native CustomGradFunc is registered with TensorFlow.

That function dispatches to the appropriate Java gradient implementation based on op_type.

This avoids the JavaCPP limitation entirely, while matching TensorFlow’s own gradient architecture.

As a result, the solution is:

robust and deterministic,

consistent with TensorFlow’s Python design,

backward-compatible,

and does not require changes to JavaCPP or TensorFlow C++.

In short: while the root cause is a JavaCPP limitation, centralizing gradient dispatch is not a hack — it is the same model TensorFlow already uses, adapted to the Java runtime constraints.

@nfeybesse
Copy link
Contributor Author

bytedeco/javacpp#648

@nfeybesse nfeybesse force-pushed the custom/gradients-dispatch branch from d7bc382 to 8d80312 Compare February 11, 2026 10:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants