PRISE: LLM-Style Sequence Compression for Learning Temporal Action Abstractions in Control

ICML 2024 (Oral Presentation)

University of Maryland, College Park1      Microsoft Research2

Abstract

In this work, we propose a novel view that treats inducing temporal action abstractions as a sequence compression problem. To do so, we bring a subtle but critical component of LLM training pipelines -- input tokenization via byte pair encoding (BPE) -- to the seemingly distant task of learning skills of variable time span in continuous control domains. We introduce an approach called Primitive Sequence Encoding (PRISE) that combines continuous action quantization with BPE to learn powerful action abstractions. We empirically show that high-level skills discovered by PRISE from a multitask set of robotic manipulation demonstrations significantly boost the performance of both multitask imitation learning as well as few-shot imitation learning on unseen tasks.

Method

The focus of PRISE is to leverage the large multi-task offline dataset to pretrain a vocabulary of skill tokens representing temporally extended low-level policies for downstream control tasks. The pretraining of PRISE can be divided into two stages.

Description of First Image

In stage I, PRISE learns a state-dependent action quantization module. This module processes the pretraining multitask dataset D by transforming each of its trajectories – a sequence of ⟨observation, continuous action⟩ pairs – into a sequence of discrete codes, one code per time step, as shown in the Figure on the left.

In stage II, PRISE first converts a trajectory of continuous state and actions into discrete codes. Then based on the corpus of quantized trajectories from the multitask offline dataset, PRISE applies the BPE tokenization algorithm to learn vocabulary of skill tokens, where each token represents a sequence of discrete action codes.

Description of First Image

Then during downstream time, the goal is to leverage the pretrained skill token to learn either a generalist multitask policy, or adapt to an unseen task with a few expert demonstration trajectories. PRISE achieves this by learning a skill-token policy π, where it first tokenizes downstream demonstration trajectories by greedily searching for the longest token for each time step, then optimizes the skill-token policy π by matching its output with the target token with cross entropy loss.

Description of First Image

During evaluation time, PRISE rollout its policy by querying the skill-token policy π for the skill token and then using pretrained decoder ψ to decode raw actions. The decoder ψ here is finetuned if we are given a new task.

Experimental Results

Generalist Multitask Policy Learning

test Description of First Image

On the challenging LIBERO-90 benchmark, PRISE achieves a significant performance gain compared with the baseline algorithms, demonstrating the significance of temporal action abstraction for knowledge sharing across diverse task.

Five-shot Adaptation to Unseen Tasks

test Description of First Image

Additionally, on both MetaWorld and LIBERO, we demonstrate that PRISE pretrained skill tokens significantly improve few-shot imitation learning performance on unseen tasks.

BibTeX

If you find our method or code relevant to your research, please consider citing the paper as follows:
@misc{zheng2024prise,
        title={PRISE: Learning Temporal Action Abstractions as a Sequence Compression Problem}, 
        author={Ruijie Zheng and Ching-An Cheng and Hal Daumé III au2 and Furong Huang and Andrey Kolobov},
        year={2024},
        eprint={2402.10450},
        archivePrefix={arXiv},
        primaryClass={cs.LG}
  }