44using System . Linq ;
55using UnityEngine ;
66using MLAgents ;
7+ using System ;
78
89public class TrainerMAES : MonoBehaviour , ITrainer
910{
1011
1112 /// Reference to the brain that uses this CoreBrainInternal
1213 protected Brain brain ;
13- public ESOptimizerType optimizer ;
14-
15- public OptimizationModes optimizationMode ;
16- public int iterationPerFrame = 20 ;
17- public int evaluationBatchSize = 8 ;
14+
1815 public bool debugVisualization = true ;
1916
20- private Dictionary < AgentES , OptimizationData > currentOptimizingAgents ;
21-
22-
23- public enum ESOptimizerType
24- {
25- MAES ,
26- LMMAES
27- }
28-
29- protected class OptimizationData
30- {
31- public OptimizationData ( int populationSize , IMAES optimizerToUse , int dim )
32- {
33- samples = new OptimizationSample [ populationSize ] ;
34- for ( int i = 0 ; i < populationSize ; ++ i )
35- {
36- samples [ i ] = new OptimizationSample ( dim ) ;
37- }
38- interation = 0 ;
39- optimizer = optimizerToUse ;
40- }
41-
42- public int interation ;
43- public OptimizationSample [ ] samples ;
44- public IMAES optimizer ;
45- }
46-
47-
17+
4818
49- protected void FixedUpdate ( )
50- {
51- ContinueOptimization ( ) ;
52- }
53- /// Create the reference to the brain
5419 public void Initialize ( )
5520 {
56- currentOptimizingAgents = new Dictionary < AgentES , OptimizationData > ( ) ;
57- }
58-
59-
60-
61-
62- protected void AddOptimization ( List < AgentES > agents )
63- {
64- foreach ( var agent in agents )
65- {
66- currentOptimizingAgents [ agent ] = new OptimizationData ( agent . populationSize , optimizer == ESOptimizerType . LMMAES ? ( IMAES ) new LMMAES ( ) : ( IMAES ) new MAES ( ) , agent . GetParamDimension ( ) ) ;
67- currentOptimizingAgents [ agent ] . optimizer . init ( brain . brainParameters . vectorActionSize ,
68- agent . populationSize , new double [ brain . brainParameters . vectorActionSize ] , agent . initialStepSize , optimizationMode ) ;
69- agent . OnEndOptimizationRequested += OnEndOptimizationRequested ;
70- }
71- }
72-
73- protected void ContinueOptimization ( )
74- {
75- for ( int it = 0 ; it < iterationPerFrame ; ++ it )
76- {
77- List < AgentES > agentList = currentOptimizingAgents . Keys . ToList ( ) ;
78- foreach ( var agent in agentList )
79- {
80- var optData = currentOptimizingAgents [ agent ] ;
81- optData . optimizer . generateSamples ( optData . samples ) ;
82-
83-
84- agent . SetVisualizationMode ( debugVisualization ? AgentES . VisualizationMode . Sampling : AgentES . VisualizationMode . None ) ;
85-
86- for ( int s = 0 ; s <= optData . samples . Length / evaluationBatchSize ; ++ s )
87- {
88- List < double [ ] > paramList = new List < double [ ] > ( ) ;
89- for ( int b = 0 ; b < evaluationBatchSize ; ++ b )
90- {
91- int ind = s * evaluationBatchSize + b ;
92- if ( ind < optData . samples . Length )
93- {
94- paramList . Add ( optData . samples [ ind ] . x ) ;
95- }
96- }
97-
98- var values = agent . Evaluate ( paramList ) ;
99-
100- for ( int b = 0 ; b < evaluationBatchSize ; ++ b )
101- {
102- int ind = s * evaluationBatchSize + b ;
103- if ( ind < optData . samples . Length )
104- {
105- optData . samples [ ind ] . objectiveFuncVal = values [ b ] ;
106- }
107- }
108-
109- }
110- /*foreach (OptimizationSample s in optData.samples)
111- {
112- float value = agent.Evaluate(new List<double[]> { s.x })[0];
113- s.objectiveFuncVal = value;
114- }*/
115-
116-
117-
118- optData . optimizer . update ( optData . samples ) ;
119- double bestScore = optData . optimizer . getBestObjectiveFuncValue ( ) ;
120- //Debug.Log("Best shot score " + optData.optimizer.getBestObjectiveFuncValue());
121- agent . SetVisualizationMode ( debugVisualization ? AgentES . VisualizationMode . Best : AgentES . VisualizationMode . None ) ;
122- agent . Evaluate ( new List < double [ ] > { optData . optimizer . getBest ( ) } ) ;
123-
124- optData . interation ++ ;
125- if ( ( optData . interation >= agent . maxIteration && agent . maxIteration > 0 ) ||
126- ( bestScore <= agent . targetValue && optimizationMode == OptimizationModes . minimize ) ||
127- ( bestScore >= agent . targetValue && optimizationMode == OptimizationModes . maximize ) )
128- {
129- //optimizatoin is done
130- agent . OnReady ( optData . optimizer . getBest ( ) ) ;
131- currentOptimizingAgents . Remove ( agent ) ;
132- }
133- }
134- }
13521 }
13622
137- protected void OnEndOptimizationRequested ( AgentES agent )
138- {
139- if ( currentOptimizingAgents . ContainsKey ( agent ) )
140- {
141- var optData = currentOptimizingAgents [ agent ] ;
142- agent . OnReady ( optData . optimizer . getBest ( ) ) ;
143- currentOptimizingAgents . Remove ( agent ) ;
144- agent . OnEndOptimizationRequested -= OnEndOptimizationRequested ;
145- }
146- }
14723
14824
14925 public int GetStep ( )
@@ -158,28 +34,23 @@ public int GetMaxStep()
15834
15935 public Dictionary < Agent , TakeActionOutput > TakeAction ( Dictionary < Agent , AgentInfo > agentInfos )
16036 {
161- var agentList = agentInfos . Keys ;
162- List < AgentES > agentsToOptimize = new List < AgentES > ( ) ;
163- foreach ( Agent a in agentList )
37+ var result = new Dictionary < Agent , TakeActionOutput > ( ) ;
38+ foreach ( var a in agentInfos )
16439 {
165- if ( ! ( a is AgentES ) )
166- {
167- Debug . LogError ( "Agents using CoreBrainMAES must inherit from AgentES" ) ;
168- }
169- if ( ! currentOptimizingAgents . ContainsKey ( ( AgentES ) a ) )
170- {
171- agentsToOptimize . Add ( ( AgentES ) a ) ;
172- }
173- else
40+ AgentES agent = a . Key as AgentES ;
41+ if ( agent != null )
17442 {
175- //Debug.LogError("new decision requested while last decision is not made yet");
43+ if ( agent . synchronizedDecision )
44+ {
45+
46+ result [ agent ] = new TakeActionOutput ( ) { outputAction = Array . ConvertAll ( agent . Optimize ( ) , t => ( float ) t ) } ;
47+ }
48+ else
49+ {
50+ agent . OptimizeAsync ( ) ;
51+ }
17652 }
17753 }
178-
179- if ( agentsToOptimize . Count > 0 )
180- AddOptimization ( agentsToOptimize ) ;
181-
182-
18354 return new Dictionary < Agent , TakeActionOutput > ( ) ;
18455 }
18556
0 commit comments