diff --git a/src/gui_g2/data_config/rec_internvl.yaml b/src/gui_g2/data_config/rec_internvl.yaml
new file mode 100644
index 0000000..0e9c2a6
--- /dev/null
+++ b/src/gui_g2/data_config/rec_internvl.yaml
@@ -0,0 +1,4 @@
+datasets:
+ - json_path: /data10/shz/dataset/rec/rec_jsons_internvl/refcoco_train.json
+ - json_path: /data10/shz/dataset/rec/rec_jsons_internvl/refcocop_train.json
+ - json_path: /data10/shz/dataset/rec/rec_jsons_internvl/refcocog_train.json
\ No newline at end of file
diff --git a/src/gui_g2/local_scripts/create_vision_cot_data.py b/src/gui_g2/local_scripts/create_vision_cot_data.py
new file mode 100644
index 0000000..fec2d7c
--- /dev/null
+++ b/src/gui_g2/local_scripts/create_vision_cot_data.py
@@ -0,0 +1,153 @@
+import argparse
+import base64
+import concurrent.futures
+import io
+import json
+import os
+import random
+import re
+import time
+from concurrent.futures import ThreadPoolExecutor
+from functools import partial
+from io import BytesIO
+from typing import Dict, List
+
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+from datasets import Dataset, concatenate_datasets, load_dataset, load_from_disk
+from tqdm import tqdm
+
+import bytedtos
+import seaborn as sns
+import yaml
+from openai import AzureOpenAI
+from PIL import Image
+from pillow_avif import AvifImagePlugin
+
+
+PROMPT_FORMAT = """I will provide you with an image, an original question, and its answer related to the image. Your task is to rewrite the question in such a way that answering it requires step-by-step Chain-of-Thought (CoT) reasoning with numerical or mathematical expressions where applicable. The reasoning process can include expressions like "let me think," "oh, I see," or other natural language thought expressions.
+
+Please make sure your question is to ask for a certain answer with a certain value, do not ask for open-ended answer, and the answer is correct and easy to verify via simple protocol, like "2" or "A".
+
+Please strictly do not include "Answer:" in the question part to avoid confusion and leakage.
+
+Input Format:
+Original Question: {original_question}
+Original Answer: {original_answer}
+
+Output Format:
+Question: [rewrite the question if necessary]
+Answer: [answer with reasoning steps, including calculations where applicable]
+step-by-step reasoning process
+easy to verify answer
+"""
+
+
+def get_image_data_url(image_input):
+ if isinstance(image_input, str) and image_input.startswith("data:"):
+ return image_input
+
+ if isinstance(image_input, str) and image_input.startswith("http"):
+ image_input = load_image(image_input)
+
+ if isinstance(image_input, str):
+ image_input = Image.open(image_input)
+
+ if not isinstance(image_input, Image.Image):
+ raise ValueError("Unsupported image input type")
+
+ if image_input.mode != "RGB":
+ image_input = image_input.convert("RGB")
+
+ buffer = BytesIO()
+ image_input.save(buffer, format="JPEG")
+ img_bytes = buffer.getvalue()
+ base64_data = base64.b64encode(img_bytes).decode("utf-8")
+ return f"data:image/jpeg;base64,{base64_data}"
+
+
+def gpt4o_query(image, prompt, max_retries=5, initial_delay=3):
+ if image is None:
+ return None
+
+ data_url_list = [get_image_data_url(image)]
+ client = AzureOpenAI(
+ azure_endpoint="YOUR_AZURE_ENDPOINT",
+ api_version="2023-07-01-preview",
+ api_key="YOUR_API_KEY",
+ )
+
+ for attempt in range(max_retries):
+ try:
+ messages = [
+ {
+ "role": "system",
+ "content": "You are an expert to analyze the image and provide useful information for users.",
+ },
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": prompt},
+ ],
+ },
+ ]
+
+ for data_url in data_url_list:
+ messages[1]["content"].insert(
+ 0, {"type": "image_url", "image_url": {"url": data_url}}
+ )
+
+ response = client.chat.completions.create(
+ model="gpt-4o-2024-08-06",
+ messages=messages,
+ temperature=0.2,
+ max_tokens=8192,
+ )
+ return response.choices[0].message.content
+
+ except Exception as e:
+ if attempt == max_retries - 1:
+ raise Exception(
+ f"Failed after {max_retries} attempts. Last error: {str(e)}"
+ )
+ delay = initial_delay * (2**attempt) + random.uniform(
+ 0, 0.1 * initial_delay * (2**attempt)
+ )
+ time.sleep(delay)
+
+
+def process_single_item(example):
+ try:
+ image_path = example["image_path"]
+ formatted_prompt = PROMPT_FORMAT.format(
+ original_question=example["question"], original_answer=example["answer"]
+ )
+
+ response = gpt4o_query(image_path, formatted_prompt)
+ example["gpt4o_response"] = response
+ return example
+ except Exception as e:
+ print(f"Error processing item: {str(e)}")
+ example["gpt4o_response"] = None
+ return example
+
+
+def main():
+ dataset_path = "path/to/your/dataset"
+ full_dataset = load_from_disk(dataset_path)
+
+ processed_dataset = full_dataset.map(
+ function=partial(process_single_item),
+ num_proc=256,
+ desc="Processing dataset with GPT-4o",
+ keep_in_memory=True,
+ )
+
+ output_path = f"{dataset_path}_processed"
+ processed_dataset.save_to_disk(output_path)
+ print(f"Processed dataset saved to: {output_path}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/gui_g2/local_scripts/lmms_eval_qwen2vl.sh b/src/gui_g2/local_scripts/lmms_eval_qwen2vl.sh
new file mode 100644
index 0000000..6d38769
--- /dev/null
+++ b/src/gui_g2/local_scripts/lmms_eval_qwen2vl.sh
@@ -0,0 +1,61 @@
+export HF_HOME=""
+export HF_TOKEN=""
+export HF_HUB_ENABLE_HF_TRANSFER="1"
+
+export API_TYPE=""
+export AZURE_ENDPOINT=""
+export AZURE_API_KEY=""
+export API_VERSION=""
+export MODEL_VERSION=""
+export NAVIT_ATTENTION_IMPLEMENTATION="eager"
+
+# Prompt for installation with 3-second timeout
+read -t 3 -p "Do you want to install dependencies? (YES/no, timeout in 3s): " install_deps || true
+if [ "$install_deps" = "YES" ]; then
+ # Prepare the environment
+ pip3 install --upgrade pip
+ pip3 install -U setuptools
+
+ cd
+ if [ ! -d "maas_engine" ]; then
+ git clone
+ else
+ echo "maas_engine directory already exists, skipping clone"
+ fi
+ cd maas_engine
+ git pull
+ git checkout
+ pip3 install --no-cache-dir --no-build-isolation -e ".[standalone]"
+
+ current_version=$(pip3 show transformers | grep Version | cut -d' ' -f2)
+ if [ "$current_version" != "4.46.2" ]; then
+ echo "Installing transformers 4.46.2 (current version: $current_version)"
+ pip3 install transformers==4.46.2
+ else
+ echo "transformers 4.46.2 is already installed"
+ fi
+
+ cd
+ rm -rf
+ pip3 install -e .
+ pip3 install -U pydantic
+ pip3 install Levenshtein
+ pip3 install nltk
+ python3 -c "import nltk; nltk.download('wordnet', quiet=True); nltk.download('punkt', quiet=True)"
+fi
+
+TASKS=mmmu_val,mathvista_testmini,mmmu_pro
+MODEL_BASENAME=qwen2_vl
+
+model_checkpoint=""
+echo "MODEL_BASENAME: ${MODEL_BASENAME}"
+cd
+
+python3 -m accelerate.commands.launch --num_processes=8 --main_process_port=12345 lmms_eval \
+ --model qwen2_vl \
+ --model_args=pretrained=${model_checkpoint},max_pixels=2359296 \
+ --tasks ${TASKS} \
+ --batch_size 1 \
+ --log_samples \
+ --log_samples_suffix ${MODEL_BASENAME} \
+ --output_path ./logs
\ No newline at end of file
diff --git a/src/gui_g2/local_scripts/prepare_hf_data.py b/src/gui_g2/local_scripts/prepare_hf_data.py
new file mode 100644
index 0000000..62eab9e
--- /dev/null
+++ b/src/gui_g2/local_scripts/prepare_hf_data.py
@@ -0,0 +1,166 @@
+import matplotlib.pyplot as plt
+import seaborn as sns
+import pandas as pd
+import random
+from typing import List, Dict
+import numpy as np
+from concurrent.futures import ThreadPoolExecutor
+from tqdm import tqdm
+import datasets
+
+import io
+from datasets import load_dataset, load_from_disk, concatenate_datasets
+from PIL import Image
+from tqdm import tqdm
+from functools import partial
+from pillow_avif import AvifImagePlugin
+from datasets import Dataset
+import json
+import yaml
+import os
+import re
+import time
+import random
+import base64
+from openai import AzureOpenAI
+import concurrent.futures
+from typing import List, Dict
+import argparse
+import time
+
+
+def extract_problem_solution(gpt4o_response):
+ # Split the response into parts
+ parts = gpt4o_response.split("")
+
+ # Extract the problem (first part before any tags)
+ problem = parts[0].strip()
+ # Remove "Question:" prefix if it exists
+ problem = re.sub(r"^Question:\s*", "", problem)
+ # Remove "Answer:" at the end of the problem
+ problem = re.sub(r"\s*Answer:\s*$", "", problem).strip()
+
+ # Combine all the reasoning steps into a single block
+ think_parts = [p.split("")[0].strip() for p in parts[1:] if "" in p]
+ solution = f"{' '.join(think_parts)}"
+
+ # Add the final answer if it exists, removing "Answer:" prefix
+ if "" in gpt4o_response:
+ final_answer = (
+ gpt4o_response.split("")[-1].split("")[0].strip()
+ )
+ final_answer = re.sub(r"^Answer:\s*", "", final_answer)
+ solution += f"\n\n{final_answer}"
+
+ return problem, solution
+
+
+def load_image_from_path(image_path):
+ try:
+ img = Image.open(image_path)
+ return img
+ except Exception as e:
+ print(f"Error loading image {image_path}: {str(e)}")
+ return None
+
+
+def process_raw_data(raw_data):
+ # Parse the raw data if it's a string
+ if isinstance(raw_data, str):
+ data = json.loads(raw_data)
+ else:
+ data = raw_data
+
+ # Extract problem and solution
+ try:
+ problem, solution = extract_problem_solution(data["gpt4o_response"])
+ image = load_image_from_path(data["image_path"])
+
+ return {
+ "image": image,
+ "problem": problem,
+ "solution": solution,
+ "original_question": data["question"],
+ "original_answer": data["answer"],
+ }
+ except Exception as e:
+ print(f"Error processing data {data}: {str(e)}")
+ return {
+ "image": None,
+ "problem": None,
+ "solution": None,
+ "original_question": None,
+ "original_answer": None,
+ }
+
+
+raw_data_list = [
+ "/path/to/reasoning_data_with_response_90k_verified",
+]
+
+raw_data = concatenate_datasets([load_from_disk(path) for path in raw_data_list])
+
+processed_data = raw_data.map(process_raw_data, num_proc=128).shuffle(seed=42)
+
+hf_dict = {
+ "image": [],
+ "problem": [],
+ "solution": [],
+ "original_question": [],
+ "original_answer": [],
+}
+
+for item in tqdm(processed_data):
+ hf_dict["image"].append(item["image"])
+ hf_dict["problem"].append(item["problem"])
+ hf_dict["solution"].append(item["solution"])
+ hf_dict["original_question"].append(item["original_question"])
+ hf_dict["original_answer"].append(item["original_answer"])
+
+
+features = datasets.Features(
+ {
+ "image": datasets.Image(),
+ "problem": datasets.Value("string"),
+ "solution": datasets.Value("string"),
+ "original_question": datasets.Value("string"),
+ "original_answer": datasets.Value("string"),
+ }
+)
+
+
+def has_empty_tags(text):
+ # Pattern to match empty tags like
+ pattern = r"<[^>]+>[^>]+>"
+ return bool(re.search(pattern, text))
+
+
+def has_answer_pattern(text):
+ if "Answer:" in text:
+ return True
+ return False
+
+
+def has_valid_image_size(example): # for Qwen2-VL-2B's processor requirement
+ # Assuming the image is in a format that can be checked for dimensions
+ # You might need to adjust this depending on how the image is stored in your dataset
+ try:
+ image = example["image"] # or however your image is accessed
+ if isinstance(image, dict) and "height" in image and "width" in image:
+ return image["height"] >= 28 and image["width"] >= 28
+ # If image is a PIL Image or similar
+ return image.height >= 28 and image.width >= 28
+ except:
+ return False
+
+
+ds = datasets.Dataset.from_dict(hf_dict, features=features)
+ds = ds.filter(
+ lambda x: not has_empty_tags(x["solution"])
+ and not has_answer_pattern(x["problem"])
+ and has_valid_image_size(x)
+ and x["image"] is not None,
+ num_proc=128,
+)
+# Push to Hugging Face Hub
+ds.push_to_hub("path/to/your/dataset")
diff --git a/src/gui_g2/local_scripts/train_aria_moe.sh b/src/gui_g2/local_scripts/train_aria_moe.sh
new file mode 100644
index 0000000..5a3b696
--- /dev/null
+++ b/src/gui_g2/local_scripts/train_aria_moe.sh
@@ -0,0 +1,68 @@
+#!/bin/bash
+
+export NCCL_BLOCKING_WAIT=0
+export TOKENIZERS_PARALLELISM=false
+export OMP_NUM_THREADS=8
+export NCCL_IB_DISABLE=0
+export NCCL_IB_GID_INDEX=3
+export NCCL_SOCKET_IFNAME=eth0
+export NCCL_DEBUG=INFO
+
+# CONFIG Huggingface
+# export HF_TOKEN=""
+export HF_TOKEN=""
+export HF_HOME="$HOME/.cache/huggingface"
+export HF_HUB_ENABLE_HF_TRANSFER="1"
+
+export NCCL_DEBUG=INFO
+
+GPUS="0,1,2,3,4,5,6,7"
+
+# 取 worker0 第一个 port
+ports=($(echo $METIS_WORKER_0_PORT | tr ',' ' '))
+port=${ports[0]}
+port_in_cmd="$(echo "${METIS_WORKER_0_PORT:-2000}" | awk -F',' '{print $1}')"
+
+echo "total workers: ${ARNOLD_WORKER_NUM}"
+echo "cur worker id: ${ARNOLD_ID}"
+echo "gpus per worker: ${ARNOLD_WORKER_GPU}"
+echo "master ip: ${METIS_WORKER_0_HOST}"
+echo "master port: ${port}"
+echo "master port in cmd: ${port_in_cmd}"
+
+# export WANDB_BASE_URL=https://api.wandb.ai
+# export WANDB_API_KEY=""
+# wandb login $WANDB_API_KEY
+
+export WANDB_BASE_URL=https://api.wandb.ai
+export WANDB_PROJECT=vision-reasoning
+export WANDB_API_KEY=""
+export WANDB_RUN_NAME=Qwen-VL-2B-GRPO-$(date +%Y-%m-%d-%H-%M-%S)
+wandb login $WANDB_API_KEY
+
+cd /home/tiger/multimodal-open-r1
+# pip3 install vllm==0.6.6.post1
+pip3 install -e ".[dev]"
+pip3 install wandb==0.18.3
+
+torchrun --nproc_per_node="${ARNOLD_WORKER_GPU}" \
+ --nnodes="${ARNOLD_WORKER_NUM}" \
+ --node_rank="${ARNOLD_ID}" \
+ --master_addr="${METIS_WORKER_0_HOST}" \
+ --master_port="${port_in_cmd}" \
+ src/open_r1/grpo.py \
+ --deepspeed scripts/zero3.json \
+ --output_dir Aria-GRPO-mini_cot_80k \
+ --model_name_or_path rhymes-ai/Aria \
+ --dataset_name luodian/mini_cot_80k \
+ --max_prompt_length 8192 \
+ --per_device_train_batch_size 1 \
+ --gradient_accumulation_steps 1 \
+ --logging_steps 1 \
+ --bf16 \
+ --report_to wandb \
+ --gradient_checkpointing true \
+ --attn_implementation eager \
+ --save_total_limit 8 \
+ --num_train_epochs 1 \
+ --run_name $WANDB_RUN_NAME
diff --git a/src/gui_g2/local_scripts/train_qwen2_vl.sh b/src/gui_g2/local_scripts/train_qwen2_vl.sh
new file mode 100644
index 0000000..137310e
--- /dev/null
+++ b/src/gui_g2/local_scripts/train_qwen2_vl.sh
@@ -0,0 +1,61 @@
+#!/bin/bash
+
+export NCCL_BLOCKING_WAIT=0
+export TOKENIZERS_PARALLELISM=false
+export OMP_NUM_THREADS=8
+export NCCL_IB_DISABLE=0
+export NCCL_IB_GID_INDEX=3
+export NCCL_SOCKET_IFNAME=eth0
+export NCCL_DEBUG=INFO
+
+GPUS="0,1,2,3,4,5,6,7"
+
+# 取 worker0 第一个 port
+ports=($(echo $METIS_WORKER_0_PORT | tr ',' ' '))
+port=${ports[0]}
+port_in_cmd="$(echo "${METIS_WORKER_0_PORT:-2000}" | awk -F',' '{print $1}')"
+
+echo "total workers: ${ARNOLD_WORKER_NUM}"
+echo "cur worker id: ${ARNOLD_ID}"
+echo "gpus per worker: ${ARNOLD_WORKER_GPU}"
+echo "master ip: ${METIS_WORKER_0_HOST}"
+echo "master port: ${port}"
+echo "master port in cmd: ${port_in_cmd}"
+
+# export WANDB_BASE_URL=https://api.wandb.ai
+# export WANDB_API_KEY=""
+# wandb login $WANDB_API_KEY
+
+export WANDB_BASE_URL=https://api.wandb.ai
+export WANDB_PROJECT=vision-reasoning
+export WANDB_API_KEY=""
+export WANDB_RUN_NAME=Qwen-VL-2B-GRPO-$(date +%Y-%m-%d-%H-%M-%S)
+wandb login $WANDB_API_KEY
+
+cd /home/tiger/multimodal-open-r1
+# pip3 install vllm==0.6.6.post1
+pip3 install -e ".[dev]"
+pip3 install wandb==0.18.3
+
+torchrun --nproc_per_node="${ARNOLD_WORKER_GPU}" \
+ --nnodes="${ARNOLD_WORKER_NUM}" \
+ --node_rank="${ARNOLD_ID}" \
+ --master_addr="${METIS_WORKER_0_HOST}" \
+ --master_port="${port_in_cmd}" \
+ src/open_r1/grpo.py \
+ --deepspeed scripts/zero3.json \
+ --output_dir checkpoints/${WANDB_RUN_NAME} \
+ --model_name_or_path Qwen/Qwen2-VL-2B-Instruct \
+ --dataset_name luodian/${DATASET_NAME} \
+ --max_prompt_length 8192 \
+ --per_device_train_batch_size 1 \
+ --gradient_accumulation_steps 1 \
+ --logging_steps 1 \
+ --bf16 \
+ --report_to wandb \
+ --gradient_checkpointing true \
+ --attn_implementation flash_attention_2 \
+ --max_pixels 2359296 \
+ --save_total_limit 8 \
+ --num_train_epochs 1 \
+ --run_name $WANDB_RUN_NAME
diff --git a/src/gui_g2/local_scripts/zero2.json b/src/gui_g2/local_scripts/zero2.json
new file mode 100644
index 0000000..b5ba7eb
--- /dev/null
+++ b/src/gui_g2/local_scripts/zero2.json
@@ -0,0 +1,41 @@
+{
+ "fp16": {
+ "enabled": "auto",
+ "loss_scale": 0,
+ "loss_scale_window": 1000,
+ "initial_scale_power": 16,
+ "hysteresis": 2,
+ "min_loss_scale": 1
+ },
+ "bf16": {
+ "enabled": "auto"
+ },
+ "optimizer": {
+ "type": "AdamW",
+ "params": {
+ "lr": "auto",
+ "betas": "auto",
+ "eps": "auto",
+ "weight_decay": "auto"
+ }
+ },
+ "zero_optimization": {
+ "stage": 2,
+ "offload_optimizer": {
+ "device": "none",
+ "pin_memory": true
+ },
+ "allgather_partitions": true,
+ "allgather_bucket_size": 2e8,
+ "overlap_comm": false,
+ "reduce_scatter": true,
+ "reduce_bucket_size": 2e8,
+ "contiguous_gradients": true
+ },
+ "gradient_accumulation_steps": "auto",
+ "gradient_clipping": "auto",
+ "steps_per_print": 100,
+ "train_batch_size": "auto",
+ "train_micro_batch_size_per_gpu": "auto",
+ "wall_clock_breakdown": false
+}
\ No newline at end of file
diff --git a/src/gui_g2/local_scripts/zero3.json b/src/gui_g2/local_scripts/zero3.json
new file mode 100644
index 0000000..02d3431
--- /dev/null
+++ b/src/gui_g2/local_scripts/zero3.json
@@ -0,0 +1,41 @@
+{
+ "fp16": {
+ "enabled": "auto",
+ "loss_scale": 0,
+ "loss_scale_window": 1000,
+ "initial_scale_power": 16,
+ "hysteresis": 2,
+ "min_loss_scale": 1
+ },
+ "bf16": {
+ "enabled": "auto"
+ },
+
+ "zero_optimization": {
+ "stage": 3,
+ "offload_optimizer": {
+ "device": "none",
+ "pin_memory": true
+ },
+ "offload_param": {
+ "device": "none",
+ "pin_memory": true
+ },
+ "overlap_comm": true,
+ "contiguous_gradients": true,
+ "sub_group_size": 1e9,
+ "reduce_bucket_size": "auto",
+ "stage3_prefetch_bucket_size": "auto",
+ "stage3_param_persistence_threshold": "auto",
+ "stage3_max_live_parameters": 1e9,
+ "stage3_max_reuse_distance": 1e9,
+ "stage3_gather_16bit_weights_on_model_save": true
+ },
+
+ "gradient_accumulation_steps": "auto",
+ "gradient_clipping": "auto",
+ "steps_per_print": 100,
+ "train_batch_size": "auto",
+ "train_micro_batch_size_per_gpu": "auto",
+ "wall_clock_breakdown": false
+}
\ No newline at end of file
diff --git a/src/gui_g2/local_scripts/zero3.yaml b/src/gui_g2/local_scripts/zero3.yaml
new file mode 100644
index 0000000..b5a1201
--- /dev/null
+++ b/src/gui_g2/local_scripts/zero3.yaml
@@ -0,0 +1,22 @@
+compute_environment: LOCAL_MACHINE
+debug: false
+deepspeed_config:
+ deepspeed_multinode_launcher: standard
+ offload_optimizer_device: none
+ offload_param_device: none
+ zero3_init_flag: true
+ zero3_save_16bit_model: true
+ zero_stage: 3
+distributed_type: DEEPSPEED
+downcast_bf16: 'no'
+machine_rank: 0
+main_training_function: main
+mixed_precision: bf16
+num_machines: 1
+num_processes: 8
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
diff --git a/src/gui_g2/local_scripts/zero3_offload.json b/src/gui_g2/local_scripts/zero3_offload.json
new file mode 100644
index 0000000..9da12de
--- /dev/null
+++ b/src/gui_g2/local_scripts/zero3_offload.json
@@ -0,0 +1,48 @@
+{
+ "fp16": {
+ "enabled": "auto",
+ "loss_scale": 0,
+ "loss_scale_window": 1000,
+ "initial_scale_power": 16,
+ "hysteresis": 2,
+ "min_loss_scale": 1
+ },
+ "bf16": {
+ "enabled": "auto"
+ },
+ "optimizer": {
+ "type": "AdamW",
+ "params": {
+ "lr": "auto",
+ "betas": "auto",
+ "eps": "auto",
+ "weight_decay": "auto"
+ }
+ },
+ "zero_optimization": {
+ "stage": 3,
+ "offload_optimizer": {
+ "device": "cpu",
+ "pin_memory": true
+ },
+ "offload_param": {
+ "device": "cpu",
+ "pin_memory": true
+ },
+ "overlap_comm": true,
+ "contiguous_gradients": true,
+ "sub_group_size": 1e9,
+ "reduce_bucket_size": "auto",
+ "stage3_prefetch_bucket_size": "auto",
+ "stage3_param_persistence_threshold": "auto",
+ "stage3_max_live_parameters": 1e9,
+ "stage3_max_reuse_distance": 1e9,
+ "gather_16bit_weights_on_model_save": true
+ },
+ "gradient_accumulation_steps": "auto",
+ "gradient_clipping": "auto",
+ "train_batch_size": "auto",
+ "train_micro_batch_size_per_gpu": "auto",
+ "steps_per_print": 1e5,
+ "wall_clock_breakdown": false
+}
\ No newline at end of file
diff --git a/src/gui_g2/src/open_r1/gaussian_grpo.py b/src/gui_g2/src/open_r1/gaussian_grpo.py
index d82bf76..166744d 100644
--- a/src/gui_g2/src/open_r1/gaussian_grpo.py
+++ b/src/gui_g2/src/open_r1/gaussian_grpo.py
@@ -58,7 +58,8 @@
logger = logging.getLogger(__name__)
from filelock import FileLock
-
+from open_r1.vlm_modules.qwen_module import Qwen2VLModule
+from open_r1.vlm_modules.internvl_module import InvernVLModule
def custom_forward(
self,
@@ -409,7 +410,7 @@ def format_reward(completions, **kwargs):
with open(log_path, "a") as f:
f.write(f"\n|||||||||||||||||||||||||||||||||||||||||||||||||||| RANK: {dist.get_rank()}, match: {num} ||||||||||||||||||||||||||||||||||||||||||||||||||||\n")
f.write(f"Image Path: \n{kwargs['image_path'][0]}\n")
- f.write(f"Resized Width: {kwargs['width_resized'][0]}, Resized Height: {kwargs['height_resized'][0]}\n")
+ # f.write(f"Width: {kwargs['width'][0]}, Height: {kwargs['height'][0]}\n")
f.write(f"\nInstruction: \n{kwargs['problem'][0]}\n")
f.write(f"\nformat not matched\n")
f.write(f"completion_contents: \n{completion_contents[i]}\n")
diff --git a/src/gui_g2/src/open_r1/trainer/__init__.py b/src/gui_g2/src/open_r1/trainer/__init__.py
index cf4f64c..92a2f16 100644
--- a/src/gui_g2/src/open_r1/trainer/__init__.py
+++ b/src/gui_g2/src/open_r1/trainer/__init__.py
@@ -1,5 +1,6 @@
from .grpo_trainer import VLMGRPOTrainer
-from .grpo_trainer_test import Qwen2VLGRPOTrainerTest
+# from .grpo_trainer_test import Qwen2VLGRPOTrainerTest
from .grpo_config import GRPOConfig
-__all__ = ["VLMGRPOTrainer","Qwen2VLGRPOTrainerTest"]
+# __all__ = ["VLMGRPOTrainer","Qwen2VLGRPOTrainerTest"]
+__all__ = ["VLMGRPOTrainer"]
\ No newline at end of file
diff --git a/src/gui_g2/src/open_r1/vlm_modules/internvl_module.py b/src/gui_g2/src/open_r1/vlm_modules/internvl_module.py
new file mode 100644
index 0000000..5271c8b
--- /dev/null
+++ b/src/gui_g2/src/open_r1/vlm_modules/internvl_module.py
@@ -0,0 +1,328 @@
+from open_r1.vlm_modules.vlm_module import VLMBaseModule
+from typing import Dict, Any, Union
+from transformers import AutoModel, AutoProcessor, AutoConfig
+import torch
+import torchvision.transforms as T
+from PIL import Image
+from torchvision.transforms.functional import InterpolationMode
+from transformers.feature_extraction_sequence_utils import BatchFeature
+
+IMG_START_TOKEN='
'
+IMG_END_TOKEN=''
+IMG_CONTEXT_TOKEN=''
+
+IMAGENET_MEAN = (0.485, 0.456, 0.406)
+IMAGENET_STD = (0.229, 0.224, 0.225)
+
+class InvernVLModule(VLMBaseModule):
+ def __init__(self):
+ super().__init__()
+ self.conv_template = None
+ self.num_image_token = None
+
+ def get_vlm_key(self):
+ return "internvl"
+
+ def get_model_class(self, model_id: str, model_init_kwargs: dict):
+ assert "InternVL" in model_id, f"model_id must contain 'InternVL', but got {model_id}"
+ self.model_config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
+ # The model class of InternVL when being mapped has been determined by its config
+ model_cls = AutoModel
+ # InternVL should be inputted with "trust_remote_code=True"
+ model_init_kwargs["trust_remote_code"] = True
+ # "use_cache" should be removed
+ model_init_kwargs.pop("use_cache", None)
+ # "flash_attention_2" should be modified to "use_flash_attn" in InternVL
+ if "flash_attention_2" in model_init_kwargs.get("attn_implementation", ""):
+ model_init_kwargs["use_flash_attn"] = True
+ model_init_kwargs.pop("attn_implementation")
+ return model_cls
+
+ def post_model_init(self, model, processing_class):
+ self.conv_template = model.conv_template if self.conv_template is None else self.conv_template
+ self.num_image_token = model.num_image_token if self.num_image_token is None else self.num_image_token
+ img_context_token_id = processing_class.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
+ model.img_context_token_id = img_context_token_id
+
+ def is_embeds_input(self):
+ return True
+
+ def get_processing_class(self):
+ return AutoProcessor
+
+ def get_eos_token_id(self, processing_class):
+ eos_token_id = processing_class.convert_tokens_to_ids(self.conv_template.sep.strip())
+ return eos_token_id
+
+ def get_vision_modules_keywords(self):
+ return ['vision_model']
+
+ def get_custom_multimodal_keywords(self):
+ return ['pixel_values', 'image_flags']
+
+ def get_non_generate_params(self):
+ return ['image_flags']
+
+ def get_custom_processing_keywords(self):
+ return [('None', 'max_anyres_num')]
+
+ def prepare_prompt(self, processing_class, inputs: dict[str, Union[torch.Tensor, Any]]):
+ prompts_text = []
+ for example in inputs:
+ template = self.conv_template.copy()
+ conversation_list = example["prompt"]
+ system_message = extract_system_message(conversation_list)
+ if system_message is not None:
+ template.system_message = system_message
+
+ processed_list = process_conversation_list(conversation_list, system_message)
+ for i, processed_item in enumerate(processed_list):
+ if i % 2 == 0:
+ template.append_message(template.roles[0], processed_item)
+ else:
+ template.append_message(template.roles[1], processed_item)
+ if len(processed_list) % 2 == 1:
+ template.append_message(template.roles[1], None)
+ query = template.get_prompt()
+ prompts_text.append(query)
+ return prompts_text
+
+ def prepare_model_inputs(self, processing_class, prompts_text, images, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False):
+ # Process images
+ full_pixel_values = []
+ num_patches_list = []
+ for img in images:
+ pixel_values = self._load_image(img, input_size=self.model_config.vision_config.image_size, max_num=processing_class.max_anyres_num)
+ full_pixel_values.append(pixel_values)
+ num_patches_list.append(pixel_values.shape[0])
+ full_pixel_values = torch.cat(full_pixel_values, dim=0)
+
+ # Process prompts
+ queries = []
+ image_idx = 0
+ for query in prompts_text:
+ while "" in query:
+ num_patches = num_patches_list[image_idx]
+ image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
+ query = query.replace("", image_tokens, 1)
+ image_idx += 1
+ queries.append(query)
+ assert image_idx == len(num_patches_list)
+
+ model_inputs = processing_class(
+ queries,
+ return_tensors=return_tensors,
+ padding=padding,
+ padding_side=padding_side,
+ add_special_tokens=add_special_tokens,
+ )
+ model_inputs["pixel_values"] = full_pixel_values
+ # Only support pure-image data currently (each sample should contain the image)
+ model_inputs['image_flags'] = torch.ones(full_pixel_values.shape[0], dtype=torch.long)
+
+ model_inputs = BatchFeature(data=model_inputs)
+
+ return model_inputs, None
+
+ def _load_image(self, image: Image.Image, input_size: int=448, max_num:int=12):
+ transform = build_transform(input_size=input_size)
+ images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
+ pixel_values = [transform(image) for image in images]
+ pixel_values = torch.stack(pixel_values)
+ return pixel_values
+
+ @staticmethod
+ def get_question_template(task_type: str):
+ match task_type:
+ case _:
+ return "{Question} First output the thinking process in tags and then output the final answer in tags."
+
+ @staticmethod
+ def format_reward_rec(completions, **kwargs):
+ """Check if the InternVL model output matches a specific format."""
+ import re
+ import os
+ from datetime import datetime
+ pattern = r".*?\s*.*?\[\d+,\s*\d+,\s*\d+,\s*\d+\].*?"
+ completion_contents = [completion[0]["content"] for completion in completions]
+ matches = [re.search(pattern, content, re.DOTALL) is not None for content in completion_contents]
+ current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
+ if os.getenv("DEBUG_MODE") == "true":
+ log_path = os.getenv("LOG_PATH")
+ with open(log_path.replace(".txt", "_format.txt"), "a", encoding='utf-8') as f:
+ f.write(f"------------- {current_time} Format reward -------------\n")
+ for content, match in zip(completion_contents, matches):
+ f.write(f"Content: {content}\n")
+ f.write(f"Has format: {bool(match)}\n")
+ return [1.0 if match else 0.0 for match in matches]
+
+ @staticmethod
+ def iou_reward(completions, solution, **kwargs):
+ """Calculate IoU reward between predicted bounding box from InternVL model and ground truth bounding box."""
+ """Adopt soft iou reward here"""
+ import re
+ import os
+ import json
+ from datetime import datetime
+ def iou(box1, box2):
+ inter_x1 = max(box1[0], box2[0])
+ inter_y1 = max(box1[1], box2[1])
+ inter_x2 = min(box1[2]-1, box2[2]-1)
+ inter_y2 = min(box1[3]-1, box2[3]-1)
+ if inter_x1 < inter_x2 and inter_y1 < inter_y2:
+ inter = (inter_x2-inter_x1+1)*(inter_y2-inter_y1+1)
+ else:
+ inter = 0
+ union = (box1[2]-box1[0])*(box1[3]-box1[1]) + (box2[2]-box2[0])*(box2[3]-box2[1]) - inter
+ return float(inter)/union
+ contents = [completion[0]["content"] for completion in completions]
+ rewards = []
+ current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
+ answer_tag_pattern = r'(.*?)'
+ bbox_pattern = r'\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)]'
+ for i, (content, sol) in enumerate(zip(contents, solution)):
+ sol = re.findall(answer_tag_pattern, sol, re.DOTALL)[-1]
+ sol = json.loads(sol.strip())
+ reward = 0.0
+ # Try symbolic verification first
+ try:
+ content_answer_match = re.search(answer_tag_pattern, content, re.DOTALL)
+ if content_answer_match:
+ content_answer = content_answer_match.group(1).strip()
+ bbox_match = re.search(bbox_pattern, content_answer)
+ if bbox_match:
+ bbox = [int(bbox_match.group(1)), int(bbox_match.group(2)), int(bbox_match.group(3)), int(bbox_match.group(4))]
+ reward = iou(bbox, sol)
+ except Exception:
+ pass # Continue to next verification method if this fails
+
+ rewards.append(reward)
+ if os.getenv("DEBUG_MODE") == "true":
+ log_path = os.getenv("LOG_PATH")
+ current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
+ image_path = kwargs.get("image_path")[i] if "image_path" in kwargs else None
+ problem = kwargs.get("problem")[i]
+ if reward <= 1.0: # this condition can be changed for debug
+ with open(log_path, "a", encoding='utf-8') as f:
+ f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
+ f.write(f"image_path: {image_path}\n")
+ f.write(f"problem: {problem}\n")
+ f.write(f"Content: {content}\n")
+ f.write(f"Solution: {sol}\n")
+ return rewards
+
+ @staticmethod
+ def select_reward_func(func: str, task_type: str):
+ if func == "accuracy":
+ match task_type:
+ case "rec":
+ return InvernVLModule.iou_reward
+ case _:
+ raise ValueError(f"Unsupported reward function: {func}")
+ elif func == "format":
+ match task_type:
+ case "rec":
+ return InvernVLModule.format_reward_rec
+ case _:
+ raise ValueError(f"Unsupported reward function: {func}")
+ else:
+ raise ValueError(f"Unsupported reward function: {func}")
+
+
+def process_conversation_list(conversation_list, system_message=None, image_newline=True):
+ if system_message is not None:
+ conversation_list = conversation_list[1:]
+ processed_list = []
+
+ for item in conversation_list:
+ role = item["role"]
+ content = item["content"]
+
+ if isinstance(content, list):
+ overall_str = ""
+ for content_item in content:
+ if content_item.get("type") == "image":
+ overall_str += "" if not image_newline else "\n"
+ elif content_item.get("type") == "text":
+ overall_str += content_item.get("text")
+ else:
+ raise ValueError(f"Unsupported content type: {type(content_item)}")
+ processed_list.append(overall_str)
+ elif isinstance(content, str):
+ processed_list.append(content)
+ else:
+ raise ValueError(f"Unsupported content type: {type(content)}")
+
+ return processed_list
+
+def extract_system_message(conversation_list):
+ if conversation_list[0]["role"] == "system":
+ if isinstance(conversation_list[0]["content"], list):
+ return conversation_list[0]["content"][0]["text"]
+ else:
+ return conversation_list[0]["content"]
+ return None
+
+
+def build_transform(input_size):
+ MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
+ transform = T.Compose([
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
+ T.ToTensor(),
+ T.Normalize(mean=MEAN, std=STD)
+ ])
+ return transform
+
+def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
+ best_ratio_diff = float('inf')
+ best_ratio = (1, 1)
+ area = width * height
+ for ratio in target_ratios:
+ target_aspect_ratio = ratio[0] / ratio[1]
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
+ if ratio_diff < best_ratio_diff:
+ best_ratio_diff = ratio_diff
+ best_ratio = ratio
+ elif ratio_diff == best_ratio_diff:
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
+ best_ratio = ratio
+ return best_ratio
+
+def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
+ orig_width, orig_height = image.size
+ aspect_ratio = orig_width / orig_height
+
+ # calculate the existing image aspect ratio
+ target_ratios = set(
+ (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
+ i * j <= max_num and i * j >= min_num)
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
+
+ # find the closest aspect ratio to the target
+ target_aspect_ratio = find_closest_aspect_ratio(
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size)
+
+ # calculate the target width and height
+ target_width = image_size * target_aspect_ratio[0]
+ target_height = image_size * target_aspect_ratio[1]
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
+
+ # resize the image
+ resized_img = image.resize((target_width, target_height))
+ processed_images = []
+ for i in range(blocks):
+ box = (
+ (i % (target_width // image_size)) * image_size,
+ (i // (target_width // image_size)) * image_size,
+ ((i % (target_width // image_size)) + 1) * image_size,
+ ((i // (target_width // image_size)) + 1) * image_size
+ )
+ # split the image
+ split_img = resized_img.crop(box)
+ processed_images.append(split_img)
+ assert len(processed_images) == blocks
+ if use_thumbnail and len(processed_images) != 1:
+ thumbnail_img = image.resize((image_size, image_size))
+ processed_images.append(thumbnail_img)
+ return processed_images
\ No newline at end of file