diff --git a/config/opal_check_ofi.m4 b/config/opal_check_ofi.m4 index ce575e0554e..35d97da66a6 100644 --- a/config/opal_check_ofi.m4 +++ b/config/opal_check_ofi.m4 @@ -148,7 +148,7 @@ AC_DEFUN([OPAL_CHECK_OFI],[ AC_DEFINE_UNQUOTED([OPAL_OFI_HAVE_FI_MR_IFACE], [${opal_check_fi_mr_attr_iface}], - [check if iface avaiable in fi_mr_attr]) + [check if iface available in fi_mr_attr]) AC_CHECK_DECL([FI_HMEM_ROCR], [opal_check_fi_hmem_rocr=1], @@ -157,7 +157,7 @@ AC_DEFUN([OPAL_CHECK_OFI],[ AC_DEFINE_UNQUOTED([OPAL_OFI_HAVE_FI_HMEM_ROCR], [${opal_check_fi_hmem_rocr}], - [check if FI_HMEM_ROCR avaiable in fi_hmem_iface]) + [check if FI_HMEM_ROCR available in fi_hmem_iface]) AC_CHECK_DECL([FI_HMEM_ZE], [opal_check_fi_hmem_ze=1], @@ -166,7 +166,16 @@ AC_DEFUN([OPAL_CHECK_OFI],[ AC_DEFINE_UNQUOTED([OPAL_OFI_HAVE_FI_HMEM_ZE], [${opal_check_fi_hmem_ze}], - [check if FI_HMEM_ZE avaiable in fi_hmem_iface])]) + [check if FI_HMEM_ZE available in fi_hmem_iface]) + + AC_CHECK_DECL([FI_HMEM_DEVICE_ONLY], + [opal_check_fi_hmem_device_only=1], + [opal_check_fi_hmem_device_only=0], + [#include ]) + + AC_DEFINE_UNQUOTED([OPAL_OFI_HAVE_FI_HMEM_DEVICE_ONLY], + [${opal_check_fi_hmem_device_only}], + [check if OPAL_OFI_HAVE_FI_HMEM_DEVICE_ONLY available])]) CPPFLAGS=${opal_check_ofi_save_CPPFLAGS} LDFLAGS=${opal_check_ofi_save_LDFLAGS} diff --git a/ompi/mca/mtl/ofi/mtl_ofi_mr.c b/ompi/mca/mtl/ofi/mtl_ofi_mr.c index 2f39a98ba23..c2a0d1ae01e 100644 --- a/ompi/mca/mtl/ofi/mtl_ofi_mr.c +++ b/ompi/mca/mtl/ofi/mtl_ofi_mr.c @@ -21,7 +21,7 @@ ompi_mtl_ofi_reg_mem(void *reg_data, void *base, size_t size, struct iovec iov = {0}; ompi_mtl_ofi_reg_t *mtl_reg = (ompi_mtl_ofi_reg_t *)reg; int dev_id; - uint64_t flags; + uint64_t flags, mr_flags = 0; iov.iov_base = base; iov.iov_len = size; @@ -41,7 +41,7 @@ ompi_mtl_ofi_reg_mem(void *reg_data, void *base, size_t size, attr.iface = FI_HMEM_CUDA; opal_accelerator.get_device(&attr.device.cuda); #if OPAL_OFI_HAVE_FI_HMEM_ROCR - } else if (0 == strcmp(opal_accelerator_base_selected_component.base_version.mca_component_name, "rocm")) { + } else if (0 == strcmp(opal_accelerator_base_selected_component.base_version.mca_component_name, "rocm")) { attr.iface = FI_HMEM_ROCR; opal_accelerator.get_device(&attr.device.cuda); #endif @@ -53,11 +53,16 @@ ompi_mtl_ofi_reg_mem(void *reg_data, void *base, size_t size, } else { return OPAL_ERROR; } +#if OPAL_OFI_HAVE_FI_HMEM_DEVICE_ONLY + mr_flags = flags & MCA_ACCELERATOR_FLAGS_UNIFIED_MEMORY ? 0 : + FI_HMEM_DEVICE_ONLY; +#endif } } + #endif - ret = fi_mr_regattr(ompi_mtl_ofi.domain, &attr, 0, &mtl_reg->ofi_mr); + ret = fi_mr_regattr(ompi_mtl_ofi.domain, &attr, mr_flags, &mtl_reg->ofi_mr); if (0 != ret) { opal_show_help("help-mtl-ofi.txt", "Buffer Memory Registration Failed", true, opal_accelerator_base_selected_component.base_version.mca_component_name, diff --git a/opal/mca/accelerator/accelerator.h b/opal/mca/accelerator/accelerator.h index 12f025f53c2..17e2a86dc63 100644 --- a/opal/mca/accelerator/accelerator.h +++ b/opal/mca/accelerator/accelerator.h @@ -91,6 +91,7 @@ BEGIN_C_DECLS */ /* Unified memory buffers */ #define MCA_ACCELERATOR_FLAGS_UNIFIED_MEMORY 0x00000001 +#define MCA_ACCELERATOR_FLAGS_DEVICE_ONLY_MEMORY 0x00000002 /** * Transfer types. diff --git a/opal/mca/btl/ofi/btl_ofi_module.c b/opal/mca/btl/ofi/btl_ofi_module.c index 23b0dc7dfe8..696a4614ac9 100644 --- a/opal/mca/btl/ofi/btl_ofi_module.c +++ b/opal/mca/btl/ofi/btl_ofi_module.c @@ -254,7 +254,7 @@ int mca_btl_ofi_reg_mem(void *reg_data, void *base, size_t size, mca_rcache_base_registration_t *reg) { int rc, dev_id; - uint64_t flags; + uint64_t flags, mr_flags = 0; static uint64_t access_flags = FI_REMOTE_WRITE | FI_REMOTE_READ | FI_READ | FI_WRITE; struct fi_mr_attr attr = {0}; struct iovec iov = {0}; @@ -281,7 +281,7 @@ int mca_btl_ofi_reg_mem(void *reg_data, void *base, size_t size, attr.iface = FI_HMEM_CUDA; opal_accelerator.get_device(&attr.device.cuda); #if OPAL_OFI_HAVE_FI_HMEM_ROCR - } else if (0 == strcmp(opal_accelerator_base_selected_component.base_version.mca_component_name, "rocm")) { + } else if (0 == strcmp(opal_accelerator_base_selected_component.base_version.mca_component_name, "rocm")) { attr.iface = FI_HMEM_ROCR; opal_accelerator.get_device(&attr.device.cuda); #endif @@ -293,11 +293,15 @@ int mca_btl_ofi_reg_mem(void *reg_data, void *base, size_t size, } else { return OPAL_ERROR; } +#if OPAL_OFI_HAVE_FI_HMEM_DEVICE_ONLY + mr_flags = flags & MCA_ACCELERATOR_FLAGS_UNIFIED_MEMORY ? 0 : + FI_HMEM_DEVICE_ONLY; +#endif } } #endif - rc = fi_mr_regattr(btl->domain, &attr, 0, &ur->ur_mr); + rc = fi_mr_regattr(btl->domain, &attr, mr_flags, &ur->ur_mr); if (0 != rc) { ur->ur_mr = NULL; return OPAL_ERR_OUT_OF_RESOURCE;