diff --git a/src/perth/config.py b/src/perth/config.py index f067311..85b73cd 100644 --- a/src/perth/config.py +++ b/src/perth/config.py @@ -34,6 +34,10 @@ class Config: }, } + @classmethod + def get_default_models_dir(cls) -> str: + return os.path.join(os.path.dirname(__file__), "perth_net", "pretrained") + def __init__(self, config_path: Optional[str] = None): """ Initialize configuration with default values and optional user config. @@ -47,12 +51,7 @@ def __init__(self, config_path: Optional[str] = None): self._config[section] = values.copy() # Set default models directory - self._config["perth"]["models_dir"] = os.path.join( - os.path.dirname(os.path.dirname(os.path.abspath(__file__))), - "perth", - "perth_net", - "pretrained", - ) + self._config["perth"]["models_dir"] = self.get_default_models_dir() # Load user config if provided if config_path and os.path.exists(config_path): diff --git a/src/perth/perth_net/__init__.py b/src/perth/perth_net/__init__.py index f939c3e..582d6fb 100644 --- a/src/perth/perth_net/__init__.py +++ b/src/perth/perth_net/__init__.py @@ -1,4 +1,4 @@ -from importlib.resources import files - -PREPACKAGED_MODELS_DIR = files(__name__).joinpath("pretrained") from .perth_net_implicit.perth_watermarker import PerthImplicitWatermarker # noqa: E402, F401 +from ..config import Config + +PREPACKAGED_MODELS_DIR = Config.get_default_models_dir() # For backward compatibility diff --git a/src/perth/perth_net/perth_net_implicit/perth_watermarker.py b/src/perth/perth_net/perth_net_implicit/perth_watermarker.py index 34418e7..e0edc84 100644 --- a/src/perth/perth_net/perth_net_implicit/perth_watermarker.py +++ b/src/perth/perth_net/perth_net_implicit/perth_watermarker.py @@ -3,7 +3,6 @@ from librosa import resample from .model.perth_net import PerthNet -from .. import PREPACKAGED_MODELS_DIR from perth.watermarker import WatermarkerBase @@ -17,12 +16,16 @@ class PerthImplicitWatermarker(WatermarkerBase): def __init__( self, run_name: str = "implicit", - models_dir=PREPACKAGED_MODELS_DIR, + models_dir=None, device="cpu", perth_net=None, ): assert (run_name is None) or (perth_net is None) if perth_net is None: + if not models_dir: + from perth.config import Config + + models_dir = Config.get_default_models_dir() self.perth_net = PerthNet.load(run_name, models_dir).to(device) else: self.perth_net = perth_net.to(device)