Skip to content

Commit 5f91d1c

Browse files
committed
refactor: remove MTP-specific function requirement from non-MTP models.
1 parent cb2443e commit 5f91d1c

File tree

2 files changed

+102
-12
lines changed

2 files changed

+102
-12
lines changed

xllm/core/framework/model/causal_lm.h

Lines changed: 78 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,43 @@ limitations under the License.
3636

3737
namespace xllm {
3838

39+
namespace detail {
40+
template <typename T, typename = void>
41+
struct has_get_lm_head : std::false_type {};
42+
43+
template <typename T>
44+
struct has_get_lm_head<T,
45+
std::void_t<decltype(std::declval<T>().get_lm_head())>>
46+
: std::true_type {};
47+
48+
template <typename T, typename = void>
49+
struct has_set_lm_head : std::false_type {};
50+
51+
template <typename T>
52+
struct has_set_lm_head<T,
53+
std::void_t<decltype(std::declval<T>().set_lm_head(
54+
std::declval<layer::LmHead&>()))>> : std::true_type {
55+
};
56+
57+
template <typename T, typename = void>
58+
struct has_get_word_embedding : std::false_type {};
59+
60+
template <typename T>
61+
struct has_get_word_embedding<
62+
T,
63+
std::void_t<decltype(std::declval<T>().get_word_embedding())>>
64+
: std::true_type {};
65+
66+
template <typename T, typename = void>
67+
struct has_set_word_embedding : std::false_type {};
68+
69+
template <typename T>
70+
struct has_set_word_embedding<
71+
T,
72+
std::void_t<decltype(std::declval<T>().set_word_embedding(
73+
std::declval<layer::WordEmbedding&>()))>> : std::true_type {};
74+
} // namespace detail
75+
3976
class CausalLM : public torch::nn::Module {
4077
public:
4178
~CausalLM() override = default;
@@ -65,10 +102,23 @@ class CausalLM : public torch::nn::Module {
65102

66103
virtual const torch::TensorOptions& options() const = 0;
67104

68-
virtual layer::LmHead get_lm_head() = 0;
69-
virtual void set_lm_head(layer::LmHead& head) = 0;
70-
virtual layer::WordEmbedding get_word_embedding() = 0;
71-
virtual void set_word_embedding(layer::WordEmbedding& embedding) = 0;
105+
// MTP-specific interface.
106+
virtual layer::LmHead get_lm_head() {
107+
LOG(FATAL)
108+
<< "Method 'get_lm_head' is not implemented/supported by this model.";
109+
}
110+
virtual void set_lm_head(layer::LmHead& head) {
111+
LOG(FATAL)
112+
<< "Method 'set_lm_head' is not implemented/supported by this model.";
113+
}
114+
virtual layer::WordEmbedding get_word_embedding() {
115+
LOG(FATAL) << "Method 'get_word_embedding' is not implemented/supported by "
116+
"this model.";
117+
}
118+
virtual void set_word_embedding(layer::WordEmbedding& embedding) {
119+
LOG(FATAL) << "Method 'set_word_embedding' is not implemented/supported by "
120+
"this model.";
121+
}
72122
};
73123

74124
template <typename Model>
@@ -102,16 +152,36 @@ class CausalLMImpl : public CausalLM {
102152
return model_->update_expert_weight(layer_id);
103153
}
104154

105-
layer::LmHead get_lm_head() override { return model_->get_lm_head(); };
155+
layer::LmHead get_lm_head() override {
156+
if constexpr (detail::has_get_lm_head<Model>::value) {
157+
return model_->get_lm_head();
158+
} else {
159+
return CausalLM::get_lm_head();
160+
}
161+
};
106162

107-
void set_lm_head(layer::LmHead& head) override { model_->set_lm_head(head); };
163+
void set_lm_head(layer::LmHead& head) override {
164+
if constexpr (detail::has_set_lm_head<Model>::value) {
165+
model_->set_lm_head(head);
166+
} else {
167+
CausalLM::set_lm_head(head);
168+
}
169+
};
108170

109171
layer::WordEmbedding get_word_embedding() override {
110-
return model_->get_word_embedding();
172+
if constexpr (detail::has_get_word_embedding<Model>::value) {
173+
return model_->get_word_embedding();
174+
} else {
175+
return CausalLM::get_word_embedding();
176+
}
111177
};
112178

113179
void set_word_embedding(layer::WordEmbedding& embedding) override {
114-
model_->set_word_embedding(embedding);
180+
if constexpr (detail::has_set_word_embedding<Model>::value) {
181+
model_->set_word_embedding(embedding);
182+
} else {
183+
CausalLM::set_word_embedding(embedding);
184+
}
115185
};
116186

117187
torch::Device device() const override { return options_.device(); }

xllm/core/framework/model/causal_vlm.h

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,16 +63,36 @@ class CausalVLMImpl : public CausalVLM {
6363

6464
virtual void update_expert_weight(int32_t layer_id) { return; }
6565

66-
layer::LmHead get_lm_head() override { return model_->get_lm_head(); };
66+
layer::LmHead get_lm_head() override {
67+
if constexpr (detail::has_get_lm_head<Model>::value) {
68+
return model_->get_lm_head();
69+
} else {
70+
return CausalLM::get_lm_head();
71+
}
72+
};
6773

68-
void set_lm_head(layer::LmHead& head) override { model_->set_lm_head(head); };
74+
void set_lm_head(layer::LmHead& head) override {
75+
if constexpr (detail::has_set_lm_head<Model>::value) {
76+
model_->set_lm_head(head);
77+
} else {
78+
CausalLM::set_lm_head(head);
79+
}
80+
};
6981

7082
layer::WordEmbedding get_word_embedding() override {
71-
return model_->get_word_embedding();
83+
if constexpr (detail::has_get_word_embedding<Model>::value) {
84+
return model_->get_word_embedding();
85+
} else {
86+
return CausalLM::get_word_embedding();
87+
}
7288
};
7389

7490
void set_word_embedding(layer::WordEmbedding& embedding) override {
75-
model_->set_word_embedding(embedding);
91+
if constexpr (detail::has_set_word_embedding<Model>::value) {
92+
model_->set_word_embedding(embedding);
93+
} else {
94+
CausalLM::set_word_embedding(embedding);
95+
}
7696
};
7797

7898
torch::Device device() const override { return options_.device(); }

0 commit comments

Comments
 (0)