The focus of PRISE is to leverage the large multi-task offline dataset to pretrain a vocabulary of skill tokens representing temporally extended low-level policies for downstream control tasks.
The pretraining of PRISE can be divided into two stages.
In stage I, PRISE learns a state-dependent action quantization module. This module processes the pretraining multitask dataset D by transforming each of its trajectories – a sequence of ⟨observation, continuous action⟩ pairs – into a sequence of discrete codes, one code per time step, as shown in the Figure on the left.
In stage II, PRISE first converts a trajectory of continuous state and actions into discrete codes. Then based on the corpus of quantized trajectories from the multitask offline dataset, PRISE applies the BPE tokenization algorithm to learn vocabulary of skill tokens, where each token represents a sequence of discrete action codes.
Then during downstream time, the goal is to leverage the pretrained skill token to learn either a generalist multitask policy, or adapt to an unseen task with a few expert demonstration trajectories.
PRISE achieves this by learning a skill-token policy π, where it first tokenizes downstream demonstration trajectories by greedily searching for the longest token for each time step, then
optimizes the skill-token policy π by matching its output with the target token with cross entropy loss.
During evaluation time, PRISE rollout its policy by querying the skill-token policy π for the skill token and then using pretrained decoder ψ to decode raw actions.
The decoder ψ here is finetuned if we are given a new task.