Skip to content

fix(utils): jax_to_pytorch NumPy bridge cannot convert bfloat16 arrays#269

Open
opooladz wants to merge 1 commit into
erfanzar:vnextfrom
opooladz:fix/jax-to-pytorch-bf16
Open

fix(utils): jax_to_pytorch NumPy bridge cannot convert bfloat16 arrays#269
opooladz wants to merge 1 commit into
erfanzar:vnextfrom
opooladz:fix/jax-to-pytorch-bf16

Conversation

@opooladz

Copy link
Copy Markdown

torch.from_numpy raises TypeError for ml_dtypes.bfloat16 inputs (NumPy has no native bfloat16), so the default EASY_SAFE_TRANSFER path failed for any bf16 model export (module.to_torch / save_pretrained(to_torch= True)) with: "can't convert np.ndarray of type ml_dtypes.bfloat16".

Reinterpret the array bits as uint16 and view the resulting tensor back as torch.bfloat16 -- bit-exact (verified by raw-bit comparison in the added tests), no upcast, no additional copy. Other dtypes take the existing path unchanged.

torch.from_numpy raises TypeError for ml_dtypes.bfloat16 inputs (NumPy
has no native bfloat16), so the default EASY_SAFE_TRANSFER path failed
for any bf16 model export (module.to_torch / save_pretrained(to_torch=
True)) with: "can't convert np.ndarray of type ml_dtypes.bfloat16".

Reinterpret the array bits as uint16 and view the resulting tensor back
as torch.bfloat16 -- bit-exact (verified by raw-bit comparison in the
added tests), no upcast, no additional copy. Other dtypes take the
existing path unchanged.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant