diff --git a/src/quantem/core/io/file_readers.py b/src/quantem/core/io/file_readers.py index cb36f1de..b575a6d3 100644 --- a/src/quantem/core/io/file_readers.py +++ b/src/quantem/core/io/file_readers.py @@ -30,7 +30,20 @@ def read_4dstem( Index of the dataset to load if file contains multiple datasets. If None, automatically selects the first 4D dataset found. **kwargs: dict - Additional keyword arguments to pass to the Dataset4dstem constructor. + Additional keyword arguments to pass to the file reader. + + Other Parameters + ---------------- + name : str | None, optional + A descriptive name for the dataset. If None, defaults to "4D-STEM dataset" + origin : NDArray | tuple | list | float | int | None, optional + The origin coordinates for each dimension. If None, defaults to zeros + sampling : NDArray | tuple | list | float | int | None, optional + The sampling rate/spacing for each dimension. If None, defaults to ones + units : list[str] | tuple | list | None, optional + Units for each dimension. If None, defaults to ["pixels"] * 4 + signal_units : str, optional + Units for the array values, by default "arb. units" Returns -------- @@ -39,8 +52,13 @@ def read_4dstem( if file_type is None: file_type = Path(file_path).suffix.lower().lstrip(".") + sampling_override = kwargs.pop("sampling", None) + origin_override = kwargs.pop("origin", None) + units_override = kwargs.pop("units", None) + name_override = kwargs.pop("name", None) + file_reader = importlib.import_module(f"rsciio.{file_type}").file_reader - data_list = file_reader(file_path) + data_list = file_reader(file_path, **kwargs) # If specific index provided, use it if dataset_index is not None: @@ -69,17 +87,20 @@ def read_4dstem( imported_axes = imported_data["axes"] - sampling = kwargs.pop( - "sampling", - [ax["scale"] for ax in imported_axes], + sampling = ( + sampling_override + if sampling_override is not None + else [ax.get("scale", 1) for ax in imported_axes] ) - origin = kwargs.pop( - "origin", - [ax["offset"] for ax in imported_axes], + origin = ( + origin_override + if origin_override is not None + else [ax.get("offset", 0) for ax in imported_axes] ) - units = kwargs.pop( - "units", - ["pixels" if ax["units"] == "1" else ax["units"] for ax in imported_axes], + units = ( + units_override + if units_override is not None + else ["pixels" if ax["units"] == "1" else ax["units"] for ax in imported_axes] ) dataset = Dataset4dstem.from_array( @@ -87,7 +108,7 @@ def read_4dstem( sampling=sampling, origin=origin, units=units, - **kwargs, + name=name_override, ) return dataset