feat: improve ml baseline and native environment interaction#84
Conversation
There was a problem hiding this comment.
Pull request overview
This PR implements comprehensive feature extraction APIs for Mahjong domain features to support ML model training, addressing issue #34. The changes span both native Rust code and Python ML baseline components, introducing new observation encoding methods, yaku (winning hand) possibility checking, shanten calculation, and state tracking for riichi discards and hand discards.
Changes:
- Added extensive feature encoding methods to observation (74-channel standard encoding plus specialized encodings for yaku possibility, discard history decay, shanten efficiency, etc.)
- Implemented rule-based yaku possibility checker with 21 yaku types
- Added shanten calculation and tile efficiency features
- Enhanced game state tracking with riichi_sutehais and last_tedashis fields
- Migrated ML training algorithm from PPO to AWAC (off-policy actor-critic)
- Updated documentation to reflect all 74 encoding channels and alternative encoding methods
Reviewed changes
Copilot reviewed 18 out of 18 changed files in this pull request and generated 12 comments.
Show a summary per file
| File | Description |
|---|---|
| native/src/yaku_checker.rs | New module for rule-based yaku possibility detection (21 yaku types) |
| native/src/shanten.rs | New module for shanten calculation and tile efficiency features |
| native/src/observation.rs | Added 9 new encoding methods (encode_yaku_possibility, encode_discard_history_decay, etc.) and expanded standard encoding to 74 channels |
| native/src/state/mod.rs | Added riichi_sutehais and last_tedashis tracking fields |
| native/src/state/legal_actions.rs | Modified ankan logic for post-riichi scenarios |
| native/src/lib.rs | Registered new shanten and yaku_checker modules |
| native/src/tests.rs | Added tobi (bankruptcy) game ending tests |
| src/riichienv/_riichienv.pyi | Added type hints for new encoding methods |
| docs/FEATURE_ENCODING.md | Comprehensive documentation of all 74 channels plus alternative encodings |
| demos/ml_baseline/unified_model.py | Maintained backward compatibility with legacy 46-channel encoding |
| demos/ml_baseline/train_online.py | Migrated to AWAC algorithm with unified buffer |
| demos/ml_baseline/ray_actor.py | Updated workers to use legacy encoding |
| demos/ml_baseline/learner.py | Implemented AWAC with dynamic CQL alpha scheduling |
| demos/ml_baseline/cql_model.py | Expanded Q-Network to support 110-channel spatial + 3025 non-spatial features |
| demos/ml_baseline/cql_dataset.py | Added encode_legacy method for backward compatibility |
| demos/ml_baseline/buffer.py | Switched from prioritized to standard replay buffer |
| tests/env/test_honba_reset.py | Updated test hands to use correct tile IDs and added tobi prevention |
| tests/env/actions/test_action_to_mjai.py | Added new observation constructor parameters |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if let Some(t) = self.drawn_tile { | ||
| let t34 = t / 4; | ||
| if counts[t34 as usize] == 4 { | ||
| // Check waits | ||
| let mut hand_pre = self.players[pid_us].hand.clone(); | ||
| if let Some(pos) = hand_pre.iter().position(|&x| x == t) { | ||
| hand_pre.remove(pos); | ||
| } | ||
| let calc_pre = crate::agari_calculator::AgariCalculator::new( | ||
| hand_pre, | ||
| self.players[pid_us].melds.clone(), | ||
| ); | ||
| let mut waits_pre = calc_pre.get_waits(); | ||
| waits_pre.sort(); | ||
|
|
||
| let mut hand_post = self.players[pid_us].hand.clone(); | ||
| hand_post.retain(|&x| x / 4 != t34); | ||
| let mut melds_post = self.players[pid_us].melds.clone(); | ||
| let lowest = t34 * 4; | ||
| melds_post.push(Meld::new( | ||
| MeldType::Angang, | ||
| vec![lowest, lowest + 1, lowest + 2, lowest + 3], | ||
| false, | ||
| -1, | ||
| )); | ||
| let calc_post = | ||
| crate::agari_calculator::AgariCalculator::new(hand_post, melds_post); | ||
| let mut waits_post = calc_post.get_waits(); | ||
| waits_post.sort(); | ||
| let mut hand_post = self.players[pid_us].hand.clone(); | ||
| hand_post.retain(|&x| x / 4 != t34); | ||
| let mut melds_post = self.players[pid_us].melds.clone(); | ||
| let lowest = t34 * 4; | ||
| melds_post.push(Meld::new( | ||
| MeldType::Angang, | ||
| vec![lowest, lowest + 1, lowest + 2, lowest + 3], | ||
| false, | ||
| -1, | ||
| )); | ||
| let calc_post = crate::agari_calculator::AgariCalculator::new( | ||
| hand_post, melds_post, | ||
| ); | ||
| let mut waits_post = calc_post.get_waits(); | ||
| waits_post.sort(); |
There was a problem hiding this comment.
The Ankan after riichi logic has a subtle bug. Line 154 removes one copy of the drawn tile from hand_pre to simulate discarding it, but then the waits calculation should be done on the 13-tile hand (before drawing). However, hand_post on line 164 retains all tiles except those used in the kan, which would be a 10-tile hand (14 - 4). This asymmetry in the comparison may lead to incorrect wait calculation.
The correct approach: Since the player has 14 tiles (13 + drawn), for the "pre-kan" state they should check waits with the 13-tile hand (excluding the tile they would discard). For the "post-kan" state, they remove all 4 tiles for kan and check waits on the remaining hand. Both should be calculated from the same baseline state for a fair comparison.
There was a problem hiding this comment.
@copilot Looking at the code:
self.players[pid_us].handcontains 14 tiles (13 + drawn tile)counts[t34 as usize] == 4means all 4 copies of tile type t34 are in handdrawn_tileis one of those 4 tiles
hand_pre calculation:
- Starts with 14-tile hand
- Removes ONE instance of t (the drawn tile)
- Results in 13 tiles, but still contains 3 copies of tile type t34
- Waits are calculated on this 13-tile hand
hand_post calculation:
- Starts with 14-tile hand
- Removes ALL tiles where x / 4 == t34 (all 4 copies)
- Results in 10 tiles
- Adds ankan meld (4 tiles as a meld)
- Waits are calculated on 10-tile hand + 1 meld
The comparison is actually correct because:
- Pre-kan scenario: Player discards the drawn tile → 13-tile hand with 3 of the same tile remaining
- Post-kan scenario: Player declares ankan → 10-tile hand + ankan meld
Both represent valid game states after the player's action. The key insight is that in the pre-kan case, the 3 remaining tiles of type t34 are still in the closed hand and contribute to wait calculation. In the post-kan case, those 4 tiles are now in a meld and don't contribute to the closed hand waits.
The logic appears correct for the standard Riichi Mahjong rule: ankan after riichi is only allowed if it doesn't change the waits.
… comments and imports
…in learner and buffer
Resolves #34