Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch Dynamo Error when Eager-Compiling Detectron2-based Model #103008

Closed
gs-olive opened this issue Jun 5, 2023 · 10 comments
Closed

Torch Dynamo Error when Eager-Compiling Detectron2-based Model #103008

gs-olive opened this issue Jun 5, 2023 · 10 comments
Assignees
Labels
good first issue module: dynamo oncall: pt2 release notes: dynamo triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@gs-olive
Copy link
Contributor

gs-olive commented Jun 5, 2023

馃悰 Describe the bug

The following error is encountered when compiling the OneFormer model featured in this tutorial notebook, based on the detectron2 architecture, with backend="eager":

  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/variables/user_defined.py", line 128, in call_function
    assert all(x is not None for x in items)
AssertionError: 

The same model, run on the same inputs, is functional when called directly using torch. Specifically, model(inputs) succeeds, but torch.compile(model, backend="eager")(inputs) does not.

Error logs

Traceback (most recent call last):
  File "my_demo.py", line 144, in <module>
    optimized([inputs])
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py", line 289, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _call_impl
    return forward_call(*args, **kwargs)
  File "~/OneFormer/oneformer/oneformer_model.py", line 274, in forward
    images = ImageList.from_tensors(images, self.size_divisibility)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py", line 442, in catch_errors
    return callback(frame, cache_size, hooks, frame_state)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 527, in _convert_frame
    result = inner_convert(frame, cache_size, hooks, frame_state)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 127, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 360, in _convert_frame_assert
    return _compile(
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/utils.py", line 180, in time_wrapper
    r = func(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 430, in _compile
    out_code = transform_code_object(code, transform)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1000, in transform_code_object
    transformations(instructions, code_options)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 415, in transform
    tracer.run()
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 2020, in run
    super().run()
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 704, in run
    and self.step()
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 664, in step
    getattr(self, inst.opname)(inst)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 386, in wrapper
    return inner_fn(self, inst)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 1095, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 555, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/variables/functions.py", line 120, in call_function
    return tx.inline_user_function_return(
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 591, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 2125, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 2222, in inline_call_
    tracer.run()
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 704, in run
    and self.step()
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 664, in step
    getattr(self, inst.opname)(inst)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 386, in wrapper
    return inner_fn(self, inst)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 1095, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 555, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/variables/user_defined.py", line 352, in call_function
    return self.call_method(tx, "__call__", args, kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/variables/user_defined.py", line 273, in call_method
    return UserMethodVariable(
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/variables/functions.py", line 333, in call_function
    return super().call_function(tx, args, kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/variables/functions.py", line 287, in call_function
    return super().call_function(tx, args, kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/variables/functions.py", line 120, in call_function
    return tx.inline_user_function_return(
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 591, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 2125, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 2222, in inline_call_
    tracer.run()
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 704, in run
    and self.step()
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 664, in step
    getattr(self, inst.opname)(inst)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 386, in wrapper
    return inner_fn(self, inst)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 1095, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 555, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/variables/functions.py", line 120, in call_function
    return tx.inline_user_function_return(
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 591, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 2125, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 2222, in inline_call_
    tracer.run()
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 704, in run
    and self.step()
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 664, in step
    getattr(self, inst.opname)(inst)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 386, in wrapper
    return inner_fn(self, inst)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 1095, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 555, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/variables/functions.py", line 333, in call_function
    return super().call_function(tx, args, kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/variables/functions.py", line 287, in call_function
    return super().call_function(tx, args, kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/variables/functions.py", line 120, in call_function
    return tx.inline_user_function_return(
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 591, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 2125, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 2222, in inline_call_
    tracer.run()
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 704, in run
    and self.step()
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 664, in step
    getattr(self, inst.opname)(inst)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 386, in wrapper
    return inner_fn(self, inst)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 1095, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 555, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/variables/functions.py", line 287, in call_function
    return super().call_function(tx, args, kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/variables/functions.py", line 120, in call_function
    return tx.inline_user_function_return(
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 591, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 2125, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 2222, in inline_call_
    tracer.run()
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 704, in run
    and self.step()
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 664, in step
    getattr(self, inst.opname)(inst)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 386, in wrapper
    return inner_fn(self, inst)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 1095, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 555, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/variables/functions.py", line 287, in call_function
    return super().call_function(tx, args, kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/variables/functions.py", line 120, in call_function
    return tx.inline_user_function_return(
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 591, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 2125, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 2222, in inline_call_
    tracer.run()
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 704, in run
    and self.step()
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 664, in step
    getattr(self, inst.opname)(inst)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 386, in wrapper
    return inner_fn(self, inst)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 1147, in CALL_FUNCTION_KW
    self.call_function(fn, args, kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 555, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/variables/user_defined.py", line 128, in call_function
    assert all(x is not None for x in items)
AssertionError: 

from user code:
   File "~/OneFormer/oneformer/data/tokenizer.py", line 77, in basic_clean
    text = ftfy.fix_text(text)
  File "/usr/local/lib/python3.8/dist-packages/ftfy/__init__.py", line 296, in fix_text
    config = TextFixerConfig(explain=False)

Minified repro

TORCHDYNAMO_REPRO_AFTER="dynamo" did not produce an output file. One way to reproduce the issue is to run this tutorial notebook, take the predictor object result, then run the following:

img = np.random.randint(0, 256, size=(479, 640, 3), dtype=np.uint8)
task = "panoptic"

with torch.no_grad():
    original_image = img[:, :, ::-1]
    height, width = img.shape[:2]
    image = predictor.aug.get_transform(original_image).apply_image(original_image)
    image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
    
    inputs = {"image": image, "height": height, "width": width, "task": task}

    ##### SUCCEEDS
    predictor.model([inputs])

    ##### FAILS
    optimized = torch.compile(predictor.model, backend="eager")
    optimized([inputs])

Versions

Versions of relevant libraries:
[pip3] numpy==1.22.4
[pip3] detectron2==0.6
[pip3] torch==2.1.0.dev20230601+cu118

cc @ezyang @msaroufim @wconstab @bdhirsh @anijain2305 @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @ipiszy @chenyang78

@zou3519
Copy link
Contributor

zou3519 commented Jun 8, 2023

Based on the stack trace, it looks like some namedtuples are involved

@zou3519
Copy link
Contributor

zou3519 commented Jun 8, 2023

Looking through the code, I don't really see named tuples. But it looks like we don't handle namedtuples with default values correctly and triggers the same assert that is happening in this issue.

import torch
from collections import namedtuple

@torch.compile(backend='eager', fullgraph=True)
def f(x):
    Point = namedtuple('point', ['x', 'y', 'z'], defaults=(None, None, None))
    point = Point(0, y=x)
    return point.y * x

x = torch.randn(3)
f(x)

@zou3519
Copy link
Contributor

zou3519 commented Jun 8, 2023

@gs-olive it's a bit difficult to me to install all the dependencies to actually repro the bug. Do you think it would be possible to cut the repro just down to using detectron2?

@gs-olive
Copy link
Contributor Author

gs-olive commented Jun 8, 2023

Thanks for looking into this - I will try to reduce the repro as much as possible, and will update with any findings.

@ezyang
Copy link
Contributor

ezyang commented Jun 9, 2023

fwiw, detectron2 doesn't work in our benchmark suite

@zou3519 zou3519 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jun 12, 2023
@gs-olive
Copy link
Contributor Author

Hi @zou3519 - I have a minimal reproducing example for the reported error. The model is attached below. It takes in a list of strings as input, and outputs a torch.Tensor with the tokenized words. I believe this structure is found in the OneFormer model which was failing. The exact location of the error is this line, and appears to trace to these lines in the ftfy Python package.

The model below uses only the open_clip repository and Torch. I have verified it is failing on the latest nightly (torch==2.1.0.dev20230621).

import torch
import open_clip

class SampleModule(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.tokenizer = open_clip.get_tokenizer('ViT-B-32')

    def forward(self, x):
        text = self.tokenizer(x)
        return text

"""
Code adapted from:
https://github.com/mlfoundations/open_clip/blob/fb72f4db1b17133befd6c67c9cf32a533b85a321/README.md?plain=1#L62-L71
"""
model = SampleModule().eval().cuda()
input_ = ["a diagram", "a dog", "a cat"]

# Passes
model(input_)


# Fails
compiled = torch.compile(model, backend="eager")
compiled(input_)

@ezyang
Copy link
Contributor

ezyang commented Jun 23, 2023

cc @jansel

@jansel
Copy link
Contributor

jansel commented Jun 23, 2023

Yes, this seems like a bug in our namedtuple handling. The error is under elif is_namedtuple_cls(self.value):. We currently ignore defaults, which is wrong.

Should be a relatively simple fix if someone wants to pick it up. We basically need to replace:

assert all(x is not None for x in items)

with:

  1. Check if there are default values for any of the namedtuple fields that are None
  2. If so, fill in the defaults (convert actual values to VaraibleTrackers using VariableBuilder)
  3. Assert that that handled all the Nones

@arjunm16
Copy link

Hi, a newbie here, this will be my first open source contribution.

I think I understood what changes are to be made, but I just want to know how to check if there is a default value for a namedtuple field, is there a list that contains the default values?
Basically, where should I check for the default values of the namedtuple fields that are to be passed to VariableBuilder so that they're converted to VariableTrackers?

(I tried going through the code and googled a bit, but was unable to find anything relevant)
(Sorry if it's a very dumb question, just getting started)

@jansel
Copy link
Contributor

jansel commented Jul 16, 2023

Might be somewhat similar fix to:
#104840

cc @anijain2305

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue module: dynamo oncall: pt2 release notes: dynamo triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

6 participants