diff --git a/src/specify_cli/__init__.py b/src/specify_cli/__init__.py index d3eb36391e..efc8f3424d 100644 --- a/src/specify_cli/__init__.py +++ b/src/specify_cli/__init__.py @@ -4874,6 +4874,10 @@ def extension_update( failed_updates = [] registrar = CommandRegistrar() hook_executor = HookExecutor(project_root) + from .agents import CommandRegistrar as _AgentReg # used in backup and rollback paths + + # UNSET sentinel: backup not yet captured (exception before backup step) + UNSET = object() for update in updates_available: extension_id = update["id"] @@ -4887,8 +4891,9 @@ def extension_update( backup_config_dir = backup_base / "config" # Store backup state - backup_registry_entry = None - backup_hooks = None # None means no hooks key in config; {} means hooks key existed + backup_registry_entry = None # None means registry entry not yet captured + backup_installed = UNSET # Original installed list from extensions.yml + backup_hooks = None # None means backup step 4 not yet reached; {} or {...} means backup was captured backed_up_command_files = {} try: @@ -4913,8 +4918,7 @@ def extension_update( shutil.copy2(cfg_file, backup_config_dir / cfg_file.name) # 3. Backup command files for all agents - from .agents import CommandRegistrar as _AgentReg - registered_commands = backup_registry_entry.get("registered_commands", {}) + registered_commands = backup_registry_entry.get("registered_commands", {}) if isinstance(backup_registry_entry, dict) else {} for agent_name, cmd_names in registered_commands.items(): if agent_name not in registrar.AGENT_CONFIGS: continue @@ -4939,14 +4943,20 @@ def extension_update( shutil.copy2(prompt_file, backup_prompt_path) backed_up_command_files[str(prompt_file)] = str(backup_prompt_path) - # 4. Backup hooks from extensions.yml - # Use backup_hooks=None to indicate config had no "hooks" key (don't create on restore) - # Use backup_hooks={} to indicate config had "hooks" key with no hooks for this extension + # 4. Backup hooks and installed list from extensions.yml + # get_project_config() always normalizes installed->[] and hooks->{}, + # so no sentinel is needed to distinguish key-absent from key-empty. config = hook_executor.get_project_config() - if "hooks" in config: - backup_hooks = {} # Config has hooks key - preserve this fact - for hook_name, hook_list in config["hooks"].items(): - ext_hooks = [h for h in hook_list if h.get("extension") == extension_id] + if isinstance(config, dict): + import copy + # Deep-copy so nested mapping entries (e.g. version-pin dicts) + # are not affected by in-place mutations during the update. + backup_installed = copy.deepcopy(config.get("installed", [])) + backup_hooks = {} + for hook_name, hook_list in config.get("hooks", {}).items(): + if not isinstance(hook_list, list): + continue + ext_hooks = [h for h in hook_list if isinstance(h, dict) and h.get("extension") == extension_id] if ext_hooks: backup_hooks[hook_name] = ext_hooks @@ -5099,35 +5109,51 @@ def extension_update( original_file.parent.mkdir(parents=True, exist_ok=True) shutil.copy2(backup_file, original_file) - # Restore hooks in extensions.yml - # - backup_hooks=None means original config had no "hooks" key - # - backup_hooks={} or {...} means config had hooks key - config = hook_executor.get_project_config() - if "hooks" in config: + # Restore metadata in extensions.yml (hooks and installed list). + # Only run if backup step 4 was reached (backup_hooks is not None); + # otherwise we have no safe baseline to restore from and could corrupt + # the config by removing pre-existing hooks. + if backup_hooks is not None: + config = hook_executor.get_project_config() + if not isinstance(config, dict): + config = {} + modified = False - if backup_hooks is None: - # Original config had no "hooks" key; remove it entirely - del config["hooks"] + # 1. Restore hooks in extensions.yml + if not isinstance(config.get("hooks"), dict): + config["hooks"] = {} modified = True - else: - # Remove any hooks for this extension added by failed install - for hook_name, hooks_list in config["hooks"].items(): - original_len = len(hooks_list) - config["hooks"][hook_name] = [ - h for h in hooks_list - if h.get("extension") != extension_id - ] - if len(config["hooks"][hook_name]) != original_len: - modified = True - - # Add back the backed up hooks if any - if backup_hooks: - for hook_name, hooks in backup_hooks.items(): - if hook_name not in config["hooks"]: - config["hooks"][hook_name] = [] - config["hooks"][hook_name].extend(hooks) - modified = True + + # Remove any hooks for this extension added by the failed install + for hook_name in list(config["hooks"].keys()): + hooks_list = config["hooks"][hook_name] + if not isinstance(hooks_list, list): + config["hooks"][hook_name] = [] + modified = True + continue + + original_len = len(hooks_list) + config["hooks"][hook_name] = [ + h for h in hooks_list + if isinstance(h, dict) and h.get("extension") != extension_id + ] + if len(config["hooks"][hook_name]) != original_len: + modified = True + + # Add back the backed-up hooks + if backup_hooks: + for hook_name, hooks in backup_hooks.items(): + if not isinstance(config["hooks"].get(hook_name), list): + config["hooks"][hook_name] = [] + config["hooks"][hook_name].extend(hooks) + modified = True + + # 2. Restore installed list in extensions.yml + if backup_installed is not UNSET: + if config.get("installed") != backup_installed: + config["installed"] = backup_installed + modified = True if modified: hook_executor.save_project_config(config) diff --git a/src/specify_cli/extensions.py b/src/specify_cli/extensions.py index 944ee4a06d..f657de06ce 100644 --- a/src/specify_cli/extensions.py +++ b/src/specify_cli/extensions.py @@ -1190,7 +1190,7 @@ def install_from_directory( # was used during project initialisation (feature parity). registered_skills = self._register_extension_skills(manifest, dest_dir) - # Register hooks + # Register hooks and update installed list in extensions.yml hook_executor = HookExecutor(self.project_root) hook_executor.register_hooks(manifest) @@ -2481,7 +2481,32 @@ def get_project_config(self) -> Dict[str, Any]: } try: - return yaml.safe_load(self.config_file.read_text(encoding="utf-8")) or {} + result = yaml.safe_load(self.config_file.read_text(encoding="utf-8")) + # Coerce non-dict root (including None for an empty file) to the + # fully-normalized default so callers always get guaranteed fields. + if not isinstance(result, dict): + return { + "installed": [], + "settings": {"auto_execute_hooks": True}, + "hooks": {}, + } + # Normalize nested fields so read-only callers like get_hooks_for_event() + # never see non-dict hooks or non-list installed (Feedback) + if not isinstance(result.get("hooks"), dict): + result["hooks"] = {} + if not isinstance(result.get("installed"), list): + result["installed"] = [] + if not isinstance(result.get("settings"), dict): + result["settings"] = {"auto_execute_hooks": True} + # Sanitize hook event values: coerce non-list values to [] and filter + # non-dict items so get_hooks_for_event() can safely call .get() (Feedback) + for event_key in list(result["hooks"]): + event_val = result["hooks"][event_key] + if not isinstance(event_val, list): + result["hooks"][event_key] = [] + else: + result["hooks"][event_key] = [h for h in event_val if isinstance(h, dict)] + return result except (yaml.YAMLError, OSError, UnicodeError): return { "installed": [], @@ -2501,25 +2526,141 @@ def save_project_config(self, config: Dict[str, Any]): encoding="utf-8", ) + def register_extension(self, extension_id: str): + """Add extension to the installed list in project config. + + Args: + extension_id: ID of extension to register + """ + config = self.get_project_config() + + # Ensure config is a dict (defensive) + if not isinstance(config, dict): + config = {} + + raw_installed = config.get("installed") + sanitized = self._sanitize_installed_list(raw_installed, add_id=extension_id) + + if sanitized != raw_installed: + config["installed"] = sanitized + self.save_project_config(config) + + def unregister_extension(self, extension_id: str): + """Remove extension from the installed list in project config. + + Args: + extension_id: ID of extension to unregister + """ + config = self.get_project_config() + + if not isinstance(config, dict): + config = {} + + raw_installed = config.get("installed") + sanitized = self._sanitize_installed_list(raw_installed, remove_id=extension_id) + + # Always persist if sanitized state differs from raw config (ensures normalization) + if sanitized != raw_installed: + config["installed"] = sanitized + self.save_project_config(config) + + @staticmethod + def _sanitize_installed_list( + raw: object, + *, + add_id: str = "", + remove_id: str = "", + ) -> list: + """Normalize, deduplicate, and optionally add/remove an extension id. + + Shared by register_extension() and unregister_extension() to prevent + the two paths from drifting. + + Args: + raw: The raw value from config["installed"] (may be non-list). + add_id: If non-empty, ensure this id is present (plain-string fallback). + remove_id: If non-empty, remove this id from the list. + + Returns: + A sanitized, deduplicated, alphabetically-sorted list. + """ + _VALID_ID = re.compile(r'^[a-z0-9-]+$') + + installed = raw if isinstance(raw, list) else [] + + # Keep only entries whose resolved id is a non-empty string matching + # the extension-id format (^[a-z0-9-]+$), same rule ExtensionManifest enforces. + def _valid_entry(x: object) -> bool: + if isinstance(x, str): + return bool(_VALID_ID.match(x.strip())) + if isinstance(x, dict): + eid = x.get("id") + return isinstance(eid, str) and bool(_VALID_ID.match(eid.strip())) + return False + + valid = [x for x in installed if _valid_entry(x)] + + # Deduplicate by id: prefer dict (richer metadata) over plain string + seen: dict = {} # id -> entry (dict preferred over str) + for x in valid: + eid = x.strip() if isinstance(x, str) else x.get("id", "").strip() + if eid not in seen or isinstance(x, dict): + seen[eid] = x + + # Validate add_id against the same regex before inserting + if add_id and _VALID_ID.match(add_id.strip()) and add_id not in seen: + seen[add_id] = add_id + + if remove_id: + seen.pop(remove_id, None) + + def _sort_key(x: object) -> str: + return x if isinstance(x, str) else x.get("id", "") # type: ignore[return-value] + + return sorted(seen.values(), key=_sort_key) + def register_hooks(self, manifest: ExtensionManifest): """Register extension hooks in project config. Args: manifest: Extension manifest with hooks to register """ + # Always ensure the extension is in the installed list + self.register_extension(manifest.id) + if not hasattr(manifest, "hooks") or not manifest.hooks: return config = self.get_project_config() - # Ensure hooks dict exists - if "hooks" not in config: + # Ensure config is a dict (defensive) + changed = False + if not isinstance(config, dict): + config = {} + changed = True + + # Ensure hooks dict exists and is a mapping + if "hooks" not in config or not isinstance(config["hooks"], dict): config["hooks"] = {} + changed = True + else: + # Sanitize existing hook lists to prevent crashes in downstream code (Feedback) + for h_name in list(config["hooks"].keys()): + h_list = config["hooks"][h_name] + if not isinstance(h_list, list): + config["hooks"][h_name] = [] + changed = True + else: + sanitized_h_list = [h for h in h_list if isinstance(h, dict)] + if len(sanitized_h_list) != len(h_list): + config["hooks"][h_name] = sanitized_h_list + changed = True # Register each hook for hook_name, hook_config in manifest.hooks.items(): - if hook_name not in config["hooks"]: + if hook_name not in config["hooks"] or not isinstance(config["hooks"][hook_name], list): config["hooks"][hook_name] = [] + changed = True # Add hook entry hook_entry = { @@ -2534,22 +2675,22 @@ def register_hooks(self, manifest: ExtensionManifest): "condition": hook_config.get("condition"), } - # Check if already registered - existing = [ - h - for h in config["hooks"][hook_name] - if h.get("extension") == manifest.id + # Deduplicate: remove all existing entries for this extension on this + # hook event, then append the single canonical entry. This prevents + # multiple hooks firing when hand-edited or older versions leave + # duplicate entries behind. (Feedback from review) + original_list = config["hooks"][hook_name] + deduped = [ + h for h in original_list + if not (isinstance(h, dict) and h.get("extension") == manifest.id) ] + deduped.append(hook_entry) + if deduped != original_list: + config["hooks"][hook_name] = deduped + changed = True - if not existing: - config["hooks"][hook_name].append(hook_entry) - else: - # Update existing - for i, h in enumerate(config["hooks"][hook_name]): - if h.get("extension") == manifest.id: - config["hooks"][hook_name][i] = hook_entry - - self.save_project_config(config) + if changed: + self.save_project_config(config) def unregister_hooks(self, extension_id: str): """Remove extension hooks from project config. @@ -2557,17 +2698,30 @@ def unregister_hooks(self, extension_id: str): Args: extension_id: ID of extension to unregister """ + # Always remove from installed list (Feedback from review) + self.unregister_extension(extension_id) + config = self.get_project_config() - if "hooks" not in config: + if not isinstance(config, dict): + config = {} + # We don't save yet, as there are no hooks to unregister, + # but unregister_extension above might have already saved a normalized config. + return + + if "hooks" not in config or not isinstance(config["hooks"], dict): return # Remove hooks for this extension - for hook_name in config["hooks"]: + for hook_name in list(config["hooks"].keys()): + hook_list = config["hooks"][hook_name] + if not isinstance(hook_list, list): + config["hooks"][hook_name] = [] + continue config["hooks"][hook_name] = [ h - for h in config["hooks"][hook_name] - if h.get("extension") != extension_id + for h in hook_list + if isinstance(h, dict) and h.get("extension") != extension_id ] # Clean up empty hook arrays diff --git a/tests/test_extension_registration.py b/tests/test_extension_registration.py new file mode 100644 index 0000000000..9965deae43 --- /dev/null +++ b/tests/test_extension_registration.py @@ -0,0 +1,497 @@ +import pytest +import yaml +from specify_cli.extensions import HookExecutor, ExtensionManifest + +@pytest.fixture +def project_dir(tmp_path): + """Create a mock spec-kit project directory.""" + proj_dir = tmp_path / "project" + proj_dir.mkdir() + (proj_dir / ".specify").mkdir() + return proj_dir + +class TestExtensionRegistration: + """Tests for the 'installed' list management in HookExecutor.""" + + def test_register_extension_new(self, project_dir): + """Standard registration: Adding an extension should add it to the list.""" + executor = HookExecutor(project_dir) + executor.register_extension("test-ext") + + config = executor.get_project_config() + assert "installed" in config + assert config["installed"] == ["test-ext"] + + def test_register_extension_sorting(self, project_dir): + """Order Stability: Extensions should be stored in alphabetical order.""" + executor = HookExecutor(project_dir) + executor.register_extension("zebra-ext") + executor.register_extension("apple-ext") + executor.register_extension("middle-ext") + + config = executor.get_project_config() + assert config["installed"] == ["apple-ext", "middle-ext", "zebra-ext"] + + def test_register_extension_idempotency(self, project_dir): + """Idempotency: Adding the same extension twice should not result in duplicates.""" + executor = HookExecutor(project_dir) + executor.register_extension("test-ext") + executor.register_extension("test-ext") + + config = executor.get_project_config() + assert config["installed"] == ["test-ext"] + assert len(config["installed"]) == 1 + + def test_unregister_extension(self, project_dir): + """Standard unregistration: Removing an extension should prune it from the list.""" + executor = HookExecutor(project_dir) + executor.register_extension("ext-1") + executor.register_extension("ext-2") + + executor.unregister_extension("ext-1") + + config = executor.get_project_config() + assert config["installed"] == ["ext-2"] + + def test_unregister_extension_not_present(self, project_dir): + """Safe Removal: Unregistering a non-existent extension should do nothing.""" + executor = HookExecutor(project_dir) + executor.register_extension("ext-1") + + # Should not raise or change the list + executor.unregister_extension("ext-nonexistent") + + config = executor.get_project_config() + assert config["installed"] == ["ext-1"] + + def test_register_hooks_triggers_registration(self, project_dir, tmp_path): + """Full Workflow: register_hooks should automatically register the extension.""" + # Create a mock manifest + manifest_data = { + "schema_version": "1.0", + "extension": { + "id": "hook-ext", + "name": "Hook Ext", + "version": "1.0.0", + "description": "Test", + }, + "requires": { + "speckit_version": ">=0.1.0", + "commands": [] + }, + "provides": {"commands": []}, + "hooks": { + "after_tasks": {"command": "speckit.hook-ext.run"} + } + } + manifest_path = tmp_path / "extension.yml" + with open(manifest_path, "w") as f: + yaml.dump(manifest_data, f) + + manifest = ExtensionManifest(manifest_path) + executor = HookExecutor(project_dir) + + # This should call register_extension internally + executor.register_hooks(manifest) + + config = executor.get_project_config() + assert "hook-ext" in config["installed"] + + def test_missing_installed_key_initialization(self, project_dir): + """Graceful Initialization: If 'installed' key is missing, it should be created.""" + executor = HookExecutor(project_dir) + + # Manually create a config without 'installed' + config_path = project_dir / ".specify" / "extensions.yml" + config_path.write_text(yaml.dump({"settings": {"auto_execute_hooks": True}})) + + # This should detect the missing key and initialize it + executor.register_extension("new-ext") + + config = executor.get_project_config() + assert "installed" in config + assert config["installed"] == ["new-ext"] + + def test_unregister_hooks_full_workflow(self, project_dir, tmp_path): + """Full Workflow: unregister_hooks should remove hooks and prune installed list.""" + # Create a manifest with hooks + manifest_data = { + "schema_version": "1.0", + "extension": { + "id": "hook-ext", + "name": "Hook Ext", + "version": "1.0.0", + "description": "Test", + }, + "requires": { + "speckit_version": ">=0.1.0", + "commands": [] + }, + "provides": {"commands": []}, + "hooks": { + "after_tasks": {"command": "speckit.hook-ext.run"} + } + } + manifest_path = tmp_path / "extension.yml" + with open(manifest_path, "w") as f: + yaml.dump(manifest_data, f) + + manifest = ExtensionManifest(manifest_path) + executor = HookExecutor(project_dir) + + # Register hooks first + executor.register_hooks(manifest) + + config = executor.get_project_config() + assert "hook-ext" in config["installed"] + assert "after_tasks" in config["hooks"] + + # Now unregister hooks + executor.unregister_hooks("hook-ext") + + config = executor.get_project_config() + assert "hook-ext" not in config["installed"] + # unregister_hooks() removes the empty hook array entirely, so the key is absent + assert "after_tasks" not in config["hooks"] + + def test_unregister_hooks_no_hooks_key(self, project_dir): + """Resilience: unregister_hooks should work even if config has no 'hooks' key.""" + executor = HookExecutor(project_dir) + + # Register extension without hooks + executor.register_extension("ext-no-hooks") + + config = executor.get_project_config() + assert "ext-no-hooks" in config["installed"] + + # Unregister should not crash even if no hooks key exists + executor.unregister_hooks("ext-no-hooks") + + config = executor.get_project_config() + assert "ext-no-hooks" not in config["installed"] + + def test_unregister_hooks_corrupted_config(self, project_dir): + """Resilience: unregister_hooks should gracefully handle corrupted config.""" + # Create a corrupted config (root is a list) + config_path = project_dir / ".specify" / "extensions.yml" + config_path.write_text(yaml.dump(["corrupted", "list"])) + + executor = HookExecutor(project_dir) + + # Should not raise even with corrupted config + executor.unregister_hooks("non-existent") + + # Config should remain as-is or be handled gracefully + config = executor.get_project_config() + # If it's corrupted, it's returned as-is or handled by defensive logic + assert config is not None + + def test_unregister_hooks_with_multiple_extensions(self, project_dir, tmp_path): + """Multiple Extensions: unregister_hooks should only remove target extension's hooks.""" + # Create two manifests + manifest_data_1 = { + "schema_version": "1.0", + "extension": { + "id": "ext-1", + "name": "Ext 1", + "version": "1.0.0", + "description": "Test 1", + }, + "requires": { + "speckit_version": ">=0.1.0", + "commands": [] + }, + "provides": {"commands": []}, + "hooks": { + "after_tasks": {"command": "speckit.ext-1.run"} + } + } + manifest_data_2 = { + "schema_version": "1.0", + "extension": { + "id": "ext-2", + "name": "Ext 2", + "version": "1.0.0", + "description": "Test 2", + }, + "requires": { + "speckit_version": ">=0.1.0", + "commands": [] + }, + "provides": {"commands": []}, + "hooks": { + "after_tasks": {"command": "speckit.ext-2.run"} + } + } + + manifest_path_1 = tmp_path / "extension1.yml" + manifest_path_2 = tmp_path / "extension2.yml" + with open(manifest_path_1, "w") as f: + yaml.dump(manifest_data_1, f) + with open(manifest_path_2, "w") as f: + yaml.dump(manifest_data_2, f) + + manifest1 = ExtensionManifest(manifest_path_1) + manifest2 = ExtensionManifest(manifest_path_2) + executor = HookExecutor(project_dir) + + # Register both extensions + executor.register_hooks(manifest1) + executor.register_hooks(manifest2) + + config = executor.get_project_config() + assert "ext-1" in config["installed"] + assert "ext-2" in config["installed"] + assert len(config["hooks"]["after_tasks"]) == 2 + + # Unregister first extension + executor.unregister_hooks("ext-1") + + config = executor.get_project_config() + assert "ext-1" not in config["installed"] + assert "ext-2" in config["installed"] + # ext-2's hook should still be there + assert len(config["hooks"]["after_tasks"]) == 1 + assert config["hooks"]["after_tasks"][0].get("extension") == "ext-2" + + def test_register_hooks_no_hooks_still_registers(self, project_dir, tmp_path): + """Commands-only manifest: register_hooks() must still update installed even with no hooks.""" + manifest_data = { + "schema_version": "1.0", + "extension": { + "id": "commands-only-ext", + "name": "Commands Only", + "version": "1.0.0", + "description": "No hooks, only commands", + }, + "requires": { + "speckit_version": ">=0.1.0", + "commands": [] + }, + "provides": {"commands": [{"name": "speckit.commands-only-ext.run", "file": "commands/run.md"}]}, + } + manifest_path = tmp_path / "extension.yml" + with open(manifest_path, "w") as f: + yaml.dump(manifest_data, f) + + manifest = ExtensionManifest(manifest_path) + executor = HookExecutor(project_dir) + executor.register_hooks(manifest) + + config = executor.get_project_config() + assert "commands-only-ext" in config["installed"] + + def test_register_extension_mixed_type_installed(self, project_dir): + """Regression: installed list with non-string entries must not crash on sort.""" + executor = HookExecutor(project_dir) + + # Manually write a corrupted installed list with non-string entries + config_path = project_dir / ".specify" / "extensions.yml" + config_path.write_text(yaml.dump({"installed": [1, True, "existing-ext"]})) + + # Should not raise TypeError on sort + executor.register_extension("new-ext") + + config = executor.get_project_config() + # Non-string entries are dropped; valid strings are preserved + assert "existing-ext" in config["installed"] + assert "new-ext" in config["installed"] + assert 1 not in config["installed"] + assert True not in config["installed"] + + def test_unregister_hooks_null_hook_values(self, project_dir): + """Regression: hooks: {after_tasks: null} must not crash in unregister_hooks().""" + executor = HookExecutor(project_dir) + + # Manually write a config with null hook event value + config_path = project_dir / ".specify" / "extensions.yml" + config_path.write_text(yaml.dump({ + "installed": ["broken-ext"], + "hooks": {"after_tasks": None} + })) + + # Should not raise TypeError when iterating None + executor.unregister_hooks("broken-ext") + + config = executor.get_project_config() + assert "broken-ext" not in config["installed"] + + def test_register_hooks_corrupted_hook_values(self, project_dir, tmp_path): + """Regression: register_hooks() must handle non-list hook event values in config.""" + executor = HookExecutor(project_dir) + + # Manually write a config with null hook event value + config_path = project_dir / ".specify" / "extensions.yml" + config_path.write_text(yaml.dump({ + "installed": ["some-ext"], + "hooks": {"after_tasks": None} + })) + + # Create a manifest with a hook for the same event + manifest_data = { + "schema_version": "1.0", + "extension": { + "id": "new-ext", + "name": "New Ext", + "version": "1.0.0", + "description": "Test", + }, + "requires": { + "speckit_version": ">=0.1.0", + "commands": [] + }, + "provides": {"commands": []}, + "hooks": {"after_tasks": {"command": "speckit.new-ext.run"}} + } + manifest_path = tmp_path / "extension.yml" + with open(manifest_path, "w") as f: + yaml.dump(manifest_data, f) + + manifest = ExtensionManifest(manifest_path) + + # Should not raise TypeError when trying to append to None + executor.register_hooks(manifest) + + config = executor.get_project_config() + assert "new-ext" in config["installed"] + assert isinstance(config["hooks"]["after_tasks"], list) + assert any(h["extension"] == "new-ext" for h in config["hooks"]["after_tasks"]) + + def test_register_extension_already_present_in_corrupted_list(self, project_dir): + """Regression: if extension is already present but list has non-strings, it must still be sanitized.""" + executor = HookExecutor(project_dir) + + # Extension is present, but list has garbage + config_path = project_dir / ".specify" / "extensions.yml" + config_path.write_text(yaml.dump({"installed": [1, "test-ext", True]})) + + # This should trigger sanitization and save, even though "test-ext" is already there + executor.register_extension("test-ext") + + config = executor.get_project_config() + assert config["installed"] == ["test-ext"] + # Verify it was actually saved to disk + raw_config = yaml.safe_load(config_path.read_text()) + assert raw_config["installed"] == ["test-ext"] + + def test_register_extension_with_dict_entry(self, project_dir): + """Review Feedback: register_extension should support and preserve dict entries.""" + executor = HookExecutor(project_dir) + config_path = project_dir / ".specify" / "extensions.yml" + + # Setup config with a pinned extension (dict) + pinned_ext = {"id": "pinned-ext", "version": "1.0.0"} + config_path.write_text(yaml.dump({ + "installed": [pinned_ext, "string-ext"] + })) + + # Register a new extension + executor.register_extension("new-ext") + + config = executor.get_project_config() + # Should contain all three, sorted by id: new-ext, pinned-ext, string-ext + assert config["installed"] == ["new-ext", pinned_ext, "string-ext"] + + def test_unregister_extension_with_dict_entry(self, project_dir): + """Review Feedback: unregister_extension should support removing matching dict entries.""" + executor = HookExecutor(project_dir) + config_path = project_dir / ".specify" / "extensions.yml" + + pinned_ext = {"id": "to-remove", "version": "1.0.0"} + config_path.write_text(yaml.dump({ + "installed": [pinned_ext, "other-ext"] + })) + + # Unregister by ID + executor.unregister_extension("to-remove") + + config = executor.get_project_config() + assert config["installed"] == ["other-ext"] + + def test_unregister_extension_corrupted_installed(self, project_dir): + """Hardening: unregister_extension should handle non-list installed key.""" + executor = HookExecutor(project_dir) + config_path = project_dir / ".specify" / "extensions.yml" + + config_path.write_text(yaml.dump({ + "installed": "not-a-list" + })) + + # Should not crash and should normalize to [] + executor.unregister_extension("any-ext") + + config = executor.get_project_config() + assert config["installed"] == [] + def test_register_hooks_mixed_type_hook_list(self, project_dir, tmp_path): + """Regression: register_hooks() must sanitize hook event lists by dropping non-dicts.""" + executor = HookExecutor(project_dir) + + config_path = project_dir / ".specify" / "extensions.yml" + config_path.write_text(yaml.dump({ + "installed": ["some-ext"], + "hooks": {"after_tasks": [1, "corrupted", {"extension": "other", "command": "cmd"}]} + })) + + manifest_path = tmp_path / "extension.yml" + manifest_data = { + "schema_version": "1.0", + "extension": { + "id": "new-ext", + "name": "New Ext", + "version": "1.0.0", + "description": "Test", + "author": "Test author" + }, + "requires": { + "speckit_version": ">=0.1.0", + "commands": [] + }, + "provides": {"commands": []}, + "hooks": { + "after_tasks": {"command": "new-cmd"} + } + } + manifest_path.write_text(yaml.dump(manifest_data)) + manifest = ExtensionManifest(manifest_path) + + executor.register_hooks(manifest) + + config = executor.get_project_config() + hooks = config["hooks"]["after_tasks"] + + # Should have 2 valid dict hooks, and 0 non-dict items + assert len(hooks) == 2 + assert all(isinstance(h, dict) for h in hooks) + assert any(h.get("extension") == "other" for h in hooks) + assert any(h.get("extension") == "new-ext" for h in hooks) + + def test_unregister_extension_scalar_root(self, project_dir): + """Hardening: unregister_extension should handle scalar root config.""" + executor = HookExecutor(project_dir) + config_path = project_dir / ".specify" / "extensions.yml" + + config_path.write_text(yaml.dump(123)) + + # Should not crash and should normalize to {} + executor.unregister_extension("any-ext") + + config = executor.get_project_config() + assert isinstance(config, dict) + assert config["installed"] == [] + + def test_unregister_hooks_scalar_hook_values(self, project_dir): + """Regression: unregister_hooks() must handle scalar hook event values.""" + executor = HookExecutor(project_dir) + config_path = project_dir / ".specify" / "extensions.yml" + + config_path.write_text(yaml.dump({ + "installed": ["some-ext"], + "hooks": {"after_tasks": 123} + })) + + # Should not raise TypeError when iterating + executor.unregister_hooks("some-ext") + + config = executor.get_project_config() + assert "some-ext" not in config["installed"] + assert "after_tasks" not in config["hooks"] diff --git a/tests/test_extension_update_hardening.py b/tests/test_extension_update_hardening.py new file mode 100644 index 0000000000..426e5ec7e9 --- /dev/null +++ b/tests/test_extension_update_hardening.py @@ -0,0 +1,109 @@ +from specify_cli.extensions import ExtensionManager, ExtensionRegistry, ExtensionCatalog +import pytest +import yaml +from typer.testing import CliRunner +from specify_cli import app + +runner = CliRunner() + +@pytest.fixture +def project_dir(tmp_path): + """Create a mock spec-kit project directory.""" + proj_dir = tmp_path / "project" + proj_dir.mkdir() + (proj_dir / ".specify").mkdir() + # Create required files for a project + (proj_dir / ".specify" / "config.toml").write_text("ai = 'claude'") + return proj_dir + +def test_extension_update_corrupted_config_root(project_dir, monkeypatch): + """Regression: extension update must handle corrupted extensions.yml (root is scalar).""" + # chdir into project_dir so _require_specify_project() succeeds + monkeypatch.chdir(project_dir) + + # Corrupt extensions.yml + config_path = project_dir / ".specify" / "extensions.yml" + config_path.write_text(yaml.dump(123)) + + # Mock ExtensionManager to return an installed extension for resolution + + monkeypatch.setattr(ExtensionManager, "list_installed", lambda self: [{"id": "test-ext", "name": "Test Ext", "version": "1.0.0"}]) + monkeypatch.setattr(ExtensionRegistry, "get", lambda self, ext_id: {"version": "1.0.0", "enabled": True}) + monkeypatch.setattr(ExtensionCatalog, "get_extension_info", lambda self, ext_id: {"id": "test-ext", "name": "Test Ext", "version": "1.1.0", "download_url": "https://example.com/ext.zip"}) + + # Mock download_extension to avoid network calls; use tmp_path so the test is hermetic + # and returns a Path so zip_path.exists() / zip_path.unlink() work without AttributeError + mock_zip = project_dir / "mock.zip" + monkeypatch.setattr(ExtensionCatalog, "download_extension", lambda self, ext_id: mock_zip) + + # Mock confirmation to true + monkeypatch.setattr("typer.confirm", lambda _: True) + + # Run update + result = runner.invoke(app, ["extension", "update", "test-ext"], obj={"project_root": project_dir}) + + # extension_update() catches exceptions internally and exits with code 1 on failure. + assert result.exit_code == 1 + assert "AttributeError" not in result.output + assert not isinstance(result.exception, AttributeError) + +def test_extension_update_corrupted_hooks_value(project_dir, monkeypatch): + """Regression: extension update must handle non-dict 'hooks' in extensions.yml.""" + monkeypatch.chdir(project_dir) + + config_path = project_dir / ".specify" / "extensions.yml" + config_path.write_text(yaml.dump({ + "installed": ["test-ext"], + "hooks": ["not", "a", "dict"] + })) + + monkeypatch.setattr(ExtensionManager, "list_installed", lambda self: [{"id": "test-ext", "name": "Test Ext", "version": "1.0.0"}]) + monkeypatch.setattr(ExtensionRegistry, "get", lambda self, ext_id: {"version": "1.0.0", "enabled": True}) + monkeypatch.setattr(ExtensionCatalog, "get_extension_info", lambda self, ext_id: {"id": "test-ext", "name": "Test Ext", "version": "1.1.0", "download_url": "https://example.com/ext.zip"}) + # Use tmp_path-scoped zip so the test is hermetic and returns a Path for zip_path.exists() + mock_zip = project_dir / "mock.zip" + monkeypatch.setattr(ExtensionCatalog, "download_extension", lambda self, ext_id: mock_zip) + monkeypatch.setattr("typer.confirm", lambda _: True) + + result = runner.invoke(app, ["extension", "update", "test-ext"], obj={"project_root": project_dir}) + + # extension_update() catches exceptions internally and exits with code 1 on failure. + assert result.exit_code == 1 + assert "AttributeError" not in result.output + assert not isinstance(result.exception, AttributeError) + +def test_extension_update_rollback_corrupted_config(project_dir, monkeypatch): + """Regression: extension update rollback must handle corrupted extensions.yml.""" + monkeypatch.chdir(project_dir) + + config_path = project_dir / ".specify" / "extensions.yml" + # Write config with hooks: null; get_project_config() normalizes this to {} + # so the backup captures {} and the restored config will have hooks: {}. + config_path.write_text(yaml.dump({"installed": ["test-ext"], "hooks": None})) + + # Mock update process to fail after backup + monkeypatch.setattr(ExtensionManager, "list_installed", lambda self: [{"id": "test-ext", "name": "Test Ext", "version": "1.0.0"}]) + monkeypatch.setattr(ExtensionRegistry, "get", lambda self, ext_id: {"version": "1.0.0", "enabled": True}) + + # Force failure in download_extension to trigger rollback + def mock_download_fail(*args, **kwargs): + # Corrupt the config BEFORE rollback is triggered + config_path.write_text(yaml.dump("CORRUPTED")) + raise Exception("Download failed") + + monkeypatch.setattr(ExtensionCatalog, "get_extension_info", lambda self, ext_id: {"id": "test-ext", "name": "Test Ext", "version": "1.1.0", "download_url": "https://example.com/ext.zip"}) + monkeypatch.setattr(ExtensionCatalog, "download_extension", mock_download_fail) + monkeypatch.setattr("typer.confirm", lambda _: True) + + result = runner.invoke(app, ["extension", "update", "test-ext"], obj={"project_root": project_dir}) + + # Should handle Exception and NOT crash with AttributeError during rollback + assert result.exit_code == 1 + assert "Download failed" in result.output + assert not isinstance(result.exception, AttributeError) + + # Verify hooks key was preserved (normalized to {} if it was null/corrupted) + restored_config = yaml.safe_load(config_path.read_text()) + assert isinstance(restored_config, dict) + assert "hooks" in restored_config + assert restored_config["hooks"] == {}