Skip to content

Commit 145e4c5

Browse files
issue/143 feat: static and paged graph compilers
1 parent de3e6b9 commit 145e4c5

19 files changed

+375
-21
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,5 @@ __pycache__/
2929
*.txt
3030

3131
*.http
32+
33+
*.nsys-rep
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#include "general_compiler.hpp"
2+
3+
namespace infinilm::engine {
4+
GeneralCompiler::GeneralCompiler(const std::shared_ptr<InfinilmModel> &model) : GraphCompiler(model) {
5+
static_batching_compiler_ = std::make_unique<StaticBatchingCompiler>(model_);
6+
paged_compiler_ = std::make_unique<PagedCompiler>(model_);
7+
}
8+
9+
void GeneralCompiler::compile() {
10+
static_batching_compiler_->compile();
11+
paged_compiler_->compile();
12+
}
13+
14+
GeneralCompiler::Compiled GeneralCompiler::get_compiled(const InfinilmModel::Input &input) {
15+
GeneralCompiler::Compiled result = {nullptr, nullptr};
16+
17+
// try each compiler, return the first valid result
18+
result = static_batching_compiler_.get()->get_compiled(input);
19+
if (std::get<0>(result) != nullptr && std::get<1>(result) != nullptr) {
20+
return result;
21+
}
22+
result = paged_compiler_.get()->get_compiled(input);
23+
return result;
24+
}
25+
26+
} // namespace infinilm::engine
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#pragma once
2+
3+
#include "paged_compiler.hpp"
4+
#include "static_batching_compiler.hpp"
5+
6+
namespace infinilm::engine {
7+
class GeneralCompiler : public GraphCompiler {
8+
public:
9+
GeneralCompiler(const std::shared_ptr<InfinilmModel> &model);
10+
11+
void compile() override;
12+
13+
Compiled get_compiled(const InfinilmModel::Input &input) override;
14+
15+
private:
16+
std::unique_ptr<StaticBatchingCompiler> static_batching_compiler_;
17+
std::unique_ptr<PagedCompiler> paged_compiler_;
18+
};
19+
} // namespace infinilm::engine
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#pragma once
2+
3+
#include "../../models/infinilm_model.hpp"
4+
5+
namespace infinilm::engine {
6+
7+
class GraphCompiler {
8+
public:
9+
using Compiled = std::tuple<
10+
std::shared_ptr<infinicore::graph::Graph>,
11+
std::shared_ptr<InfinilmModel::Output>>;
12+
13+
explicit GraphCompiler(const std::shared_ptr<InfinilmModel> &model) : model_(model) {}
14+
virtual ~GraphCompiler() = default;
15+
16+
virtual void compile() = 0;
17+
virtual Compiled get_compiled(const InfinilmModel::Input &input) = 0;
18+
19+
protected:
20+
std::shared_ptr<InfinilmModel> model_;
21+
};
22+
23+
} // namespace infinilm::engine
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
#include "paged_compiler.hpp"
2+
3+
namespace infinilm::engine {
4+
PagedCompiler::PagedCompiler(const std::shared_ptr<InfinilmModel> &model)
5+
: GraphCompiler(model) {
6+
for (size_t b = 1; b < 32; b++) {
7+
decode_batch_sizes_.push_back(b);
8+
}
9+
for (size_t b = 32; b < 64; b += 8) {
10+
decode_batch_sizes_.push_back(b);
11+
}
12+
for (size_t b = 64; b < 128; b += 16) {
13+
decode_batch_sizes_.push_back(b);
14+
}
15+
for (size_t b = 128; b < 256; b += 32) {
16+
decode_batch_sizes_.push_back(b);
17+
}
18+
for (size_t b = 256; b <= 512; b += 64) {
19+
decode_batch_sizes_.push_back(b);
20+
}
21+
}
22+
23+
void PagedCompiler::compile() {
24+
if (model_->get_cache_config() != nullptr && dynamic_cast<const cache::PagedKVCacheConfig *>(model_->get_cache_config())) {
25+
size_t nblocks = dynamic_cast<const cache::PagedKVCacheConfig *>(model_->get_cache_config())->num_blocks();
26+
size_t max_batch_size = *std::max_element(decode_batch_sizes_.begin(), decode_batch_sizes_.end());
27+
compiled_map_decode_.clear();
28+
block_tables_holder_ = infinicore::Tensor::empty(
29+
{nblocks}, infinicore::DataType::I64, infinicore::context::getDevice());
30+
for (size_t b : decode_batch_sizes_) {
31+
size_t block_per_req = nblocks / b;
32+
InfinilmModel::Input input;
33+
input.input_ids = infinicore::Tensor::empty({1, b}, infinicore::DataType::I64, infinicore::context::getDevice());
34+
input.position_ids = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice());
35+
input.total_sequence_lengths = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice());
36+
input.input_offsets = infinicore::Tensor::empty({b + 1}, infinicore::DataType::I64, infinicore::context::getDevice());
37+
input.block_tables = block_tables_holder_->as_strided({b, block_per_req}, {(ptrdiff_t)block_per_req, 1});
38+
input.slot_mapping = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice());
39+
infinicore::context::startGraphRecording();
40+
auto output = model_->forward(input);
41+
auto graph = infinicore::context::stopGraphRecording();
42+
43+
auto shared_output = std::shared_ptr<InfinilmModel::Output>(
44+
new InfinilmModel::Output{infinicore::graph::GraphTensor(output.logits)});
45+
46+
compiled_map_decode_[b] = CompiledResult{std::move(input), std::make_tuple(graph, shared_output)};
47+
}
48+
}
49+
}
50+
51+
PagedCompiler::Compiled PagedCompiler::get_compiled(const InfinilmModel::Input &input) {
52+
if (model_->get_cache_config() != nullptr && dynamic_cast<const cache::PagedKVCacheConfig *>(model_->get_cache_config())) {
53+
size_t batch_size = input.block_tables.value()->size(0);
54+
size_t block_per_req = input.block_tables.value()->size(1);
55+
56+
// only support decode only batch
57+
if (batch_size != input.input_ids.value()->size(1)) {
58+
return {nullptr, nullptr};
59+
} else {
60+
auto result = compiled_map_decode_.find(batch_size);
61+
if (result == compiled_map_decode_.end()) {
62+
return std::make_tuple(nullptr, nullptr);
63+
}
64+
auto &graph_input = result->second.input;
65+
66+
graph_input.input_ids.value()->copy_from(input.input_ids.value());
67+
graph_input.position_ids.value()->copy_from(input.position_ids.value());
68+
graph_input.total_sequence_lengths.value()->copy_from(input.total_sequence_lengths.value());
69+
graph_input.input_offsets.value()->copy_from(input.input_offsets.value());
70+
graph_input.block_tables.value()->narrow({{1, 0, block_per_req}})->copy_from(input.block_tables.value());
71+
graph_input.slot_mapping.value()->copy_from(input.slot_mapping.value());
72+
73+
static_cast<infinicore::graph::GraphTensor>(std::get<1>(result->second.compiled)->logits).resume();
74+
return result->second.compiled;
75+
}
76+
77+
} else {
78+
return {nullptr, nullptr};
79+
}
80+
}
81+
82+
} // namespace infinilm::engine
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#pragma once
2+
3+
#include "graph_compiler.hpp"
4+
5+
#include <unordered_map>
6+
7+
namespace infinilm::engine {
8+
class PagedCompiler : public GraphCompiler {
9+
public:
10+
PagedCompiler(const std::shared_ptr<InfinilmModel> &model);
11+
12+
void compile() override;
13+
14+
Compiled get_compiled(const InfinilmModel::Input &input) override;
15+
16+
private:
17+
std::vector<size_t> decode_batch_sizes_;
18+
19+
infinicore::Tensor block_tables_holder_;
20+
21+
struct CompiledResult {
22+
InfinilmModel::Input input;
23+
Compiled compiled;
24+
};
25+
26+
std::unordered_map<
27+
size_t, // num_requests
28+
CompiledResult>
29+
compiled_map_decode_;
30+
};
31+
} // namespace infinilm::engine
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#include "static_batching_compiler.hpp"
2+
3+
#include "../../cache/cache.hpp"
4+
5+
namespace infinilm::engine {
6+
StaticBatchingCompiler::StaticBatchingCompiler(const std::shared_ptr<InfinilmModel> &model)
7+
: GraphCompiler(model) {
8+
}
9+
10+
void StaticBatchingCompiler::compile() {
11+
if (model_->get_cache_config() != nullptr && dynamic_cast<const cache::StaticKVCacheConfig *>(model_->get_cache_config())) {
12+
size_t b = dynamic_cast<const cache::StaticKVCacheConfig *>(model_->get_cache_config())->max_batch_size();
13+
spdlog::info("Compiling graph for batch size {} and sequence length {}", b, 1);
14+
InfinilmModel::Input input;
15+
input.input_ids = infinicore::Tensor::empty({b, 1}, infinicore::DataType::I64, infinicore::context::getDevice());
16+
input.position_ids = infinicore::Tensor::empty({b, 1}, infinicore::DataType::I64, infinicore::context::getDevice());
17+
input.past_sequence_lengths = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice());
18+
input.total_sequence_lengths = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice());
19+
infinicore::context::startGraphRecording();
20+
auto output = model_->forward(input);
21+
auto graph = infinicore::context::stopGraphRecording();
22+
23+
auto shared_output = std::shared_ptr<InfinilmModel::Output>(new InfinilmModel::Output{infinicore::graph::GraphTensor(output.logits)});
24+
25+
compiled_map_[std::make_tuple(b, 1)] = CompiledResult{std::move(input), std::make_tuple(graph, shared_output)};
26+
}
27+
}
28+
29+
StaticBatchingCompiler::Compiled StaticBatchingCompiler::get_compiled(
30+
const InfinilmModel::Input &input) {
31+
if (model_->get_cache_config() != nullptr && dynamic_cast<const cache::StaticKVCacheConfig *>(model_->get_cache_config())) {
32+
size_t batch_size = input.input_ids.value()->size(0);
33+
size_t seqlen = input.input_ids.value()->size(1);
34+
auto result = compiled_map_.find(std::make_tuple(batch_size, seqlen));
35+
if (result == compiled_map_.end()) {
36+
return std::make_tuple(nullptr, nullptr);
37+
} else {
38+
auto &graph_input = result->second.input;
39+
graph_input.input_ids.value()->copy_from(input.input_ids.value());
40+
graph_input.position_ids.value()->copy_from(input.position_ids.value());
41+
graph_input.past_sequence_lengths.value()->copy_from(input.past_sequence_lengths.value());
42+
graph_input.total_sequence_lengths.value()->copy_from(input.total_sequence_lengths.value());
43+
44+
static_cast<infinicore::graph::GraphTensor>(std::get<1>(result->second.compiled)->logits).resume();
45+
return result->second.compiled;
46+
}
47+
} else {
48+
return std::make_tuple(nullptr, nullptr);
49+
}
50+
}
51+
} // namespace infinilm::engine
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#pragma once
2+
3+
#include "graph_compiler.hpp"
4+
5+
#include <unordered_map>
6+
7+
namespace infinilm::engine {
8+
class StaticBatchingCompiler : public GraphCompiler {
9+
public:
10+
StaticBatchingCompiler(const std::shared_ptr<InfinilmModel> &model);
11+
12+
void compile() override;
13+
14+
Compiled get_compiled(const InfinilmModel::Input &input) override;
15+
16+
private:
17+
struct TupleHash {
18+
size_t operator()(const std::tuple<size_t, size_t> &t) const noexcept {
19+
auto h1 = std::hash<size_t>{}(std::get<0>(t));
20+
auto h2 = std::hash<size_t>{}(std::get<1>(t));
21+
return h1 ^ (h2 + 0x9e3779b97f4a7c15ULL + (h1 << 6) + (h1 >> 2));
22+
}
23+
};
24+
25+
struct CompiledResult {
26+
InfinilmModel::Input input;
27+
Compiled compiled;
28+
};
29+
30+
std::unordered_map<
31+
std::tuple<size_t, size_t>, // (batch_size, seq_len)
32+
CompiledResult,
33+
TupleHash>
34+
compiled_map_;
35+
};
36+
} // namespace infinilm::engine

