diff --git a/comfy_cli/command/models/models.py b/comfy_cli/command/models/models.py index 2f02aadd..b18af207 100644 --- a/comfy_cli/command/models/models.py +++ b/comfy_cli/command/models/models.py @@ -7,12 +7,14 @@ import requests import typer +import yaml from rich import print from rich.markup import escape from comfy_cli import constants, tracking, ui from comfy_cli.config_manager import ConfigManager from comfy_cli.constants import DEFAULT_COMFY_MODEL_PATH +from comfy_cli.extra_model_paths import collect_extra_paths, paths_for_category from comfy_cli.file_utils import DownloadException, check_unauthorized, download_file from comfy_cli.workspace_manager import WorkspaceManager @@ -37,6 +39,21 @@ def get_workspace() -> pathlib.Path: return pathlib.Path(workspace_manager.workspace_path) +def _resolve_default_relative_path(category: str | None, basemodel: str, extras: list) -> str: + """Pick the destination subdir for a typed download. + + Returns an absolute path string when ``extras`` configures the category + (pathlib's ``/`` operator discards the workspace prefix in that case). + Otherwise returns the workspace-relative ``models//`` + form preserved from comfy-cli's existing behavior. + """ + if category and extras: + configured = paths_for_category(extras, category) + if configured: + return str(configured[0] / basemodel) if basemodel else str(configured[0]) + return os.path.join(DEFAULT_COMFY_MODEL_PATH, category or "", basemodel) + + def _format_elapsed(seconds: float) -> str: """Format elapsed seconds into a human-readable string.""" rounded = round(seconds, 1) @@ -243,10 +260,33 @@ def download( show_default=False, ), ] = None, + extra_model_paths_config: Annotated[ + list[pathlib.Path] | None, + typer.Option( + "--extra-model-paths-config", + help="Additional extra_model_paths.yaml file(s) to honor. Repeatable.", + show_default=False, + ), + ] = None, + extra_model_paths: Annotated[ + bool, + typer.Option( + "--extra-model-paths/--no-extra-model-paths", + help="Honor extra_model_paths.yaml from the workspace and any --extra-model-paths-config files.", + show_default=False, + ), + ] = True, ): if relative_path is not None: relative_path = os.path.expanduser(relative_path) + extras: list = [] + if extra_model_paths: + try: + extras = collect_extra_paths(get_workspace(), extra_model_paths_config or []) + except yaml.YAMLError as e: + print(f"[yellow]Warning: extra_model_paths YAML is invalid; ignoring extras ({escape(str(e))})[/yellow]") + local_filename = None headers = None @@ -278,7 +318,7 @@ def download( if model_path is None: model_path = ui.prompt_input("Enter model type path (e.g. loras, checkpoints, ...)", default="") - relative_path = os.path.join(DEFAULT_COMFY_MODEL_PATH, model_path, basemodel) + relative_path = _resolve_default_relative_path(model_path, basemodel, extras) elif is_civitai_api_url: local_filename, url, model_type, basemodel = request_civitai_model_version_api(version_id, headers) @@ -288,7 +328,7 @@ def download( if model_path is None: model_path = ui.prompt_input("Enter model type path (e.g. loras, checkpoints, ...)", default="") - relative_path = os.path.join(DEFAULT_COMFY_MODEL_PATH, model_path, basemodel) + relative_path = _resolve_default_relative_path(model_path, basemodel, extras) elif is_huggingface_url: model_id = "/".join(url.split("/")[-2:]) @@ -297,7 +337,7 @@ def download( if relative_path is None: model_path = ui.prompt_input("Enter model type path (e.g. loras, checkpoints, ...)", default="") basemodel = ui.prompt_input("Enter base model (e.g. SD1.5, SDXL, ...)", default="") - relative_path = os.path.join(DEFAULT_COMFY_MODEL_PATH, model_path, basemodel) + relative_path = _resolve_default_relative_path(model_path, basemodel, extras) else: print("Model source is unknown") @@ -388,47 +428,93 @@ def remove( help="Confirm for deletion and skip the prompt", show_default=False, ), + extra_model_paths_config: list[pathlib.Path] | None = typer.Option( + None, + "--extra-model-paths-config", + help="Additional extra_model_paths.yaml file(s) to honor. Repeatable.", + show_default=False, + ), + extra_model_paths: bool = typer.Option( + True, + "--extra-model-paths/--no-extra-model-paths", + help="Honor extra_model_paths.yaml from the workspace and any --extra-model-paths-config files.", + show_default=False, + ), ): """Remove one or more downloaded models, either by specifying them directly or through an interactive selection.""" - model_dir = get_workspace() / relative_path - available_models = list_models(model_dir) + primary = get_workspace() / relative_path + extras = _load_extras_safely(extra_model_paths, extra_model_paths_config) + roots = _enumerate_search_roots(primary, extras) + scanned = _scan_all_roots(roots) - if not available_models: + if not scanned: typer.echo("No models found to remove.") return - model_dir_resolved = model_dir.resolve() + resolved_roots: list[pathlib.Path] = [] + for root, _ in roots: + try: + resolved_roots.append(root.resolve()) + except OSError: + continue - to_delete = [] - # Scenario #1: User provided model names to delete + to_delete: list[pathlib.Path] = [] if model_names: - # Validate and filter models to delete based on provided names missing_models = [] + ambiguous = [] for name in model_names: - model_path = (model_dir / name).resolve() - if not model_path.is_relative_to(model_dir_resolved): - typer.echo(f"Invalid model path: {name}") + valid_matches: set[pathlib.Path] = set() + any_outside = False + for root, _ in roots: + try: + candidate = (root / name).resolve() + except OSError: + continue + if any(candidate.is_relative_to(r) for r in resolved_roots): + if candidate.is_file(): + valid_matches.add(candidate) + else: + any_outside = True + + if not valid_matches: + if any_outside: + typer.echo(f"Invalid model path: {name}") + else: + missing_models.append(name) continue - if model_path.is_file(): - to_delete.append(model_path) - else: - missing_models.append(name) + + if len(valid_matches) > 1: + ambiguous.append((name, sorted(valid_matches))) + continue + + to_delete.append(valid_matches.pop()) + + if ambiguous: + for name, paths in ambiguous: + typer.echo(f"Ambiguous model name '{name}'; matches multiple paths:") + for p in paths: + typer.echo(f" {p}") + typer.echo("Specify a more specific path to disambiguate.") + if not to_delete: + return if missing_models: typer.echo("The following models were not found and cannot be removed: " + ", ".join(missing_models)) if not to_delete: - return # Exit if no valid models were found - - # Scenario #2: User did not provide model names, prompt for selection + return else: - rel_names = [str(model.relative_to(model_dir)) for model in available_models] - selections = ui.prompt_multi_select("Select models to delete:", rel_names) + if len(roots) == 1: + single_root = roots[0][0] + labels_to_paths = {str(file.relative_to(single_root)): file for file, _, _ in scanned} + else: + labels_to_paths = {str(file): file for file, _, _ in scanned} + + selections = ui.prompt_multi_select("Select models to delete:", list(labels_to_paths.keys())) if not selections: typer.echo("No models selected for deletion.") return - to_delete = [model_dir / selection for selection in selections] + to_delete = [labels_to_paths[sel] for sel in selections] - # Confirm deletion if to_delete and ( confirm or ui.prompt_confirm_action("Are you sure you want to delete the selected files?", False) ): @@ -446,6 +532,81 @@ def list_models(path: pathlib.Path) -> list[pathlib.Path]: return sorted(f for f in path.rglob("*") if f.is_file()) +def _load_extras_safely(use_extras: bool, extra_configs: list[pathlib.Path] | None) -> list: + if not use_extras: + return [] + try: + return collect_extra_paths(get_workspace(), extra_configs or []) + except yaml.YAMLError as e: + print(f"[yellow]Warning: extra_model_paths YAML is invalid; ignoring extras ({escape(str(e))})[/yellow]") + return [] + + +def _enumerate_search_roots(primary_root: pathlib.Path, extras: list) -> list[tuple[pathlib.Path, str | None]]: + """Return ``(root, category)`` pairs to scan, longest-first. + + The primary root carries ``category=None`` so list rendering preserves + today's "category from path" behavior. Extras roots carry their canonical + category name for the Type-column prefix. Roots are deduplicated by + realpath; unresolvable roots (e.g., circular symlinks) are skipped with + a warning. Sorting longest-first ensures a file under nested roots is + assigned to the most specific one. + """ + candidates: list[tuple[pathlib.Path, str | None]] = [(primary_root, None)] + for ep in extras: + candidates.append((ep.path, ep.category)) + + seen_resolved: set[pathlib.Path] = set() + unique: list[tuple[pathlib.Path, str | None]] = [] + for root, category in candidates: + try: + resolved = root.resolve() + except OSError as e: + print(f"[yellow]Warning: skipping {root}: {e}[/yellow]") + continue + if resolved in seen_resolved: + continue + seen_resolved.add(resolved) + unique.append((root, category)) + + unique.sort(key=lambda rc: len(rc[0].parts), reverse=True) + return unique + + +def _scan_all_roots( + roots: list[tuple[pathlib.Path, str | None]], +) -> list[tuple[pathlib.Path, pathlib.Path, str | None]]: + """Return ``(file, root, category)`` tuples, each file assigned to its + deepest containing root. Output is sorted by file path.""" + seen_files: set[pathlib.Path] = set() + result: list[tuple[pathlib.Path, pathlib.Path, str | None]] = [] + for root, category in roots: + for file in list_models(root): + try: + resolved = file.resolve() + except OSError: + continue + if resolved in seen_files: + continue + seen_files.add(resolved) + result.append((file, root, category)) + result.sort(key=lambda x: x[0]) + return result + + +def _format_type_column(file: pathlib.Path, root: pathlib.Path, category: str | None) -> str: + """Compute Type column text. For extras roots the canonical category is + prepended so output is consistent with the workspace listing where the + category is implicit in the on-disk subdir.""" + rel = file.relative_to(root) + parent = str(rel.parent) if len(rel.parts) > 1 else "" + if category is None: + return parent + if not parent or parent == ".": + return category + return f"{category}/{parent}" + + @app.command("list") @tracking.track_command("model") def list_command( @@ -455,20 +616,40 @@ def list_command( help="The relative path from the current workspace where the models are stored.", show_default=True, ), + extra_model_paths_config: list[pathlib.Path] | None = typer.Option( + None, + "--extra-model-paths-config", + help="Additional extra_model_paths.yaml file(s) to honor. Repeatable.", + show_default=False, + ), + extra_model_paths: bool = typer.Option( + True, + "--extra-model-paths/--no-extra-model-paths", + help="Honor extra_model_paths.yaml from the workspace and any --extra-model-paths-config files.", + show_default=False, + ), ): """Display a list of all models currently downloaded in a table format.""" - model_dir = get_workspace() / relative_path - models = list_models(model_dir) + primary = get_workspace() / relative_path + extras = _load_extras_safely(extra_model_paths, extra_model_paths_config) + roots = _enumerate_search_roots(primary, extras) + scanned = _scan_all_roots(roots) - if not models: + if not scanned: typer.echo("No models found.") return - # Prepare data for table display + show_source = len({r for _, r, _ in scanned}) > 1 data = [] - for model in models: - rel = model.relative_to(model_dir) - model_type = str(rel.parent) if len(rel.parts) > 1 else "" - data.append((model.name, model_type, f"{model.stat().st_size // 1024} KB")) - column_names = ["Model Name", "Type", "Size"] - ui.display_table(data, column_names) + for file, root, category in scanned: + type_str = _format_type_column(file, root, category) + size_str = f"{file.stat().st_size // 1024} KB" + if show_source: + data.append((file.name, type_str, size_str, str(root))) + else: + data.append((file.name, type_str, size_str)) + + columns = ["Model Name", "Type", "Size"] + if show_source: + columns.append("Source") + ui.display_table(data, columns) diff --git a/comfy_cli/extra_model_paths.py b/comfy_cli/extra_model_paths.py new file mode 100644 index 00000000..52d154cd --- /dev/null +++ b/comfy_cli/extra_model_paths.py @@ -0,0 +1,151 @@ +"""Parse ComfyUI's ``extra_model_paths.yaml`` files. + +The behavior mirrors ComfyUI's ``utils/extra_config.load_extra_path_config`` and +the priority semantics of ``folder_paths.add_model_folder_path``, but exposes +a pure-functional API instead of mutating module-level state. +""" + +from __future__ import annotations + +import logging +import os +from collections.abc import Sequence +from dataclasses import dataclass +from pathlib import Path + +import yaml + +logger = logging.getLogger(__name__) + +DEFAULT_FILENAME = "extra_model_paths.yaml" +LEGACY_NAME_MAP = {"unet": "diffusion_models", "clip": "text_encoders"} + + +@dataclass(frozen=True) +class ExtraPath: + category: str + path: Path + is_default: bool + section: str + + +def load_extra_paths(yaml_path: Path) -> list[ExtraPath]: + """Parse one ``extra_model_paths.yaml`` file. + + Returns entries in YAML document order. Priority resolution + (``is_default``, dedup, legacy aliasing) is the responsibility of + :func:`paths_for_category`. + + Returns ``[]`` for missing or empty files. Raises ``yaml.YAMLError`` + when the file exists but is not valid YAML, and ``OSError`` when the + file cannot be read. Logs a warning and skips structurally invalid + sections (non-mapping section, non-string category value). + """ + if not yaml_path.is_file(): + return [] + + with yaml_path.open(encoding="utf-8") as stream: + config = yaml.safe_load(stream) + + if config is None: + return [] + if not isinstance(config, dict): + logger.warning("extra_model_paths file %s is not a YAML mapping; ignoring", yaml_path) + return [] + + yaml_dir = os.path.dirname(os.path.abspath(str(yaml_path))) + result: list[ExtraPath] = [] + + for section_name, section in config.items(): + if section is None: + continue + if not isinstance(section, dict): + logger.warning( + "extra_model_paths section %r in %s is not a mapping; skipping", + section_name, + yaml_path, + ) + continue + + section = dict(section) + base_path = section.pop("base_path", None) + if base_path is not None: + base_path = os.path.expandvars(os.path.expanduser(str(base_path))) + if not os.path.isabs(base_path): + base_path = os.path.abspath(os.path.join(yaml_dir, base_path)) + is_default = bool(section.pop("is_default", False)) + + for raw_category, value in section.items(): + if value is None: + continue + if not isinstance(value, str): + logger.warning( + "extra_model_paths %s/%s in %s is %s, expected string; skipping", + section_name, + raw_category, + yaml_path, + type(value).__name__, + ) + continue + category = LEGACY_NAME_MAP.get(raw_category, raw_category) + + for raw_path in value.split("\n"): + if len(raw_path) == 0: + continue + if base_path: + full_path = os.path.join(base_path, raw_path) + elif os.path.isabs(raw_path): + full_path = raw_path + else: + full_path = os.path.abspath(os.path.join(yaml_dir, raw_path)) + normalized = os.path.normpath(full_path) + result.append( + ExtraPath( + category=category, + path=Path(normalized), + is_default=is_default, + section=str(section_name), + ) + ) + + return result + + +def collect_extra_paths(workspace: Path, extra_configs: Sequence[Path] = ()) -> list[ExtraPath]: + """Read ``/extra_model_paths.yaml`` plus any explicit configs. + + Concatenates results in order: workspace yaml first, then each entry + in ``extra_configs`` in the order given. No deduplication — that + requires syscalls and is left to the caller. + """ + result: list[ExtraPath] = [] + result.extend(load_extra_paths(workspace / DEFAULT_FILENAME)) + for cfg in extra_configs: + result.extend(load_extra_paths(cfg)) + return result + + +def paths_for_category(extras: Sequence[ExtraPath], category: str) -> list[Path]: + """Filter ``extras`` to one category and return paths in priority order. + + Mirrors ComfyUI's ``folder_paths.add_model_folder_path`` exactly: + each path appears at most once; ``is_default`` paths come before + non-default; a later ``is_default`` entry pointing to the same path + moves it to slot 0; a later non-default duplicate is a no-op. Legacy + aliases (``unet``, ``clip``) map to their canonical names. + """ + target = LEGACY_NAME_MAP.get(category, category) + result: list[Path] = [] + for ep in extras: + if ep.category != target: + continue + if ep.path in result: + if ep.is_default and result[0] != ep.path: + result.remove(ep.path) + result.insert(0, ep.path) + else: + if ep.is_default: + result.insert(0, ep.path) + else: + result.append(ep.path) + return result diff --git a/tests/comfy_cli/command/models/test_models.py b/tests/comfy_cli/command/models/test_models.py index 13dae746..a624827c 100644 --- a/tests/comfy_cli/command/models/test_models.py +++ b/tests/comfy_cli/command/models/test_models.py @@ -541,3 +541,346 @@ def test_download_exception_with_markup_chars_does_not_crash(self, tmp_path): # Literal markup characters must survive to the output so the user sees the real message. assert "[/]" in result.output assert "[id]" in result.output + + +class TestDownloadCommandExtraModelPaths: + """Verify extra_model_paths.yaml integration in `comfy model download`.""" + + _CIVITAI_URL = "https://civitai.com/models/43331?version=12345" + _DEFAULT_API_RETURN = ("api_default.bin", "http://x/file.bin", "checkpoint", "SDXL") + + def _civitai_invoke(self, tmp_path, *, args, civitai_return=None): + if civitai_return is None: + civitai_return = self._DEFAULT_API_RETURN + captured: list = [] + + def fake_dl(url, local_filepath, headers, downloader): + local_filepath.parent.mkdir(parents=True, exist_ok=True) + local_filepath.write_bytes(b"x") + captured.append(local_filepath) + + with ( + patch("comfy_cli.command.models.models.get_workspace", return_value=tmp_path), + patch("comfy_cli.command.models.models.request_civitai_model_api", return_value=civitai_return), + patch("comfy_cli.command.models.models.download_file", side_effect=fake_dl), + patch("comfy_cli.tracking.track_command", lambda _cmd: lambda fn: fn), + ): + return runner.invoke(app, args), captured + + def test_no_extras_uses_workspace_path(self, tmp_path): + result, captured = self._civitai_invoke( + tmp_path, + args=["download", "--url", self._CIVITAI_URL, "--filename", "x.bin"], + ) + assert result.exit_code == 0 + assert captured == [tmp_path / "models" / "checkpoints" / "SDXL" / "x.bin"] + + def test_extras_routes_to_configured_path(self, tmp_path): + ext_root = tmp_path / "ext" + (tmp_path / "extra_model_paths.yaml").write_text(f"comfyui:\n base_path: {ext_root}\n checkpoints: cp\n") + result, captured = self._civitai_invoke( + tmp_path, + args=["download", "--url", self._CIVITAI_URL, "--filename", "x.bin"], + ) + assert result.exit_code == 0 + assert captured == [ext_root / "cp" / "SDXL" / "x.bin"] + + def test_explicit_relative_path_overrides_extras(self, tmp_path): + ext_root = tmp_path / "ext" + (tmp_path / "extra_model_paths.yaml").write_text(f"comfyui:\n base_path: {ext_root}\n checkpoints: cp\n") + manual = tmp_path / "manual" + result, captured = self._civitai_invoke( + tmp_path, + args=[ + "download", + "--url", + self._CIVITAI_URL, + "--relative-path", + str(manual), + "--filename", + "x.bin", + ], + ) + assert result.exit_code == 0 + assert captured == [manual / "x.bin"] + + def test_no_extra_model_paths_flag_disables_extras(self, tmp_path): + ext_root = tmp_path / "ext" + (tmp_path / "extra_model_paths.yaml").write_text(f"comfyui:\n base_path: {ext_root}\n checkpoints: cp\n") + result, captured = self._civitai_invoke( + tmp_path, + args=[ + "download", + "--url", + self._CIVITAI_URL, + "--filename", + "x.bin", + "--no-extra-model-paths", + ], + ) + assert result.exit_code == 0 + assert captured == [tmp_path / "models" / "checkpoints" / "SDXL" / "x.bin"] + + def test_extras_for_other_category_falls_back(self, tmp_path): + ext_root = tmp_path / "ext" + (tmp_path / "extra_model_paths.yaml").write_text(f"comfyui:\n base_path: {ext_root}\n loras: l\n") + result, captured = self._civitai_invoke( + tmp_path, + args=["download", "--url", self._CIVITAI_URL, "--filename", "x.bin"], + ) + assert result.exit_code == 0 + assert captured == [tmp_path / "models" / "checkpoints" / "SDXL" / "x.bin"] + + def test_extra_model_paths_config_flag_loads_external_yaml(self, tmp_path): + ext_root = tmp_path / "ext" + external = tmp_path / "external.yaml" + external.write_text(f"comfyui:\n base_path: {ext_root}\n checkpoints: cp\n") + result, captured = self._civitai_invoke( + tmp_path, + args=[ + "download", + "--url", + self._CIVITAI_URL, + "--filename", + "x.bin", + "--extra-model-paths-config", + str(external), + ], + ) + assert result.exit_code == 0 + assert captured == [ext_root / "cp" / "SDXL" / "x.bin"] + + def test_is_default_priority_picks_marked_section(self, tmp_path): + primary = tmp_path / "primary" + fallback = tmp_path / "fallback" + (tmp_path / "extra_model_paths.yaml").write_text( + f"first:\n base_path: {fallback}\n checkpoints: cp\n" + f"second:\n base_path: {primary}\n is_default: true\n checkpoints: cp\n" + ) + result, captured = self._civitai_invoke( + tmp_path, + args=["download", "--url", self._CIVITAI_URL, "--filename", "x.bin"], + ) + assert result.exit_code == 0 + assert captured == [primary / "cp" / "SDXL" / "x.bin"] + + def test_huggingface_url_routes_to_extras(self, tmp_path): + ext_root = tmp_path / "ext" + (tmp_path / "extra_model_paths.yaml").write_text(f"comfyui:\n base_path: {ext_root}\n checkpoints: cp\n") + captured: list = [] + + def fake_dl(url, local_filepath, headers, downloader): + local_filepath.parent.mkdir(parents=True, exist_ok=True) + local_filepath.write_bytes(b"x") + captured.append(local_filepath) + + with ( + patch("comfy_cli.command.models.models.get_workspace", return_value=tmp_path), + patch("comfy_cli.command.models.models.download_file", side_effect=fake_dl), + patch("comfy_cli.command.models.models.check_civitai_url", return_value=(False, False, None, None)), + patch( + "comfy_cli.command.models.models.check_huggingface_url", + return_value=(True, "foo/bar", "x.safetensors", None, None), + ), + patch("comfy_cli.command.models.models.check_unauthorized", return_value=False), + patch("comfy_cli.command.models.models.ui") as mock_ui, + patch("comfy_cli.tracking.track_command", lambda _cmd: lambda fn: fn), + ): + mock_ui.prompt_input.side_effect = ["checkpoints", "SDXL"] + result = runner.invoke( + app, + [ + "download", + "--url", + "https://huggingface.co/foo/bar/resolve/main/x.safetensors", + "--filename", + "x.bin", + ], + ) + assert result.exit_code == 0 + assert captured == [ext_root / "cp" / "SDXL" / "x.bin"] + + def test_invalid_extras_yaml_warns_and_falls_back(self, tmp_path): + (tmp_path / "extra_model_paths.yaml").write_text("comfyui:\n base_path: /a\n checkpoints: cp\n") + result, captured = self._civitai_invoke( + tmp_path, + args=["download", "--url", self._CIVITAI_URL, "--filename", "x.bin"], + ) + assert result.exit_code == 0 + assert "extra_model_paths" in result.output + assert captured == [tmp_path / "models" / "checkpoints" / "SDXL" / "x.bin"] + + +class TestListCommandExtraModelPaths: + """Verify extra_model_paths.yaml integration in `comfy model list`.""" + + def _setup_workspace_models(self, tmp_path): + (tmp_path / "models" / "checkpoints").mkdir(parents=True) + (tmp_path / "models" / "checkpoints" / "ws_only.safetensors").write_bytes(b"x" * 100) + + def _setup_extras_root(self, tmp_path): + ext_root = tmp_path / "ext" + (ext_root / "cp" / "SDXL").mkdir(parents=True) + (ext_root / "cp" / "SDXL" / "ext_only.safetensors").write_bytes(b"x" * 200) + return ext_root + + def test_no_extras_baseline_unchanged(self, tmp_path): + self._setup_workspace_models(tmp_path) + with patch("comfy_cli.command.models.models.get_workspace", return_value=tmp_path): + result = runner.invoke(app, ["list", "--relative-path", "models"]) + assert result.exit_code == 0 + assert "ws_only.safetensors" in result.output + assert "Source" not in result.output + + def test_extras_files_appear(self, tmp_path): + self._setup_workspace_models(tmp_path) + ext_root = self._setup_extras_root(tmp_path) + (tmp_path / "extra_model_paths.yaml").write_text(f"comfyui:\n base_path: {ext_root}\n checkpoints: cp\n") + with patch("comfy_cli.command.models.models.get_workspace", return_value=tmp_path): + result = runner.invoke(app, ["list", "--relative-path", "models"]) + assert result.exit_code == 0 + assert "ws_only.safetensors" in result.output + assert "ext_only.safetensors" in result.output + + def test_multi_root_adds_source_column(self, tmp_path): + self._setup_workspace_models(tmp_path) + ext_root = self._setup_extras_root(tmp_path) + (tmp_path / "extra_model_paths.yaml").write_text(f"comfyui:\n base_path: {ext_root}\n checkpoints: cp\n") + with patch("comfy_cli.command.models.models.get_workspace", return_value=tmp_path): + result = runner.invoke(app, ["list", "--relative-path", "models"]) + assert result.exit_code == 0 + assert "Source" in result.output + + def test_type_column_prepends_category_for_extras_files(self, tmp_path): + ext_root = self._setup_extras_root(tmp_path) + (tmp_path / "models").mkdir() + (tmp_path / "extra_model_paths.yaml").write_text(f"comfyui:\n base_path: {ext_root}\n checkpoints: cp\n") + with patch("comfy_cli.command.models.models.get_workspace", return_value=tmp_path): + result = runner.invoke(app, ["list", "--relative-path", "models"]) + assert result.exit_code == 0 + assert "checkpoints/SDXL" in result.output + + def test_dedup_when_extras_root_overlaps_workspace(self, tmp_path): + self._setup_workspace_models(tmp_path) + (tmp_path / "extra_model_paths.yaml").write_text( + f"comfyui:\n base_path: {tmp_path}\n checkpoints: models/checkpoints\n" + ) + with patch("comfy_cli.command.models.models.get_workspace", return_value=tmp_path): + result = runner.invoke(app, ["list", "--relative-path", "models"]) + assert result.exit_code == 0 + assert result.output.count("ws_only.safetensors") == 1 + + def test_no_extra_model_paths_flag_disables(self, tmp_path): + self._setup_workspace_models(tmp_path) + ext_root = self._setup_extras_root(tmp_path) + (tmp_path / "extra_model_paths.yaml").write_text(f"comfyui:\n base_path: {ext_root}\n checkpoints: cp\n") + with patch("comfy_cli.command.models.models.get_workspace", return_value=tmp_path): + result = runner.invoke(app, ["list", "--relative-path", "models", "--no-extra-model-paths"]) + assert result.exit_code == 0 + assert "ws_only.safetensors" in result.output + assert "ext_only.safetensors" not in result.output + + +class TestRemoveCommandExtraModelPaths: + """Verify extra_model_paths.yaml integration in `comfy model remove`.""" + + def test_remove_target_in_extras_root_deletes(self, tmp_path): + ext_root = tmp_path / "ext" + (ext_root / "cp").mkdir(parents=True) + target = ext_root / "cp" / "x.safetensors" + target.write_bytes(b"x") + (tmp_path / "models").mkdir() + (tmp_path / "extra_model_paths.yaml").write_text(f"comfyui:\n base_path: {ext_root}\n checkpoints: cp\n") + with patch("comfy_cli.command.models.models.get_workspace", return_value=tmp_path): + result = runner.invoke( + app, + [ + "remove", + "--relative-path", + "models", + "--model-names", + str(target), + "--confirm", + ], + ) + assert result.exit_code == 0 + assert not target.exists() + + def test_remove_model_names_ambiguous_errors(self, tmp_path): + # Same filename in workspace AND in an extras root + (tmp_path / "models").mkdir() + ws_target = tmp_path / "models" / "dup.safetensors" + ws_target.write_bytes(b"x") + ext_root = tmp_path / "ext" + ext_target = ext_root / "cp" / "dup.safetensors" + ext_target.parent.mkdir(parents=True) + ext_target.write_bytes(b"x") + # Configure extras root that, joined with "dup.safetensors", finds /ext/dup.safetensors + # — which doesn't exist. To force ambiguity we point the extras root AT the file's + # parent directly. + (tmp_path / "extra_model_paths.yaml").write_text( + f"comfyui:\n base_path: {tmp_path}\n checkpoints: |\n models\n ext/cp\n" + ) + with patch("comfy_cli.command.models.models.get_workspace", return_value=tmp_path): + result = runner.invoke( + app, + [ + "remove", + "--relative-path", + "models", + "--model-names", + "dup.safetensors", + "--confirm", + ], + ) + assert "Ambiguous" in result.output + assert ws_target.exists() + assert ext_target.exists() + + def test_remove_traversal_protection_with_extras(self, tmp_path): + (tmp_path / "models").mkdir() + (tmp_path / "models" / "legit.safetensors").write_bytes(b"x") + secret = tmp_path / "secret.txt" + secret.write_text("sensitive") + ext_root = tmp_path / "ext" + (ext_root / "cp").mkdir(parents=True) + (ext_root / "cp" / "extras_file.safetensors").write_bytes(b"x") + (tmp_path / "extra_model_paths.yaml").write_text(f"comfyui:\n base_path: {ext_root}\n checkpoints: cp\n") + with patch("comfy_cli.command.models.models.get_workspace", return_value=tmp_path): + result = runner.invoke( + app, + [ + "remove", + "--relative-path", + "models", + "--model-names", + "../secret.txt", + "--confirm", + ], + ) + assert secret.exists() + assert "Invalid model path" in result.output + + def test_remove_no_extra_model_paths_disables(self, tmp_path): + ext_root = tmp_path / "ext" + (ext_root / "cp").mkdir(parents=True) + ext_target = ext_root / "cp" / "x.safetensors" + ext_target.write_bytes(b"x") + (tmp_path / "models").mkdir() + (tmp_path / "models" / "ws_file.safetensors").write_bytes(b"x") + (tmp_path / "extra_model_paths.yaml").write_text(f"comfyui:\n base_path: {ext_root}\n checkpoints: cp\n") + with patch("comfy_cli.command.models.models.get_workspace", return_value=tmp_path): + result = runner.invoke( + app, + [ + "remove", + "--relative-path", + "models", + "--model-names", + str(ext_target), + "--confirm", + "--no-extra-model-paths", + ], + ) + assert ext_target.exists() + assert "Invalid model path" in result.output diff --git a/tests/comfy_cli/test_extra_model_paths.py b/tests/comfy_cli/test_extra_model_paths.py new file mode 100644 index 00000000..f5fd858d --- /dev/null +++ b/tests/comfy_cli/test_extra_model_paths.py @@ -0,0 +1,399 @@ +import logging +import os +from pathlib import Path +from textwrap import dedent + +import pytest +import yaml + +from comfy_cli.extra_model_paths import ( + ExtraPath, + collect_extra_paths, + load_extra_paths, + paths_for_category, +) + + +def _write(path: Path, content: str) -> Path: + path.write_text(dedent(content).lstrip()) + return path + + +# ---------- load_extra_paths ---------- + + +def test_missing_file_returns_empty(tmp_path): + assert load_extra_paths(tmp_path / "absent.yaml") == [] + + +def test_empty_file_returns_empty(tmp_path): + p = _write(tmp_path / "x.yaml", "") + assert load_extra_paths(p) == [] + + +def test_invalid_yaml_propagates(tmp_path): + p = _write( + tmp_path / "x.yaml", + """ + comfyui: + base_path: /foo + checkpoints: bar + """, + ) + with pytest.raises(yaml.YAMLError): + load_extra_paths(p) + + +def test_top_level_not_a_mapping_warns_and_returns_empty(tmp_path, caplog): + p = _write(tmp_path / "x.yaml", "[a, b]\n") + with caplog.at_level(logging.WARNING, logger="comfy_cli.extra_model_paths"): + assert load_extra_paths(p) == [] + assert "not a YAML mapping" in caplog.text + + +def test_absolute_path_no_base_path(tmp_path): + abs_dir = tmp_path / "external" / "checkpoints" + p = _write( + tmp_path / "x.yaml", + f""" + comfyui: + checkpoints: {abs_dir} + """, + ) + [entry] = load_extra_paths(p) + assert entry.category == "checkpoints" + assert entry.path == Path(os.path.normpath(str(abs_dir))) + assert entry.is_default is False + assert entry.section == "comfyui" + + +def test_relative_base_path_resolves_to_yaml_dir(tmp_path): + yaml_dir = tmp_path / "configs" + yaml_dir.mkdir() + p = _write( + yaml_dir / "x.yaml", + """ + comfyui: + base_path: store + checkpoints: cp + """, + ) + [entry] = load_extra_paths(p) + expected = Path(os.path.normpath(str(yaml_dir / "store" / "cp"))) + assert entry.path == expected + + +def test_base_path_with_tilde_expansion(tmp_path, monkeypatch): + fake_home = tmp_path / "home" + fake_home.mkdir() + monkeypatch.setenv("HOME", str(fake_home)) + monkeypatch.setenv("USERPROFILE", str(fake_home)) + p = _write( + tmp_path / "x.yaml", + """ + comfyui: + base_path: ~/models + checkpoints: cp + """, + ) + [entry] = load_extra_paths(p) + assert entry.path == Path(os.path.normpath(str(fake_home / "models" / "cp"))) + + +def test_base_path_with_env_var_expansion(tmp_path, monkeypatch): + target = tmp_path / "var_target" + monkeypatch.setenv("MODEL_ROOT", str(target)) + p = _write( + tmp_path / "x.yaml", + """ + comfyui: + base_path: $MODEL_ROOT/models + loras: l + """, + ) + [entry] = load_extra_paths(p) + assert entry.path == Path(os.path.normpath(str(target / "models" / "l"))) + + +def test_multiline_block_scalar_splits_paths(tmp_path): + p = _write( + tmp_path / "x.yaml", + """ + comfyui: + base_path: /base + text_encoders: | + models/text_encoders/ + models/clip/ + """, + ) + paths = [e.path for e in load_extra_paths(p)] + assert paths == [ + Path(os.path.normpath("/base/models/text_encoders/")), + Path(os.path.normpath("/base/models/clip/")), + ] + + +def test_blank_lines_in_block_scalar_skipped(tmp_path): + p = _write( + tmp_path / "x.yaml", + """ + comfyui: + base_path: /base + loras: | + loras + + more_loras + """, + ) + paths = [e.path for e in load_extra_paths(p)] + assert paths == [ + Path(os.path.normpath("/base/loras")), + Path(os.path.normpath("/base/more_loras")), + ] + + +def test_is_default_flag_preserved(tmp_path): + p = _write( + tmp_path / "x.yaml", + """ + comfyui: + base_path: /a + is_default: true + checkpoints: cp + a111: + base_path: /b + checkpoints: cp + """, + ) + entries = load_extra_paths(p) + assert entries[0].is_default is True and entries[0].section == "comfyui" + assert entries[1].is_default is False and entries[1].section == "a111" + + +def test_legacy_aliases_unet_and_clip(tmp_path): + p = _write( + tmp_path / "x.yaml", + """ + comfyui: + base_path: /a + unet: u + clip: c + """, + ) + cats = [e.category for e in load_extra_paths(p)] + assert cats == ["diffusion_models", "text_encoders"] + + +def test_none_section_skipped(tmp_path): + p = _write( + tmp_path / "x.yaml", + """ + comfyui: + a111: + base_path: /b + checkpoints: cp + """, + ) + [entry] = load_extra_paths(p) + assert entry.section == "a111" + + +def test_none_category_value_skipped(tmp_path): + p = _write( + tmp_path / "x.yaml", + """ + comfyui: + base_path: /a + checkpoints: + loras: l + """, + ) + [entry] = load_extra_paths(p) + assert entry.category == "loras" + + +def test_non_string_category_value_warns_and_skips(tmp_path, caplog): + p = _write( + tmp_path / "x.yaml", + """ + comfyui: + base_path: /a + checkpoints: [bad, list] + loras: ok + """, + ) + with caplog.at_level(logging.WARNING, logger="comfy_cli.extra_model_paths"): + entries = load_extra_paths(p) + assert [e.category for e in entries] == ["loras"] + assert "checkpoints" in caplog.text + + +def test_section_not_a_mapping_warns_and_skips(tmp_path, caplog): + p = _write( + tmp_path / "x.yaml", + """ + comfyui: just_a_string + a111: + base_path: /b + checkpoints: cp + """, + ) + with caplog.at_level(logging.WARNING, logger="comfy_cli.extra_model_paths"): + entries = load_extra_paths(p) + assert [e.section for e in entries] == ["a111"] + assert "comfyui" in caplog.text + + +def test_normpath_collapses_dot_dot(tmp_path): + p = _write( + tmp_path / "x.yaml", + """ + comfyui: + base_path: /a/b + checkpoints: ../cp + """, + ) + [entry] = load_extra_paths(p) + assert entry.path == Path(os.path.normpath("/a/b/../cp")) + assert ".." not in entry.path.parts + + +def test_relative_path_no_base_path_uses_yaml_dir(tmp_path): + yaml_dir = tmp_path / "cfg" + yaml_dir.mkdir() + p = _write( + yaml_dir / "x.yaml", + """ + comfyui: + checkpoints: rel/cp + """, + ) + [entry] = load_extra_paths(p) + assert entry.path == Path(os.path.normpath(str(yaml_dir / "rel" / "cp"))) + + +# ---------- collect_extra_paths ---------- + + +def test_collect_workspace_yaml_only(tmp_path): + workspace = tmp_path / "ws" + workspace.mkdir() + _write( + workspace / "extra_model_paths.yaml", + """ + comfyui: + base_path: /a + checkpoints: cp + """, + ) + entries = collect_extra_paths(workspace) + assert len(entries) == 1 + assert entries[0].section == "comfyui" + + +def test_collect_no_workspace_yaml_with_extra_config(tmp_path): + workspace = tmp_path / "ws" + workspace.mkdir() + extra = _write( + tmp_path / "extra.yaml", + """ + comfyui: + base_path: /a + loras: l + """, + ) + entries = collect_extra_paths(workspace, [extra]) + assert len(entries) == 1 + assert entries[0].category == "loras" + + +def test_collect_workspace_then_extras_in_order(tmp_path): + workspace = tmp_path / "ws" + workspace.mkdir() + _write( + workspace / "extra_model_paths.yaml", + """ + ws_section: + base_path: /a + checkpoints: cp + """, + ) + extra1 = _write( + tmp_path / "e1.yaml", + """ + e1_section: + base_path: /b + checkpoints: cp + """, + ) + extra2 = _write( + tmp_path / "e2.yaml", + """ + e2_section: + base_path: /c + checkpoints: cp + """, + ) + entries = collect_extra_paths(workspace, [extra1, extra2]) + assert [e.section for e in entries] == ["ws_section", "e1_section", "e2_section"] + + +# ---------- paths_for_category ---------- + + +def test_paths_for_unknown_category_empty(): + assert paths_for_category([], "anything") == [] + extras = [ExtraPath("loras", Path("/a"), False, "s")] + assert paths_for_category(extras, "checkpoints") == [] + + +def test_paths_preserve_yaml_order_when_no_default(): + extras = [ + ExtraPath("checkpoints", Path("/A"), False, "a"), + ExtraPath("checkpoints", Path("/B"), False, "b"), + ExtraPath("checkpoints", Path("/C"), False, "c"), + ] + assert paths_for_category(extras, "checkpoints") == [Path("/A"), Path("/B"), Path("/C")] + + +def test_is_default_paths_come_before_non_default(): + extras = [ + ExtraPath("checkpoints", Path("/X"), False, "x"), + ExtraPath("checkpoints", Path("/Y"), True, "y"), + ] + assert paths_for_category(extras, "checkpoints") == [Path("/Y"), Path("/X")] + + +def test_two_is_default_later_wins_slot_zero(): + extras = [ + ExtraPath("checkpoints", Path("/A"), True, "a"), + ExtraPath("checkpoints", Path("/B"), True, "b"), + ] + assert paths_for_category(extras, "checkpoints") == [Path("/B"), Path("/A")] + + +def test_legacy_alias_query_returns_canonical_paths(): + extras = [ + ExtraPath("diffusion_models", Path("/dm"), False, "s"), + ExtraPath("text_encoders", Path("/te"), False, "s"), + ] + assert paths_for_category(extras, "unet") == [Path("/dm")] + assert paths_for_category(extras, "clip") == [Path("/te")] + + +def test_duplicate_with_default_moves_to_head(): + extras = [ + ExtraPath("loras", Path("/A"), False, "a"), + ExtraPath("loras", Path("/B"), False, "b"), + ExtraPath("loras", Path("/B"), True, "c"), + ] + assert paths_for_category(extras, "loras") == [Path("/B"), Path("/A")] + + +def test_duplicate_without_default_is_noop(): + extras = [ + ExtraPath("loras", Path("/A"), True, "a"), + ExtraPath("loras", Path("/B"), False, "b"), + ExtraPath("loras", Path("/A"), False, "c"), + ] + assert paths_for_category(extras, "loras") == [Path("/A"), Path("/B")]