Forward Prediction for Physical Reasoning

Girdhar et al., 2020

Source: Girdhar et al., 2020

Summary

  • Studies the use of forward prediction models for solving physical reasoning tasks
    • Incorporates object and pixel-based forward prediction models into simple physical reasoning agents
  • These models improve performance but do not generalize to new tasks
  • Predictors with better pixel accuracy do not nocessarily lead to better physical reasoning
  • Links: [ website ] [ pdf ]

Background

  • Humans’ abilty to imagine how the state of the world will unfold allows us to solve many real-world physical reasoning tasks, such as playing billiards
  • PHYRE benchmark was proposed to test a systems ability to perform complex physical reasoning
    • Involves placing balls in 2D world, such that the world reaches a particular state when played forward
    • Task is straightforward with perfect forward-prediction model

Methods

  • PHYRE Benchmark:
    • Initial state of task is depicted in 256x256 image
    • Colors indicate object properties, e.g. static, dynamic, part of goal state
    • Tasks involving placing either one or two balls, parameterized by their position and radius
    • Solved if blue or purple object touches green object for at least three seconds
    • 25 task templates, with 100 tasks each
  • Forward-Prediction Models:
    • Object-based:
      • Interaction Networks: graph neural network models pairwise interactions between object representations, two interaction networks with different temporal offsets
      • Transformers: self-attention over latent object representations to predict future state, sinusoidal temporal position encoding
    • Pixel-based:
      • Spatial Transformer Networks: split input frame into segments by object based on channels, then apply encoder and predict rotation and translation for each channel
        • Trained with spatial cross-entropy, sums the CE values of HxW softmax prediction over all image channels
      • Deconvolutional Networks: directly predict pixels in next fram with deconvolutional network
        • Trained with per-pixel cross-entropy
  • Task-Solution Models: determine whether two particular target objects are touching or not, i.e. is the task solved or not
    • Recognition is harder when using object-based representations since size and shapes of objects need to be accounted for, not just center positions
    • Receives initial and predicted frames and/or latent representations and produces a binary classification
    • Pixel-based classifiers can be used on object-based representations by rendering them to pixels first
    • Object-based classifer: Transformer model that concatenates encodings for all objects across the timesteps
    • Pixel-based classifer: 3D CNN on latent state or pixel representations
  • Search Strategy: For a given forward-prediction model and task-solution model combination:
    • Sample K actions uniformly
    • Alter initial state with action and feed to forward-prediction model, for each sampled action
    • Evaluate task-solution model on ouptut of forward-prediction
    • Selects action most likely to solve the task

Results

  • Metrics
    • Area under the success curve (AUCCESS): ranges from 0 to 100 and is higher when agent needs fewer attempts to solve task
    • Forward-prediciton accuracy (FPA): percentage of pixels for dynamic objects that match ground-truth in 10-second rollout
  • Using PHYRE simulator as forward-prediction model with various task-solution models:
    • Works nearly perfectly when rolled out for 10 seconds in within-template setting
    • Pixel-based task-solution model generalizes better in cross-template setting
  • Using various forward-prediction models with pixel-based task-solution model:
    • AUCCESS increases slightly as rollout length increases, indicating the value of forward-prediction models
    • Joint training of forward-prediction and task-solution models performs the best (deconvolutional forward-prediction model)
    • AUCCESS plateaus after 5 second rollouts, indicating forward-prediction models are only accurate for short timescales
    • Perform poorly in cross-template settings, indicating forward-prediction does not generalize well across
  • Joint model, with explicit forward-prediction, outperforms previous SotA on PHYRE
  • Compare per-template AUCCESS, the average AUCCESS over all tasks in a template, at different rollout lengths:
    • Forward prediction is more beneficial in complex tasks, e.g. tasks with more objects

Conclusion

  • Initial investigation confirms the intuition that forward-prediction models are useful for physical reasoning
  • Does not address the difficult problem of learning generic forward-prediction models
    • Forward-prediciton models presented generalize poorly
  • Future directions include modeling uncertainty in the forward-prediction models
Elias Z. Wang
Elias Z. Wang
AI Researcher | PhD Candidate