@@ -3,10 +3,15 @@ package mlx
33import (
44 "context"
55 "errors"
6+ "fmt"
67 "net/http"
8+ "os/exec"
9+ "strings"
710
811 "github.com/docker/model-runner/pkg/inference"
12+ "github.com/docker/model-runner/pkg/inference/backends"
913 "github.com/docker/model-runner/pkg/inference/models"
14+ "github.com/docker/model-runner/pkg/inference/platform"
1015 "github.com/docker/model-runner/pkg/logging"
1116)
1217
@@ -15,19 +20,37 @@ const (
1520 Name = "mlx"
1621)
1722
23+ var ErrStatusNotFound = errors .New ("Python or mlx-lm not found" )
24+
1825// mlx is the MLX-based backend implementation.
1926type mlx struct {
2027 // log is the associated logger.
2128 log logging.Logger
2229 // modelManager is the shared model manager.
2330 modelManager * models.Manager
31+ // serverLog is the logger to use for the MLX server process.
32+ serverLog logging.Logger
33+ // config is the configuration for the MLX backend.
34+ config * Config
35+ // status is the state in which the MLX backend is in.
36+ status string
37+ // pythonPath is the path to the python3 binary.
38+ pythonPath string
2439}
2540
2641// New creates a new MLX-based backend.
27- func New (log logging.Logger , modelManager * models.Manager ) (inference.Backend , error ) {
42+ func New (log logging.Logger , modelManager * models.Manager , serverLog logging.Logger , conf * Config ) (inference.Backend , error ) {
43+ // If no config is provided, use the default configuration
44+ if conf == nil {
45+ conf = NewDefaultMLXConfig ()
46+ }
47+
2848 return & mlx {
2949 log : log ,
3050 modelManager : modelManager ,
51+ serverLog : serverLog ,
52+ config : conf ,
53+ status : "not installed" ,
3154 }, nil
3255}
3356
@@ -38,31 +61,89 @@ func (m *mlx) Name() string {
3861
3962// UsesExternalModelManagement implements
4063// inference.Backend.UsesExternalModelManagement.
41- func (l * mlx ) UsesExternalModelManagement () bool {
64+ func (m * mlx ) UsesExternalModelManagement () bool {
4265 return false
4366}
4467
4568// Install implements inference.Backend.Install.
4669func (m * mlx ) Install (ctx context.Context , httpClient * http.Client ) error {
47- // TODO: Implement.
48- return errors .New ("not implemented" )
70+ if ! platform .SupportsMLX () {
71+ return errors .New ("MLX is only available on macOS ARM64" )
72+ }
73+
74+ // Check if Python 3 is available
75+ pythonPath , err := exec .LookPath ("python3" )
76+ if err != nil {
77+ m .status = ErrStatusNotFound .Error ()
78+ return ErrStatusNotFound
79+ }
80+
81+ // Store the python path for later use
82+ m .pythonPath = pythonPath
83+
84+ // Check if mlx-lm package is installed by attempting to import it
85+ cmd := exec .CommandContext (ctx , pythonPath , "-c" , "import mlx_lm" )
86+ if err := cmd .Run (); err != nil {
87+ m .status = "mlx-lm package not installed"
88+ m .log .Warnf ("mlx-lm package not found. Install with: uv pip install mlx-lm" )
89+ return fmt .Errorf ("mlx-lm package not installed: %w" , err )
90+ }
91+
92+ // Get MLX version
93+ cmd = exec .CommandContext (ctx , pythonPath , "-c" , "import mlx; print(mlx.__version__)" )
94+ output , err := cmd .Output ()
95+ if err != nil {
96+ m .log .Warnf ("could not get MLX version: %v" , err )
97+ m .status = "running MLX version: unknown"
98+ } else {
99+ m .status = fmt .Sprintf ("running MLX version: %s" , strings .TrimSpace (string (output )))
100+ }
101+
102+ return nil
49103}
50104
51105// Run implements inference.Backend.Run.
52- func (m * mlx ) Run (ctx context.Context , socket , model string , modelRef string , mode inference.BackendMode , config * inference.BackendConfiguration ) error {
53- // TODO: Implement.
54- m .log .Warn ("MLX backend is not yet supported" )
55- return errors .New ("not implemented" )
106+ func (m * mlx ) Run (ctx context.Context , socket , model string , modelRef string , mode inference.BackendMode , backendConfig * inference.BackendConfiguration ) error {
107+ bundle , err := m .modelManager .GetBundle (model )
108+ if err != nil {
109+ return fmt .Errorf ("failed to get model: %w" , err )
110+ }
111+
112+ args , err := m .config .GetArgs (bundle , socket , mode , backendConfig )
113+ if err != nil {
114+ return fmt .Errorf ("failed to get MLX arguments: %w" , err )
115+ }
116+
117+ // Add served model name
118+ args = append (args , "--served-model-name" , model , modelRef )
119+
120+ return backends .RunBackend (ctx , backends.RunnerConfig {
121+ BackendName : "MLX" ,
122+ Socket : socket ,
123+ BinaryPath : m .pythonPath ,
124+ SandboxPath : "" ,
125+ SandboxConfig : "" ,
126+ Args : args ,
127+ Logger : m .log ,
128+ ServerLogWriter : m .serverLog .Writer (),
129+ })
56130}
57131
58132func (m * mlx ) Status () string {
59- return "not running"
133+ return m . status
60134}
61135
62136func (m * mlx ) GetDiskUsage () (int64 , error ) {
137+ // MLX doesn't have a dedicated installation directory
138+ // It's installed via pip in the system Python environment
63139 return 0 , nil
64140}
65141
66142func (m * mlx ) GetRequiredMemoryForModel (ctx context.Context , model string , config * inference.BackendConfiguration ) (inference.RequiredMemory , error ) {
143+ // TODO: Implement accurate memory estimation based on model size.
144+ // MLX runs on unified memory architecture (Apple Silicon), so memory estimation
145+ // will need to account for the unified nature of RAM and VRAM on Apple Silicon.
146+ // Returning an error prevents the scheduler from making incorrect decisions based
147+ // on placeholder values.
67148 return inference.RequiredMemory {}, errors .New ("not implemented" )
68149}
0 commit comments