This is the official implementation of "Toward Safer Diffusion Language Models: Discovery and Mitigation of Priming Vulnerability" (ICLR 2026).
This project was developed with CUDA 12.9, PyTorch 2.7.1, and Python 3.10.
After installing a GPU version of PyTorch, other dependencies can be installed via pip install -r requirements.txt.
You can run Recovery Alignment with the following command:
python train.py \
--dataset_name $your_dataset \
--model_name "GSAI-ML/LLaDA-8B-Instruct" \
--seed 42Trained models are saved under models/xxx.
If you want to use the same training dataset as in our paper, first create it from BeaverTails 30K:
python prepare_train_dataset.pyThis creates datasets/PKU-Alignment_BeaverTails_unsafe_30k. You can then run:
python train.py \
--dataset_name "datasets/PKU-Alignment_BeaverTails_unsafe_30k" \
--model_name "GSAI-ML/LLaDA-8B-Instruct"\
--seed 42For safety evaluation with anchoring attack, use:
python anchoring.py \
--model_name "GSAI-ML/LLaDA-8B-Instruct" \
--dataset_name "datasets/JBB-Behaviors_full_target" \
--num_samples 100 \
--initial_step 1 \
--seed 42You can create datasets/JBB-Behaviors_full_target with:
python create_full_target.py \
--base_dataset "datasets/JBB-Behaviors" \
--output_dataset "datasets/JBB-Behaviors_full_target"To evaluate the model with First-Step GCG, first run:
python gcg.py \
--model_name "GSAI-ML/LLaDA-8B-Instruct" \
--dataset_name "datasets/JBB-Behaviors" \
--num_samples 100 \
--save_path "datasets/JBB-Behaviors_GCG_LLaDA_100samples_first_step_gen_length128_search_width256_top_k128_seed42" \
--gen_length 128 \
--search_width 256 \
--top_k 128 \
--seed 42 \
--compute_first_step_lossThis creates optimized adversarial prompts in datasets/JBB-Behaviors_GCG_LLaDA_100samples_first_step_gen_length128_search_width256_top_k128_seed42.
Then evaluate ASR:
python evaluate_robust.py \
--model_name "GSAI-ML/LLaDA-8B-Instruct" \
--dataset_name "datasets/JBB-Behaviors_GCG_LLaDA_100samples_first_step_gen_length128_search_width256_top_k128_seed42" \
--no_gptFor GPT-based evaluation, set OPENAI_API_KEY and remove --no_gpt.
For utility evaluation, use evaluate_general.py via utility_evaluation.sh.
SGE_TASK_ID controls which benchmark runs:
SGE_TASK_ID=1 bash utility_evaluation.sh $your_model $your_output_dirParts of this implementation are adapted from the official LLaDA implementation.