@@ -36,6 +36,43 @@ limitations under the License.
3636
3737namespace xllm {
3838
39+ namespace detail {
40+ template <typename T, typename = void >
41+ struct has_get_lm_head : std::false_type {};
42+
43+ template <typename T>
44+ struct has_get_lm_head <T,
45+ std::void_t <decltype (std::declval<T>().get_lm_head())>>
46+ : std::true_type {};
47+
48+ template <typename T, typename = void >
49+ struct has_set_lm_head : std::false_type {};
50+
51+ template <typename T>
52+ struct has_set_lm_head <T,
53+ std::void_t <decltype (std::declval<T>().set_lm_head(
54+ std::declval<layer::LmHead&>()))>> : std::true_type {
55+ };
56+
57+ template <typename T, typename = void >
58+ struct has_get_word_embedding : std::false_type {};
59+
60+ template <typename T>
61+ struct has_get_word_embedding <
62+ T,
63+ std::void_t <decltype (std::declval<T>().get_word_embedding())>>
64+ : std::true_type {};
65+
66+ template <typename T, typename = void >
67+ struct has_set_word_embedding : std::false_type {};
68+
69+ template <typename T>
70+ struct has_set_word_embedding <
71+ T,
72+ std::void_t <decltype (std::declval<T>().set_word_embedding(
73+ std::declval<layer::WordEmbedding&>()))>> : std::true_type {};
74+ } // namespace detail
75+
3976class CausalLM : public torch ::nn::Module {
4077 public:
4178 ~CausalLM () override = default ;
@@ -65,10 +102,23 @@ class CausalLM : public torch::nn::Module {
65102
66103 virtual const torch::TensorOptions& options () const = 0;
67104
68- virtual layer::LmHead get_lm_head () = 0;
69- virtual void set_lm_head (layer::LmHead& head) = 0;
70- virtual layer::WordEmbedding get_word_embedding () = 0;
71- virtual void set_word_embedding (layer::WordEmbedding& embedding) = 0;
105+ // MTP-specific interface.
106+ virtual layer::LmHead get_lm_head () {
107+ LOG (FATAL)
108+ << " Method 'get_lm_head' is not implemented/supported by this model." ;
109+ }
110+ virtual void set_lm_head (layer::LmHead& head) {
111+ LOG (FATAL)
112+ << " Method 'set_lm_head' is not implemented/supported by this model." ;
113+ }
114+ virtual layer::WordEmbedding get_word_embedding () {
115+ LOG (FATAL) << " Method 'get_word_embedding' is not implemented/supported by "
116+ " this model." ;
117+ }
118+ virtual void set_word_embedding (layer::WordEmbedding& embedding) {
119+ LOG (FATAL) << " Method 'set_word_embedding' is not implemented/supported by "
120+ " this model." ;
121+ }
72122};
73123
74124template <typename Model>
@@ -102,16 +152,36 @@ class CausalLMImpl : public CausalLM {
102152 return model_->update_expert_weight (layer_id);
103153 }
104154
105- layer::LmHead get_lm_head () override { return model_->get_lm_head (); };
155+ layer::LmHead get_lm_head () override {
156+ if constexpr (detail::has_get_lm_head<Model>::value) {
157+ return model_->get_lm_head ();
158+ } else {
159+ return CausalLM::get_lm_head ();
160+ }
161+ };
106162
107- void set_lm_head (layer::LmHead& head) override { model_->set_lm_head (head); };
163+ void set_lm_head (layer::LmHead& head) override {
164+ if constexpr (detail::has_set_lm_head<Model>::value) {
165+ model_->set_lm_head (head);
166+ } else {
167+ CausalLM::set_lm_head (head);
168+ }
169+ };
108170
109171 layer::WordEmbedding get_word_embedding () override {
110- return model_->get_word_embedding ();
172+ if constexpr (detail::has_get_word_embedding<Model>::value) {
173+ return model_->get_word_embedding ();
174+ } else {
175+ return CausalLM::get_word_embedding ();
176+ }
111177 };
112178
113179 void set_word_embedding (layer::WordEmbedding& embedding) override {
114- model_->set_word_embedding (embedding);
180+ if constexpr (detail::has_set_word_embedding<Model>::value) {
181+ model_->set_word_embedding (embedding);
182+ } else {
183+ CausalLM::set_word_embedding (embedding);
184+ }
115185 };
116186
117187 torch::Device device () const override { return options_.device (); }
0 commit comments