diff --git a/packages/jumpstarter-driver-ridesx/jumpstarter_driver_ridesx/client.py b/packages/jumpstarter-driver-ridesx/jumpstarter_driver_ridesx/client.py index 2a557eeca..07403fb46 100644 --- a/packages/jumpstarter-driver-ridesx/jumpstarter_driver_ridesx/client.py +++ b/packages/jumpstarter-driver-ridesx/jumpstarter_driver_ridesx/client.py @@ -1,7 +1,12 @@ +import ssl +import tempfile from dataclasses import dataclass from pathlib import Path from typing import Dict, Optional +from urllib.parse import urlparse +import aiohttp +import click from jumpstarter_driver_composite.client import CompositeClient from jumpstarter_driver_opendal.client import FlasherClient, operator_for_path from jumpstarter_driver_power.client import PowerClient @@ -22,10 +27,51 @@ def __post_init__(self): def boot_to_fastboot(self): return self.call("boot_to_fastboot") - def _upload_file_if_needed(self, file_path: str, operator: Operator | None = None) -> str: + def _is_http_url(self, path: str) -> bool: + """Check if the path is an HTTP or HTTPS URL.""" + return isinstance(path, str) and path.startswith(("http://", "https://")) + + def _download_http_to_storage(self, url: str, storage, filename: str, insecure_tls: bool = False) -> None: + async def _download(): + parsed = urlparse(url) + if parsed.scheme == "http" or insecure_tls: + ssl_context: ssl.SSLContext | bool = False + else: + ssl_context = True + + connector = aiohttp.TCPConnector(ssl=ssl_context) + async with aiohttp.ClientSession(connector=connector) as session: + async with session.get(url) as response: + response.raise_for_status() + with tempfile.NamedTemporaryFile(delete=False, dir="/var/tmp") as f: + async for chunk in response.content.iter_chunked(65536): + f.write(chunk) + return Path(f.name) + + tmp_path = self.portal.call(_download) + try: + storage.write_from_path(filename, tmp_path) + finally: + tmp_path.unlink() + + def _upload_file_if_needed( + self, file_path: str, operator: Operator | None = None, insecure_tls: bool = False + ) -> str: if not file_path or not file_path.strip(): raise ValueError("File path cannot be empty. Please provide a valid file path.") + if self._is_http_url(file_path) and operator is None: + parsed = urlparse(file_path) + is_insecure_http = parsed.scheme == "http" + + # use aiohttp for: http:// URLs, or https:// with insecure_tls + if is_insecure_http or insecure_tls: + filename = Path(parsed.path).name + self.logger.info(f"Downloading {file_path} to storage as {filename}") + self._download_http_to_storage(file_path, self.storage, filename, insecure_tls=insecure_tls) + return filename + + # use opendal for local files, https:// (secure), and other schemes if operator is None: path_buf, operator, operator_scheme = operator_for_path(file_path) else: @@ -46,12 +92,18 @@ def _upload_file_if_needed(self, file_path: str, operator: Operator | None = Non return filename - def flash_images(self, partitions: Dict[str, str], operators: Optional[Dict[str, Operator]] = None): + def flash_images( + self, + partitions: Dict[str, str], + operators: Optional[Dict[str, Operator]] = None, + insecure_tls: bool = False, + ): """Flash images to specified partitions Args: partitions: Dictionary mapping partition names to file paths operators: Optional dictionary mapping partition names to operators + insecure_tls: Skip TLS certificate verification for HTTPS URLs """ if not partitions: raise ValueError("At least one partition must be provided") @@ -62,7 +114,7 @@ def flash_images(self, partitions: Dict[str, str], operators: Optional[Dict[str, for partition, file_path in partitions.items(): self.logger.info(f"Processing {partition} image: {file_path}") operator = operators.get(partition) - remote_files[partition] = self._upload_file_if_needed(file_path, operator) + remote_files[partition] = self._upload_file_if_needed(file_path, operator, insecure_tls=insecure_tls) self.logger.info("Checking for fastboot devices on Exporter...") detection_result = self.call("detect_fastboot_device", 5, 2.0) @@ -84,6 +136,7 @@ def flash( target: str | None = None, operator: Operator | Dict[str, Operator] | None = None, compression=None, + insecure_tls: bool = False, ): if isinstance(path, dict): partitions = path @@ -109,7 +162,7 @@ def flash( self.boot_to_fastboot() - result = self.flash_images(partitions, operators) + result = self.flash_images(partitions, operators, insecure_tls=insecure_tls) self.logger.info("flash operation completed successfully") @@ -130,7 +183,35 @@ def base(): pass for name, cmd in generic_cli.commands.items(): - base.add_command(cmd, name=name) + if name != "flash": + base.add_command(cmd, name=name) + + @base.command() + @click.argument("file", nargs=-1, required=False) + @click.option( + "--target", + "-t", + "target_specs", + multiple=True, + help="name:file", + ) + @click.option("--insecure-tls", is_flag=True, help="Skip TLS certificate verification") + def flash(file, target_specs, insecure_tls): + """Flash image to DUT""" + if target_specs: + mapping: dict[str, str] = {} + for spec in target_specs: + if ":" not in spec: + raise click.ClickException(f"Invalid target spec '{spec}', expected name:file") + name, img = spec.split(":", 1) + mapping[name] = img + self.flash(mapping, insecure_tls=insecure_tls) + return + + if not file: + raise click.ClickException("FILE argument is required unless --target/-t is used") + + self.flash(file[0], target=None, insecure_tls=insecure_tls) @base.command() def boot_to_fastboot():