Skip to content

Commit acf63aa

Browse files
authored
Merge pull request #425 from docker/enhance-mlx
Take MLX support one step further
2 parents f523abe + a4ef4cc commit acf63aa

File tree

5 files changed

+426
-10
lines changed

5 files changed

+426
-10
lines changed

main.go

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"github.com/docker/model-runner/pkg/gpuinfo"
1414
"github.com/docker/model-runner/pkg/inference"
1515
"github.com/docker/model-runner/pkg/inference/backends/llamacpp"
16+
"github.com/docker/model-runner/pkg/inference/backends/mlx"
1617
"github.com/docker/model-runner/pkg/inference/backends/vllm"
1718
"github.com/docker/model-runner/pkg/inference/config"
1819
"github.com/docker/model-runner/pkg/inference/memory"
@@ -131,9 +132,23 @@ func main() {
131132
log.Fatalf("unable to initialize %s backend: %v", vllm.Name, err)
132133
}
133134

135+
mlxBackend, err := mlx.New(
136+
log,
137+
modelManager,
138+
log.WithFields(logrus.Fields{"component": mlx.Name}),
139+
nil,
140+
)
141+
if err != nil {
142+
log.Fatalf("unable to initialize %s backend: %v", mlx.Name, err)
143+
}
144+
134145
scheduler := scheduling.NewScheduler(
135146
log,
136-
map[string]inference.Backend{llamacpp.Name: llamaCppBackend, vllm.Name: vllmBackend},
147+
map[string]inference.Backend{
148+
llamacpp.Name: llamaCppBackend,
149+
vllm.Name: vllmBackend,
150+
mlx.Name: mlxBackend,
151+
},
137152
llamaCppBackend,
138153
modelManager,
139154
http.DefaultClient,

pkg/inference/backends/mlx/mlx.go

Lines changed: 90 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,15 @@ package mlx
33
import (
44
"context"
55
"errors"
6+
"fmt"
67
"net/http"
8+
"os/exec"
9+
"strings"
710

811
"github.com/docker/model-runner/pkg/inference"
12+
"github.com/docker/model-runner/pkg/inference/backends"
913
"github.com/docker/model-runner/pkg/inference/models"
14+
"github.com/docker/model-runner/pkg/inference/platform"
1015
"github.com/docker/model-runner/pkg/logging"
1116
)
1217

@@ -15,19 +20,37 @@ const (
1520
Name = "mlx"
1621
)
1722

23+
var ErrStatusNotFound = errors.New("Python or mlx-lm not found")
24+
1825
// mlx is the MLX-based backend implementation.
1926
type mlx struct {
2027
// log is the associated logger.
2128
log logging.Logger
2229
// modelManager is the shared model manager.
2330
modelManager *models.Manager
31+
// serverLog is the logger to use for the MLX server process.
32+
serverLog logging.Logger
33+
// config is the configuration for the MLX backend.
34+
config *Config
35+
// status is the state in which the MLX backend is in.
36+
status string
37+
// pythonPath is the path to the python3 binary.
38+
pythonPath string
2439
}
2540

2641
// New creates a new MLX-based backend.
27-
func New(log logging.Logger, modelManager *models.Manager) (inference.Backend, error) {
42+
func New(log logging.Logger, modelManager *models.Manager, serverLog logging.Logger, conf *Config) (inference.Backend, error) {
43+
// If no config is provided, use the default configuration
44+
if conf == nil {
45+
conf = NewDefaultMLXConfig()
46+
}
47+
2848
return &mlx{
2949
log: log,
3050
modelManager: modelManager,
51+
serverLog: serverLog,
52+
config: conf,
53+
status: "not installed",
3154
}, nil
3255
}
3356

@@ -38,31 +61,89 @@ func (m *mlx) Name() string {
3861

3962
// UsesExternalModelManagement implements
4063
// inference.Backend.UsesExternalModelManagement.
41-
func (l *mlx) UsesExternalModelManagement() bool {
64+
func (m *mlx) UsesExternalModelManagement() bool {
4265
return false
4366
}
4467

4568
// Install implements inference.Backend.Install.
4669
func (m *mlx) Install(ctx context.Context, httpClient *http.Client) error {
47-
// TODO: Implement.
48-
return errors.New("not implemented")
70+
if !platform.SupportsMLX() {
71+
return errors.New("MLX is only available on macOS ARM64")
72+
}
73+
74+
// Check if Python 3 is available
75+
pythonPath, err := exec.LookPath("python3")
76+
if err != nil {
77+
m.status = ErrStatusNotFound.Error()
78+
return ErrStatusNotFound
79+
}
80+
81+
// Store the python path for later use
82+
m.pythonPath = pythonPath
83+
84+
// Check if mlx-lm package is installed by attempting to import it
85+
cmd := exec.CommandContext(ctx, pythonPath, "-c", "import mlx_lm")
86+
if err := cmd.Run(); err != nil {
87+
m.status = "mlx-lm package not installed"
88+
m.log.Warnf("mlx-lm package not found. Install with: uv pip install mlx-lm")
89+
return fmt.Errorf("mlx-lm package not installed: %w", err)
90+
}
91+
92+
// Get MLX version
93+
cmd = exec.CommandContext(ctx, pythonPath, "-c", "import mlx; print(mlx.__version__)")
94+
output, err := cmd.Output()
95+
if err != nil {
96+
m.log.Warnf("could not get MLX version: %v", err)
97+
m.status = "running MLX version: unknown"
98+
} else {
99+
m.status = fmt.Sprintf("running MLX version: %s", strings.TrimSpace(string(output)))
100+
}
101+
102+
return nil
49103
}
50104

51105
// Run implements inference.Backend.Run.
52-
func (m *mlx) Run(ctx context.Context, socket, model string, modelRef string, mode inference.BackendMode, config *inference.BackendConfiguration) error {
53-
// TODO: Implement.
54-
m.log.Warn("MLX backend is not yet supported")
55-
return errors.New("not implemented")
106+
func (m *mlx) Run(ctx context.Context, socket, model string, modelRef string, mode inference.BackendMode, backendConfig *inference.BackendConfiguration) error {
107+
bundle, err := m.modelManager.GetBundle(model)
108+
if err != nil {
109+
return fmt.Errorf("failed to get model: %w", err)
110+
}
111+
112+
args, err := m.config.GetArgs(bundle, socket, mode, backendConfig)
113+
if err != nil {
114+
return fmt.Errorf("failed to get MLX arguments: %w", err)
115+
}
116+
117+
// Add served model name
118+
args = append(args, "--served-model-name", model, modelRef)
119+
120+
return backends.RunBackend(ctx, backends.RunnerConfig{
121+
BackendName: "MLX",
122+
Socket: socket,
123+
BinaryPath: m.pythonPath,
124+
SandboxPath: "",
125+
SandboxConfig: "",
126+
Args: args,
127+
Logger: m.log,
128+
ServerLogWriter: m.serverLog.Writer(),
129+
})
56130
}
57131

58132
func (m *mlx) Status() string {
59-
return "not running"
133+
return m.status
60134
}
61135

62136
func (m *mlx) GetDiskUsage() (int64, error) {
137+
// MLX doesn't have a dedicated installation directory
138+
// It's installed via pip in the system Python environment
63139
return 0, nil
64140
}
65141

66142
func (m *mlx) GetRequiredMemoryForModel(ctx context.Context, model string, config *inference.BackendConfiguration) (inference.RequiredMemory, error) {
143+
// TODO: Implement accurate memory estimation based on model size.
144+
// MLX runs on unified memory architecture (Apple Silicon), so memory estimation
145+
// will need to account for the unified nature of RAM and VRAM on Apple Silicon.
146+
// Returning an error prevents the scheduler from making incorrect decisions based
147+
// on placeholder values.
67148
return inference.RequiredMemory{}, errors.New("not implemented")
68149
}
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
package mlx
2+
3+
import (
4+
"fmt"
5+
"path/filepath"
6+
"strconv"
7+
8+
"github.com/docker/model-runner/pkg/distribution/types"
9+
"github.com/docker/model-runner/pkg/inference"
10+
)
11+
12+
// Config is the configuration for the MLX backend.
13+
type Config struct {
14+
// Args are the base arguments that are always included.
15+
Args []string
16+
}
17+
18+
// NewDefaultMLXConfig creates a new MLXConfig with default values.
19+
func NewDefaultMLXConfig() *Config {
20+
return &Config{
21+
Args: []string{},
22+
}
23+
}
24+
25+
// GetArgs implements BackendConfig.GetArgs.
26+
func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference.BackendMode, config *inference.BackendConfiguration) ([]string, error) {
27+
// Start with the arguments from MLXConfig
28+
args := append([]string{}, c.Args...)
29+
30+
// MLX uses Python module: python -m mlx_lm.server
31+
args = append(args, "-m", "mlx_lm.server")
32+
33+
// Add model path (MLX works with safetensors format)
34+
safetensorsPath := bundle.SafetensorsPath()
35+
if safetensorsPath == "" {
36+
return nil, fmt.Errorf("safetensors path required by MLX backend")
37+
}
38+
modelPath := filepath.Dir(safetensorsPath)
39+
40+
// Add model and socket arguments
41+
args = append(args, "--model", modelPath, "--host", socket)
42+
43+
// Add mode-specific arguments
44+
switch mode {
45+
case inference.BackendModeCompletion:
46+
// Default mode for MLX
47+
case inference.BackendModeEmbedding:
48+
// MLX doesn't have a specific embedding flag - embedding models are detected automatically
49+
case inference.BackendModeReranking:
50+
// MLX may not support reranking mode
51+
return nil, fmt.Errorf("reranking mode not supported by MLX backend")
52+
default:
53+
return nil, fmt.Errorf("unsupported backend mode %q", mode)
54+
}
55+
56+
// Add max-tokens if specified in model config or backend config
57+
if maxLen := GetMaxTokens(bundle.RuntimeConfig(), config); maxLen != nil {
58+
args = append(args, "--max-tokens", strconv.FormatUint(*maxLen, 10))
59+
}
60+
61+
// Add arguments from backend config
62+
if config != nil {
63+
args = append(args, config.RuntimeFlags...)
64+
}
65+
66+
return args, nil
67+
}
68+
69+
// GetMaxTokens returns the max tokens (context size) from model config or backend config.
70+
// Model config takes precedence over backend config.
71+
// Returns nil if neither is specified (MLX will use model defaults).
72+
func GetMaxTokens(modelCfg types.Config, backendCfg *inference.BackendConfiguration) *uint64 {
73+
// Model config takes precedence
74+
if modelCfg.ContextSize != nil {
75+
return modelCfg.ContextSize
76+
}
77+
// else use backend config
78+
if backendCfg != nil && backendCfg.ContextSize > 0 {
79+
val := uint64(backendCfg.ContextSize)
80+
return &val
81+
}
82+
// Return nil to let MLX use model defaults
83+
return nil
84+
}

0 commit comments

Comments
 (0)