-
Notifications
You must be signed in to change notification settings - Fork 16
Description
Hello, thank you for your excellent work!
I encountered the following issues when finetuning model for the pick-place task of the G1 robotic arm + Inspire hand.
- Dataset: 40 pick-place trajectories(20 trajectories for the left hand and 20 for the right hand), totaling about 3200 samples (loaded via robot_dataset.py);
- Data Validation: I have attempted to visualize the robot_dataset data (not model inference), and everything appears to be normal;
hand_visualization_gt.mp4
- Training Status: The training loss has dropped to a low level, and the inference process (forward) is almost consistent with the training logic.
The model finetuned with my own dataset shows abnormal numerical values during inference, leading to complete failure in action prediction:
The action values generated after DDIM inference are extremely large (e.g., reaching even 1e3), and the values increase with the number of sampling steps;
- Q1: For the pick-place task of G1 + Inspire hand, what is the "sufficient" amount of training data (number of trajectories/samples)?
- Q2: If the dataset size is small (e.g., 40 trajectories/3200 samples as in the current setup), how many epochs are generally required to ensure model convergence and stable inference?
- Q3: What are the possible reasons for the abnormal surge of inference values with increasing sampling steps in the robot model?
Supplementary Information:
{"Other/action_decay_lr": 9.999999747378752e-06, "Other/action_no_decay_lr": 9.999999747378752e-06, "Other/backbone_decay_lr": 9.999999747378752e-06, "Other/backbone_no_decay_lr": 9.999999747378752e-06, "Other/left_hand_6d": 0.01212206669151783, "Other/left_hand_joints": 0.001226466498337686, "Other/right_hand_6d": 0.013476401567459106, "Other/right_hand_joints": 0.0013391943648457527, "VLA Train/Action Model Learning Rate": 9.999999747378752e-06, "VLA Train/Backbone Learning Rate": 1e-05, "VLA Train/Epoch": 3407, "VLA Train/Loss": 0.006972263567149639, "VLA Train/Loss (Raw)": 0.006972263567149639, "VLA Train/Step": 40891, "VLA Train/Step Time": 0.9877007007598877}
{
"dataset_name": "g1_dataset_angle_statistics.json",
"state_left": {
"mean": [
-0.09846563346365428,
-0.07189968581151333,
0.5552647678181529,
-1.9552400696277619,
-0.7371173732262105,
-0.671558623679739,
1.361502437032759,
1.4073270041123032,
1.4383700435981155,
1.4571667616069317,
0.4189337496086955,
0.5554633510485292
],
"std": [
0.11629729519977818,
0.11048016267072877,
0.08380906220023358,
0.3797638935456922,
0.2198520686508205,
0.5590735149911361,
0.41418462298478753,
0.37322479624135463,
0.3352652153910174,
0.31105845285759975,
0.06177707394060401,
0.5287993555901387
]
},
"action_left": {
"mean": [
8.535617773304694e-05,
0.0015341395311588713,
-0.0005557794542983174,
0.0031265452099084714,
-0.0006956933071577254,
-0.003929421403708389,
1.361570304571651,
1.4073283992800862,
1.4383456727396697,
1.4571759923361243,
0.42111912051681427,
0.5748433425254188
],
"std": [
0.026862299907272867,
0.027768631912195498,
0.01592526927185907,
0.0712971745465609,
0.06171651875480636,
0.07794519746861449,
0.4142374186503008,
0.37322577610530977,
0.3352462864942665,
0.3110655812862949,
0.06306191716566524,
0.5473377940935756
]
},
"state_right": {
"mean": [
0.21767626585788094,
-0.004476849833154119,
0.47746903356164694,
-0.2928831235319376,
0.9590034171566367,
0.5393346774668862,
1.3646994915907271,
1.403198125641793,
1.4196519945561885,
1.4113655454292893,
0.4048081259056926,
0.7164125960599631
],
"std": [
0.07970357168849243,
0.1095203245701311,
0.06908497780819367,
2.5985438747691134,
0.1836750590426289,
0.47586928579167,
0.45898104433954556,
0.4015636510767251,
0.3617053357581947,
0.3544796685955173,
0.05767335085261397,
0.3996291846844025
]
},
"action_right": {
"mean": [
-0.000240774987032637,
0.001424086174720287,
-0.0007370814040768891,
-0.0009779747938845596,
0.0002483229833055134,
0.0027350077341981205,
1.3678337347600609,
1.4065471259434708,
1.423158377893269,
1.4154009219049477,
0.4048742780170869,
0.7364083153998945
],
"std": [
0.015529417738671765,
0.020977183352639747,
0.013185637764688646,
0.06820058034859755,
0.04119139252116459,
0.06723588368948946,
0.45868731633344456,
0.40083370116728556,
0.3609911956738703,
0.35419027368633316,
0.05751536310992969,
0.411742893804988
]
},
"num_traj": 3200
}
thank you!