Skip to content

mdl-lab/dlm-priming-vulnerability

Repository files navigation

Toward Safer Diffusion Language Models: Discovery and Mitigation of Priming Vulnerability

This is the official implementation of "Toward Safer Diffusion Language Models: Discovery and Mitigation of Priming Vulnerability" (ICLR 2026).

This project was developed with CUDA 12.9, PyTorch 2.7.1, and Python 3.10.

After installing a GPU version of PyTorch, other dependencies can be installed via pip install -r requirements.txt.

Recovery Alignment

You can run Recovery Alignment with the following command:

python train.py \
  --dataset_name $your_dataset \
  --model_name "GSAI-ML/LLaDA-8B-Instruct" \
  --seed 42

Trained models are saved under models/xxx.

If you want to use the same training dataset as in our paper, first create it from BeaverTails 30K:

python prepare_train_dataset.py

This creates datasets/PKU-Alignment_BeaverTails_unsafe_30k. You can then run:

python train.py \
  --dataset_name "datasets/PKU-Alignment_BeaverTails_unsafe_30k" \
  --model_name "GSAI-ML/LLaDA-8B-Instruct"\
  --seed 42

Evaluating Model Safety with Anchoring Attack

For safety evaluation with anchoring attack, use:

python anchoring.py \
  --model_name "GSAI-ML/LLaDA-8B-Instruct" \
  --dataset_name "datasets/JBB-Behaviors_full_target" \
  --num_samples 100 \
  --initial_step 1 \
  --seed 42

You can create datasets/JBB-Behaviors_full_target with:

python create_full_target.py \
  --base_dataset "datasets/JBB-Behaviors" \
  --output_dataset "datasets/JBB-Behaviors_full_target"

Evaluating Model Safety with First-Step GCG

To evaluate the model with First-Step GCG, first run:

python gcg.py \
  --model_name "GSAI-ML/LLaDA-8B-Instruct" \
  --dataset_name "datasets/JBB-Behaviors" \
  --num_samples 100 \
  --save_path "datasets/JBB-Behaviors_GCG_LLaDA_100samples_first_step_gen_length128_search_width256_top_k128_seed42" \
  --gen_length 128 \
  --search_width 256 \
  --top_k 128 \
  --seed 42 \
  --compute_first_step_loss

This creates optimized adversarial prompts in datasets/JBB-Behaviors_GCG_LLaDA_100samples_first_step_gen_length128_search_width256_top_k128_seed42.

Then evaluate ASR:

python evaluate_robust.py \
  --model_name "GSAI-ML/LLaDA-8B-Instruct" \
  --dataset_name "datasets/JBB-Behaviors_GCG_LLaDA_100samples_first_step_gen_length128_search_width256_top_k128_seed42" \
  --no_gpt

For GPT-based evaluation, set OPENAI_API_KEY and remove --no_gpt.

Evaluating Model Utility

For utility evaluation, use evaluate_general.py via utility_evaluation.sh. SGE_TASK_ID controls which benchmark runs:

SGE_TASK_ID=1 bash utility_evaluation.sh $your_model $your_output_dir

Parts of this implementation are adapted from the official LLaDA implementation.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors