Avoid fp32 cast for Torch div operator#2241
Closed
HennerM wants to merge 3 commits intoapple:mainfrom
Closed
Conversation
The `div` Torch op was always casting both operands to fp32, even if both operands are of type fp16. This cast should get removed by the `"common::add_fp16_cast"` optimization pass. However, it causes issues during the PyTorch conversion, for example let's say we have a forward method like this:
```python
class Foo:
def __init__(self):
super().__init__()
self.proj = torch.nn.Linear(16, 1)
def forward(self, x, y): # both fp16 tensors, shape [1, 16]
r = x / y # r is now fp32
return self.proj(r) # Problem
```
Now if we have moved the model (and it's parameters) to fp16 with eg. `m = Foo().to(torch.float16)`, we get an error at conversion time:
> In op, of type linear, named linear_0, the named input `bias` must have the same data type as the named input `x`. However, bias has dtype fp16 whereas x has dtype fp32.
This is because the result of the `div` operation stays fp32, and this doesn't match the resulting type of the PyTorch expression.
Collaborator
|
Please add a unit test to test_torch_ops.py which fails without your fix but passes with your fix. |
|
I stumbled across the same issue and managed to debug it. It turns out the root cause isn't the Here's a minimal repro that does not use the import coremltools as ct
import numpy as np
import torch
class Net(torch.nn.Module):
def __init__(self):
super().__init__()
self.proj = torch.nn.Linear(16, 1)
def forward(self, x):
return self.proj(x)
x = torch.randn(1, 16, dtype=torch.float16)
with torch.no_grad():
mlmodel = ct.convert(
torch.jit.trace(Net().half().eval(), x),
inputs=[ct.TensorType(name="x", shape=x.shape, dtype=np.float16)],
outputs=[ct.TensorType(name="output")],
convert_to="mlprogram",
compute_precision=ct.precision.FLOAT16,
minimum_deployment_target=ct.target.iOS17,
)This fails with the same exception as the snippet above with the I've got a fix for this in #2274. |
Collaborator
|
Hi @HennerM, Concretely, internally we translate torch model in fp32. Then,
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
The
divTorch op was always casting both operands to fp32, even if both operands are of type fp16. This cast should get removed by the"common::add_fp16_cast"optimization pass. However, it causes issues during the PyTorch conversion, for example let's say we have a forward method like this:Now if we have moved the model (and it's parameters) to fp16 with eg.
m = Foo().to(torch.float16), we get an error at conversion time:This is because the result of the
divoperation stays fp32, and this doesn't match the resulting type of the PyTorch expression.