Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 107 additions & 88 deletions contrib/cdisetup/nvidia/nvidia.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,13 @@ import (
// This is example of experimental on-demand setup of a CDI devices.
// This code is not currently shipping with BuildKit and will probably change.

const (
cdiKind = "nvidia.com/gpu"
defaultVersion = "570.0"
)
const cdiKind = "nvidia.com/gpu"

// https://github.com/ollama/ollama/blob/b816ff86c923e0290f58f2275e831fc17c29ba37/discover/gpu_linux.go#L33-L43
var libcudaGlobs = []string{
"/usr/lib/*-linux-gnu/libcuda.so*",
"/usr/lib/wsl/drivers/*/libcuda.so*",
}

func init() {
cdidevices.Register(cdiKind, &setup{})
Expand All @@ -39,8 +42,7 @@ type setup struct{}
var _ cdidevices.Setup = &setup{}

func (s *setup) Validate() error {
_, err := readVersion()
if err == nil {
if _, err := readVersion(); err == nil {
return nil
}
b, err := hasNvidiaDevices()
Expand Down Expand Up @@ -93,55 +95,94 @@ func (s *setup) Run(ctx context.Context) (err error) {
return errors.Errorf("NVIDIA setup is currently only supported on Debian/Ubuntu")
}

var needsDriver bool

if _, err := os.Stat("/proc/driver/nvidia"); err != nil {
needsDriver = true
needsDriver := true
if _, err := os.Stat("/proc/driver/nvidia"); err == nil {
needsDriver = false
} else if nvidiaSmi, err := exec.LookPath("nvidia-smi"); err == nil && nvidiaSmi != "" {
if err := run(ctx, []string{nvidiaSmi, "-L"}, pw, dgst); err == nil {
needsDriver = false
}
}
if needsDriver {
if hasWSLGPU() {
return errors.Errorf("NVIDIA drivers are required for WSL with non PCI-based GPUs")
}
return errors.Errorf("NVIDIA drivers are required. Try loading NVIDIA kernel module with \"modprobe nvidia\" command")
}

var arch string
switch runtime.GOARCH {
case "amd64":
arch = "x86_64"
case "arm64":
arch = "sbsa"
// for non-sbsa could use https://nvidia.github.io/libnvidia-container/stable/deb
var dv string
if !hasLibsInstalled() && !hasWSLGPU() {
version, err := readVersion()
if err != nil {
return errors.Wrapf(err, "failed to read NVIDIA driver version")
}
var ok bool
dv, _, ok = strings.Cut(version, ".")
if !ok {
return errors.Errorf("failed to parse NVIDIA driver version %q", version)
}
}

if arch == "" {
return errors.Errorf("unsupported architecture: %s", runtime.GOARCH)
if err := run(ctx, []string{"apt-get", "update"}, pw, dgst); err != nil {
return err
}

if needsDriver {
pw.Write(identity.NewID(), client.VertexWarning{
Vertex: dgst,
Short: []byte("NVIDIA Drivers not found. Installing prebuilt drivers is not recommended"),
})
if err := run(ctx, []string{"apt-get", "install", "-y", "gpg"}, pw, dgst); err != nil {
return err
}

version, err := readVersion()
if err != nil && !needsDriver {
return errors.Wrapf(err, "failed to read NVIDIA driver version")
if err := installPackages(ctx, dv, pw, dgst); err != nil {
return err
}
if version == "" {
version = defaultVersion

if err := os.MkdirAll("/etc/cdi", 0700); err != nil {
return errors.Wrapf(err, "failed to create /etc/cdi")
}
v1, _, ok := strings.Cut(version, ".")
if !ok {
return errors.Errorf("failed to parse NVIDIA driver version %q", version)

buf := &bytes.Buffer{}

cmd := exec.CommandContext(ctx, "nvidia-ctk", "cdi", "generate")
cmd.Stdout = buf
cmd.Stderr = newStream(pw, 2, dgst)
if err := cmd.Run(); err != nil {
return errors.Wrapf(err, "failed to generate CDI spec")
}

if err := run(ctx, []string{"apt-get", "update"}, pw, dgst); err != nil {
return err
if len(buf.Bytes()) == 0 {
return errors.Errorf("nvidia-ctk output is empty")
}

if err := run(ctx, []string{"apt-get", "install", "-y", "gpg"}, pw, dgst); err != nil {
return err
if err := os.WriteFile("/etc/cdi/nvidia.yaml", buf.Bytes(), 0644); err != nil {
return errors.Wrapf(err, "failed to write /etc/cdi/nvidia.yaml")
}

return nil
}

func run(ctx context.Context, args []string, pw progress.Writer, dgst digest.Digest) error {
fmt.Fprintf(newStream(pw, 2, dgst), "> %s\n", strings.Join(args, " "))
cmd := exec.CommandContext(ctx, args[0], args[1:]...) //nolint:gosec
cmd.Stderr = newStream(pw, 2, dgst)
cmd.Stdout = newStream(pw, 1, dgst)
return cmd.Run()
}

func installPackages(ctx context.Context, dv string, pw progress.Writer, dgst digest.Digest) error {
const aptDistro = "ubuntu2404"
aptURL := "https://developer.download.nvidia.com/compute/cuda/repos/" + aptDistro + "/" + arch + "/"

var arch string
switch runtime.GOARCH {
case "amd64":
arch = "x86_64"
case "arm64":
arch = "sbsa"
// for non-sbsa could use https://nvidia.github.io/libnvidia-container/stable/deb
}
if arch == "" {
return errors.Errorf("unsupported architecture: %s", runtime.GOARCH)
}

aptURL := "https://developer.download.nvidia.com/compute/cuda/repos/" + aptDistro + "/" + arch + "/"
keyTarget := "/usr/share/keyrings/nvidia-cuda-keyring.gpg"

if _, err := os.Stat(keyTarget); err != nil {
Expand Down Expand Up @@ -174,59 +215,17 @@ func (s *setup) Run(ctx context.Context) (err error) {
return err
}

if needsDriver {
// this pretty much never works, is it even worth having?
// better approach could be to try to create another chroot/container that is built with same kernel packages as the host
// could nvidia-headless-no-dkms- be reusable
if err := run(ctx, []string{"apt-get", "install", "-y", "nvidia-driver-" + v1}, pw, dgst); err != nil {
return err
}
_, err := os.Stat("/proc/driver/nvidia")
if err != nil {
return errors.Wrapf(err, "failed to install NVIDIA kernel module. Please install NVIDIA drivers manually")
}
}

if err := run(ctx, []string{"apt-get", "install", "-y", "--no-install-recommends",
"libnvidia-compute-" + v1,
"libnvidia-extra-" + v1,
"libnvidia-gl-" + v1,
"nvidia-utils-" + v1,
"nvidia-container-toolkit-base",
}, pw, dgst); err != nil {
return err
}

if err := os.MkdirAll("/etc/cdi", 0700); err != nil {
return errors.Wrapf(err, "failed to create /etc/cdi")
pkgs := []string{"nvidia-container-toolkit-base"}
if dv != "" {
pkgs = append(pkgs, []string{
"libnvidia-compute-" + dv,
"libnvidia-extra-" + dv,
"libnvidia-gl-" + dv,
"nvidia-utils-" + dv,
}...)
}

buf := &bytes.Buffer{}

cmd := exec.CommandContext(ctx, "nvidia-ctk", "cdi", "generate")
cmd.Stdout = buf
cmd.Stderr = newStream(pw, 2, dgst)
if err := cmd.Run(); err != nil {
return errors.Wrapf(err, "failed to generate CDI spec")
}

if len(buf.Bytes()) == 0 {
return errors.Errorf("nvidia-ctk output is empty")
}

if err := os.WriteFile("/etc/cdi/nvidia.yaml", buf.Bytes(), 0644); err != nil {
return errors.Wrapf(err, "failed to write /etc/cdi/nvidia.yaml")
}

return nil
}

func run(ctx context.Context, args []string, pw progress.Writer, dgst digest.Digest) error {
fmt.Fprintf(newStream(pw, 2, dgst), "> %s\n", strings.Join(args, " "))
cmd := exec.CommandContext(ctx, args[0], args[1:]...) //nolint:gosec
cmd.Stderr = newStream(pw, 2, dgst)
cmd.Stdout = newStream(pw, 1, dgst)
return cmd.Run()
return run(ctx, append([]string{"apt-get", "install", "-y", "--no-install-recommends"}, pkgs...), pw, dgst)
}

func readVersion() (string, error) {
Expand Down Expand Up @@ -268,6 +267,10 @@ func hasNvidiaDevices() (bool, error) {
}
}

if !found {
found = hasWSLGPU()
}

return found, nil
}

Expand Down Expand Up @@ -302,3 +305,19 @@ func isDebianOrUbuntu() (bool, error) {

return id == "debian" || id == "ubuntu", nil
}

func hasWSLGPU() bool {
// WSL-specific GPU mapping that doesn't expose PCI info.
_, err := os.Stat("/dev/dxg")
return err == nil
}

func hasLibsInstalled() bool {
// Check for libcuda in the standard locations to confirm NVIDIA GPU drivers
for _, p := range libcudaGlobs {
if matches, err := filepath.Glob(p); err == nil && len(matches) > 0 {
return true
}
}
return false
}