Forward Prediction for Physical Reasoning
Girdhar et al., 2020
Summary
- Studies the use of forward prediction models for solving physical reasoning tasks
- Incorporates object and pixel-based forward prediction models into simple physical reasoning agents
- These models improve performance but do not generalize to new tasks
- Predictors with better pixel accuracy do not nocessarily lead to better physical reasoning
- Links: [ website ] [ pdf ]
Background
- Humans’ abilty to imagine how the state of the world will unfold allows us to solve many real-world physical reasoning tasks, such as playing billiards
- PHYRE benchmark was proposed to test a systems ability to perform complex physical reasoning
- Involves placing balls in 2D world, such that the world reaches a particular state when played forward
- Task is straightforward with perfect forward-prediction model
Methods
- PHYRE Benchmark:
- Initial state of task is depicted in 256x256 image
- Colors indicate object properties, e.g. static, dynamic, part of goal state
- Tasks involving placing either one or two balls, parameterized by their position and radius
- Solved if blue or purple object touches green object for at least three seconds
- 25 task templates, with 100 tasks each
- Forward-Prediction Models:
- Object-based:
- Interaction Networks: graph neural network models pairwise interactions between object representations, two interaction networks with different temporal offsets
- Transformers: self-attention over latent object representations to predict future state, sinusoidal temporal position encoding
- Pixel-based:
- Spatial Transformer Networks: split input frame into segments by object based on channels, then apply encoder and predict rotation and translation for each channel
- Trained with spatial cross-entropy, sums the CE values of HxW softmax prediction over all image channels
- Deconvolutional Networks: directly predict pixels in next fram with deconvolutional network
- Trained with per-pixel cross-entropy
- Spatial Transformer Networks: split input frame into segments by object based on channels, then apply encoder and predict rotation and translation for each channel
- Object-based:
- Task-Solution Models: determine whether two particular target objects are touching or not, i.e. is the task solved or not
- Recognition is harder when using object-based representations since size and shapes of objects need to be accounted for, not just center positions
- Receives initial and predicted frames and/or latent representations and produces a binary classification
- Pixel-based classifiers can be used on object-based representations by rendering them to pixels first
- Object-based classifer: Transformer model that concatenates encodings for all objects across the timesteps
- Pixel-based classifer: 3D CNN on latent state or pixel representations
- Search Strategy: For a given forward-prediction model and task-solution model combination:
- Sample K actions uniformly
- Alter initial state with action and feed to forward-prediction model, for each sampled action
- Evaluate task-solution model on ouptut of forward-prediction
- Selects action most likely to solve the task
Results
- Metrics
- Area under the success curve (AUCCESS): ranges from 0 to 100 and is higher when agent needs fewer attempts to solve task
- Forward-prediciton accuracy (FPA): percentage of pixels for dynamic objects that match ground-truth in 10-second rollout
- Using PHYRE simulator as forward-prediction model with various task-solution models:
- Works nearly perfectly when rolled out for 10 seconds in within-template setting
- Pixel-based task-solution model generalizes better in cross-template setting
- Using various forward-prediction models with pixel-based task-solution model:
- AUCCESS increases slightly as rollout length increases, indicating the value of forward-prediction models
- Joint training of forward-prediction and task-solution models performs the best (deconvolutional forward-prediction model)
- AUCCESS plateaus after 5 second rollouts, indicating forward-prediction models are only accurate for short timescales
- Perform poorly in cross-template settings, indicating forward-prediction does not generalize well across
- Joint model, with explicit forward-prediction, outperforms previous SotA on PHYRE
- Compare per-template AUCCESS, the average AUCCESS over all tasks in a template, at different rollout lengths:
- Forward prediction is more beneficial in complex tasks, e.g. tasks with more objects
Conclusion
- Initial investigation confirms the intuition that forward-prediction models are useful for physical reasoning
- Does not address the difficult problem of learning generic forward-prediction models
- Forward-prediciton models presented generalize poorly
- Future directions include modeling uncertainty in the forward-prediction models