Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion cmd/cli/desktop/desktop.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"io"
"net/http"
"net/url"
"os"
"strconv"
"strings"
"time"
Expand Down Expand Up @@ -107,8 +108,18 @@ func (c *Client) Status() Status {
func (c *Client) Pull(model string, ignoreRuntimeMemoryCheck bool, printer standalone.StatusPrinter) (string, bool, error) {
model = normalizeHuggingFaceModelName(model)

// Check if this is a Hugging Face model and if HF_TOKEN is set
var hfToken string
if strings.HasPrefix(strings.ToLower(model), "hf.co/") {
hfToken = os.Getenv("HF_TOKEN")
}

return c.withRetries("download", 3, printer, func(attempt int) (string, bool, error, bool) {
jsonData, err := json.Marshal(dmrm.ModelCreateRequest{From: model, IgnoreRuntimeMemoryCheck: ignoreRuntimeMemoryCheck})
jsonData, err := json.Marshal(dmrm.ModelCreateRequest{
From: model,
IgnoreRuntimeMemoryCheck: ignoreRuntimeMemoryCheck,
BearerToken: hfToken,
})
if err != nil {
// Marshaling errors are not retryable
return "", false, fmt.Errorf("error marshaling request: %w", err), false
Expand Down
15 changes: 12 additions & 3 deletions pkg/distribution/distribution/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/docker/model-runner/pkg/distribution/registry"
"github.com/docker/model-runner/pkg/distribution/tarball"
"github.com/docker/model-runner/pkg/distribution/types"
"github.com/docker/model-runner/pkg/go-containerregistry/pkg/authn"
"github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1/remote"
"github.com/docker/model-runner/pkg/inference/platform"
)
Expand Down Expand Up @@ -138,11 +139,19 @@ func NewClient(opts ...Option) (*Client, error) {
}

// PullModel pulls a model from a registry and returns the local file path
func (c *Client) PullModel(ctx context.Context, reference string, progressWriter io.Writer) error {
func (c *Client) PullModel(ctx context.Context, reference string, progressWriter io.Writer, bearerToken ...string) error {
c.log.Infoln("Starting model pull:", utils.SanitizeForLog(reference))

// Use the client's registry, or create a temporary one if bearer token is provided
registryClient := c.registry
if len(bearerToken) > 0 && bearerToken[0] != "" {
// Create a temporary registry client with bearer token authentication
auth := &authn.Bearer{Token: bearerToken[0]}
registryClient = registry.FromClient(c.registry, registry.WithAuth(auth))
}

// First, fetch the remote model to get the manifest
remoteModel, err := c.registry.Model(ctx, reference)
remoteModel, err := registryClient.Model(ctx, reference)
if err != nil {
return fmt.Errorf("reading model from registry: %w", err)
}
Expand Down Expand Up @@ -214,7 +223,7 @@ func (c *Client) PullModel(ctx context.Context, reference string, progressWriter
}
digestReference := repository + "@" + remoteDigest.String()
c.log.Infof("Re-fetching model with digest reference: %s", utils.SanitizeForLog(digestReference))
remoteModel, err = c.registry.Model(ctx, digestReference)
remoteModel, err = registryClient.Model(ctx, digestReference)
if err != nil {
return fmt.Errorf("reading model from registry with resume context: %w", err)
}
Expand Down
24 changes: 24 additions & 0 deletions pkg/distribution/registry/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,15 @@ func WithAuthConfig(username, password string) ClientOption {
}
}

// WithAuth sets a custom authenticator.
func WithAuth(auth authn.Authenticator) ClientOption {
return func(c *Client) {
if auth != nil {
c.auth = auth
}
}
}

func NewClient(opts ...ClientOption) *Client {
client := &Client{
transport: remote.DefaultTransport,
Expand All @@ -97,6 +106,21 @@ func NewClient(opts ...ClientOption) *Client {
return client
}

// FromClient creates a new Client by copying an existing client's configuration
// and applying optional modifications via ClientOption functions.
func FromClient(base *Client, opts ...ClientOption) *Client {
client := &Client{
transport: base.transport,
userAgent: base.userAgent,
keychain: base.keychain,
auth: base.auth,
}
for _, opt := range opts {
opt(client)
}
return client
}

func (c *Client) Model(ctx context.Context, reference string) (types.ModelArtifact, error) {
// Parse the reference
ref, err := name.ParseReference(reference, GetDefaultRegistryOptions()...)
Expand Down
2 changes: 2 additions & 0 deletions pkg/inference/models/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ type ModelCreateRequest struct {
// IgnoreRuntimeMemoryCheck indicates whether the server should check if it has sufficient
// memory to run the given model (assuming default configuration).
IgnoreRuntimeMemoryCheck bool `json:"ignore-runtime-memory-check,omitempty"`
// BearerToken is an optional bearer token for authentication.
BearerToken string `json:"bearer-token,omitempty"`
}

// ModelPackageRequest represents a model package request, which creates a new model
Expand Down
35 changes: 22 additions & 13 deletions pkg/inference/models/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ import (
"github.com/docker/model-runner/pkg/distribution/distribution"
"github.com/docker/model-runner/pkg/distribution/registry"
"github.com/docker/model-runner/pkg/distribution/types"
v1 "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1"
"github.com/docker/model-runner/pkg/inference"
"github.com/docker/model-runner/pkg/inference/memory"
"github.com/docker/model-runner/pkg/logging"
"github.com/docker/model-runner/pkg/middleware"
v1 "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1"
"github.com/sirupsen/logrus"
)

Expand Down Expand Up @@ -221,7 +221,7 @@ func (m *Manager) handleCreateModel(w http.ResponseWriter, r *http.Request) {
return
}
}
if err := m.PullModel(request.From, r, w); err != nil {
if err := m.PullModel(request.From, request.BearerToken, r, w); err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
m.log.Infof("Request canceled/timed out while pulling model %q", request.From)
return
Expand Down Expand Up @@ -881,15 +881,15 @@ func (m *Manager) GetModels() ([]*Model, error) {
return nil, fmt.Errorf("error while listing models: %w", err)
}

apiModels := make([]*Model, 0, len(models))
for _, model := range models {
apiModel, err := ToModel(model)
if err != nil {
m.log.Warnf("error while converting model, skipping: %v", err)
continue
}
apiModels = append(apiModels, apiModel)
}
apiModels := make([]*Model, 0, len(models))
for _, model := range models {
apiModel, err := ToModel(model)
if err != nil {
m.log.Warnf("error while converting model, skipping: %v", err)
continue
}
apiModels = append(apiModels, apiModel)
}

return apiModels, nil
}
Expand Down Expand Up @@ -941,7 +941,7 @@ func (m *Manager) GetBundle(ref string) (types.ModelBundle, error) {

// PullModel pulls a model to local storage. Any error it returns is suitable
// for writing back to the client.
func (m *Manager) PullModel(model string, r *http.Request, w http.ResponseWriter) error {
func (m *Manager) PullModel(model string, bearerToken string, r *http.Request, w http.ResponseWriter) error {
// Restrict model pull concurrency.
select {
case <-m.pullTokens:
Expand Down Expand Up @@ -983,7 +983,16 @@ func (m *Manager) PullModel(model string, r *http.Request, w http.ResponseWriter

// Pull the model using the Docker model distribution client
m.log.Infoln("Pulling model:", model)
err := m.distributionClient.PullModel(r.Context(), model, progressWriter)

// Use bearer token if provided
var err error
if bearerToken != "" {
m.log.Infoln("Using provided bearer token for authentication")
err = m.distributionClient.PullModel(r.Context(), model, progressWriter, bearerToken)
} else {
err = m.distributionClient.PullModel(r.Context(), model, progressWriter)
}
Comment on lines +989 to +995
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This if/else block can be simplified. The distribution.Client.PullModel function takes a variadic bearerToken ...string argument. Calling it with an empty string for bearerToken is handled correctly by the receiver and has the same effect as calling it without the variadic argument. You can remove the else block and always pass the bearerToken.

        if bearerToken != "" {
			m.log.Infoln("Using provided bearer token for authentication")
		}
		err = m.distributionClient.PullModel(r.Context(), model, progressWriter, bearerToken)


if err != nil {
return fmt.Errorf("error while pulling model: %w", err)
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/inference/models/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ func TestPullModel(t *testing.T) {
}

w := httptest.NewRecorder()
err = m.PullModel(tag, r, w)
err = m.PullModel(tag, "", r, w)
if err != nil {
t.Fatalf("Failed to pull model: %v", err)
}
Expand Down Expand Up @@ -246,7 +246,7 @@ func TestHandleGetModel(t *testing.T) {
if !tt.remote && !strings.Contains(tt.modelName, "nonexistent") {
r := httptest.NewRequest("POST", "/models/create", strings.NewReader(`{"from": "`+tt.modelName+`"}`))
w := httptest.NewRecorder()
err = m.PullModel(tt.modelName, r, w)
err = m.PullModel(tt.modelName, "", r, w)
if err != nil {
t.Fatalf("Failed to pull model: %v", err)
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/ollama/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@ func (h *Handler) handlePull(w http.ResponseWriter, r *http.Request) {
r.Header.Set("Accept", "application/json")

// Call the model manager's PullModel method
if err := h.modelManager.PullModel(modelName, r, w); err != nil {
if err := h.modelManager.PullModel(modelName, "", r, w); err != nil {
h.log.Errorf("Failed to pull model: %v", err)
// Only write error if headers haven't been sent yet
if !isHeadersSent(w) {
Expand Down
Loading