Skip to content

Commit efcbfa6

Browse files
author
兆惠
committed
fix unsaved issue
1 parent a3f361e commit efcbfa6

5 files changed

Lines changed: 887 additions & 304 deletions

File tree

docs/source/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
88

99
project = 'dInfer'
10-
copyright = '2025, Lun Du'
10+
copyright = '2025, dInfer Team'
1111
author = 'Lun Du'
1212

1313
# -- General configuration ---------------------------------------------------

docs/source/index.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,10 @@ for the dInfer project. For now, you can start with the installation guide.
1919
start/install
2020
start/quickstart
2121

22+
.. toctree::
23+
:maxdepth: 2
24+
:caption: Advanced Usage
25+
26+
start/advanced_decoding
27+
start/performance
28+
Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
.. _advanced_decoding:
2+
3+
=========================
4+
Advanced Decoding Methods
5+
=========================
6+
7+
Last updated: 2025-11-20
8+
9+
This page introduces several advanced decoding strategies supported by dInfer,
10+
building on the basic setup shown in :doc:`quickstart`.
11+
12+
.. note::
13+
14+
In all code snippets below, we assume you have already:
15+
16+
- Loaded a tokenizer and model.
17+
- Defined ``mask_id`` and ``eos_id``.
18+
- Created a prompt and corresponding ``input_ids`` tensor on the correct device.
19+
20+
--------------------------------
21+
1. Hierarchical Decoding
22+
--------------------------------
23+
24+
Hierarchical decoding uses two thresholds to balance quality and speed.
25+
26+
.. code-block:: python
27+
28+
from dinfer import HierarchyDecoder, BlockWiseDiffusionLLM, BlockIteratorFactory
29+
30+
decoder = HierarchyDecoder(
31+
temperature=0.0,
32+
threshold=0.9, # High confidence threshold
33+
low_threshold=0.3, # Low confidence threshold
34+
mask_id=mask_id,
35+
eos_id=eos_id,
36+
)
37+
38+
dllm = BlockWiseDiffusionLLM(
39+
model=model,
40+
decoder=decoder,
41+
iterator_factory=BlockIteratorFactory(start_block_align=True),
42+
early_stop=True,
43+
)
44+
45+
output = dllm.generate(input_ids, gen_length=512, block_length=64)
46+
47+
**How it works:**
48+
49+
- Tokens with confidence > ``threshold`` are accepted immediately.
50+
- Tokens with confidence < ``low_threshold`` remain masked.
51+
- Tokens with intermediate confidence are accepted **only if** they are
52+
local maxima within masked regions.
53+
54+
This creates a hierarchy:
55+
56+
1. High-confidence tokens.
57+
2. Medium-confidence tokens in promising regions.
58+
3. Remaining low-confidence tokens.
59+
60+
----------------------------------------------
61+
2. Credit-Based Threshold Decoding
62+
----------------------------------------------
63+
64+
Credit-based decoding tracks decoding history to make better decisions.
65+
66+
.. code-block:: python
67+
68+
from dinfer import CreditThresholdParallelDecoder
69+
from dinfer import BlockWiseDiffusionLLM, BlockIteratorFactory
70+
71+
decoder = CreditThresholdParallelDecoder(
72+
temperature=0.0,
73+
threshold=0.9,
74+
mask_id=mask_id,
75+
eos_id=eos_id,
76+
)
77+
78+
dllm = BlockWiseDiffusionLLM(
79+
model=model,
80+
decoder=decoder,
81+
iterator_factory=BlockIteratorFactory(start_block_align=True),
82+
early_stop=True,
83+
)
84+
85+
output = dllm.generate(input_ids, gen_length=512, block_length=64)
86+
87+
**Benefits:**
88+
89+
- Accumulates "credits" for tokens that repeatedly have high confidence.
90+
- Helps prevent premature acceptance in difficult regions.
91+
- Leads to more stable convergence in challenging generation scenarios.
92+
93+
-----------------------------------------------------
94+
3. Iterative Smoothing with Vicinity-Aware KV Cache
95+
-----------------------------------------------------
96+
97+
To improve coherence, you can use iterative smoothing together with a
98+
vicinity-aware KV cache.
99+
100+
.. code-block:: python
101+
102+
from dinfer import IterSmoothWithVicinityCacheDiffusionLLM, KVCacheFactory
103+
from dinfer import BlockIteratorFactory
104+
105+
cache_factory = KVCacheFactory(
106+
cache_type='dual', # Use both prefix and suffix caching
107+
is_bd_model=False,
108+
)
109+
110+
dllm = IterSmoothWithVicinityCacheDiffusionLLM(
111+
model=model,
112+
decoder=decoder,
113+
iterator_factory=BlockIteratorFactory(start_block_align=True),
114+
cache_factory=cache_factory,
115+
early_stop=True,
116+
cont_weight=0.3, # Continuity weight for smoothing
117+
prefix_look=16, # Look-back context size
118+
after_look=16, # Look-ahead context size
119+
warmup_steps=4, # Number of warmup iterations
120+
)
121+
122+
output = dllm.generate(input_ids, gen_length=512, block_length=64)
123+
124+
**Key parameters:**
125+
126+
- ``cont_weight`` (0.0–1.0):
127+
Controls the strength of continuity regularization.
128+
Higher → smoother transitions; lower → more independent predictions.
129+
130+
- ``prefix_look``:
131+
Number of tokens to look back for context.
132+
133+
- ``after_look``:
134+
Number of tokens to look ahead for context.
135+
136+
- ``warmup_steps``:
137+
Number of initial iterations with full diffusion before enabling smoothing.
138+
139+
---------------------------------------------
140+
4. Block Diffusion (LLaDA2.0 Models)
141+
---------------------------------------------
142+
143+
LLaDA2.0 models are trained with block diffusion and require special handling.
144+
145+
.. code-block:: python
146+
147+
import torch
148+
from transformers import AutoConfig
149+
from dinfer.model import LLaDA2MoeModelLM
150+
from dinfer import BlockDiffusionLLM, KVCacheFactory, BlockIteratorFactory
151+
from dinfer import ThresholdParallelDecoder
152+
153+
device = torch.device("cuda:0")
154+
model_name = "/path/to/local/LLaDA2.0-mini-preview"
155+
156+
# Load LLaDA2 model
157+
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
158+
model = LLaDA2MoeModelLM(config=model_config).eval()
159+
model.load_weights(model_name, torch_dtype=torch.bfloat16, device=device)
160+
model = model.to(device)
161+
162+
mask_id = 156895
163+
eos_id = 156892
164+
165+
decoder = ThresholdParallelDecoder(
166+
temperature=0.0,
167+
threshold=0.9,
168+
mask_id=mask_id,
169+
eos_id=eos_id,
170+
)
171+
172+
cache_factory = KVCacheFactory(cache_type='prefix', is_bd_model=True)
173+
174+
dllm = BlockDiffusionLLM(
175+
model=model,
176+
decoder=decoder,
177+
iterator_factory=BlockIteratorFactory(
178+
start_block_align=True,
179+
use_block_diffusion=True, # Enable block diffusion mode
180+
),
181+
cache_factory=cache_factory,
182+
early_stop=True,
183+
)
184+
185+
output = dllm.generate(input_ids, gen_length=2048, block_length=32)
186+
187+
-------------------------------------------------
188+
5. KV Cache Strategies in dInfer
189+
-------------------------------------------------
190+
191+
dInfer supports multiple KV cache strategies for efficiency:
192+
193+
.. code-block:: python
194+
195+
from dinfer import KVCacheFactory, BlockWiseDiffusionLLM, BlockIteratorFactory
196+
197+
# Option 1: Prefix caching only (common for causal LMs)
198+
cache_factory = KVCacheFactory(cache_type='prefix', is_bd_model=False)
199+
200+
# Option 2: Dual caching (prefix + suffix refresh)
201+
cache_factory = KVCacheFactory(cache_type='dual', is_bd_model=False)
202+
203+
# Option 3: No caching (simplest, but slower)
204+
cache_factory = None
205+
206+
dllm = BlockWiseDiffusionLLM(
207+
model=model,
208+
decoder=decoder,
209+
iterator_factory=BlockIteratorFactory(start_block_align=True),
210+
cache_factory=cache_factory,
211+
early_stop=True,
212+
)
213+
214+
**Cache type comparison:**
215+
216+
- ``prefix``:
217+
- Caches only the prompt and fixed prefix context.
218+
- Best for: Single-turn generation, simple prompts.
219+
- Memory usage: Low.
220+
221+
- ``dual``:
222+
- Caches both prefix and dynamically refreshes vicinity tokens.
223+
- Best for: Multi-turn generation, complex reasoning tasks.
224+
- Memory usage: Medium.
225+
226+
- ``None``:
227+
- No caching; recomputes everything.
228+
- Best for: Very short sequences, debugging scenarios.
229+
- Memory usage: Lowest.

0 commit comments

Comments
 (0)