New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
Torch Dynamo Error when Eager-Compiling Detectron2-based Model #103008
Comments
Based on the stack trace, it looks like some namedtuples are involved |
Looking through the code, I don't really see named tuples. But it looks like we don't handle namedtuples with default values correctly and triggers the same assert that is happening in this issue. import torch
from collections import namedtuple
@torch.compile(backend='eager', fullgraph=True)
def f(x):
Point = namedtuple('point', ['x', 'y', 'z'], defaults=(None, None, None))
point = Point(0, y=x)
return point.y * x
x = torch.randn(3)
f(x) |
@gs-olive it's a bit difficult to me to install all the dependencies to actually repro the bug. Do you think it would be possible to cut the repro just down to using detectron2? |
Thanks for looking into this - I will try to reduce the repro as much as possible, and will update with any findings. |
fwiw, detectron2 doesn't work in our benchmark suite |
Hi @zou3519 - I have a minimal reproducing example for the reported error. The model is attached below. It takes in a list of strings as input, and outputs a The model below uses only the open_clip repository and Torch. I have verified it is failing on the latest nightly ( import torch
import open_clip
class SampleModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.tokenizer = open_clip.get_tokenizer('ViT-B-32')
def forward(self, x):
text = self.tokenizer(x)
return text
"""
Code adapted from:
https://github.com/mlfoundations/open_clip/blob/fb72f4db1b17133befd6c67c9cf32a533b85a321/README.md?plain=1#L62-L71
"""
model = SampleModule().eval().cuda()
input_ = ["a diagram", "a dog", "a cat"]
# Passes
model(input_)
# Fails
compiled = torch.compile(model, backend="eager")
compiled(input_) |
cc @jansel |
Yes, this seems like a bug in our namedtuple handling. The error is under Should be a relatively simple fix if someone wants to pick it up. We basically need to replace:
with:
|
Hi, a newbie here, this will be my first open source contribution. I think I understood what changes are to be made, but I just want to know how to check if there is a default value for a namedtuple field, is there a list that contains the default values? (I tried going through the code and googled a bit, but was unable to find anything relevant) |
Might be somewhat similar fix to: cc @anijain2305 |
馃悰 Describe the bug
The following error is encountered when compiling the OneFormer model featured in this tutorial notebook, based on the detectron2 architecture, with
backend="eager"
:The same model, run on the same inputs, is functional when called directly using torch. Specifically,
model(inputs)
succeeds, buttorch.compile(model, backend="eager")(inputs)
does not.Error logs
Minified repro
TORCHDYNAMO_REPRO_AFTER="dynamo"
did not produce an output file. One way to reproduce the issue is to run this tutorial notebook, take thepredictor
object result, then run the following:Versions
cc @ezyang @msaroufim @wconstab @bdhirsh @anijain2305 @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @ipiszy @chenyang78
The text was updated successfully, but these errors were encountered: