# Contrastive Learning of Structured World Models

Kipf et al., 2020

Source: Kipf et al., 2020

## Summary

• Contrastively-trained Structured World Models (C-SWMs) learn a world model, in terms of objects, relations, and hierarchies, from raw sensory data
• Utilizes a contrastive approach, which distinguishes real versus fake experiences
• Overcomes limitations of pixel-based methods
• Links: [ website ] [ pdf ]

## Background

• Human compositional reasoning in terms of objects, relations, and actions serves as motivation for developing artificial systems that decompose scenes into similar representations, which greatly facilitate physical dynamics prediction
• Most methods use a generative appraoch, reconstructing visual predictions and optimizing an objective in pixel-space
• This results in ignoring visually small, but significant, features or wasting model capcity on visually rich, but irrelevant, features
• C-SWMs, on the other hand, use a discriminative approach which does not suffer from the same challenges

## Methods

• Consider off-policy setting, using buffer of offline experience
• Object Extractor and Encoder:
• CNN-based object extractor: takes images and produces $K$ features maps, corresponding to the $K$ object slots, which can be interpreted as object masks
• MLP-based object encoder: take a flattened feature map and outputs an abstract state representation, $z^k_t$ – shared across objects
• Relational Transition Model ($T(z_t, a_t)$): graph neural network that models pairwise interactions between object states as well as actions applied to objects
• Multi-object Contrastive Loss: energy-based hinge loss with energy computed by taking the mean energy across the $K$ objects
• Postive samples: $H = \frac{1}{K}\sum^K_{k=1}d(z^k_t+T^k(z_t,a_t), z^k_{t+1})$, where $d(\cdot, \cdot)$ is the squared Euclidean distance
• Negative samples: $\tilde{H} = \frac{1}{K}\sum^K_{k=1}d(\tilde{z}^k_t, z^k_{t+1})$, where $\tilde{z}^k$ is the representation for the corrupted state sampled from buffer
• Final loss: $L = H + max(0, \gamma - \tilde{H})$
• Concatenate two consecutive frames as input for environments with internal dynamics

## Results

• Datasets: random policy used to collect experience
• Novel grid world: 2D and 3D, with multiple interacting objects that can be manipulated
• Atari 2600 games: Atari Pong and Space Invaders
• 3-body physics simulation
• Metrics: based on ranking comparison of predicted state representation to encoded true observation and a set reference states from experience buffer
• Hits at Rank 1 (H@1)
• Mean Reciprocal Rank (MRR)
• Baselines:
• C-SWM ablations: remove latent GNN, latent GNN + factored states, or contrastive loss
• Autoencoder-based World Model: AE or VAE to learn state representations, and MLP to predict action-conditioned next state
• Physics as Inverse Graphics (PAIG): encoder-decoder architecture trained with pixel-based reconstruction, using differentiable physics engine on latent space with explicit position and velocity representations – making it only applicable to 3-body physics environment
• Results summary:
• C-SWMs discover object-specific filters that can be mapped to object positions
• Baselines that use pixel-based reconstruction losses do not generalize as well to unseen scenes
• Removing GNN reduces performance on grid world datasets for longer time horizons, but no for single step
• Removing factored states or contrastive loss reduces performance drastically
• Performance is basically at ceiling for non-Atari datsets, even for non-ablation baselines
• For Atari datasets, C-SWM outperforms AE-based World Model for 1 step predictions, but is reliant on tuning of $K$
• For longer time horizons C-SWM is better, but that’s expected due to the GNN vs. MLP

## Conclusion

• Object extractor relies on distinct visual features to disambiguate objects
• No method to ensure there isn’t redundancy between the $K$ object feature maps
• The contrastive loss as an alternative to pixel-based losses is promising