diff --git a/tests/multipods/experimental_multipod.libsonnet b/tests/multipods/experimental_multipod.libsonnet index a6e9fbcbd..ed2affb22 100644 --- a/tests/multipods/experimental_multipod.libsonnet +++ b/tests/multipods/experimental_multipod.libsonnet @@ -76,38 +76,64 @@ local volumes = import 'templates/volumes.libsonnet'; 'create-tpu-slices': { image: 'google/cloud-sdk', local tpuCreateSettings = { - acceleratorName: std.escapeStringBash(config.accelerator.name), + acceleratorName: config.accelerator.name, sliceCount: config.tpuSettings.slices, softwareVersion: std.escapeStringBash(config.tpuSettings.softwareVersion), startupScript: std.escapeStringBash(config.tpuSettings.tpuVmStartupScript), sleepTime: config.tpuSettings.tpuVmCreateSleepSeconds, testName: std.strReplace(config.testName, '.', '-'), + tpuExists: config.tpuExists, + tpuPrefix: config.tpuPrefix, + userName: config.userName, }, command: utils.scriptCommand(||| set +x project=$(curl -sS "http://metadata.google.internal/computeMetadata/v1/project/project-id" -H "Metadata-Flavor: Google") zone=$(curl -sS "http://metadata.google.internal/computeMetadata/v1/instance/zone" -H "Metadata-Flavor: Google" | awk -F'/' '{print $4}') tpu_name_prefix=tpu-${POD_UID} + if [ %(tpuExists)s = true ]; then + tpu_name_prefix=%(tpuPrefix)s + fi ssh-keygen -t rsa -f /scripts/id_rsa -q -N "" echo "${project}:$(cat /scripts/id_rsa.pub)" > ssh-keys.txt echo %(startupScript)s > startup-script.txt echo %(sliceCount)s >> /scripts/slice_count - for (( i=0; i < %(sliceCount)s; i++ )); do - tpu_name=${tpu_name_prefix}-${i} - echo " - gcloud alpha compute tpus tpu-vm delete -q ${tpu_name} --zone=${zone} - " > /scripts/cleanup_${i}.sh - + if [ %(tpuExists)s = false ]; then + for (( i=0; i < %(sliceCount)s; i++ )); do + tpu_name_delete=${tpu_name_prefix}-${i} + echo " + gcloud alpha compute tpus tpu-vm delete -q ${tpu_name_delete} --zone=${zone} --project=${project} + " > /scripts/cleanup_${i}.sh + echo " + bash /scripts/cleanup_${i}.sh + " >> /scripts/cleanup.sh + done + else echo " - bash /scripts/cleanup_${i}.sh + true " >> /scripts/cleanup.sh - + fi + delete_tpus() { + echo -e "\n\nDeleting TPUs..." + for tpu_id in "${TPU_LIST[@]}"; do + echo -e "\n${tpu_id} is being deleted." + gcloud alpha compute tpus tpu-vm delete -q "${tpu_id}" --zone=${zone} --project=${project} + if [[ $? -ne 0 ]]; then + echo "Failed to delete the TPU ${TPU_NAME}. Delete it manually." + exit 1 + fi + done + } + create_tpu() { + echo "Create TPU called" + TPU_NAME=$1 + SLICE_ID=$2 # Retry every 30 seconds for 10 minutes for j in {1..20}; do set +e - gcloud alpha compute tpus tpu-vm create ${tpu_name} \ + gcloud alpha compute tpus tpu-vm create ${TPU_NAME} \ --accelerator-type=%(acceleratorName)s \ --version=%(softwareVersion)s \ --metadata-from-file='ssh-keys=ssh-keys.txt,startup-script=startup-script.txt' \ @@ -120,19 +146,67 @@ local volumes = import 'templates/volumes.libsonnet'; done if [ $exit_code -ne 0 ]; then + echo "TPU VM with name ${TPU_NAME} failed to create. So exiting the setup." + delete_tpus exit $exit_code fi - - echo ${tpu_name} >> /scripts/tpu_name_${i} - - if [ ${i} -eq 0 ]; then - gcloud compute tpus describe ${tpu_name} --project=${project} --zone=${zone} --format="value(networkEndpoints[0].ipAddress)" > /scripts/coordinator_ip + echo -e "Slice_${SLICE_ID}: TPU VM ${TPU_NAME} successfully created." + TPU_CREATED=true + } + create_tpu_slice_environment() { + echo -e "\n\nSetting %(sliceCount)s TPU Slices with %(acceleratorName)s in each slice..." + for (( i=0; i < %(sliceCount)s; i++ )); do + TPU_NAME=${tpu_name_prefix}-${i} + tpu_exist_with_same_type=false + tpu_exist_with_diff_type=false + echo "$TPU_NAME, $zone, $project, $(gcloud compute tpus list --zone=${zone} --project=${project} | grep "^$TPU_NAME ")" + if [[ -z "$(gcloud compute tpus list --zone=${zone} --project=${project} | grep "^$TPU_NAME ")" ]]; then + list_of_tpu_with_same_name='' + else + list_of_tpu_with_same_name=$(gcloud compute tpus list --zone=${zone} --project=${project} | grep "^$TPU_NAME ") + fi + if [[ ! -z "$(gcloud compute tpus list --zone=${zone} --project=${project} | grep "^$TPU_NAME ")" ]]; then + list_of_tpu_with_same_type=$(echo "$list_of_tpu_with_same_name" | grep "%(acceleratorName)s") + echo "$list_of_tpu_with_same_type" + if [[ ! -z "$list_of_tpu_with_same_type" ]]; then + tpu_exist_with_same_type=true + else + tpu_exist_with_diff_type=true + fi + fi + echo "$TPU_NAME, $tpu_exist_with_same_type, $tpu_exist_with_diff_type" + if [[ %(tpuExists)s = true ]]; then + if [[ "$tpu_exist_with_same_type" = false ]]; then + if [[ "$tpu_exist_with_diff_type" = false ]]; then + echo -e "\nYou chooses to use existing TPU. But TPU with name $TPU_NAME doesn't exist!" + else + echo -e "\nTPU with name $TPU_NAME already exists but with different configuration. So exiting." + fi + exit 1 + fi + else + if [[ "$tpu_exist_with_same_type" = true ]] || [[ "$tpu_exist_with_diff_type" = true ]]; then + echo -e "\nTPU with name $TPU_NAME already exists and you choose USE_EXISTING_TPUS=%(tpuExists)s. So exiting." + exit 1 + fi + create_tpu "$TPU_NAME" $i + fi + TPU_LIST+=(${TPU_NAME}) + echo ${TPU_NAME} >> /scripts/tpu_name_${i} + if [ ${i} -eq 0 ]; then + gcloud compute tpus describe ${TPU_NAME} --project=${project} --zone=${zone} --format="value(networkEndpoints[0].ipAddress)" > /scripts/coordinator_ip + fi + gcloud compute tpus describe ${TPU_NAME} --project=${project} --zone=${zone} --format="value(networkEndpoints[0].ipAddress)" >> /scripts/tpu_ip_slice_${i} + gcloud compute tpus describe ${TPU_NAME} --project=${project} --zone=${zone} --flatten="networkEndpoints[]" --format="csv[no-heading](networkEndpoints.ipAddress)" >> /scripts/all_tpu_ips_slice_${i} + wc -l < /scripts/all_tpu_ips_slice_${i} >> /scripts/worker_count_slice_${i} + done + if [[ "$TPU_CREATED" = false ]]; then + echo -e "\nUsing already created %(sliceCount)s TPU Slices with %(acceleratorName)s in each slice..." fi - gcloud compute tpus describe ${tpu_name} --project=${project} --zone=${zone} --format="value(networkEndpoints[0].ipAddress)" >> /scripts/tpu_ip_slice_${i} - gcloud compute tpus describe ${tpu_name} --project=${project} --zone=${zone} --flatten="networkEndpoints[]" --format="csv[no-heading](networkEndpoints.ipAddress)" >> /scripts/all_tpu_ips_slice_${i} - wc -l < /scripts/all_tpu_ips_slice_${i} >> /scripts/worker_count_slice_${i} - done - + } + TPU_CREATED=false + create_tpu_slice_environment + echo "$TPU_LIST" sleep %(sleepTime)d COORDINATOR_IP=$(cat /scripts/coordinator_ip) @@ -143,9 +217,13 @@ local volumes = import 'templates/volumes.libsonnet'; echo "export MEGASCALE_COORDINATOR_ADDRESS=${COORDINATOR_IP}:8080" >> ~/.profile echo "export MEGASCALE_NUM_SLICES=${SLICE_COUNT}" >> ~/.profile echo "export MEGASCALE_SLICE_ID=${i}" >> ~/.profile + echo "export MEGASCALE_TRANSPORT_TYPE=\"grpc\"" >> ~/.profile + echo "export MEGASCALE_PORT=8080" >> ~/.profile + echo "export MEGASCALE_AUTHENTICATION=\"insecure\"" >> ~/.profile SCRIPT_EOF - - gcloud alpha compute tpus tpu-vm ssh cloud-tpu-multipod-dev@$(cat /scripts/tpu_name_${i}) \ + echo $(cat /scripts/tpu_name_${i}) + echo "$(cat set_mxla_flags.sh)" + gcloud alpha compute tpus tpu-vm ssh %(userName)s@$(cat /scripts/tpu_name_${i}) \ --zone=${zone} \ --ssh-key-file=/scripts/id_rsa \ --strict-host-key-checking=no \ @@ -156,7 +234,7 @@ local volumes = import 'templates/volumes.libsonnet'; echo ${zone} > /scripts/zone - echo "LOGGER: TPU VMs created successfully." + echo "LOGGER: TPU VMs setup successful." ||| % tpuCreateSettings), env: [ { @@ -213,3 +291,4 @@ local volumes = import 'templates/volumes.libsonnet'; }, }, } + diff --git a/tests/multipods/jax/common.libsonnet b/tests/multipods/jax/common.libsonnet index 0a2b4ef31..10d40143b 100644 --- a/tests/multipods/jax/common.libsonnet +++ b/tests/multipods/jax/common.libsonnet @@ -24,7 +24,9 @@ local tpus = import 'templates/tpus.libsonnet'; frameworkPrefix: 'mp-jax', image: 'google/cloud-sdk', accelerator: tpus.v4_16, - + tpuExists: false, + tpuPrefix: 'test', + userName: 'cloud-tpu-multipod-dev', metricConfig+: { sourceMap+:: { tensorboard+: { @@ -71,52 +73,54 @@ local tpus = import 'templates/tpus.libsonnet'; ||| set +x set -u + SLICE_COUNT=$(cat /scripts/slice_count) + ZONE=$(cat /scripts/zone) - cat > testsetup.sh << SCRIPT_EOF - set +x - set -u - set -e - - # .bash_logout sometimes causes a spurious bad exit code, remove it. - rm .bash_logout - - %(installPipPackages)s - %(installJax)s - %(installJaxlib)s - %(installLibtpu)s + if [ %(tpuExists)s = false ]; then + cat > testsetup.sh << SCRIPT_EOF + set +x + set -u + set -e + + # .bash_logout sometimes causes a spurious bad exit code, remove it. + rm .bash_logout + + %(installPipPackages)s + %(installJax)s + %(installJaxlib)s + %(installLibtpu)s SCRIPT_EOF - setup_process_ids=() + setup_process_ids=() - SLICE_COUNT=$(cat /scripts/slice_count) - ZONE=$(cat /scripts/zone) - - for (( i=0; i < ${SLICE_COUNT}; i++ )); do - gcloud alpha compute tpus tpu-vm ssh cloud-tpu-multipod-dev@$(cat /scripts/tpu_name_${i}) \ - --zone=${ZONE} \ - --ssh-key-file=/scripts/id_rsa \ - --strict-host-key-checking=no \ - --internal-ip \ - --worker=all \ - --command "$(cat testsetup.sh)" >> output_testsetup_${i}.txt 2>&1 & - - setup_process_ids+=($!) - done + for (( i=0; i < ${SLICE_COUNT}; i++ )); do + gcloud alpha compute tpus tpu-vm ssh %(userName)s@$(cat /scripts/tpu_name_${i}) \ + --zone=${ZONE} \ + --ssh-key-file=/scripts/id_rsa \ + --strict-host-key-checking=no \ + --internal-ip \ + --worker=all \ + --command "$(cat testsetup.sh)" >> output_testsetup_${i}.txt 2>&1 & - echo "LOGGER: Waiting for test setup to be installed on all TPU VM hosts in ${SLICE_COUNT} slices." + setup_process_ids+=($!) + done - for i in "${!setup_process_ids[@]}"; do - wait ${setup_process_ids[$i]} - if [[ $? -ne 0 ]]; then - echo "LOGGER: Set up failed on slice_${i}. Here is the output:" - cat output_testsetup_${i}.txt - bash /scripts/cleanup.sh - exit 1 - fi - done + echo "LOGGER: Waiting for test setup to be installed on all TPU VM hosts in ${SLICE_COUNT} slices." - echo "LOGGER: Test set up completed successfully on ${SLICE_COUNT} slices." + for i in "${!setup_process_ids[@]}"; do + wait ${setup_process_ids[$i]} + if [[ $? -ne 0 ]]; then + echo "LOGGER: Set up failed on slice_${i}. Here is the output:" + cat output_testsetup_${i}.txt + bash /scripts/cleanup.sh + exit 1 + fi + done + echo "LOGGER: Test set up completed successfully on ${SLICE_COUNT} slices." + else + echo "LOGGER: Not installing anything" + fi test_script_process_ids=() cat > test_script.sh << TEST_SCRIPT_EOF @@ -125,7 +129,7 @@ local tpus = import 'templates/tpus.libsonnet'; for (( i=0; i < ${SLICE_COUNT}; i++ )); do for (( j=0; j < $(cat /scripts/worker_count_slice_${i}); j++ )); do - gcloud alpha compute tpus tpu-vm ssh cloud-tpu-multipod-dev@$(cat /scripts/tpu_name_${i}) \ + gcloud alpha compute tpus tpu-vm ssh %(userName)s@$(cat /scripts/tpu_name_${i}) \ --zone=${ZONE} \ --ssh-key-file=/scripts/id_rsa \ --strict-host-key-checking=no \ @@ -153,17 +157,15 @@ local tpus = import 'templates/tpus.libsonnet'; echo "LOGGER: Test script completed successfully on all the TPU VM hosts of ${SLICE_COUNT} slices. Here is the output from Slice 0:" cat output_slice_0_worker_0.txt - - echo "LOGGER: Cleaning up the TPU VM resources:" - - sleep 60 - + + sleep 30 + echo $(cat /scripts/cleanup.sh) bash /scripts/cleanup.sh exit_code=$? exit $exit_code - ||| % { testScript: config.testScript, installPipPackages: config.scriptConfig.installPipPackages, installJax: config.scriptConfig.installJax, installJaxlib: config.scriptConfig.installJaxlib, installLibtpu: config.scriptConfig.installLibtpu }, + ||| % { testScript: config.testScript, installPipPackages: config.scriptConfig.installPipPackages, installJax: config.scriptConfig.installJax, installJaxlib: config.scriptConfig.installJaxlib, installLibtpu: config.scriptConfig.installLibtpu, userName: config.userName, tpuExists: config.tpuExists }, ], }, @@ -241,7 +243,41 @@ local tpus = import 'templates/tpus.libsonnet'; |||, }, }, - + jaxlibOldStable:: { + jaxlibVersion:: 'old', + scriptConfig+: { + installJax: ||| + pip3 install jax==0.3.25 + |||, + installJaxlib: ||| + pip3 install jaxlib==0.3.25 + |||, + installLibtpu: ||| + /usr/bin/docker-credential-gcr configure-docker + sudo bash /var/scripts/docker-login.sh + + sudo docker create --name libtpu_next gcr.io/cloud-tpu-v2-images-dev/libtpu_unsanitized:libtpu_unsanitized_2022111705_RC00 "/bin/bash" + sudo docker cp libtpu_next:_libtpu_next.so /lib/libtpu.so + + sudo docker rm libtpu_next + echo "export TPU_LIBRARY_PATH=/lib/libtpu.so" >> ~/.profile + |||, + }, + }, + noInstall:: { + jaxlibVersion:: 'not-installed', + scriptConfig+: { + installJax: ||| + true + |||, + installJaxlib: ||| + true + |||, + installLibtpu: ||| + true + |||, + }, + }, tpuVmV4Base:: { local config = self, accelerator: tpus.v4_16,