@@ -10,6 +10,8 @@ import (
1010 "strconv"
1111 "text/template"
1212
13+ "terraform-provider-iterative/iterative/aws"
14+ "terraform-provider-iterative/iterative/azure"
1315 "terraform-provider-iterative/iterative/utils"
1416
1517 "github.com/hashicorp/terraform-plugin-sdk/v2/diag"
@@ -179,69 +181,13 @@ func resourceRunnerDelete(ctx context.Context, d *schema.ResourceData, m interfa
179181 return resourceMachineDelete (ctx , d , m )
180182}
181183
182- func provisionerCode (d * schema.ResourceData ) (string , error ) {
183- var code string
184-
185- tfResource := ResourceType {
186- Mode : "managed" ,
187- Type : "iterative_cml_runner" ,
188- Name : "runner" ,
189- Provider : "provider[\" registry.terraform.io/iterative/iterative\" ]" ,
190- Instances : InstancesType {
191- InstanceType {
192- SchemaVersion : 0 ,
193- Attributes : AttributesType {
194- ID : d .Id (),
195- Cloud : d .Get ("cloud" ).(string ),
196- Region : d .Get ("region" ).(string ),
197- Name : d .Get ("name" ).(string ),
198- Labels : "" ,
199- IdleTimeout : d .Get ("idle_timeout" ).(int ),
200- Repo : "" ,
201- Token : "" ,
202- Driver : "" ,
203- AwsSecurityGroup : "" ,
204- CustomData : "" ,
205- Image : "" ,
206- InstanceGpu : "" ,
207- InstanceHddSize : d .Get ("instance_hdd_size" ).(int ),
208- InstanceIP : "" ,
209- InstanceLaunchTime : "" ,
210- InstanceType : "" ,
211- SSHName : "" ,
212- SSHPrivate : "" ,
213- SSHPublic : "" ,
214- },
215- },
216- },
217- }
218- jsonResource , err := json .Marshal (tfResource )
219- if err != nil {
220- return code , err
221- }
222-
223- data := make (map [string ]string )
224- data ["cloud" ] = d .Get ("cloud" ).(string )
225- data ["token" ] = d .Get ("token" ).(string )
226- data ["repo" ] = d .Get ("repo" ).(string )
227- data ["driver" ] = d .Get ("driver" ).(string )
228- data ["labels" ] = d .Get ("labels" ).(string )
229- data ["idle_timeout" ] = strconv .Itoa (d .Get ("idle_timeout" ).(int ))
230- data ["name" ] = d .Get ("name" ).(string )
231- data ["tf_resource" ] = base64 .StdEncoding .EncodeToString (jsonResource )
232- data ["instance_gpu" ] = d .Get ("instance_gpu" ).(string )
233- data ["AWS_SECRET_ACCESS_KEY" ] = os .Getenv ("AWS_SECRET_ACCESS_KEY" )
234- data ["AWS_ACCESS_KEY_ID" ] = os .Getenv ("AWS_ACCESS_KEY_ID" )
235- data ["AWS_SESSION_TOKEN" ] = os .Getenv ("AWS_SESSION_TOKEN" )
236- data ["AZURE_CLIENT_ID" ] = os .Getenv ("AZURE_CLIENT_ID" )
237- data ["AZURE_CLIENT_SECRET" ] = os .Getenv ("AZURE_CLIENT_SECRET" )
238- data ["AZURE_SUBSCRIPTION_ID" ] = os .Getenv ("AZURE_SUBSCRIPTION_ID" )
239- data ["AZURE_TENANT_ID" ] = os .Getenv ("AZURE_TENANT_ID" )
184+ func renderScript (data map [string ]interface {}) (string , error ) {
185+ var script string
240186
241187 tmpl , err := template .New ("deploy" ).Parse (`#!/bin/sh
242188export DEBIAN_FRONTEND=noninteractive
243189
244- {{if eq .cloud "azure" }}
190+ {{if not .ami }}
245191echo "APT::Get::Assume-Yes \"true\";" | sudo tee -a /etc/apt/apt.conf.d/90assumeyes
246192
247193sudo apt remove unattended-upgrades
@@ -312,10 +258,87 @@ sudo systemctl enable cml.service --now
312258 err = tmpl .Execute (& customDataBuffer , data )
313259
314260 if err == nil {
315- code = customDataBuffer .String ()
261+ script = customDataBuffer .String ()
262+ }
263+
264+ return script , err
265+ }
266+
267+ func provisionerCode (d * schema.ResourceData ) (string , error ) {
268+ var code string
269+
270+ tfResource := ResourceType {
271+ Mode : "managed" ,
272+ Type : "iterative_cml_runner" ,
273+ Name : "runner" ,
274+ Provider : "provider[\" registry.terraform.io/iterative/iterative\" ]" ,
275+ Instances : InstancesType {
276+ InstanceType {
277+ SchemaVersion : 0 ,
278+ Attributes : AttributesType {
279+ ID : d .Id (),
280+ Cloud : d .Get ("cloud" ).(string ),
281+ Region : d .Get ("region" ).(string ),
282+ Name : d .Get ("name" ).(string ),
283+ Labels : "" ,
284+ IdleTimeout : d .Get ("idle_timeout" ).(int ),
285+ Repo : "" ,
286+ Token : "" ,
287+ Driver : "" ,
288+ AwsSecurityGroup : "" ,
289+ CustomData : "" ,
290+ Image : "" ,
291+ InstanceGpu : "" ,
292+ InstanceHddSize : d .Get ("instance_hdd_size" ).(int ),
293+ InstanceIP : "" ,
294+ InstanceLaunchTime : "" ,
295+ InstanceType : "" ,
296+ SSHName : "" ,
297+ SSHPrivate : "" ,
298+ SSHPublic : "" ,
299+ },
300+ },
301+ },
302+ }
303+ jsonResource , err := json .Marshal (tfResource )
304+ if err != nil {
305+ return code , err
306+ }
307+
308+ data := make (map [string ]interface {})
309+ data ["token" ] = d .Get ("token" ).(string )
310+ data ["repo" ] = d .Get ("repo" ).(string )
311+ data ["driver" ] = d .Get ("driver" ).(string )
312+ data ["labels" ] = d .Get ("labels" ).(string )
313+ data ["idle_timeout" ] = strconv .Itoa (d .Get ("idle_timeout" ).(int ))
314+ data ["name" ] = d .Get ("name" ).(string )
315+ data ["tf_resource" ] = base64 .StdEncoding .EncodeToString (jsonResource )
316+ data ["instance_gpu" ] = d .Get ("instance_gpu" ).(string )
317+ data ["AWS_SECRET_ACCESS_KEY" ] = os .Getenv ("AWS_SECRET_ACCESS_KEY" )
318+ data ["AWS_ACCESS_KEY_ID" ] = os .Getenv ("AWS_ACCESS_KEY_ID" )
319+ data ["AWS_SESSION_TOKEN" ] = os .Getenv ("AWS_SESSION_TOKEN" )
320+ data ["AZURE_CLIENT_ID" ] = os .Getenv ("AZURE_CLIENT_ID" )
321+ data ["AZURE_CLIENT_SECRET" ] = os .Getenv ("AZURE_CLIENT_SECRET" )
322+ data ["AZURE_SUBSCRIPTION_ID" ] = os .Getenv ("AZURE_SUBSCRIPTION_ID" )
323+ data ["AZURE_TENANT_ID" ] = os .Getenv ("AZURE_TENANT_ID" )
324+ data ["ami" ] = isAMIAvailable (d .Get ("cloud" ).(string ), d .Get ("region" ).(string ))
325+
326+ return renderScript (data )
327+ }
328+
329+ func isAMIAvailable (cloud string , region string ) bool {
330+ regions := aws .ImageRegions
331+ if cloud == "azure" {
332+ regions = azure .ImageRegions
333+ }
334+
335+ for _ , item := range regions {
336+ if item == region {
337+ return true
338+ }
316339 }
317340
318- return code , nil
341+ return false
319342}
320343
321344type AttributesType struct {
0 commit comments