Contrastive Learning of Structured World Models

Kipf et al., 2020

Source: Kipf et al., 2020

Summary

  • Contrastively-trained Structured World Models (C-SWMs) learn a world model, in terms of objects, relations, and hierarchies, from raw sensory data
  • Utilizes a contrastive approach, which distinguishes real versus fake experiences
    • Overcomes limitations of pixel-based methods
  • Links: [ website ] [ pdf ]

Background

  • Human compositional reasoning in terms of objects, relations, and actions serves as motivation for developing artificial systems that decompose scenes into similar representations, which greatly facilitate physical dynamics prediction
  • Most methods use a generative appraoch, reconstructing visual predictions and optimizing an objective in pixel-space
    • This results in ignoring visually small, but significant, features or wasting model capcity on visually rich, but irrelevant, features
  • C-SWMs, on the other hand, use a discriminative approach which does not suffer from the same challenges

Methods

  • Consider off-policy setting, using buffer of offline experience
  • Object Extractor and Encoder:
    • CNN-based object extractor: takes images and produces $K$ features maps, corresponding to the $K$ object slots, which can be interpreted as object masks
    • MLP-based object encoder: take a flattened feature map and outputs an abstract state representation, $z^k_t$ – shared across objects
  • Relational Transition Model ($T(z_t, a_t)$): graph neural network that models pairwise interactions between object states as well as actions applied to objects
  • Multi-object Contrastive Loss: energy-based hinge loss with energy computed by taking the mean energy across the $K$ objects
    • Postive samples: $H = \frac{1}{K}\sum^K_{k=1}d(z^k_t+T^k(z_t,a_t), z^k_{t+1})$, where $d(\cdot, \cdot)$ is the squared Euclidean distance
    • Negative samples: $\tilde{H} = \frac{1}{K}\sum^K_{k=1}d(\tilde{z}^k_t, z^k_{t+1})$, where $\tilde{z}^k$ is the representation for the corrupted state sampled from buffer
    • Final loss: $L = H + max(0, \gamma - \tilde{H})$
  • Concatenate two consecutive frames as input for environments with internal dynamics

Results

  • Datasets: random policy used to collect experience
    • Novel grid world: 2D and 3D, with multiple interacting objects that can be manipulated
    • Atari 2600 games: Atari Pong and Space Invaders
    • 3-body physics simulation
  • Metrics: based on ranking comparison of predicted state representation to encoded true observation and a set reference states from experience buffer
    • Hits at Rank 1 (H@1)
    • Mean Reciprocal Rank (MRR)
  • Baselines:
    • C-SWM ablations: remove latent GNN, latent GNN + factored states, or contrastive loss
    • Autoencoder-based World Model: AE or VAE to learn state representations, and MLP to predict action-conditioned next state
    • Physics as Inverse Graphics (PAIG): encoder-decoder architecture trained with pixel-based reconstruction, using differentiable physics engine on latent space with explicit position and velocity representations – making it only applicable to 3-body physics environment
  • Results summary:
    • C-SWMs discover object-specific filters that can be mapped to object positions
    • Baselines that use pixel-based reconstruction losses do not generalize as well to unseen scenes
    • Removing GNN reduces performance on grid world datasets for longer time horizons, but no for single step
    • Removing factored states or contrastive loss reduces performance drastically
    • Performance is basically at ceiling for non-Atari datsets, even for non-ablation baselines
    • For Atari datasets, C-SWM outperforms AE-based World Model for 1 step predictions, but is reliant on tuning of $K$
      • For longer time horizons C-SWM is better, but that’s expected due to the GNN vs. MLP

Conclusion

  • Object extractor relies on distinct visual features to disambiguate objects
  • No method to ensure there isn’t redundancy between the $K$ object feature maps
  • The contrastive loss as an alternative to pixel-based losses is promising
Elias Z. Wang
Elias Z. Wang
AI Researcher | PhD Candidate