@@ -13,14 +13,14 @@ See the License for the specific language governing permissions and
1313limitations under the License.
1414==============================================================================*/
1515
16- #include " multi_tier_block_manager_pool .h"
16+ #include " hierarchy_block_manager_pool .h"
1717
1818#include " block_manager_impl.h"
1919#include " concurrent_block_manager_impl.h"
2020
2121namespace xllm {
2222
23- MultiTierBlockManagerPool::MultiTierBlockManagerPool (
23+ HierarchyBlockManagerPool::HierarchyBlockManagerPool (
2424 const BlockManagerPool::Options& options,
2525 Engine* engine,
2626 int32_t dp_size)
@@ -52,7 +52,7 @@ MultiTierBlockManagerPool::MultiTierBlockManagerPool(
5252 saved_device_blocks_.resize (host_block_managers_.size ());
5353}
5454
55- void MultiTierBlockManagerPool ::deallocate (Sequence* sequence) {
55+ void HierarchyBlockManagerPool ::deallocate (Sequence* sequence) {
5656 DCHECK (sequence != nullptr );
5757 // add blocks to the prefix cache
5858 int32_t dp_rank = BlockManagerPool::get_dp_rank (sequence);
@@ -65,7 +65,7 @@ void MultiTierBlockManagerPool::deallocate(Sequence* sequence) {
6565 return ;
6666 }
6767
68- int cached_block_num =
68+ size_t cached_block_num =
6969 sequence->host_kv_state ().kv_cache_tokens_num () / options_.block_size ();
7070
7171 if (host_blocks->size () > 0 ) {
@@ -82,7 +82,7 @@ void MultiTierBlockManagerPool::deallocate(Sequence* sequence) {
8282 sequence->host_kv_state ().add_kv_blocks (
8383 host_block_managers_[dp_rank]->allocate (needed_block_num));
8484
85- for (int i = cached_block_num; i < host_blocks->size (); i++) {
85+ for (size_t i = cached_block_num; i < host_blocks->size (); i++) {
8686 if (blocks->at (i).ref_count () != 2 ) {
8787 continue ;
8888 }
@@ -107,11 +107,12 @@ void MultiTierBlockManagerPool::deallocate(Sequence* sequence) {
107107 sequence->reset ();
108108}
109109
110- bool MultiTierBlockManagerPool ::allocate (Sequence* sequence,
110+ bool HierarchyBlockManagerPool ::allocate (Sequence* sequence,
111111 size_t num_tokens) {
112112 BlockManagerPool::allocate (sequence, num_tokens);
113113
114- if (sequence->host_kv_state ().num_kv_blocks () == 0 ) {
114+ if (sequence->host_kv_state ().num_kv_blocks () == 0 &&
115+ sequence->stage () != SequenceStage::DECODE) {
115116 allocate_host_shared (sequence);
116117 }
117118
@@ -137,7 +138,7 @@ bool MultiTierBlockManagerPool::allocate(Sequence* sequence,
137138 return true ;
138139}
139140
140- void MultiTierBlockManagerPool ::allocate_host_shared (Sequence* sequence) {
141+ void HierarchyBlockManagerPool ::allocate_host_shared (Sequence* sequence) {
141142 if (options_.enable_prefix_cache ()) {
142143 int32_t dp_rank = BlockManagerPool::get_dp_rank (sequence);
143144 std::vector<Block> shared_blocks =
@@ -146,7 +147,7 @@ void MultiTierBlockManagerPool::allocate_host_shared(Sequence* sequence) {
146147 }
147148}
148149
149- void MultiTierBlockManagerPool ::prefetch_from_storage (
150+ void HierarchyBlockManagerPool ::prefetch_from_storage (
150151 std::shared_ptr<Request>& request) {
151152 if (!options_.enable_kvcache_store ()) {
152153 return ;
@@ -202,7 +203,7 @@ void MultiTierBlockManagerPool::prefetch_from_storage(
202203 }
203204}
204205
205- bool MultiTierBlockManagerPool ::update_prefetch_result (
206+ bool HierarchyBlockManagerPool ::update_prefetch_result (
206207 std::shared_ptr<Request>& request,
207208 const uint32_t timeout) {
208209 if (!options_.enable_kvcache_store ()) {
@@ -216,8 +217,9 @@ bool MultiTierBlockManagerPool::update_prefetch_result(
216217 return prefetch_result;
217218}
218219
219- void MultiTierBlockManagerPool::transfer_blocks (std::vector<Batch>* batches) {
220- if (batches != nullptr ) {
220+ void HierarchyBlockManagerPool::transfer_blocks (
221+ std::optional<std::vector<Batch>> batches) {
222+ if (batches.has_value ()) {
221223 // load blocks from host to device
222224 for (int i = 0 ; i < batches->size (); i++) {
223225 if (!load_block_transfer_infos_[i].empty ()) {
@@ -265,7 +267,7 @@ void MultiTierBlockManagerPool::transfer_blocks(std::vector<Batch>* batches) {
265267 saved_device_blocks_.resize (host_block_managers_.size ());
266268}
267269
268- void MultiTierBlockManagerPool ::get_merged_kvcache_event (
270+ void HierarchyBlockManagerPool ::get_merged_kvcache_event (
269271 KvCacheEvent* event) const {
270272 if (host_block_managers_.empty ()) {
271273 BlockManagerPool::get_merged_kvcache_event (event);
0 commit comments