csrc/engine/infer_engine.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ InferEngine::InferEngine(
1010
const InfinilmModel::Config &config,
1111
const distributed::DistConfig &distributed_config,
1212
infinicore::Device::Type device_type,
13-
const cache::CacheConfig *cache_config) // Changed parameter
13+
const cache::CacheConfig *cache_config,
14+
bool enable_graph_compiling) // Changed parameter
1415
: communication_group_(distributed_config, device_type),
1516
model_config_(config) {
1617

@@ -24,7 +25,8 @@ InferEngine::InferEngine(
2425
workers_.emplace_back(std::make_unique<RankWorker>(
2526
model_config_,
2627
communication_group_.get_rank_info(r),
27-
cache_config_ != nullptr ? cache_config_.get() : nullptr));
28+
cache_config_ != nullptr ? cache_config_.get() : nullptr,
29+
enable_graph_compiling));
2830
}
2931
}
3032

csrc/engine/infer_engine.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ class InferEngine {
2222
const InfinilmModel::Config &config,
2323
const distributed::DistConfig &distributed_config = distributed::DistConfig(),
2424
infinicore::Device::Type device_type = infinicore::context::getDevice().getType(),
25-
const cache::CacheConfig *cache_config = nullptr);
25+
const cache::CacheConfig *cache_config = nullptr,
26+
bool enable_graph_compiling = false);
2627

2728
// Load a parameter to all workers (each can extract its shard inside RankWorker)
2829
void load_param(const std::string &name, const infinicore::Tensor &param);

0 commit comments

Comments
 (0)